from decimal import Decimal from django.test import TestCase from django.utils import timezone from apps.events.services import decision_create, process_events_batch from apps.events.tests.helpers import make_event_type, make_exposure_type from apps.experiments.models import ( ExperimentLog, ExperimentOutcome, ExperimentStatus, LogType, OutcomeType, ) from apps.experiments.services import ( experiment_approve, experiment_start, experiment_submit_for_review, ) from apps.experiments.tests.helpers import add_two_variants, make_experiment from apps.guardrails.models import ( Guardrail, GuardrailAction, GuardrailTrigger, ) from apps.guardrails.services import ( check_all_running_experiments, check_experiment_guardrails, guardrail_create, ) from apps.metrics.models import MetricDirection, MetricType from apps.metrics.services import metric_definition_create from apps.reviews.services import review_settings_update from apps.reviews.tests.helpers import make_approver def _start_experiment(experiment, approver): exp = experiment_submit_for_review( experiment=experiment, user=experiment.owner, ) exp = experiment_approve(experiment=exp, approver=approver) return experiment_start(experiment=exp, user=experiment.owner) class GuardrailCheckPauseTest(TestCase): def setUp(self) -> None: review_settings_update( default_min_approvals=1, allow_any_approver=True, ) self.approver = make_approver("_gp") self.exposure_type = make_exposure_type() self.error_type = make_event_type( name="error_occurred", display_name="Error", requires_exposure=False, ) self.experiment = make_experiment(suffix="_gcp") self.v_control, self.v_treatment = add_two_variants(self.experiment) self.error_rate_metric = metric_definition_create( key="gcp_error_rate", name="Error Rate", metric_type=MetricType.RATIO, direction=MetricDirection.LOWER_IS_BETTER, calculation_rule={ "type": "ratio", "numerator_event": "error_occurred", "denominator_event": "exposure", }, ) guardrail_create( experiment=self.experiment, metric=self.error_rate_metric, threshold=Decimal("0.05"), observation_window_minutes=60, action=GuardrailAction.PAUSE, ) self.experiment = _start_experiment(self.experiment, self.approver) self.now = timezone.now() def _create_decision_and_exposure(self, decision_id, subject_id, variant): decision_create( decision_id=decision_id, flag_key="flag_gcp", subject_id=subject_id, experiment_id=str(self.experiment.pk), variant_id=str(variant.pk), value=variant.value, reason="experiment", ) process_events_batch( [ { "event_id": f"exp_{decision_id}", "event_type": "exposure", "decision_id": decision_id, "subject_id": subject_id, "timestamp": self.now.isoformat(), "properties": {}, } ] ) def _send_error(self, event_id, decision_id, subject_id): process_events_batch( [ { "event_id": event_id, "event_type": "error_occurred", "decision_id": decision_id, "subject_id": subject_id, "timestamp": self.now.isoformat(), "properties": {}, } ] ) def test_no_trigger_when_below_threshold(self) -> None: for i in range(20): self._create_decision_and_exposure( f"dec_ok_{i}", f"u{i}", self.v_treatment, ) self._send_error("err_ok_0", "dec_ok_0", "u0") triggers = check_experiment_guardrails(self.experiment) self.assertEqual(len(triggers), 0) self.experiment.refresh_from_db() self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING) def test_trigger_pause_when_above_threshold(self) -> None: for i in range(10): self._create_decision_and_exposure( f"dec_err_{i}", f"u{i}", self.v_treatment, ) for i in range(10): self._send_error(f"err_{i}", f"dec_err_{i}", f"u{i}") triggers = check_experiment_guardrails(self.experiment) self.assertEqual(len(triggers), 1) self.experiment.refresh_from_db() self.assertEqual(self.experiment.status, ExperimentStatus.PAUSED) self.assertEqual(triggers[0].action, GuardrailAction.PAUSE) self.assertEqual(triggers[0].metric_key, "gcp_error_rate") def test_trigger_audit_log_created(self) -> None: for i in range(5): self._create_decision_and_exposure( f"dec_al_{i}", f"u{i}", self.v_treatment, ) for i in range(5): self._send_error(f"err_al_{i}", f"dec_al_{i}", f"u{i}") check_experiment_guardrails(self.experiment) log = ExperimentLog.objects.filter( experiment=self.experiment, log_type=LogType.GUARDRAIL_TRIGGERED, ).first() self.assertIsNotNone(log) self.assertIn("gcp_error_rate", log.comment) self.assertIn("threshold", log.metadata) self.assertIn("actual_value", log.metadata) def test_trigger_record_created(self) -> None: for i in range(5): self._create_decision_and_exposure( f"dec_tr_{i}", f"u{i}", self.v_treatment, ) for i in range(5): self._send_error(f"err_tr_{i}", f"dec_tr_{i}", f"u{i}") check_experiment_guardrails(self.experiment) trigger = GuardrailTrigger.objects.filter( experiment=self.experiment, ).first() self.assertIsNotNone(trigger) self.assertEqual(trigger.metric_key, "gcp_error_rate") self.assertEqual(trigger.threshold, Decimal("0.05")) self.assertGreater(trigger.actual_value, Decimal("0.05")) self.assertEqual(trigger.action, GuardrailAction.PAUSE) self.assertIsNotNone(trigger.triggered_at) def test_no_trigger_for_non_running_experiment(self) -> None: self.experiment.status = ExperimentStatus.PAUSED self.experiment.save(update_fields=["status"]) triggers = check_experiment_guardrails(self.experiment) self.assertEqual(len(triggers), 0) def test_no_trigger_when_no_data(self) -> None: triggers = check_experiment_guardrails(self.experiment) self.assertEqual(len(triggers), 0) def test_inactive_guardrail_skipped(self) -> None: Guardrail.objects.filter(experiment=self.experiment).update( is_active=False, ) for i in range(5): self._create_decision_and_exposure( f"dec_ia_{i}", f"u{i}", self.v_treatment, ) for i in range(5): self._send_error(f"err_ia_{i}", f"dec_ia_{i}", f"u{i}") triggers = check_experiment_guardrails(self.experiment) self.assertEqual(len(triggers), 0) self.experiment.refresh_from_db() self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING) class GuardrailCheckRollbackTest(TestCase): def setUp(self) -> None: review_settings_update( default_min_approvals=1, allow_any_approver=True, ) self.approver = make_approver("_grb") self.exposure_type = make_exposure_type() self.error_type = make_event_type( name="rb_error", display_name="Error", requires_exposure=False, ) self.experiment = make_experiment(suffix="_grb") self.v_control, self.v_treatment = add_two_variants(self.experiment) self.error_rate_metric = metric_definition_create( key="grb_error_rate", name="Error Rate", metric_type=MetricType.RATIO, direction=MetricDirection.LOWER_IS_BETTER, calculation_rule={ "type": "ratio", "numerator_event": "rb_error", "denominator_event": "exposure", }, ) guardrail_create( experiment=self.experiment, metric=self.error_rate_metric, threshold=Decimal("0.10"), observation_window_minutes=60, action=GuardrailAction.ROLLBACK, ) self.experiment = _start_experiment(self.experiment, self.approver) self.now = timezone.now() def _create_decision_and_exposure(self, decision_id, subject_id, variant): decision_create( decision_id=decision_id, flag_key="flag_grb", subject_id=subject_id, experiment_id=str(self.experiment.pk), variant_id=str(variant.pk), value=variant.value, reason="experiment", ) process_events_batch( [ { "event_id": f"exp_{decision_id}", "event_type": "exposure", "decision_id": decision_id, "subject_id": subject_id, "timestamp": self.now.isoformat(), "properties": {}, } ] ) def test_rollback_completes_experiment(self) -> None: for i in range(5): self._create_decision_and_exposure( f"dec_rb_{i}", f"u{i}", self.v_treatment, ) for i in range(5): process_events_batch( [ { "event_id": f"err_rb_{i}", "event_type": "rb_error", "decision_id": f"dec_rb_{i}", "subject_id": f"u{i}", "timestamp": self.now.isoformat(), "properties": {}, } ] ) triggers = check_experiment_guardrails(self.experiment) self.assertEqual(len(triggers), 1) self.experiment.refresh_from_db() self.assertEqual(self.experiment.status, ExperimentStatus.COMPLETED) self.assertEqual(triggers[0].action, GuardrailAction.ROLLBACK) def test_rollback_creates_outcome(self) -> None: for i in range(5): self._create_decision_and_exposure( f"dec_rbo_{i}", f"u{i}", self.v_treatment, ) for i in range(5): process_events_batch( [ { "event_id": f"err_rbo_{i}", "event_type": "rb_error", "decision_id": f"dec_rbo_{i}", "subject_id": f"u{i}", "timestamp": self.now.isoformat(), "properties": {}, } ] ) check_experiment_guardrails(self.experiment) outcome = ExperimentOutcome.objects.filter( experiment=self.experiment, ).first() self.assertIsNotNone(outcome) self.assertEqual(outcome.outcome, OutcomeType.ROLLBACK) self.assertEqual(outcome.winning_variant, self.v_control) self.assertIsNone(outcome.decided_by) self.assertIn("guardrail", outcome.rationale.lower()) def test_rollback_audit_log(self) -> None: for i in range(5): self._create_decision_and_exposure( f"dec_rba_{i}", f"u{i}", self.v_treatment, ) for i in range(5): process_events_batch( [ { "event_id": f"err_rba_{i}", "event_type": "rb_error", "decision_id": f"dec_rba_{i}", "subject_id": f"u{i}", "timestamp": self.now.isoformat(), "properties": {}, } ] ) check_experiment_guardrails(self.experiment) log = ExperimentLog.objects.filter( experiment=self.experiment, log_type=LogType.GUARDRAIL_TRIGGERED, ).first() self.assertIsNotNone(log) self.assertEqual(log.metadata["action"], GuardrailAction.ROLLBACK) self.assertEqual( log.metadata["to_status"], ExperimentStatus.COMPLETED, ) class CheckAllRunningTest(TestCase): def setUp(self) -> None: review_settings_update( default_min_approvals=1, allow_any_approver=True, ) self.approver = make_approver("_all") self.exposure_type = make_exposure_type() self.error_type = make_event_type( name="all_error", display_name="Error", requires_exposure=False, ) self.metric = metric_definition_create( key="all_error_rate", name="Error Rate", metric_type=MetricType.RATIO, direction=MetricDirection.LOWER_IS_BETTER, calculation_rule={ "type": "ratio", "numerator_event": "all_error", "denominator_event": "exposure", }, ) def test_check_all_running(self) -> None: exp1 = make_experiment(suffix="_all1") add_two_variants(exp1) guardrail_create( experiment=exp1, metric=self.metric, threshold=Decimal("0.05"), action=GuardrailAction.PAUSE, ) _start_experiment(exp1, self.approver) exp2 = make_experiment(suffix="_all2") add_two_variants(exp2) _start_experiment(exp2, self.approver) results = check_all_running_experiments() self.assertEqual(results["checked"], 2) def test_check_all_with_trigger(self) -> None: exp = make_experiment(suffix="_allt") _v_ctrl, v_treat = add_two_variants(exp) guardrail_create( experiment=exp, metric=self.metric, threshold=Decimal("0.05"), action=GuardrailAction.PAUSE, ) exp = _start_experiment(exp, self.approver) now = timezone.now() for i in range(5): decision_create( decision_id=f"dec_allt_{i}", flag_key="flag_allt", subject_id=f"u{i}", experiment_id=str(exp.pk), variant_id=str(v_treat.pk), value=v_treat.value, reason="experiment", ) process_events_batch( [ { "event_id": f"exp_allt_{i}", "event_type": "exposure", "decision_id": f"dec_allt_{i}", "subject_id": f"u{i}", "timestamp": now.isoformat(), "properties": {}, } ] ) process_events_batch( [ { "event_id": f"err_allt_{i}", "event_type": "all_error", "decision_id": f"dec_allt_{i}", "subject_id": f"u{i}", "timestamp": now.isoformat(), "properties": {}, } ] ) results = check_all_running_experiments() self.assertEqual(results["triggered"], 1) self.assertGreater(len(results["triggers"]), 0)