fix(): business logic fixes and code refactoring

This commit is contained in:
ITQ
2026-02-24 09:58:07 +03:00
parent e51b74a133
commit 16b48fee40
18 changed files with 307 additions and 140 deletions
+3 -1
View File
@@ -1,9 +1,11 @@
from typing import Any
from ninja import Field, Schema
class DecisionIn(Schema):
subject_id: str
subject_attributes: dict[str, str | int | float | bool] = Field(
subject_attributes: dict[str, Any] = Field(
default_factory=dict,
)
flags: list[str] = Field(
+3 -2
View File
@@ -45,6 +45,7 @@ def create_event_type(
name=payload.name,
display_name=payload.display_name,
description=payload.description,
is_exposure=payload.is_exposure,
requires_exposure=payload.requires_exposure,
required_fields=payload.required_fields,
)
@@ -95,7 +96,7 @@ def update_event_type(
et = event_type_get(event_type_id)
if not et:
raise Http404
fields = payload.dict(exclude_unset=True)
fields = payload.model_dump(exclude_unset=True)
updated = event_type_update(event_type=et, **fields)
return HTTPStatus.OK, EventTypeOut.model_validate(updated)
@@ -115,7 +116,7 @@ def ingest_events(
request: HttpRequest,
payload: EventsBatchIn,
) -> tuple[int, EventsBatchOut]:
events_data = [e.dict() for e in payload.events]
events_data = [e.model_dump() for e in payload.events]
batch = process_events_batch(events_data)
EVENTS_INGESTED.labels(status="accepted").inc(batch.accepted)
EVENTS_INGESTED.labels(status="duplicate").inc(batch.duplicates)
+1 -1
View File
@@ -153,7 +153,7 @@ def update_guardrail(
raise Http404 from Guardrail.DoesNotExist
g = guardrail_update(
guardrail=g,
**payload.dict(exclude_unset=True),
**payload.model_dump(exclude_unset=True),
)
g = Guardrail.objects.select_related("metric").get(pk=g.pk)
return HTTPStatus.OK, GuardrailOut.from_guardrail(g)
+1 -1
View File
@@ -110,7 +110,7 @@ def update_learning(
learning = learning_update(
learning=learning,
user=request.auth,
**payload.dict(exclude_unset=True),
**payload.model_dump(exclude_unset=True),
)
learning = learning_get(learning.pk)
return HTTPStatus.OK, LearningOut.from_learning(learning)
+1 -1
View File
@@ -91,7 +91,7 @@ def update_metric(
raise Http404
metric = metric_definition_update(
metric=metric,
**payload.dict(exclude_unset=True),
**payload.model_dump(exclude_unset=True),
)
return HTTPStatus.OK, MetricDefinitionOut.model_validate(metric)
@@ -93,7 +93,7 @@ def update_channel(
raise Http404
ch = channel_update(
channel=ch,
**payload.dict(exclude_unset=True),
**payload.model_dump(exclude_unset=True),
)
return HTTPStatus.OK, ChannelOut.model_validate(ch)
@@ -178,7 +178,7 @@ def update_rule(
raise Http404 from NotificationRule.DoesNotExist
r = rule_update(
rule=r,
**payload.dict(exclude_unset=True),
**payload.model_dump(exclude_unset=True),
)
r = NotificationRule.objects.select_related("channel", "experiment").get(
pk=r.pk
+13 -4
View File
@@ -5,11 +5,12 @@ from django.http import Http404, HttpRequest
from django.utils.dateparse import parse_datetime
from ninja import Router
from api.v1.auth.endpoints import jwt_bearer
from api.v1.reports.schemas import ExperimentReportOut
from apps.experiments.models import Experiment
from apps.reports.services import build_experiment_report
router = Router(tags=["reports"])
router = Router(tags=["reports"], auth=jwt_bearer)
@router.get(
@@ -26,10 +27,18 @@ def get_experiment_report(
try:
experiment = Experiment.objects.get(pk=experiment_id)
except Experiment.DoesNotExist:
raise Http404 from Experiment.DoesNotExist
raise Http404 from None
parsed_start = parse_datetime(start_date) if start_date else None
parsed_end = parse_datetime(end_date) if end_date else None
parsed_start = None
parsed_end = None
if start_date:
parsed_start = parse_datetime(start_date)
if parsed_start is None:
raise Http404
if end_date:
parsed_end = parse_datetime(end_date)
if parsed_end is None:
raise Http404
report_data = build_experiment_report(
experiment=experiment,
@@ -13,11 +13,14 @@ from apps.metrics.services import (
experiment_metric_add,
metric_definition_create,
)
from apps.users.tests.helpers import auth_header, make_user
class ExperimentReportAPITest(TestCase):
@override
def setUp(self) -> None:
self.user = make_user(username="rapi_user", email="rapi@test.local")
self.auth = auth_header(self.user)
self.client = Client()
self.exposure_type = make_exposure_type()
self.click_type = make_event_type(
@@ -75,14 +78,14 @@ class ExperimentReportAPITest(TestCase):
]
)
def test_get_report_no_auth_required(self) -> None:
def test_report_requires_auth(self) -> None:
resp = self.client.get(
reverse(
"api-1:get_experiment_report",
args=[self.experiment.pk],
),
)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.status_code, 401)
def test_report_structure(self) -> None:
self._create_data()
@@ -91,6 +94,7 @@ class ExperimentReportAPITest(TestCase):
"api-1:get_experiment_report",
args=[self.experiment.pk],
),
HTTP_AUTHORIZATION=self.auth,
)
self.assertEqual(resp.status_code, 200)
data = resp.json()
@@ -113,6 +117,7 @@ class ExperimentReportAPITest(TestCase):
"api-1:get_experiment_report",
args=[self.experiment.pk],
),
HTTP_AUTHORIZATION=self.auth,
)
data = resp.json()
treatment = next(
@@ -138,6 +143,7 @@ class ExperimentReportAPITest(TestCase):
args=[self.experiment.pk],
),
{"start_date": start, "end_date": end},
HTTP_AUTHORIZATION=self.auth,
)
self.assertEqual(resp.status_code, 200)
data = resp.json()
@@ -150,5 +156,6 @@ class ExperimentReportAPITest(TestCase):
"api-1:get_experiment_report",
args=[uuid.uuid4()],
),
HTTP_AUTHORIZATION=self.auth,
)
self.assertEqual(resp.status_code, 404)
+1 -1
View File
@@ -46,7 +46,7 @@ def _hash_subject(subject_id: str, experiment_id: str, salt: str) -> Decimal:
def _select_variant(
variants: list[Variant], hash_value: float
variants: list[Variant], hash_value: Decimal
) -> Variant | None:
cumulative = Decimal(0)
for variant in sorted(variants, key=lambda v: v.name):
+1 -1
View File
@@ -42,7 +42,7 @@ def _notify(
NotificationPayload(
title=f"{event_type.replace('_', ' ').title()}",
body=(
f"Experiment '{experiment.name}' "
f"Experiment '{experiment.name}' - "
f"{event_type.replace('_', ' ')}."
),
event_type=event_type,
+62 -14
View File
@@ -4,7 +4,7 @@ from typing import Any
from django.core.exceptions import ValidationError
from django.db import transaction
from django.db.models import QuerySet
from django.db.models import Case, QuerySet, Value, When
from django.utils import timezone
from apps.experiments.models import (
@@ -21,23 +21,38 @@ from apps.guardrails.models import (
GuardrailAction,
GuardrailTrigger,
)
from apps.metrics.models import MetricDefinition
from apps.metrics.models import MetricDefinition, MetricDirection
from apps.notifications.services import (
NotificationPayload,
notification_enqueue,
)
from apps.reports.services import calculate_metric_value
_ACTION_SEVERITY = {
GuardrailAction.ROLLBACK: 0,
GuardrailAction.PAUSE: 1,
}
@transaction.atomic
def guardrail_create(
*,
experiment: Experiment,
metric: MetricDefinition,
threshold: Any,
threshold: Decimal,
observation_window_minutes: int = 60,
action: str = GuardrailAction.PAUSE,
) -> Guardrail:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Guardrails cannot be added after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
guardrail = Guardrail(
experiment=experiment,
metric=metric,
@@ -135,9 +150,22 @@ def _calculate_guardrail_metric(
if not values:
return None
direction = guardrail.metric.direction
if direction == MetricDirection.HIGHER_IS_BETTER:
return min(values)
return max(values)
def _is_threshold_breached(
actual_value: Decimal,
threshold: Decimal,
direction: str,
) -> bool:
if direction == MetricDirection.HIGHER_IS_BETTER:
return actual_value < threshold
return actual_value > threshold
@transaction.atomic
def _execute_guardrail_action(
guardrail: Guardrail,
@@ -146,6 +174,12 @@ def _execute_guardrail_action(
) -> GuardrailTrigger:
now = timezone.now()
experiment = Experiment.objects.select_for_update().get(pk=experiment.pk)
if experiment.status != ExperimentStatus.RUNNING:
return None
from_status = experiment.status
trigger = GuardrailTrigger.objects.create(
guardrail=guardrail,
experiment=experiment,
@@ -175,7 +209,7 @@ def _execute_guardrail_action(
"threshold": str(guardrail.threshold),
"actual_value": str(actual_value),
"action": guardrail.action,
"from_status": ExperimentStatus.RUNNING,
"from_status": from_status,
"to_status": ExperimentStatus.PAUSED,
},
)
@@ -213,7 +247,7 @@ def _execute_guardrail_action(
"threshold": str(guardrail.threshold),
"actual_value": str(actual_value),
"action": guardrail.action,
"from_status": ExperimentStatus.RUNNING,
"from_status": from_status,
"to_status": ExperimentStatus.COMPLETED,
"control_variant_id": (
str(control_variant.pk) if control_variant else None
@@ -252,10 +286,21 @@ def check_experiment_guardrails(
if experiment.status != ExperimentStatus.RUNNING:
return []
guardrails = Guardrail.objects.filter(
experiment=experiment,
is_active=True,
).select_related("metric")
guardrails = (
Guardrail.objects.filter(
experiment=experiment,
is_active=True,
)
.select_related("metric")
.annotate(
action_order=Case(
When(action=GuardrailAction.ROLLBACK, then=Value(0)),
When(action=GuardrailAction.PAUSE, then=Value(1)),
default=Value(2),
),
)
.order_by("action_order", "-created_at")
)
triggers: list[GuardrailTrigger] = []
@@ -264,16 +309,19 @@ def check_experiment_guardrails(
if actual_value is None:
continue
if actual_value > guardrail.threshold:
experiment.refresh_from_db()
if experiment.status != ExperimentStatus.RUNNING:
break
if _is_threshold_breached(
actual_value,
guardrail.threshold,
guardrail.metric.direction,
):
trigger = _execute_guardrail_action(
guardrail,
experiment,
actual_value,
)
if trigger is None:
break
triggers.append(trigger)
if guardrail.action in {
@@ -1,5 +1,6 @@
from decimal import Decimal
from django.core.exceptions import ValidationError
from django.test import TestCase
from django.utils import timezone
@@ -27,6 +28,9 @@ from apps.guardrails.services import (
check_all_running_experiments,
check_experiment_guardrails,
guardrail_create,
guardrail_delete,
guardrail_list,
guardrail_update,
)
from apps.metrics.models import MetricDirection, MetricType
from apps.metrics.services import metric_definition_create
@@ -482,3 +486,100 @@ class CheckAllRunningTest(TestCase):
results = check_all_running_experiments()
self.assertEqual(results["triggered"], 1)
self.assertGreater(len(results["triggers"]), 0)
class GuardrailServiceTest(TestCase):
def setUp(self) -> None:
self.experiment = make_experiment(suffix="_gr")
self.metric = metric_definition_create(
key="gr_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "error",
"denominator_event": "exposure",
},
)
def test_create_guardrail(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
observation_window_minutes=30,
action=GuardrailAction.PAUSE,
)
self.assertEqual(g.threshold, Decimal("0.05"))
self.assertEqual(g.action, GuardrailAction.PAUSE)
def test_list_guardrails(self) -> None:
guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
grs = guardrail_list(self.experiment)
self.assertEqual(grs.count(), 1)
def test_update_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
updated = guardrail_update(
guardrail=g,
threshold=Decimal("0.10"),
)
self.assertEqual(updated.threshold, Decimal("0.10"))
def test_reject_update_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gu")
add_two_variants(self.experiment)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError):
guardrail_update(guardrail=g, threshold=Decimal("0.10"))
def test_delete_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
guardrail_delete(guardrail=g)
self.assertEqual(Guardrail.objects.count(), 0)
def test_reject_delete_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gd")
add_two_variants(self.experiment)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError):
guardrail_delete(guardrail=g)
+32
View File
@@ -5,6 +5,7 @@ from django.core.exceptions import ValidationError
from django.db import transaction
from django.db.models import QuerySet
from apps.experiments.models import STARTED_STATUSES
from apps.metrics.models import (
ExperimentMetric,
MetricDefinition,
@@ -41,6 +42,17 @@ def _validate_calculation_rule(
)
}
)
valid = VALID_RULE_FIELDS.get(metric_type, set())
extra = set(rule.keys()) - valid
if extra:
raise ValidationError(
{
"calculation_rule": (
f"Unknown fields for '{metric_type}': "
f"{', '.join(sorted(extra))}."
)
}
)
@transaction.atomic
@@ -103,6 +115,16 @@ def experiment_metric_add(
metric: MetricDefinition,
is_primary: bool = False,
) -> ExperimentMetric:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Metrics cannot be modified after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
if is_primary:
experiment.experiment_metrics.filter(is_primary=True).update(
is_primary=False,
@@ -122,6 +144,16 @@ def experiment_metric_remove(
experiment: Any,
metric: MetricDefinition,
) -> None:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Metrics cannot be modified after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
deleted, _ = ExperimentMetric.objects.filter(
experiment=experiment,
metric=metric,
+2 -104
View File
@@ -9,13 +9,8 @@ from apps.experiments.services import (
experiment_submit_for_review,
)
from apps.experiments.tests.helpers import add_two_variants, make_experiment
from apps.guardrails.models import Guardrail, GuardrailAction
from apps.guardrails.services import (
guardrail_create,
guardrail_delete,
guardrail_list,
guardrail_update,
)
from apps.guardrails.models import GuardrailAction
from apps.guardrails.services import guardrail_create
from apps.metrics.models import ExperimentMetric, MetricDirection, MetricType
from apps.metrics.services import (
experiment_metric_add,
@@ -246,100 +241,3 @@ class ExperimentMetricTest(TestCase):
em1.refresh_from_db()
self.assertFalse(em1.is_primary)
self.assertTrue(em2.is_primary)
class GuardrailServiceTest(TestCase):
def setUp(self) -> None:
self.experiment = make_experiment(suffix="_gr")
self.metric = metric_definition_create(
key="gr_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "error",
"denominator_event": "exposure",
},
)
def test_create_guardrail(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
observation_window_minutes=30,
action=GuardrailAction.PAUSE,
)
self.assertEqual(g.threshold, Decimal("0.05"))
self.assertEqual(g.action, GuardrailAction.PAUSE)
def test_list_guardrails(self) -> None:
guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
grs = guardrail_list(self.experiment)
self.assertEqual(grs.count(), 1)
def test_update_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
updated = guardrail_update(
guardrail=g,
threshold=Decimal("0.10"),
)
self.assertEqual(updated.threshold, Decimal("0.10"))
def test_reject_update_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gu")
add_two_variants(self.experiment)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
with self.assertRaises(ValidationError):
guardrail_update(guardrail=g, threshold=Decimal("0.10"))
def test_delete_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
guardrail_delete(guardrail=g)
self.assertEqual(Guardrail.objects.count(), 0)
def test_reject_delete_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gd")
add_two_variants(self.experiment)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
with self.assertRaises(ValidationError):
guardrail_delete(guardrail=g)
+2 -2
View File
@@ -129,7 +129,7 @@ def calculate_metric_value(
if not decision_ids:
return None
metric_type = rule.get("type", metric.metric_type)
metric_type = metric.metric_type
if metric_type == MetricType.RATIO:
numerator = _count_events(
@@ -145,7 +145,7 @@ def calculate_metric_value(
end_date,
)
if denominator == 0:
return Decimal(0)
return None
return Decimal(str(round(numerator / denominator, 6)))
if metric_type == MetricType.COUNT:
@@ -212,6 +212,60 @@ class CalculateMetricValueTest(TestCase):
)
self.assertIsNone(value)
def test_percentile_metric(self) -> None:
metric = metric_definition_create(
key="rpt_p95_latency",
name="P95 Latency",
metric_type=MetricType.PERCENTILE,
calculation_rule={
"type": "percentile",
"event": "page_loaded",
"property": "latency_ms",
"percentile": 95,
},
)
self._create_decision_and_exposure(
"dec_pct1",
"u1",
self.v_treatment,
)
for i, latency in enumerate(range(10, 110, 10)):
self._send_event(
f"evt_pct_{i}",
"page_loaded",
"dec_pct1",
"u1",
properties={"latency_ms": latency},
)
value = calculate_metric_value(
metric=metric,
experiment_id=self.experiment.pk,
variant_id=self.v_treatment.pk,
)
self.assertIsNotNone(value)
self.assertGreaterEqual(value, Decimal("90"))
self.assertLessEqual(value, Decimal("100"))
def test_percentile_no_data_returns_none(self) -> None:
metric = metric_definition_create(
key="rpt_p50_empty",
name="P50 Empty",
metric_type=MetricType.PERCENTILE,
calculation_rule={
"type": "percentile",
"event": "page_loaded",
"property": "latency_ms",
"percentile": 50,
},
)
value = calculate_metric_value(
metric=metric,
experiment_id=self.experiment.pk,
variant_id=self.v_control.pk,
)
self.assertIsNone(value)
class BuildExperimentReportTest(TestCase):
def setUp(self) -> None:
@@ -21,6 +21,7 @@ from apps.metrics.services import (
)
from apps.reviews.services import review_settings_update
from apps.reviews.tests.helpers import make_approver, make_experimenter
from apps.users.tests.helpers import auth_header
class APIContractFlowTest(TestCase):
@@ -35,6 +36,7 @@ class APIContractFlowTest(TestCase):
)
owner = make_experimenter("_api")
self.auth = auth_header(owner)
approver = make_approver("_api")
self.experiment = make_experiment(
@@ -136,6 +138,7 @@ class APIContractFlowTest(TestCase):
"api-1:get_experiment_report",
args=[self.experiment.pk],
),
HTTP_AUTHORIZATION=self.auth,
)
self.assertEqual(report_resp.status_code, 200)
report = report_resp.json()
@@ -227,5 +230,6 @@ class APIContractFlowTest(TestCase):
"api-1:get_experiment_report",
args=[uuid.uuid4()],
),
HTTP_AUTHORIZATION=self.auth,
)
self.assertEqual(resp.status_code, 404)
@@ -36,11 +36,15 @@ def _start_experiment(owner, approver, suffix, traffic=Decimal("100.00")):
traffic_allocation=traffic,
)
add_two_variants(experiment)
experiment = experiment_submit_for_review(
experiment=experiment, user=owner
return experiment
def _submit_and_start(experiment, approver):
exp = experiment_submit_for_review(
experiment=experiment, user=experiment.owner
)
experiment = experiment_approve(experiment=experiment, approver=approver)
return experiment_start(experiment=experiment, user=owner)
exp = experiment_approve(experiment=exp, approver=approver)
return experiment_start(experiment=exp, user=experiment.owner)
class GuardrailPauseIntegrationTest(TestCase):
@@ -82,6 +86,8 @@ class GuardrailPauseIntegrationTest(TestCase):
action=GuardrailAction.PAUSE,
)
self.experiment = _submit_and_start(self.experiment, self.approver)
def test_guardrail_pauses_experiment_on_threshold_breach(self) -> None:
now = timezone.now().isoformat()
cache.clear()
@@ -186,6 +192,8 @@ class GuardrailRollbackIntegrationTest(TestCase):
action=GuardrailAction.ROLLBACK,
)
self.experiment = _submit_and_start(self.experiment, self.approver)
def test_rollback_completes_experiment_with_control_winner(self) -> None:
now = timezone.now().isoformat()
cache.clear()
@@ -255,6 +263,9 @@ class GuardrailCheckAllTest(TestCase):
action=GuardrailAction.PAUSE,
)
self.exp1 = _submit_and_start(self.exp1, self.approver)
self.exp2 = _submit_and_start(self.exp2, self.approver)
make_exposure_type(name="gca_exposure")
make_event_type(name="gca_error", display_name="Error")