fix(): business logic fixes and code refactoring

This commit is contained in:
ITQ
2026-02-24 09:58:07 +03:00
parent e51b74a133
commit 16b48fee40
18 changed files with 307 additions and 140 deletions
+62 -14
View File
@@ -4,7 +4,7 @@ from typing import Any
from django.core.exceptions import ValidationError
from django.db import transaction
from django.db.models import QuerySet
from django.db.models import Case, QuerySet, Value, When
from django.utils import timezone
from apps.experiments.models import (
@@ -21,23 +21,38 @@ from apps.guardrails.models import (
GuardrailAction,
GuardrailTrigger,
)
from apps.metrics.models import MetricDefinition
from apps.metrics.models import MetricDefinition, MetricDirection
from apps.notifications.services import (
NotificationPayload,
notification_enqueue,
)
from apps.reports.services import calculate_metric_value
_ACTION_SEVERITY = {
GuardrailAction.ROLLBACK: 0,
GuardrailAction.PAUSE: 1,
}
@transaction.atomic
def guardrail_create(
*,
experiment: Experiment,
metric: MetricDefinition,
threshold: Any,
threshold: Decimal,
observation_window_minutes: int = 60,
action: str = GuardrailAction.PAUSE,
) -> Guardrail:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Guardrails cannot be added after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
guardrail = Guardrail(
experiment=experiment,
metric=metric,
@@ -135,9 +150,22 @@ def _calculate_guardrail_metric(
if not values:
return None
direction = guardrail.metric.direction
if direction == MetricDirection.HIGHER_IS_BETTER:
return min(values)
return max(values)
def _is_threshold_breached(
actual_value: Decimal,
threshold: Decimal,
direction: str,
) -> bool:
if direction == MetricDirection.HIGHER_IS_BETTER:
return actual_value < threshold
return actual_value > threshold
@transaction.atomic
def _execute_guardrail_action(
guardrail: Guardrail,
@@ -146,6 +174,12 @@ def _execute_guardrail_action(
) -> GuardrailTrigger:
now = timezone.now()
experiment = Experiment.objects.select_for_update().get(pk=experiment.pk)
if experiment.status != ExperimentStatus.RUNNING:
return None
from_status = experiment.status
trigger = GuardrailTrigger.objects.create(
guardrail=guardrail,
experiment=experiment,
@@ -175,7 +209,7 @@ def _execute_guardrail_action(
"threshold": str(guardrail.threshold),
"actual_value": str(actual_value),
"action": guardrail.action,
"from_status": ExperimentStatus.RUNNING,
"from_status": from_status,
"to_status": ExperimentStatus.PAUSED,
},
)
@@ -213,7 +247,7 @@ def _execute_guardrail_action(
"threshold": str(guardrail.threshold),
"actual_value": str(actual_value),
"action": guardrail.action,
"from_status": ExperimentStatus.RUNNING,
"from_status": from_status,
"to_status": ExperimentStatus.COMPLETED,
"control_variant_id": (
str(control_variant.pk) if control_variant else None
@@ -252,10 +286,21 @@ def check_experiment_guardrails(
if experiment.status != ExperimentStatus.RUNNING:
return []
guardrails = Guardrail.objects.filter(
experiment=experiment,
is_active=True,
).select_related("metric")
guardrails = (
Guardrail.objects.filter(
experiment=experiment,
is_active=True,
)
.select_related("metric")
.annotate(
action_order=Case(
When(action=GuardrailAction.ROLLBACK, then=Value(0)),
When(action=GuardrailAction.PAUSE, then=Value(1)),
default=Value(2),
),
)
.order_by("action_order", "-created_at")
)
triggers: list[GuardrailTrigger] = []
@@ -264,16 +309,19 @@ def check_experiment_guardrails(
if actual_value is None:
continue
if actual_value > guardrail.threshold:
experiment.refresh_from_db()
if experiment.status != ExperimentStatus.RUNNING:
break
if _is_threshold_breached(
actual_value,
guardrail.threshold,
guardrail.metric.direction,
):
trigger = _execute_guardrail_action(
guardrail,
experiment,
actual_value,
)
if trigger is None:
break
triggers.append(trigger)
if guardrail.action in {
@@ -1,5 +1,6 @@
from decimal import Decimal
from django.core.exceptions import ValidationError
from django.test import TestCase
from django.utils import timezone
@@ -27,6 +28,9 @@ from apps.guardrails.services import (
check_all_running_experiments,
check_experiment_guardrails,
guardrail_create,
guardrail_delete,
guardrail_list,
guardrail_update,
)
from apps.metrics.models import MetricDirection, MetricType
from apps.metrics.services import metric_definition_create
@@ -482,3 +486,100 @@ class CheckAllRunningTest(TestCase):
results = check_all_running_experiments()
self.assertEqual(results["triggered"], 1)
self.assertGreater(len(results["triggers"]), 0)
class GuardrailServiceTest(TestCase):
def setUp(self) -> None:
self.experiment = make_experiment(suffix="_gr")
self.metric = metric_definition_create(
key="gr_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "error",
"denominator_event": "exposure",
},
)
def test_create_guardrail(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
observation_window_minutes=30,
action=GuardrailAction.PAUSE,
)
self.assertEqual(g.threshold, Decimal("0.05"))
self.assertEqual(g.action, GuardrailAction.PAUSE)
def test_list_guardrails(self) -> None:
guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
grs = guardrail_list(self.experiment)
self.assertEqual(grs.count(), 1)
def test_update_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
updated = guardrail_update(
guardrail=g,
threshold=Decimal("0.10"),
)
self.assertEqual(updated.threshold, Decimal("0.10"))
def test_reject_update_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gu")
add_two_variants(self.experiment)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError):
guardrail_update(guardrail=g, threshold=Decimal("0.10"))
def test_delete_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
guardrail_delete(guardrail=g)
self.assertEqual(Guardrail.objects.count(), 0)
def test_reject_delete_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gd")
add_two_variants(self.experiment)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError):
guardrail_delete(guardrail=g)