From cdf104af8ed5b40f1fc0099e0d01542a5199eafb Mon Sep 17 00:00:00 2001 From: ITQ Date: Mon, 23 Feb 2026 10:56:10 +0300 Subject: [PATCH] feat(guardrails): added guardrails business logic --- src/backend/apps/guardrails/__init__.py | 0 src/backend/apps/guardrails/apps.py | 5 + .../guardrails/migrations/0001_initial.py | 65 +++ .../apps/guardrails/migrations/__init__.py | 0 src/backend/apps/guardrails/models.py | 131 +++++ src/backend/apps/guardrails/services.py | 318 ++++++++++++ src/backend/apps/guardrails/tasks.py | 20 + src/backend/apps/guardrails/tests/__init__.py | 0 .../apps/guardrails/tests/test_guardrails.py | 484 ++++++++++++++++++ 9 files changed, 1023 insertions(+) create mode 100644 src/backend/apps/guardrails/__init__.py create mode 100644 src/backend/apps/guardrails/apps.py create mode 100644 src/backend/apps/guardrails/migrations/0001_initial.py create mode 100644 src/backend/apps/guardrails/migrations/__init__.py create mode 100644 src/backend/apps/guardrails/models.py create mode 100644 src/backend/apps/guardrails/services.py create mode 100644 src/backend/apps/guardrails/tasks.py create mode 100644 src/backend/apps/guardrails/tests/__init__.py create mode 100644 src/backend/apps/guardrails/tests/test_guardrails.py diff --git a/src/backend/apps/guardrails/__init__.py b/src/backend/apps/guardrails/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/backend/apps/guardrails/apps.py b/src/backend/apps/guardrails/apps.py new file mode 100644 index 0000000..353ad86 --- /dev/null +++ b/src/backend/apps/guardrails/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class GuardrailsConfig(AppConfig): + name = "apps.guardrails" diff --git a/src/backend/apps/guardrails/migrations/0001_initial.py b/src/backend/apps/guardrails/migrations/0001_initial.py new file mode 100644 index 0000000..4ce0312 --- /dev/null +++ b/src/backend/apps/guardrails/migrations/0001_initial.py @@ -0,0 +1,65 @@ +# Generated by Django 5.2.11 on 2026-02-14 09:55 + +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('experiments', '0001_initial'), + ('metrics', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Guardrail', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('threshold', models.DecimalField(decimal_places=4, max_digits=10, verbose_name='threshold')), + ('observation_window_minutes', models.PositiveIntegerField(default=60, verbose_name='observation window (minutes)')), + ('action', models.CharField(choices=[('pause', 'Pause experiment'), ('rollback', 'Rollback to control')], max_length=20, verbose_name='action on trigger')), + ('is_active', models.BooleanField(default=True, verbose_name='is active')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='created at')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='updated at')), + ('experiment', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='guardrails', to='experiments.experiment', verbose_name='experiment')), + ('metric', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='guardrail_usages', to='metrics.metricdefinition', verbose_name='metric')), + ], + options={ + 'verbose_name': 'guardrail', + 'verbose_name_plural': 'guardrails', + 'ordering': ['-created_at'], + }, + ), + migrations.CreateModel( + name='GuardrailTrigger', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('metric_key', models.CharField(max_length=100, verbose_name='metric key')), + ('threshold', models.DecimalField(decimal_places=4, max_digits=10, verbose_name='threshold')), + ('actual_value', models.DecimalField(decimal_places=4, max_digits=10, verbose_name='actual value')), + ('observation_window_minutes', models.PositiveIntegerField(verbose_name='observation window (minutes)')), + ('action', models.CharField(max_length=20, verbose_name='action taken')), + ('triggered_at', models.DateTimeField(verbose_name='triggered at')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='created at')), + ('experiment', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='guardrail_triggers', to='experiments.experiment', verbose_name='experiment')), + ('guardrail', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='triggers', to='guardrails.guardrail', verbose_name='guardrail')), + ], + options={ + 'verbose_name': 'guardrail trigger', + 'verbose_name_plural': 'guardrail triggers', + 'ordering': ['-triggered_at'], + }, + ), + migrations.AddIndex( + model_name='guardrail', + index=models.Index(fields=['experiment', 'is_active'], name='idx_guardrail_exp_active'), + ), + migrations.AddIndex( + model_name='guardrailtrigger', + index=models.Index(fields=['experiment', '-triggered_at'], name='idx_trigger_exp_time'), + ), + ] diff --git a/src/backend/apps/guardrails/migrations/__init__.py b/src/backend/apps/guardrails/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/backend/apps/guardrails/models.py b/src/backend/apps/guardrails/models.py new file mode 100644 index 0000000..8c10557 --- /dev/null +++ b/src/backend/apps/guardrails/models.py @@ -0,0 +1,131 @@ +from typing import override + +from django.db import models +from django.utils.translation import gettext_lazy as _ + +from apps.core.models import BaseModel + + +class GuardrailAction(models.TextChoices): + PAUSE = "pause", _("Pause experiment") + ROLLBACK = "rollback", _("Rollback to control") + + +class Guardrail(BaseModel): + experiment = models.ForeignKey( + "experiments.Experiment", + on_delete=models.CASCADE, + related_name="guardrails", + verbose_name=_("experiment"), + ) + metric = models.ForeignKey( + "metrics.MetricDefinition", + on_delete=models.PROTECT, + related_name="guardrail_usages", + verbose_name=_("metric"), + ) + threshold = models.DecimalField( + max_digits=10, + decimal_places=4, + verbose_name=_("threshold"), + ) + observation_window_minutes = models.PositiveIntegerField( + default=60, + verbose_name=_("observation window (minutes)"), + ) + action = models.CharField( + max_length=20, + choices=GuardrailAction.choices, + verbose_name=_("action on trigger"), + ) + is_active = models.BooleanField( + default=True, + verbose_name=_("is active"), + ) + created_at = models.DateTimeField( + auto_now_add=True, + verbose_name=_("created at"), + ) + updated_at = models.DateTimeField( + auto_now=True, + verbose_name=_("updated at"), + ) + + class Meta: + verbose_name = _("guardrail") + verbose_name_plural = _("guardrails") + ordering = ["-created_at"] + indexes = [ + models.Index( + fields=["experiment", "is_active"], + name="idx_guardrail_exp_active", + ), + ] + + @override + def __str__(self) -> str: + return ( + f"Guardrail({self.metric.key} > {self.threshold}, " + f"action={self.action})" + ) + + +class GuardrailTrigger(BaseModel): + guardrail = models.ForeignKey( + Guardrail, + on_delete=models.CASCADE, + related_name="triggers", + verbose_name=_("guardrail"), + ) + experiment = models.ForeignKey( + "experiments.Experiment", + on_delete=models.CASCADE, + related_name="guardrail_triggers", + verbose_name=_("experiment"), + ) + metric_key = models.CharField( + max_length=100, + verbose_name=_("metric key"), + ) + threshold = models.DecimalField( + max_digits=10, + decimal_places=4, + verbose_name=_("threshold"), + ) + actual_value = models.DecimalField( + max_digits=10, + decimal_places=4, + verbose_name=_("actual value"), + ) + observation_window_minutes = models.PositiveIntegerField( + verbose_name=_("observation window (minutes)"), + ) + action = models.CharField( + max_length=20, + verbose_name=_("action taken"), + ) + triggered_at = models.DateTimeField( + verbose_name=_("triggered at"), + ) + created_at = models.DateTimeField( + auto_now_add=True, + verbose_name=_("created at"), + ) + + class Meta: + verbose_name = _("guardrail trigger") + verbose_name_plural = _("guardrail triggers") + ordering = ["-triggered_at"] + indexes = [ + models.Index( + fields=["experiment", "-triggered_at"], + name="idx_trigger_exp_time", + ), + ] + + @override + def __str__(self) -> str: + return ( + f"Trigger({self.metric_key}: " + f"{self.actual_value} > {self.threshold})" + ) diff --git a/src/backend/apps/guardrails/services.py b/src/backend/apps/guardrails/services.py new file mode 100644 index 0000000..1f52113 --- /dev/null +++ b/src/backend/apps/guardrails/services.py @@ -0,0 +1,318 @@ +from datetime import timedelta +from decimal import Decimal +from typing import Any + +from django.core.exceptions import ValidationError +from django.db import transaction +from django.db.models import QuerySet +from django.utils import timezone + +from apps.experiments.models import ( + STARTED_STATUSES, + Experiment, + ExperimentLog, + ExperimentOutcome, + ExperimentStatus, + LogType, + OutcomeType, +) +from apps.guardrails.models import ( + Guardrail, + GuardrailAction, + GuardrailTrigger, +) +from apps.metrics.models import MetricDefinition +from apps.notifications.services import ( + NotificationPayload, + notification_enqueue, +) +from apps.reports.services import calculate_metric_value + + +@transaction.atomic +def guardrail_create( + *, + experiment: Experiment, + metric: MetricDefinition, + threshold: Any, + observation_window_minutes: int = 60, + action: str = GuardrailAction.PAUSE, +) -> Guardrail: + guardrail = Guardrail( + experiment=experiment, + metric=metric, + threshold=threshold, + observation_window_minutes=observation_window_minutes, + action=action, + ) + guardrail.save() + return guardrail + + +def guardrail_update( + *, + guardrail: Guardrail, + **fields: Any, +) -> Guardrail: + if guardrail.experiment.status in STARTED_STATUSES: + raise ValidationError( + { + "experiment": ( + "Guardrails cannot be modified after the experiment " + "has been started " + f"(status: '{guardrail.experiment.status}')." + ) + } + ) + allowed = { + "threshold", + "observation_window_minutes", + "action", + "is_active", + } + for key in fields: + if key not in allowed: + raise ValidationError({key: f"Field '{key}' cannot be updated."}) + for key, value in fields.items(): + if value is not None: + setattr(guardrail, key, value) + guardrail.save() + return guardrail + + +def guardrail_delete(*, guardrail: Guardrail) -> None: + if guardrail.experiment.status in STARTED_STATUSES: + raise ValidationError( + { + "experiment": ( + "Guardrails cannot be deleted after the experiment " + "has been started " + f"(status: '{guardrail.experiment.status}')." + ) + } + ) + guardrail.delete() + + +def guardrail_list( + experiment: Experiment, +) -> QuerySet[Guardrail]: + return Guardrail.objects.filter(experiment=experiment).select_related( + "metric" + ) + + +def guardrail_trigger_list( + experiment: Experiment, +) -> QuerySet[GuardrailTrigger]: + return GuardrailTrigger.objects.filter( + experiment=experiment + ).select_related("guardrail", "guardrail__metric") + + +def _calculate_guardrail_metric( + guardrail: Guardrail, + experiment: Experiment, +) -> Decimal | None: + now = timezone.now() + window_start = now - timedelta( + minutes=guardrail.observation_window_minutes, + ) + + variants = experiment.variants.all() + values: list[Decimal] = [] + for variant in variants: + val = calculate_metric_value( + metric=guardrail.metric, + experiment_id=experiment.pk, + variant_id=variant.pk, + start_date=window_start, + end_date=now, + ) + if val is not None: + values.append(val) + + if not values: + return None + + return max(values) + + +@transaction.atomic +def _execute_guardrail_action( + guardrail: Guardrail, + experiment: Experiment, + actual_value: Decimal, +) -> GuardrailTrigger: + now = timezone.now() + + trigger = GuardrailTrigger.objects.create( + guardrail=guardrail, + experiment=experiment, + metric_key=guardrail.metric.key, + threshold=guardrail.threshold, + actual_value=actual_value, + observation_window_minutes=guardrail.observation_window_minutes, + action=guardrail.action, + triggered_at=now, + ) + + if guardrail.action == GuardrailAction.PAUSE: + experiment.status = ExperimentStatus.PAUSED + experiment.save(update_fields=["status", "updated_at"]) + + ExperimentLog.objects.create( + experiment=experiment, + log_type=LogType.GUARDRAIL_TRIGGERED, + comment=( + f"Guardrail triggered: {guardrail.metric.key} " + f"= {actual_value} (threshold: {guardrail.threshold}). " + f"Experiment paused." + ), + metadata={ + "guardrail_id": str(guardrail.pk), + "metric_key": guardrail.metric.key, + "threshold": str(guardrail.threshold), + "actual_value": str(actual_value), + "action": guardrail.action, + "from_status": ExperimentStatus.RUNNING, + "to_status": ExperimentStatus.PAUSED, + }, + ) + + elif guardrail.action == GuardrailAction.ROLLBACK: + control_variant = experiment.variants.filter( + is_control=True, + ).first() + + experiment.status = ExperimentStatus.COMPLETED + experiment.save(update_fields=["status", "updated_at"]) + + ExperimentOutcome.objects.create( + experiment=experiment, + outcome=OutcomeType.ROLLBACK, + winning_variant=control_variant, + rationale=( + f"Automatic rollback: guardrail {guardrail.metric.key} " + f"= {actual_value} exceeded threshold {guardrail.threshold}." + ), + decided_by=None, + ) + + ExperimentLog.objects.create( + experiment=experiment, + log_type=LogType.GUARDRAIL_TRIGGERED, + comment=( + f"Guardrail triggered: {guardrail.metric.key} " + f"= {actual_value} (threshold: {guardrail.threshold}). " + f"Experiment rolled back to control." + ), + metadata={ + "guardrail_id": str(guardrail.pk), + "metric_key": guardrail.metric.key, + "threshold": str(guardrail.threshold), + "actual_value": str(actual_value), + "action": guardrail.action, + "from_status": ExperimentStatus.RUNNING, + "to_status": ExperimentStatus.COMPLETED, + "control_variant_id": ( + str(control_variant.pk) if control_variant else None + ), + }, + ) + + notification_enqueue( + "guardrail_triggered", + NotificationPayload( + title="Guardrail Triggered", + body=( + f"Guardrail triggered on experiment '{experiment.name}': " + f"{guardrail.metric.key} = {actual_value} " + f"(threshold: {guardrail.threshold}). " + f"Action: {guardrail.action}." + ), + event_type="guardrail_triggered", + experiment_id=str(experiment.pk), + experiment_name=experiment.name, + extra={ + "metric_key": guardrail.metric.key, + "threshold": str(guardrail.threshold), + "actual_value": str(actual_value), + "action": guardrail.action, + }, + ), + ) + + return trigger + + +def check_experiment_guardrails( + experiment: Experiment, +) -> list[GuardrailTrigger]: + if experiment.status != ExperimentStatus.RUNNING: + return [] + + guardrails = Guardrail.objects.filter( + experiment=experiment, + is_active=True, + ).select_related("metric") + + triggers: list[GuardrailTrigger] = [] + + for guardrail in guardrails: + actual_value = _calculate_guardrail_metric(guardrail, experiment) + if actual_value is None: + continue + + if actual_value > guardrail.threshold: + experiment.refresh_from_db() + if experiment.status != ExperimentStatus.RUNNING: + break + + trigger = _execute_guardrail_action( + guardrail, + experiment, + actual_value, + ) + triggers.append(trigger) + + if guardrail.action in { + GuardrailAction.PAUSE, + GuardrailAction.ROLLBACK, + }: + break + + return triggers + + +def check_all_running_experiments() -> dict[str, Any]: + running = ( + Experiment.objects.filter(status=ExperimentStatus.RUNNING) + .select_related("flag") + .prefetch_related("variants", "guardrails__metric") + ) + + results: dict[str, Any] = { + "checked": 0, + "triggered": 0, + "triggers": [], + } + + for experiment in running: + results["checked"] += 1 + triggers = check_experiment_guardrails(experiment) + if triggers: + results["triggered"] += 1 + for t in triggers: + results["triggers"].append( + { + "experiment_id": str(experiment.pk), + "experiment_name": experiment.name, + "metric_key": t.metric_key, + "threshold": str(t.threshold), + "actual_value": str(t.actual_value), + "action": t.action, + } + ) + + return results diff --git a/src/backend/apps/guardrails/tasks.py b/src/backend/apps/guardrails/tasks.py new file mode 100644 index 0000000..9226349 --- /dev/null +++ b/src/backend/apps/guardrails/tasks.py @@ -0,0 +1,20 @@ +import logging + +from apps.guardrails.services import check_all_running_experiments +from config.celery import app + +logger = logging.getLogger("lotty") + + +@app.task(bind=True, name="guardrails.check_all") +def check_all_experiment_guardrails_task(self): + results = check_all_running_experiments() + logger.info( + "guardrail_check_completed", + extra={ + "checked": results["checked"], + "triggered": results["triggered"], + "triggers_count": len(results["triggers"]), + }, + ) + return results diff --git a/src/backend/apps/guardrails/tests/__init__.py b/src/backend/apps/guardrails/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/backend/apps/guardrails/tests/test_guardrails.py b/src/backend/apps/guardrails/tests/test_guardrails.py new file mode 100644 index 0000000..ea6e8f9 --- /dev/null +++ b/src/backend/apps/guardrails/tests/test_guardrails.py @@ -0,0 +1,484 @@ +from decimal import Decimal + +from django.test import TestCase +from django.utils import timezone + +from apps.events.services import decision_create, process_events_batch +from apps.events.tests.helpers import make_event_type, make_exposure_type +from apps.experiments.models import ( + ExperimentLog, + ExperimentOutcome, + ExperimentStatus, + LogType, + OutcomeType, +) +from apps.experiments.services import ( + experiment_approve, + experiment_start, + experiment_submit_for_review, +) +from apps.experiments.tests.helpers import add_two_variants, make_experiment +from apps.guardrails.models import ( + Guardrail, + GuardrailAction, + GuardrailTrigger, +) +from apps.guardrails.services import ( + check_all_running_experiments, + check_experiment_guardrails, + guardrail_create, +) +from apps.metrics.models import MetricDirection, MetricType +from apps.metrics.services import metric_definition_create +from apps.reviews.services import review_settings_update +from apps.reviews.tests.helpers import make_approver + + +def _start_experiment(experiment, approver): + exp = experiment_submit_for_review( + experiment=experiment, + user=experiment.owner, + ) + exp = experiment_approve(experiment=exp, approver=approver) + return experiment_start(experiment=exp, user=experiment.owner) + + +class GuardrailCheckPauseTest(TestCase): + def setUp(self) -> None: + review_settings_update( + default_min_approvals=1, + allow_any_approver=True, + ) + self.approver = make_approver("_gp") + + self.exposure_type = make_exposure_type() + self.error_type = make_event_type( + name="error_occurred", + display_name="Error", + requires_exposure=False, + ) + + self.experiment = make_experiment(suffix="_gcp") + self.v_control, self.v_treatment = add_two_variants(self.experiment) + + self.error_rate_metric = metric_definition_create( + key="gcp_error_rate", + name="Error Rate", + metric_type=MetricType.RATIO, + direction=MetricDirection.LOWER_IS_BETTER, + calculation_rule={ + "type": "ratio", + "numerator_event": "error_occurred", + "denominator_event": "exposure", + }, + ) + + guardrail_create( + experiment=self.experiment, + metric=self.error_rate_metric, + threshold=Decimal("0.05"), + observation_window_minutes=60, + action=GuardrailAction.PAUSE, + ) + + self.experiment = _start_experiment(self.experiment, self.approver) + self.now = timezone.now() + + def _create_decision_and_exposure(self, decision_id, subject_id, variant): + decision_create( + decision_id=decision_id, + flag_key="flag_gcp", + subject_id=subject_id, + experiment_id=str(self.experiment.pk), + variant_id=str(variant.pk), + value=variant.value, + reason="experiment", + ) + process_events_batch( + [ + { + "event_id": f"exp_{decision_id}", + "event_type": "exposure", + "decision_id": decision_id, + "subject_id": subject_id, + "timestamp": self.now.isoformat(), + "properties": {}, + } + ] + ) + + def _send_error(self, event_id, decision_id, subject_id): + process_events_batch( + [ + { + "event_id": event_id, + "event_type": "error_occurred", + "decision_id": decision_id, + "subject_id": subject_id, + "timestamp": self.now.isoformat(), + "properties": {}, + } + ] + ) + + def test_no_trigger_when_below_threshold(self) -> None: + for i in range(20): + self._create_decision_and_exposure( + f"dec_ok_{i}", + f"u{i}", + self.v_treatment, + ) + self._send_error("err_ok_0", "dec_ok_0", "u0") + + triggers = check_experiment_guardrails(self.experiment) + + self.assertEqual(len(triggers), 0) + self.experiment.refresh_from_db() + self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING) + + def test_trigger_pause_when_above_threshold(self) -> None: + for i in range(10): + self._create_decision_and_exposure( + f"dec_err_{i}", + f"u{i}", + self.v_treatment, + ) + for i in range(10): + self._send_error(f"err_{i}", f"dec_err_{i}", f"u{i}") + + triggers = check_experiment_guardrails(self.experiment) + + self.assertEqual(len(triggers), 1) + self.experiment.refresh_from_db() + self.assertEqual(self.experiment.status, ExperimentStatus.PAUSED) + self.assertEqual(triggers[0].action, GuardrailAction.PAUSE) + self.assertEqual(triggers[0].metric_key, "gcp_error_rate") + + def test_trigger_audit_log_created(self) -> None: + for i in range(5): + self._create_decision_and_exposure( + f"dec_al_{i}", + f"u{i}", + self.v_treatment, + ) + for i in range(5): + self._send_error(f"err_al_{i}", f"dec_al_{i}", f"u{i}") + + check_experiment_guardrails(self.experiment) + + log = ExperimentLog.objects.filter( + experiment=self.experiment, + log_type=LogType.GUARDRAIL_TRIGGERED, + ).first() + self.assertIsNotNone(log) + self.assertIn("gcp_error_rate", log.comment) + self.assertIn("threshold", log.metadata) + self.assertIn("actual_value", log.metadata) + + def test_trigger_record_created(self) -> None: + for i in range(5): + self._create_decision_and_exposure( + f"dec_tr_{i}", + f"u{i}", + self.v_treatment, + ) + for i in range(5): + self._send_error(f"err_tr_{i}", f"dec_tr_{i}", f"u{i}") + + check_experiment_guardrails(self.experiment) + + trigger = GuardrailTrigger.objects.filter( + experiment=self.experiment, + ).first() + self.assertIsNotNone(trigger) + self.assertEqual(trigger.metric_key, "gcp_error_rate") + self.assertEqual(trigger.threshold, Decimal("0.05")) + self.assertGreater(trigger.actual_value, Decimal("0.05")) + self.assertEqual(trigger.action, GuardrailAction.PAUSE) + self.assertIsNotNone(trigger.triggered_at) + + def test_no_trigger_for_non_running_experiment(self) -> None: + self.experiment.status = ExperimentStatus.PAUSED + self.experiment.save(update_fields=["status"]) + + triggers = check_experiment_guardrails(self.experiment) + self.assertEqual(len(triggers), 0) + + def test_no_trigger_when_no_data(self) -> None: + triggers = check_experiment_guardrails(self.experiment) + self.assertEqual(len(triggers), 0) + + def test_inactive_guardrail_skipped(self) -> None: + Guardrail.objects.filter(experiment=self.experiment).update( + is_active=False, + ) + + for i in range(5): + self._create_decision_and_exposure( + f"dec_ia_{i}", + f"u{i}", + self.v_treatment, + ) + for i in range(5): + self._send_error(f"err_ia_{i}", f"dec_ia_{i}", f"u{i}") + + triggers = check_experiment_guardrails(self.experiment) + self.assertEqual(len(triggers), 0) + self.experiment.refresh_from_db() + self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING) + + +class GuardrailCheckRollbackTest(TestCase): + def setUp(self) -> None: + review_settings_update( + default_min_approvals=1, + allow_any_approver=True, + ) + self.approver = make_approver("_grb") + + self.exposure_type = make_exposure_type() + self.error_type = make_event_type( + name="rb_error", + display_name="Error", + requires_exposure=False, + ) + + self.experiment = make_experiment(suffix="_grb") + self.v_control, self.v_treatment = add_two_variants(self.experiment) + + self.error_rate_metric = metric_definition_create( + key="grb_error_rate", + name="Error Rate", + metric_type=MetricType.RATIO, + direction=MetricDirection.LOWER_IS_BETTER, + calculation_rule={ + "type": "ratio", + "numerator_event": "rb_error", + "denominator_event": "exposure", + }, + ) + + guardrail_create( + experiment=self.experiment, + metric=self.error_rate_metric, + threshold=Decimal("0.10"), + observation_window_minutes=60, + action=GuardrailAction.ROLLBACK, + ) + + self.experiment = _start_experiment(self.experiment, self.approver) + self.now = timezone.now() + + def _create_decision_and_exposure(self, decision_id, subject_id, variant): + decision_create( + decision_id=decision_id, + flag_key="flag_grb", + subject_id=subject_id, + experiment_id=str(self.experiment.pk), + variant_id=str(variant.pk), + value=variant.value, + reason="experiment", + ) + process_events_batch( + [ + { + "event_id": f"exp_{decision_id}", + "event_type": "exposure", + "decision_id": decision_id, + "subject_id": subject_id, + "timestamp": self.now.isoformat(), + "properties": {}, + } + ] + ) + + def test_rollback_completes_experiment(self) -> None: + for i in range(5): + self._create_decision_and_exposure( + f"dec_rb_{i}", + f"u{i}", + self.v_treatment, + ) + for i in range(5): + process_events_batch( + [ + { + "event_id": f"err_rb_{i}", + "event_type": "rb_error", + "decision_id": f"dec_rb_{i}", + "subject_id": f"u{i}", + "timestamp": self.now.isoformat(), + "properties": {}, + } + ] + ) + + triggers = check_experiment_guardrails(self.experiment) + + self.assertEqual(len(triggers), 1) + self.experiment.refresh_from_db() + self.assertEqual(self.experiment.status, ExperimentStatus.COMPLETED) + self.assertEqual(triggers[0].action, GuardrailAction.ROLLBACK) + + def test_rollback_creates_outcome(self) -> None: + for i in range(5): + self._create_decision_and_exposure( + f"dec_rbo_{i}", + f"u{i}", + self.v_treatment, + ) + for i in range(5): + process_events_batch( + [ + { + "event_id": f"err_rbo_{i}", + "event_type": "rb_error", + "decision_id": f"dec_rbo_{i}", + "subject_id": f"u{i}", + "timestamp": self.now.isoformat(), + "properties": {}, + } + ] + ) + + check_experiment_guardrails(self.experiment) + + outcome = ExperimentOutcome.objects.filter( + experiment=self.experiment, + ).first() + self.assertIsNotNone(outcome) + self.assertEqual(outcome.outcome, OutcomeType.ROLLBACK) + self.assertEqual(outcome.winning_variant, self.v_control) + self.assertIsNone(outcome.decided_by) + self.assertIn("guardrail", outcome.rationale.lower()) + + def test_rollback_audit_log(self) -> None: + for i in range(5): + self._create_decision_and_exposure( + f"dec_rba_{i}", + f"u{i}", + self.v_treatment, + ) + for i in range(5): + process_events_batch( + [ + { + "event_id": f"err_rba_{i}", + "event_type": "rb_error", + "decision_id": f"dec_rba_{i}", + "subject_id": f"u{i}", + "timestamp": self.now.isoformat(), + "properties": {}, + } + ] + ) + + check_experiment_guardrails(self.experiment) + + log = ExperimentLog.objects.filter( + experiment=self.experiment, + log_type=LogType.GUARDRAIL_TRIGGERED, + ).first() + self.assertIsNotNone(log) + self.assertEqual(log.metadata["action"], GuardrailAction.ROLLBACK) + self.assertEqual( + log.metadata["to_status"], + ExperimentStatus.COMPLETED, + ) + + +class CheckAllRunningTest(TestCase): + def setUp(self) -> None: + review_settings_update( + default_min_approvals=1, + allow_any_approver=True, + ) + self.approver = make_approver("_all") + + self.exposure_type = make_exposure_type() + self.error_type = make_event_type( + name="all_error", + display_name="Error", + requires_exposure=False, + ) + + self.metric = metric_definition_create( + key="all_error_rate", + name="Error Rate", + metric_type=MetricType.RATIO, + direction=MetricDirection.LOWER_IS_BETTER, + calculation_rule={ + "type": "ratio", + "numerator_event": "all_error", + "denominator_event": "exposure", + }, + ) + + def test_check_all_running(self) -> None: + exp1 = make_experiment(suffix="_all1") + add_two_variants(exp1) + guardrail_create( + experiment=exp1, + metric=self.metric, + threshold=Decimal("0.05"), + action=GuardrailAction.PAUSE, + ) + _start_experiment(exp1, self.approver) + + exp2 = make_experiment(suffix="_all2") + add_two_variants(exp2) + _start_experiment(exp2, self.approver) + + results = check_all_running_experiments() + self.assertEqual(results["checked"], 2) + + def test_check_all_with_trigger(self) -> None: + exp = make_experiment(suffix="_allt") + _v_ctrl, v_treat = add_two_variants(exp) + guardrail_create( + experiment=exp, + metric=self.metric, + threshold=Decimal("0.05"), + action=GuardrailAction.PAUSE, + ) + exp = _start_experiment(exp, self.approver) + + now = timezone.now() + for i in range(5): + decision_create( + decision_id=f"dec_allt_{i}", + flag_key="flag_allt", + subject_id=f"u{i}", + experiment_id=str(exp.pk), + variant_id=str(v_treat.pk), + value=v_treat.value, + reason="experiment", + ) + process_events_batch( + [ + { + "event_id": f"exp_allt_{i}", + "event_type": "exposure", + "decision_id": f"dec_allt_{i}", + "subject_id": f"u{i}", + "timestamp": now.isoformat(), + "properties": {}, + } + ] + ) + process_events_batch( + [ + { + "event_id": f"err_allt_{i}", + "event_type": "all_error", + "decision_id": f"dec_allt_{i}", + "subject_id": f"u{i}", + "timestamp": now.isoformat(), + "properties": {}, + } + ] + ) + + results = check_all_running_experiments() + self.assertEqual(results["triggered"], 1) + self.assertGreater(len(results["triggers"]), 0)