fix(): business logic fixes and code refactoring

This commit is contained in:
ITQ
2026-02-24 09:58:07 +03:00
parent e51b74a133
commit 16b48fee40
18 changed files with 307 additions and 140 deletions
+3 -1
View File
@@ -1,9 +1,11 @@
from typing import Any
from ninja import Field, Schema from ninja import Field, Schema
class DecisionIn(Schema): class DecisionIn(Schema):
subject_id: str subject_id: str
subject_attributes: dict[str, str | int | float | bool] = Field( subject_attributes: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
) )
flags: list[str] = Field( flags: list[str] = Field(
+3 -2
View File
@@ -45,6 +45,7 @@ def create_event_type(
name=payload.name, name=payload.name,
display_name=payload.display_name, display_name=payload.display_name,
description=payload.description, description=payload.description,
is_exposure=payload.is_exposure,
requires_exposure=payload.requires_exposure, requires_exposure=payload.requires_exposure,
required_fields=payload.required_fields, required_fields=payload.required_fields,
) )
@@ -95,7 +96,7 @@ def update_event_type(
et = event_type_get(event_type_id) et = event_type_get(event_type_id)
if not et: if not et:
raise Http404 raise Http404
fields = payload.dict(exclude_unset=True) fields = payload.model_dump(exclude_unset=True)
updated = event_type_update(event_type=et, **fields) updated = event_type_update(event_type=et, **fields)
return HTTPStatus.OK, EventTypeOut.model_validate(updated) return HTTPStatus.OK, EventTypeOut.model_validate(updated)
@@ -115,7 +116,7 @@ def ingest_events(
request: HttpRequest, request: HttpRequest,
payload: EventsBatchIn, payload: EventsBatchIn,
) -> tuple[int, EventsBatchOut]: ) -> tuple[int, EventsBatchOut]:
events_data = [e.dict() for e in payload.events] events_data = [e.model_dump() for e in payload.events]
batch = process_events_batch(events_data) batch = process_events_batch(events_data)
EVENTS_INGESTED.labels(status="accepted").inc(batch.accepted) EVENTS_INGESTED.labels(status="accepted").inc(batch.accepted)
EVENTS_INGESTED.labels(status="duplicate").inc(batch.duplicates) EVENTS_INGESTED.labels(status="duplicate").inc(batch.duplicates)
+1 -1
View File
@@ -153,7 +153,7 @@ def update_guardrail(
raise Http404 from Guardrail.DoesNotExist raise Http404 from Guardrail.DoesNotExist
g = guardrail_update( g = guardrail_update(
guardrail=g, guardrail=g,
**payload.dict(exclude_unset=True), **payload.model_dump(exclude_unset=True),
) )
g = Guardrail.objects.select_related("metric").get(pk=g.pk) g = Guardrail.objects.select_related("metric").get(pk=g.pk)
return HTTPStatus.OK, GuardrailOut.from_guardrail(g) return HTTPStatus.OK, GuardrailOut.from_guardrail(g)
+1 -1
View File
@@ -110,7 +110,7 @@ def update_learning(
learning = learning_update( learning = learning_update(
learning=learning, learning=learning,
user=request.auth, user=request.auth,
**payload.dict(exclude_unset=True), **payload.model_dump(exclude_unset=True),
) )
learning = learning_get(learning.pk) learning = learning_get(learning.pk)
return HTTPStatus.OK, LearningOut.from_learning(learning) return HTTPStatus.OK, LearningOut.from_learning(learning)
+1 -1
View File
@@ -91,7 +91,7 @@ def update_metric(
raise Http404 raise Http404
metric = metric_definition_update( metric = metric_definition_update(
metric=metric, metric=metric,
**payload.dict(exclude_unset=True), **payload.model_dump(exclude_unset=True),
) )
return HTTPStatus.OK, MetricDefinitionOut.model_validate(metric) return HTTPStatus.OK, MetricDefinitionOut.model_validate(metric)
@@ -93,7 +93,7 @@ def update_channel(
raise Http404 raise Http404
ch = channel_update( ch = channel_update(
channel=ch, channel=ch,
**payload.dict(exclude_unset=True), **payload.model_dump(exclude_unset=True),
) )
return HTTPStatus.OK, ChannelOut.model_validate(ch) return HTTPStatus.OK, ChannelOut.model_validate(ch)
@@ -178,7 +178,7 @@ def update_rule(
raise Http404 from NotificationRule.DoesNotExist raise Http404 from NotificationRule.DoesNotExist
r = rule_update( r = rule_update(
rule=r, rule=r,
**payload.dict(exclude_unset=True), **payload.model_dump(exclude_unset=True),
) )
r = NotificationRule.objects.select_related("channel", "experiment").get( r = NotificationRule.objects.select_related("channel", "experiment").get(
pk=r.pk pk=r.pk
+13 -4
View File
@@ -5,11 +5,12 @@ from django.http import Http404, HttpRequest
from django.utils.dateparse import parse_datetime from django.utils.dateparse import parse_datetime
from ninja import Router from ninja import Router
from api.v1.auth.endpoints import jwt_bearer
from api.v1.reports.schemas import ExperimentReportOut from api.v1.reports.schemas import ExperimentReportOut
from apps.experiments.models import Experiment from apps.experiments.models import Experiment
from apps.reports.services import build_experiment_report from apps.reports.services import build_experiment_report
router = Router(tags=["reports"]) router = Router(tags=["reports"], auth=jwt_bearer)
@router.get( @router.get(
@@ -26,10 +27,18 @@ def get_experiment_report(
try: try:
experiment = Experiment.objects.get(pk=experiment_id) experiment = Experiment.objects.get(pk=experiment_id)
except Experiment.DoesNotExist: except Experiment.DoesNotExist:
raise Http404 from Experiment.DoesNotExist raise Http404 from None
parsed_start = parse_datetime(start_date) if start_date else None parsed_start = None
parsed_end = parse_datetime(end_date) if end_date else None parsed_end = None
if start_date:
parsed_start = parse_datetime(start_date)
if parsed_start is None:
raise Http404
if end_date:
parsed_end = parse_datetime(end_date)
if parsed_end is None:
raise Http404
report_data = build_experiment_report( report_data = build_experiment_report(
experiment=experiment, experiment=experiment,
@@ -13,11 +13,14 @@ from apps.metrics.services import (
experiment_metric_add, experiment_metric_add,
metric_definition_create, metric_definition_create,
) )
from apps.users.tests.helpers import auth_header, make_user
class ExperimentReportAPITest(TestCase): class ExperimentReportAPITest(TestCase):
@override @override
def setUp(self) -> None: def setUp(self) -> None:
self.user = make_user(username="rapi_user", email="rapi@test.local")
self.auth = auth_header(self.user)
self.client = Client() self.client = Client()
self.exposure_type = make_exposure_type() self.exposure_type = make_exposure_type()
self.click_type = make_event_type( self.click_type = make_event_type(
@@ -75,14 +78,14 @@ class ExperimentReportAPITest(TestCase):
] ]
) )
def test_get_report_no_auth_required(self) -> None: def test_report_requires_auth(self) -> None:
resp = self.client.get( resp = self.client.get(
reverse( reverse(
"api-1:get_experiment_report", "api-1:get_experiment_report",
args=[self.experiment.pk], args=[self.experiment.pk],
), ),
) )
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 401)
def test_report_structure(self) -> None: def test_report_structure(self) -> None:
self._create_data() self._create_data()
@@ -91,6 +94,7 @@ class ExperimentReportAPITest(TestCase):
"api-1:get_experiment_report", "api-1:get_experiment_report",
args=[self.experiment.pk], args=[self.experiment.pk],
), ),
HTTP_AUTHORIZATION=self.auth,
) )
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
data = resp.json() data = resp.json()
@@ -113,6 +117,7 @@ class ExperimentReportAPITest(TestCase):
"api-1:get_experiment_report", "api-1:get_experiment_report",
args=[self.experiment.pk], args=[self.experiment.pk],
), ),
HTTP_AUTHORIZATION=self.auth,
) )
data = resp.json() data = resp.json()
treatment = next( treatment = next(
@@ -138,6 +143,7 @@ class ExperimentReportAPITest(TestCase):
args=[self.experiment.pk], args=[self.experiment.pk],
), ),
{"start_date": start, "end_date": end}, {"start_date": start, "end_date": end},
HTTP_AUTHORIZATION=self.auth,
) )
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
data = resp.json() data = resp.json()
@@ -150,5 +156,6 @@ class ExperimentReportAPITest(TestCase):
"api-1:get_experiment_report", "api-1:get_experiment_report",
args=[uuid.uuid4()], args=[uuid.uuid4()],
), ),
HTTP_AUTHORIZATION=self.auth,
) )
self.assertEqual(resp.status_code, 404) self.assertEqual(resp.status_code, 404)
+1 -1
View File
@@ -46,7 +46,7 @@ def _hash_subject(subject_id: str, experiment_id: str, salt: str) -> Decimal:
def _select_variant( def _select_variant(
variants: list[Variant], hash_value: float variants: list[Variant], hash_value: Decimal
) -> Variant | None: ) -> Variant | None:
cumulative = Decimal(0) cumulative = Decimal(0)
for variant in sorted(variants, key=lambda v: v.name): for variant in sorted(variants, key=lambda v: v.name):
+1 -1
View File
@@ -42,7 +42,7 @@ def _notify(
NotificationPayload( NotificationPayload(
title=f"{event_type.replace('_', ' ').title()}", title=f"{event_type.replace('_', ' ').title()}",
body=( body=(
f"Experiment '{experiment.name}' " f"Experiment '{experiment.name}' - "
f"{event_type.replace('_', ' ')}." f"{event_type.replace('_', ' ')}."
), ),
event_type=event_type, event_type=event_type,
+60 -12
View File
@@ -4,7 +4,7 @@ from typing import Any
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import Case, QuerySet, Value, When
from django.utils import timezone from django.utils import timezone
from apps.experiments.models import ( from apps.experiments.models import (
@@ -21,23 +21,38 @@ from apps.guardrails.models import (
GuardrailAction, GuardrailAction,
GuardrailTrigger, GuardrailTrigger,
) )
from apps.metrics.models import MetricDefinition from apps.metrics.models import MetricDefinition, MetricDirection
from apps.notifications.services import ( from apps.notifications.services import (
NotificationPayload, NotificationPayload,
notification_enqueue, notification_enqueue,
) )
from apps.reports.services import calculate_metric_value from apps.reports.services import calculate_metric_value
_ACTION_SEVERITY = {
GuardrailAction.ROLLBACK: 0,
GuardrailAction.PAUSE: 1,
}
@transaction.atomic @transaction.atomic
def guardrail_create( def guardrail_create(
*, *,
experiment: Experiment, experiment: Experiment,
metric: MetricDefinition, metric: MetricDefinition,
threshold: Any, threshold: Decimal,
observation_window_minutes: int = 60, observation_window_minutes: int = 60,
action: str = GuardrailAction.PAUSE, action: str = GuardrailAction.PAUSE,
) -> Guardrail: ) -> Guardrail:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Guardrails cannot be added after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
guardrail = Guardrail( guardrail = Guardrail(
experiment=experiment, experiment=experiment,
metric=metric, metric=metric,
@@ -135,9 +150,22 @@ def _calculate_guardrail_metric(
if not values: if not values:
return None return None
direction = guardrail.metric.direction
if direction == MetricDirection.HIGHER_IS_BETTER:
return min(values)
return max(values) return max(values)
def _is_threshold_breached(
actual_value: Decimal,
threshold: Decimal,
direction: str,
) -> bool:
if direction == MetricDirection.HIGHER_IS_BETTER:
return actual_value < threshold
return actual_value > threshold
@transaction.atomic @transaction.atomic
def _execute_guardrail_action( def _execute_guardrail_action(
guardrail: Guardrail, guardrail: Guardrail,
@@ -146,6 +174,12 @@ def _execute_guardrail_action(
) -> GuardrailTrigger: ) -> GuardrailTrigger:
now = timezone.now() now = timezone.now()
experiment = Experiment.objects.select_for_update().get(pk=experiment.pk)
if experiment.status != ExperimentStatus.RUNNING:
return None
from_status = experiment.status
trigger = GuardrailTrigger.objects.create( trigger = GuardrailTrigger.objects.create(
guardrail=guardrail, guardrail=guardrail,
experiment=experiment, experiment=experiment,
@@ -175,7 +209,7 @@ def _execute_guardrail_action(
"threshold": str(guardrail.threshold), "threshold": str(guardrail.threshold),
"actual_value": str(actual_value), "actual_value": str(actual_value),
"action": guardrail.action, "action": guardrail.action,
"from_status": ExperimentStatus.RUNNING, "from_status": from_status,
"to_status": ExperimentStatus.PAUSED, "to_status": ExperimentStatus.PAUSED,
}, },
) )
@@ -213,7 +247,7 @@ def _execute_guardrail_action(
"threshold": str(guardrail.threshold), "threshold": str(guardrail.threshold),
"actual_value": str(actual_value), "actual_value": str(actual_value),
"action": guardrail.action, "action": guardrail.action,
"from_status": ExperimentStatus.RUNNING, "from_status": from_status,
"to_status": ExperimentStatus.COMPLETED, "to_status": ExperimentStatus.COMPLETED,
"control_variant_id": ( "control_variant_id": (
str(control_variant.pk) if control_variant else None str(control_variant.pk) if control_variant else None
@@ -252,10 +286,21 @@ def check_experiment_guardrails(
if experiment.status != ExperimentStatus.RUNNING: if experiment.status != ExperimentStatus.RUNNING:
return [] return []
guardrails = Guardrail.objects.filter( guardrails = (
Guardrail.objects.filter(
experiment=experiment, experiment=experiment,
is_active=True, is_active=True,
).select_related("metric") )
.select_related("metric")
.annotate(
action_order=Case(
When(action=GuardrailAction.ROLLBACK, then=Value(0)),
When(action=GuardrailAction.PAUSE, then=Value(1)),
default=Value(2),
),
)
.order_by("action_order", "-created_at")
)
triggers: list[GuardrailTrigger] = [] triggers: list[GuardrailTrigger] = []
@@ -264,16 +309,19 @@ def check_experiment_guardrails(
if actual_value is None: if actual_value is None:
continue continue
if actual_value > guardrail.threshold: if _is_threshold_breached(
experiment.refresh_from_db() actual_value,
if experiment.status != ExperimentStatus.RUNNING: guardrail.threshold,
break guardrail.metric.direction,
):
trigger = _execute_guardrail_action( trigger = _execute_guardrail_action(
guardrail, guardrail,
experiment, experiment,
actual_value, actual_value,
) )
if trigger is None:
break
triggers.append(trigger) triggers.append(trigger)
if guardrail.action in { if guardrail.action in {
@@ -1,5 +1,6 @@
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import ValidationError
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
@@ -27,6 +28,9 @@ from apps.guardrails.services import (
check_all_running_experiments, check_all_running_experiments,
check_experiment_guardrails, check_experiment_guardrails,
guardrail_create, guardrail_create,
guardrail_delete,
guardrail_list,
guardrail_update,
) )
from apps.metrics.models import MetricDirection, MetricType from apps.metrics.models import MetricDirection, MetricType
from apps.metrics.services import metric_definition_create from apps.metrics.services import metric_definition_create
@@ -482,3 +486,100 @@ class CheckAllRunningTest(TestCase):
results = check_all_running_experiments() results = check_all_running_experiments()
self.assertEqual(results["triggered"], 1) self.assertEqual(results["triggered"], 1)
self.assertGreater(len(results["triggers"]), 0) self.assertGreater(len(results["triggers"]), 0)
class GuardrailServiceTest(TestCase):
def setUp(self) -> None:
self.experiment = make_experiment(suffix="_gr")
self.metric = metric_definition_create(
key="gr_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "error",
"denominator_event": "exposure",
},
)
def test_create_guardrail(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
observation_window_minutes=30,
action=GuardrailAction.PAUSE,
)
self.assertEqual(g.threshold, Decimal("0.05"))
self.assertEqual(g.action, GuardrailAction.PAUSE)
def test_list_guardrails(self) -> None:
guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
grs = guardrail_list(self.experiment)
self.assertEqual(grs.count(), 1)
def test_update_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
updated = guardrail_update(
guardrail=g,
threshold=Decimal("0.10"),
)
self.assertEqual(updated.threshold, Decimal("0.10"))
def test_reject_update_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gu")
add_two_variants(self.experiment)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError):
guardrail_update(guardrail=g, threshold=Decimal("0.10"))
def test_delete_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
guardrail_delete(guardrail=g)
self.assertEqual(Guardrail.objects.count(), 0)
def test_reject_delete_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gd")
add_two_variants(self.experiment)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError):
guardrail_delete(guardrail=g)
+32
View File
@@ -5,6 +5,7 @@ from django.core.exceptions import ValidationError
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from apps.experiments.models import STARTED_STATUSES
from apps.metrics.models import ( from apps.metrics.models import (
ExperimentMetric, ExperimentMetric,
MetricDefinition, MetricDefinition,
@@ -41,6 +42,17 @@ def _validate_calculation_rule(
) )
} }
) )
valid = VALID_RULE_FIELDS.get(metric_type, set())
extra = set(rule.keys()) - valid
if extra:
raise ValidationError(
{
"calculation_rule": (
f"Unknown fields for '{metric_type}': "
f"{', '.join(sorted(extra))}."
)
}
)
@transaction.atomic @transaction.atomic
@@ -103,6 +115,16 @@ def experiment_metric_add(
metric: MetricDefinition, metric: MetricDefinition,
is_primary: bool = False, is_primary: bool = False,
) -> ExperimentMetric: ) -> ExperimentMetric:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Metrics cannot be modified after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
if is_primary: if is_primary:
experiment.experiment_metrics.filter(is_primary=True).update( experiment.experiment_metrics.filter(is_primary=True).update(
is_primary=False, is_primary=False,
@@ -122,6 +144,16 @@ def experiment_metric_remove(
experiment: Any, experiment: Any,
metric: MetricDefinition, metric: MetricDefinition,
) -> None: ) -> None:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Metrics cannot be modified after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
deleted, _ = ExperimentMetric.objects.filter( deleted, _ = ExperimentMetric.objects.filter(
experiment=experiment, experiment=experiment,
metric=metric, metric=metric,
+2 -104
View File
@@ -9,13 +9,8 @@ from apps.experiments.services import (
experiment_submit_for_review, experiment_submit_for_review,
) )
from apps.experiments.tests.helpers import add_two_variants, make_experiment from apps.experiments.tests.helpers import add_two_variants, make_experiment
from apps.guardrails.models import Guardrail, GuardrailAction from apps.guardrails.models import GuardrailAction
from apps.guardrails.services import ( from apps.guardrails.services import guardrail_create
guardrail_create,
guardrail_delete,
guardrail_list,
guardrail_update,
)
from apps.metrics.models import ExperimentMetric, MetricDirection, MetricType from apps.metrics.models import ExperimentMetric, MetricDirection, MetricType
from apps.metrics.services import ( from apps.metrics.services import (
experiment_metric_add, experiment_metric_add,
@@ -246,100 +241,3 @@ class ExperimentMetricTest(TestCase):
em1.refresh_from_db() em1.refresh_from_db()
self.assertFalse(em1.is_primary) self.assertFalse(em1.is_primary)
self.assertTrue(em2.is_primary) self.assertTrue(em2.is_primary)
class GuardrailServiceTest(TestCase):
def setUp(self) -> None:
self.experiment = make_experiment(suffix="_gr")
self.metric = metric_definition_create(
key="gr_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "error",
"denominator_event": "exposure",
},
)
def test_create_guardrail(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
observation_window_minutes=30,
action=GuardrailAction.PAUSE,
)
self.assertEqual(g.threshold, Decimal("0.05"))
self.assertEqual(g.action, GuardrailAction.PAUSE)
def test_list_guardrails(self) -> None:
guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
grs = guardrail_list(self.experiment)
self.assertEqual(grs.count(), 1)
def test_update_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
updated = guardrail_update(
guardrail=g,
threshold=Decimal("0.10"),
)
self.assertEqual(updated.threshold, Decimal("0.10"))
def test_reject_update_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gu")
add_two_variants(self.experiment)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
with self.assertRaises(ValidationError):
guardrail_update(guardrail=g, threshold=Decimal("0.10"))
def test_delete_guardrail_in_draft(self) -> None:
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
guardrail_delete(guardrail=g)
self.assertEqual(Guardrail.objects.count(), 0)
def test_reject_delete_after_start(self) -> None:
review_settings_update(
default_min_approvals=1, allow_any_approver=True
)
approver = make_approver("_gd")
add_two_variants(self.experiment)
exp = experiment_submit_for_review(
experiment=self.experiment,
user=self.experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner)
g = guardrail_create(
experiment=self.experiment,
metric=self.metric,
threshold=Decimal("0.05"),
)
with self.assertRaises(ValidationError):
guardrail_delete(guardrail=g)
+2 -2
View File
@@ -129,7 +129,7 @@ def calculate_metric_value(
if not decision_ids: if not decision_ids:
return None return None
metric_type = rule.get("type", metric.metric_type) metric_type = metric.metric_type
if metric_type == MetricType.RATIO: if metric_type == MetricType.RATIO:
numerator = _count_events( numerator = _count_events(
@@ -145,7 +145,7 @@ def calculate_metric_value(
end_date, end_date,
) )
if denominator == 0: if denominator == 0:
return Decimal(0) return None
return Decimal(str(round(numerator / denominator, 6))) return Decimal(str(round(numerator / denominator, 6)))
if metric_type == MetricType.COUNT: if metric_type == MetricType.COUNT:
@@ -212,6 +212,60 @@ class CalculateMetricValueTest(TestCase):
) )
self.assertIsNone(value) self.assertIsNone(value)
def test_percentile_metric(self) -> None:
metric = metric_definition_create(
key="rpt_p95_latency",
name="P95 Latency",
metric_type=MetricType.PERCENTILE,
calculation_rule={
"type": "percentile",
"event": "page_loaded",
"property": "latency_ms",
"percentile": 95,
},
)
self._create_decision_and_exposure(
"dec_pct1",
"u1",
self.v_treatment,
)
for i, latency in enumerate(range(10, 110, 10)):
self._send_event(
f"evt_pct_{i}",
"page_loaded",
"dec_pct1",
"u1",
properties={"latency_ms": latency},
)
value = calculate_metric_value(
metric=metric,
experiment_id=self.experiment.pk,
variant_id=self.v_treatment.pk,
)
self.assertIsNotNone(value)
self.assertGreaterEqual(value, Decimal("90"))
self.assertLessEqual(value, Decimal("100"))
def test_percentile_no_data_returns_none(self) -> None:
metric = metric_definition_create(
key="rpt_p50_empty",
name="P50 Empty",
metric_type=MetricType.PERCENTILE,
calculation_rule={
"type": "percentile",
"event": "page_loaded",
"property": "latency_ms",
"percentile": 50,
},
)
value = calculate_metric_value(
metric=metric,
experiment_id=self.experiment.pk,
variant_id=self.v_control.pk,
)
self.assertIsNone(value)
class BuildExperimentReportTest(TestCase): class BuildExperimentReportTest(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
@@ -21,6 +21,7 @@ from apps.metrics.services import (
) )
from apps.reviews.services import review_settings_update from apps.reviews.services import review_settings_update
from apps.reviews.tests.helpers import make_approver, make_experimenter from apps.reviews.tests.helpers import make_approver, make_experimenter
from apps.users.tests.helpers import auth_header
class APIContractFlowTest(TestCase): class APIContractFlowTest(TestCase):
@@ -35,6 +36,7 @@ class APIContractFlowTest(TestCase):
) )
owner = make_experimenter("_api") owner = make_experimenter("_api")
self.auth = auth_header(owner)
approver = make_approver("_api") approver = make_approver("_api")
self.experiment = make_experiment( self.experiment = make_experiment(
@@ -136,6 +138,7 @@ class APIContractFlowTest(TestCase):
"api-1:get_experiment_report", "api-1:get_experiment_report",
args=[self.experiment.pk], args=[self.experiment.pk],
), ),
HTTP_AUTHORIZATION=self.auth,
) )
self.assertEqual(report_resp.status_code, 200) self.assertEqual(report_resp.status_code, 200)
report = report_resp.json() report = report_resp.json()
@@ -227,5 +230,6 @@ class APIContractFlowTest(TestCase):
"api-1:get_experiment_report", "api-1:get_experiment_report",
args=[uuid.uuid4()], args=[uuid.uuid4()],
), ),
HTTP_AUTHORIZATION=self.auth,
) )
self.assertEqual(resp.status_code, 404) self.assertEqual(resp.status_code, 404)
@@ -36,11 +36,15 @@ def _start_experiment(owner, approver, suffix, traffic=Decimal("100.00")):
traffic_allocation=traffic, traffic_allocation=traffic,
) )
add_two_variants(experiment) add_two_variants(experiment)
experiment = experiment_submit_for_review( return experiment
experiment=experiment, user=owner
def _submit_and_start(experiment, approver):
exp = experiment_submit_for_review(
experiment=experiment, user=experiment.owner
) )
experiment = experiment_approve(experiment=experiment, approver=approver) exp = experiment_approve(experiment=exp, approver=approver)
return experiment_start(experiment=experiment, user=owner) return experiment_start(experiment=exp, user=experiment.owner)
class GuardrailPauseIntegrationTest(TestCase): class GuardrailPauseIntegrationTest(TestCase):
@@ -82,6 +86,8 @@ class GuardrailPauseIntegrationTest(TestCase):
action=GuardrailAction.PAUSE, action=GuardrailAction.PAUSE,
) )
self.experiment = _submit_and_start(self.experiment, self.approver)
def test_guardrail_pauses_experiment_on_threshold_breach(self) -> None: def test_guardrail_pauses_experiment_on_threshold_breach(self) -> None:
now = timezone.now().isoformat() now = timezone.now().isoformat()
cache.clear() cache.clear()
@@ -186,6 +192,8 @@ class GuardrailRollbackIntegrationTest(TestCase):
action=GuardrailAction.ROLLBACK, action=GuardrailAction.ROLLBACK,
) )
self.experiment = _submit_and_start(self.experiment, self.approver)
def test_rollback_completes_experiment_with_control_winner(self) -> None: def test_rollback_completes_experiment_with_control_winner(self) -> None:
now = timezone.now().isoformat() now = timezone.now().isoformat()
cache.clear() cache.clear()
@@ -255,6 +263,9 @@ class GuardrailCheckAllTest(TestCase):
action=GuardrailAction.PAUSE, action=GuardrailAction.PAUSE,
) )
self.exp1 = _submit_and_start(self.exp1, self.approver)
self.exp2 = _submit_and_start(self.exp2, self.approver)
make_exposure_type(name="gca_exposure") make_exposure_type(name="gca_exposure")
make_event_type(name="gca_error", display_name="Error") make_event_type(name="gca_error", display_name="Error")