from decimal import Decimal from typing import override from django.core.cache import cache from django.test import TestCase from django.utils import timezone from apps.decision.services import decide_for_flag from apps.events.services import process_events_batch from apps.events.tests.helpers import make_event_type, make_exposure_type from apps.experiments.models import ExperimentOutcome, ExperimentStatus from apps.experiments.services import ( experiment_approve, experiment_start, experiment_submit_for_review, ) from apps.experiments.tests.helpers import add_two_variants, make_experiment from apps.guardrails.models import GuardrailAction, GuardrailTrigger from apps.guardrails.services import ( check_all_running_experiments, check_experiment_guardrails, guardrail_create, ) from apps.metrics.services import ( experiment_metric_add, metric_definition_create, ) from apps.reviews.services import review_settings_update from apps.reviews.tests.helpers import make_approver, make_experimenter def _start_experiment(owner, approver, suffix, traffic=Decimal("100.00")): experiment = make_experiment( owner=owner, suffix=suffix, traffic_allocation=traffic, ) add_two_variants(experiment) return experiment def _submit_and_start(experiment, approver): exp = experiment_submit_for_review( experiment=experiment, user=experiment.owner ) exp = experiment_approve(experiment=exp, approver=approver) return experiment_start(experiment=exp, user=experiment.owner) class GuardrailPauseIntegrationTest(TestCase): @override def setUp(self) -> None: cache.clear() review_settings_update( default_min_approvals=1, allow_any_approver=True, ) self.owner = make_experimenter("_gpi") self.approver = make_approver("_gpi") self.experiment = _start_experiment(self.owner, self.approver, "_gpi") self.error_metric = metric_definition_create( key="error_rate_gpi", name="Error Rate", metric_type="ratio", direction="lower_is_better", calculation_rule={ "numerator_event": "gpi_error", "denominator_event": "gpi_exposure", }, ) experiment_metric_add( experiment=self.experiment, metric=self.error_metric, ) make_exposure_type(name="gpi_exposure") make_event_type(name="gpi_error", display_name="Error") guardrail_create( experiment=self.experiment, metric=self.error_metric, threshold=Decimal("0.10"), observation_window_minutes=60, action=GuardrailAction.PAUSE, ) self.experiment = _submit_and_start(self.experiment, self.approver) def test_guardrail_pauses_experiment_on_threshold_breach(self) -> None: now = timezone.now().isoformat() cache.clear() d = decide_for_flag("flag_gpi", "user_gp1", {}) self.assertEqual(d["reason"], "experiment_assigned") process_events_batch( [ { "event_id": "gpi_exp_1", "event_type": "gpi_exposure", "decision_id": d["decision_id"], "subject_id": "user_gp1", "timestamp": now, "properties": {}, }, { "event_id": "gpi_err_1", "event_type": "gpi_error", "decision_id": d["decision_id"], "subject_id": "user_gp1", "timestamp": now, "properties": {}, }, ] ) triggers = check_experiment_guardrails(self.experiment) self.assertTrue(len(triggers) > 0) self.experiment.refresh_from_db() self.assertEqual(self.experiment.status, ExperimentStatus.PAUSED) self.assertTrue( GuardrailTrigger.objects.filter( experiment=self.experiment ).exists() ) def test_no_trigger_when_metric_below_threshold(self) -> None: now = timezone.now().isoformat() cache.clear() decisions = [] for i in range(10): cache.clear() d = decide_for_flag("flag_gpi", f"user_ok_{i}", {}) decisions.append(d) events = [ { "event_id": f"gpi_ok_exp_{i}", "event_type": "gpi_exposure", "decision_id": d["decision_id"], "subject_id": f"user_ok_{i}", "timestamp": now, "properties": {}, } for i, d in enumerate(decisions) ] process_events_batch(events) triggers = check_experiment_guardrails(self.experiment) self.assertEqual(len(triggers), 0) self.experiment.refresh_from_db() self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING) class GuardrailRollbackIntegrationTest(TestCase): @override def setUp(self) -> None: cache.clear() review_settings_update( default_min_approvals=1, allow_any_approver=True, ) self.owner = make_experimenter("_gri") self.approver = make_approver("_gri") self.experiment = _start_experiment(self.owner, self.approver, "_gri") self.crash_metric = metric_definition_create( key="crash_rate_gri", name="Crash Rate", metric_type="ratio", direction="lower_is_better", calculation_rule={ "numerator_event": "gri_crash", "denominator_event": "gri_exposure", }, ) experiment_metric_add( experiment=self.experiment, metric=self.crash_metric, ) make_exposure_type(name="gri_exposure") make_event_type(name="gri_crash", display_name="Crash") guardrail_create( experiment=self.experiment, metric=self.crash_metric, threshold=Decimal("0.05"), action=GuardrailAction.ROLLBACK, ) self.experiment = _submit_and_start(self.experiment, self.approver) def test_rollback_completes_experiment_with_control_winner(self) -> None: now = timezone.now().isoformat() cache.clear() d = decide_for_flag("flag_gri", "user_rb1", {}) process_events_batch( [ { "event_id": "gri_exp_1", "event_type": "gri_exposure", "decision_id": d["decision_id"], "subject_id": "user_rb1", "timestamp": now, "properties": {}, }, { "event_id": "gri_crash_1", "event_type": "gri_crash", "decision_id": d["decision_id"], "subject_id": "user_rb1", "timestamp": now, "properties": {}, }, ] ) triggers = check_experiment_guardrails(self.experiment) self.assertTrue(len(triggers) > 0) self.experiment.refresh_from_db() self.assertEqual(self.experiment.status, ExperimentStatus.COMPLETED) outcome = ExperimentOutcome.objects.get(experiment=self.experiment) self.assertEqual(outcome.outcome, "rollback") class GuardrailCheckAllTest(TestCase): @override def setUp(self) -> None: cache.clear() review_settings_update( default_min_approvals=1, allow_any_approver=True, ) self.owner = make_experimenter("_gca") self.approver = make_approver("_gca") self.exp1 = _start_experiment(self.owner, self.approver, "_gca1") self.exp2 = _start_experiment(self.owner, self.approver, "_gca2") self.metric = metric_definition_create( key="err_gca", name="Error", metric_type="ratio", direction="lower_is_better", calculation_rule={ "numerator_event": "gca_error", "denominator_event": "gca_exposure", }, ) for exp in (self.exp1, self.exp2): experiment_metric_add(experiment=exp, metric=self.metric) guardrail_create( experiment=exp, metric=self.metric, threshold=Decimal("0.10"), action=GuardrailAction.PAUSE, ) self.exp1 = _submit_and_start(self.exp1, self.approver) self.exp2 = _submit_and_start(self.exp2, self.approver) make_exposure_type(name="gca_exposure") make_event_type(name="gca_error", display_name="Error") def test_check_all_processes_multiple_experiments(self) -> None: now = timezone.now().isoformat() for suffix, _ in [("gca1", self.exp1), ("gca2", self.exp2)]: cache.clear() d = decide_for_flag(f"flag_{suffix}", "user_ca", {}) process_events_batch( [ { "event_id": f"{suffix}_exp", "event_type": "gca_exposure", "decision_id": d["decision_id"], "subject_id": "user_ca", "timestamp": now, "properties": {}, }, { "event_id": f"{suffix}_err", "event_type": "gca_error", "decision_id": d["decision_id"], "subject_id": "user_ca", "timestamp": now, "properties": {}, }, ] ) results = check_all_running_experiments() self.assertEqual(results["checked"], 2) self.assertGreaterEqual(results["triggered"], 1)