fix(): business logic fixes and code refactoring
This commit is contained in:
@@ -5,6 +5,7 @@ from django.core.exceptions import ValidationError
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from apps.experiments.models import STARTED_STATUSES
|
||||
from apps.metrics.models import (
|
||||
ExperimentMetric,
|
||||
MetricDefinition,
|
||||
@@ -41,6 +42,17 @@ def _validate_calculation_rule(
|
||||
)
|
||||
}
|
||||
)
|
||||
valid = VALID_RULE_FIELDS.get(metric_type, set())
|
||||
extra = set(rule.keys()) - valid
|
||||
if extra:
|
||||
raise ValidationError(
|
||||
{
|
||||
"calculation_rule": (
|
||||
f"Unknown fields for '{metric_type}': "
|
||||
f"{', '.join(sorted(extra))}."
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
@@ -103,6 +115,16 @@ def experiment_metric_add(
|
||||
metric: MetricDefinition,
|
||||
is_primary: bool = False,
|
||||
) -> ExperimentMetric:
|
||||
if experiment.status in STARTED_STATUSES:
|
||||
raise ValidationError(
|
||||
{
|
||||
"experiment": (
|
||||
"Metrics cannot be modified after the experiment "
|
||||
"has been started "
|
||||
f"(status: '{experiment.status}')."
|
||||
)
|
||||
}
|
||||
)
|
||||
if is_primary:
|
||||
experiment.experiment_metrics.filter(is_primary=True).update(
|
||||
is_primary=False,
|
||||
@@ -122,6 +144,16 @@ def experiment_metric_remove(
|
||||
experiment: Any,
|
||||
metric: MetricDefinition,
|
||||
) -> None:
|
||||
if experiment.status in STARTED_STATUSES:
|
||||
raise ValidationError(
|
||||
{
|
||||
"experiment": (
|
||||
"Metrics cannot be modified after the experiment "
|
||||
"has been started "
|
||||
f"(status: '{experiment.status}')."
|
||||
)
|
||||
}
|
||||
)
|
||||
deleted, _ = ExperimentMetric.objects.filter(
|
||||
experiment=experiment,
|
||||
metric=metric,
|
||||
|
||||
@@ -9,13 +9,8 @@ from apps.experiments.services import (
|
||||
experiment_submit_for_review,
|
||||
)
|
||||
from apps.experiments.tests.helpers import add_two_variants, make_experiment
|
||||
from apps.guardrails.models import Guardrail, GuardrailAction
|
||||
from apps.guardrails.services import (
|
||||
guardrail_create,
|
||||
guardrail_delete,
|
||||
guardrail_list,
|
||||
guardrail_update,
|
||||
)
|
||||
from apps.guardrails.models import GuardrailAction
|
||||
from apps.guardrails.services import guardrail_create
|
||||
from apps.metrics.models import ExperimentMetric, MetricDirection, MetricType
|
||||
from apps.metrics.services import (
|
||||
experiment_metric_add,
|
||||
@@ -246,100 +241,3 @@ class ExperimentMetricTest(TestCase):
|
||||
em1.refresh_from_db()
|
||||
self.assertFalse(em1.is_primary)
|
||||
self.assertTrue(em2.is_primary)
|
||||
|
||||
|
||||
class GuardrailServiceTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.experiment = make_experiment(suffix="_gr")
|
||||
self.metric = metric_definition_create(
|
||||
key="gr_error_rate",
|
||||
name="Error Rate",
|
||||
metric_type=MetricType.RATIO,
|
||||
direction=MetricDirection.LOWER_IS_BETTER,
|
||||
calculation_rule={
|
||||
"type": "ratio",
|
||||
"numerator_event": "error",
|
||||
"denominator_event": "exposure",
|
||||
},
|
||||
)
|
||||
|
||||
def test_create_guardrail(self) -> None:
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
observation_window_minutes=30,
|
||||
action=GuardrailAction.PAUSE,
|
||||
)
|
||||
self.assertEqual(g.threshold, Decimal("0.05"))
|
||||
self.assertEqual(g.action, GuardrailAction.PAUSE)
|
||||
|
||||
def test_list_guardrails(self) -> None:
|
||||
guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
grs = guardrail_list(self.experiment)
|
||||
self.assertEqual(grs.count(), 1)
|
||||
|
||||
def test_update_guardrail_in_draft(self) -> None:
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
updated = guardrail_update(
|
||||
guardrail=g,
|
||||
threshold=Decimal("0.10"),
|
||||
)
|
||||
self.assertEqual(updated.threshold, Decimal("0.10"))
|
||||
|
||||
def test_reject_update_after_start(self) -> None:
|
||||
review_settings_update(
|
||||
default_min_approvals=1, allow_any_approver=True
|
||||
)
|
||||
approver = make_approver("_gu")
|
||||
add_two_variants(self.experiment)
|
||||
exp = experiment_submit_for_review(
|
||||
experiment=self.experiment,
|
||||
user=self.experiment.owner,
|
||||
)
|
||||
exp = experiment_approve(experiment=exp, approver=approver)
|
||||
experiment_start(experiment=exp, user=self.experiment.owner)
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
with self.assertRaises(ValidationError):
|
||||
guardrail_update(guardrail=g, threshold=Decimal("0.10"))
|
||||
|
||||
def test_delete_guardrail_in_draft(self) -> None:
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
guardrail_delete(guardrail=g)
|
||||
self.assertEqual(Guardrail.objects.count(), 0)
|
||||
|
||||
def test_reject_delete_after_start(self) -> None:
|
||||
review_settings_update(
|
||||
default_min_approvals=1, allow_any_approver=True
|
||||
)
|
||||
approver = make_approver("_gd")
|
||||
add_two_variants(self.experiment)
|
||||
exp = experiment_submit_for_review(
|
||||
experiment=self.experiment,
|
||||
user=self.experiment.owner,
|
||||
)
|
||||
exp = experiment_approve(experiment=exp, approver=approver)
|
||||
experiment_start(experiment=exp, user=self.experiment.owner)
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
with self.assertRaises(ValidationError):
|
||||
guardrail_delete(guardrail=g)
|
||||
|
||||
Reference in New Issue
Block a user