fix(): business logic fixes and code refactoring
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models import Case, QuerySet, Value, When
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.experiments.models import (
|
||||
@@ -21,23 +21,38 @@ from apps.guardrails.models import (
|
||||
GuardrailAction,
|
||||
GuardrailTrigger,
|
||||
)
|
||||
from apps.metrics.models import MetricDefinition
|
||||
from apps.metrics.models import MetricDefinition, MetricDirection
|
||||
from apps.notifications.services import (
|
||||
NotificationPayload,
|
||||
notification_enqueue,
|
||||
)
|
||||
from apps.reports.services import calculate_metric_value
|
||||
|
||||
_ACTION_SEVERITY = {
|
||||
GuardrailAction.ROLLBACK: 0,
|
||||
GuardrailAction.PAUSE: 1,
|
||||
}
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
def guardrail_create(
|
||||
*,
|
||||
experiment: Experiment,
|
||||
metric: MetricDefinition,
|
||||
threshold: Any,
|
||||
threshold: Decimal,
|
||||
observation_window_minutes: int = 60,
|
||||
action: str = GuardrailAction.PAUSE,
|
||||
) -> Guardrail:
|
||||
if experiment.status in STARTED_STATUSES:
|
||||
raise ValidationError(
|
||||
{
|
||||
"experiment": (
|
||||
"Guardrails cannot be added after the experiment "
|
||||
"has been started "
|
||||
f"(status: '{experiment.status}')."
|
||||
)
|
||||
}
|
||||
)
|
||||
guardrail = Guardrail(
|
||||
experiment=experiment,
|
||||
metric=metric,
|
||||
@@ -135,9 +150,22 @@ def _calculate_guardrail_metric(
|
||||
if not values:
|
||||
return None
|
||||
|
||||
direction = guardrail.metric.direction
|
||||
if direction == MetricDirection.HIGHER_IS_BETTER:
|
||||
return min(values)
|
||||
return max(values)
|
||||
|
||||
|
||||
def _is_threshold_breached(
|
||||
actual_value: Decimal,
|
||||
threshold: Decimal,
|
||||
direction: str,
|
||||
) -> bool:
|
||||
if direction == MetricDirection.HIGHER_IS_BETTER:
|
||||
return actual_value < threshold
|
||||
return actual_value > threshold
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
def _execute_guardrail_action(
|
||||
guardrail: Guardrail,
|
||||
@@ -146,6 +174,12 @@ def _execute_guardrail_action(
|
||||
) -> GuardrailTrigger:
|
||||
now = timezone.now()
|
||||
|
||||
experiment = Experiment.objects.select_for_update().get(pk=experiment.pk)
|
||||
if experiment.status != ExperimentStatus.RUNNING:
|
||||
return None
|
||||
|
||||
from_status = experiment.status
|
||||
|
||||
trigger = GuardrailTrigger.objects.create(
|
||||
guardrail=guardrail,
|
||||
experiment=experiment,
|
||||
@@ -175,7 +209,7 @@ def _execute_guardrail_action(
|
||||
"threshold": str(guardrail.threshold),
|
||||
"actual_value": str(actual_value),
|
||||
"action": guardrail.action,
|
||||
"from_status": ExperimentStatus.RUNNING,
|
||||
"from_status": from_status,
|
||||
"to_status": ExperimentStatus.PAUSED,
|
||||
},
|
||||
)
|
||||
@@ -213,7 +247,7 @@ def _execute_guardrail_action(
|
||||
"threshold": str(guardrail.threshold),
|
||||
"actual_value": str(actual_value),
|
||||
"action": guardrail.action,
|
||||
"from_status": ExperimentStatus.RUNNING,
|
||||
"from_status": from_status,
|
||||
"to_status": ExperimentStatus.COMPLETED,
|
||||
"control_variant_id": (
|
||||
str(control_variant.pk) if control_variant else None
|
||||
@@ -252,10 +286,21 @@ def check_experiment_guardrails(
|
||||
if experiment.status != ExperimentStatus.RUNNING:
|
||||
return []
|
||||
|
||||
guardrails = Guardrail.objects.filter(
|
||||
experiment=experiment,
|
||||
is_active=True,
|
||||
).select_related("metric")
|
||||
guardrails = (
|
||||
Guardrail.objects.filter(
|
||||
experiment=experiment,
|
||||
is_active=True,
|
||||
)
|
||||
.select_related("metric")
|
||||
.annotate(
|
||||
action_order=Case(
|
||||
When(action=GuardrailAction.ROLLBACK, then=Value(0)),
|
||||
When(action=GuardrailAction.PAUSE, then=Value(1)),
|
||||
default=Value(2),
|
||||
),
|
||||
)
|
||||
.order_by("action_order", "-created_at")
|
||||
)
|
||||
|
||||
triggers: list[GuardrailTrigger] = []
|
||||
|
||||
@@ -264,16 +309,19 @@ def check_experiment_guardrails(
|
||||
if actual_value is None:
|
||||
continue
|
||||
|
||||
if actual_value > guardrail.threshold:
|
||||
experiment.refresh_from_db()
|
||||
if experiment.status != ExperimentStatus.RUNNING:
|
||||
break
|
||||
|
||||
if _is_threshold_breached(
|
||||
actual_value,
|
||||
guardrail.threshold,
|
||||
guardrail.metric.direction,
|
||||
):
|
||||
trigger = _execute_guardrail_action(
|
||||
guardrail,
|
||||
experiment,
|
||||
actual_value,
|
||||
)
|
||||
if trigger is None:
|
||||
break
|
||||
|
||||
triggers.append(trigger)
|
||||
|
||||
if guardrail.action in {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
@@ -27,6 +28,9 @@ from apps.guardrails.services import (
|
||||
check_all_running_experiments,
|
||||
check_experiment_guardrails,
|
||||
guardrail_create,
|
||||
guardrail_delete,
|
||||
guardrail_list,
|
||||
guardrail_update,
|
||||
)
|
||||
from apps.metrics.models import MetricDirection, MetricType
|
||||
from apps.metrics.services import metric_definition_create
|
||||
@@ -482,3 +486,100 @@ class CheckAllRunningTest(TestCase):
|
||||
results = check_all_running_experiments()
|
||||
self.assertEqual(results["triggered"], 1)
|
||||
self.assertGreater(len(results["triggers"]), 0)
|
||||
|
||||
|
||||
class GuardrailServiceTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.experiment = make_experiment(suffix="_gr")
|
||||
self.metric = metric_definition_create(
|
||||
key="gr_error_rate",
|
||||
name="Error Rate",
|
||||
metric_type=MetricType.RATIO,
|
||||
direction=MetricDirection.LOWER_IS_BETTER,
|
||||
calculation_rule={
|
||||
"type": "ratio",
|
||||
"numerator_event": "error",
|
||||
"denominator_event": "exposure",
|
||||
},
|
||||
)
|
||||
|
||||
def test_create_guardrail(self) -> None:
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
observation_window_minutes=30,
|
||||
action=GuardrailAction.PAUSE,
|
||||
)
|
||||
self.assertEqual(g.threshold, Decimal("0.05"))
|
||||
self.assertEqual(g.action, GuardrailAction.PAUSE)
|
||||
|
||||
def test_list_guardrails(self) -> None:
|
||||
guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
grs = guardrail_list(self.experiment)
|
||||
self.assertEqual(grs.count(), 1)
|
||||
|
||||
def test_update_guardrail_in_draft(self) -> None:
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
updated = guardrail_update(
|
||||
guardrail=g,
|
||||
threshold=Decimal("0.10"),
|
||||
)
|
||||
self.assertEqual(updated.threshold, Decimal("0.10"))
|
||||
|
||||
def test_reject_update_after_start(self) -> None:
|
||||
review_settings_update(
|
||||
default_min_approvals=1, allow_any_approver=True
|
||||
)
|
||||
approver = make_approver("_gu")
|
||||
add_two_variants(self.experiment)
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
exp = experiment_submit_for_review(
|
||||
experiment=self.experiment,
|
||||
user=self.experiment.owner,
|
||||
)
|
||||
exp = experiment_approve(experiment=exp, approver=approver)
|
||||
experiment_start(experiment=exp, user=self.experiment.owner)
|
||||
with self.assertRaises(ValidationError):
|
||||
guardrail_update(guardrail=g, threshold=Decimal("0.10"))
|
||||
|
||||
def test_delete_guardrail_in_draft(self) -> None:
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
guardrail_delete(guardrail=g)
|
||||
self.assertEqual(Guardrail.objects.count(), 0)
|
||||
|
||||
def test_reject_delete_after_start(self) -> None:
|
||||
review_settings_update(
|
||||
default_min_approvals=1, allow_any_approver=True
|
||||
)
|
||||
approver = make_approver("_gd")
|
||||
add_two_variants(self.experiment)
|
||||
g = guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
)
|
||||
exp = experiment_submit_for_review(
|
||||
experiment=self.experiment,
|
||||
user=self.experiment.owner,
|
||||
)
|
||||
exp = experiment_approve(experiment=exp, approver=approver)
|
||||
experiment_start(experiment=exp, user=self.experiment.owner)
|
||||
with self.assertRaises(ValidationError):
|
||||
guardrail_delete(guardrail=g)
|
||||
|
||||
Reference in New Issue
Block a user