Files
Lotty/src/backend/apps/guardrails/tests/test_guardrails.py
T

485 lines
16 KiB
Python

from decimal import Decimal
from django.test import TestCase
from django.utils import timezone
from apps.events.services import decision_create, process_events_batch
from apps.events.tests.helpers import make_event_type, make_exposure_type
from apps.experiments.models import (
ExperimentLog,
ExperimentOutcome,
ExperimentStatus,
LogType,
OutcomeType,
)
from apps.experiments.services import (
experiment_approve,
experiment_start,
experiment_submit_for_review,
)
from apps.experiments.tests.helpers import add_two_variants, make_experiment
from apps.guardrails.models import (
Guardrail,
GuardrailAction,
GuardrailTrigger,
)
from apps.guardrails.services import (
check_all_running_experiments,
check_experiment_guardrails,
guardrail_create,
)
from apps.metrics.models import MetricDirection, MetricType
from apps.metrics.services import metric_definition_create
from apps.reviews.services import review_settings_update
from apps.reviews.tests.helpers import make_approver
def _start_experiment(experiment, approver):
exp = experiment_submit_for_review(
experiment=experiment,
user=experiment.owner,
)
exp = experiment_approve(experiment=exp, approver=approver)
return experiment_start(experiment=exp, user=experiment.owner)
class GuardrailCheckPauseTest(TestCase):
def setUp(self) -> None:
review_settings_update(
default_min_approvals=1,
allow_any_approver=True,
)
self.approver = make_approver("_gp")
self.exposure_type = make_exposure_type()
self.error_type = make_event_type(
name="error_occurred",
display_name="Error",
requires_exposure=False,
)
self.experiment = make_experiment(suffix="_gcp")
self.v_control, self.v_treatment = add_two_variants(self.experiment)
self.error_rate_metric = metric_definition_create(
key="gcp_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "error_occurred",
"denominator_event": "exposure",
},
)
guardrail_create(
experiment=self.experiment,
metric=self.error_rate_metric,
threshold=Decimal("0.05"),
observation_window_minutes=60,
action=GuardrailAction.PAUSE,
)
self.experiment = _start_experiment(self.experiment, self.approver)
self.now = timezone.now()
def _create_decision_and_exposure(self, decision_id, subject_id, variant):
decision_create(
decision_id=decision_id,
flag_key="flag_gcp",
subject_id=subject_id,
experiment_id=str(self.experiment.pk),
variant_id=str(variant.pk),
value=variant.value,
reason="experiment",
)
process_events_batch(
[
{
"event_id": f"exp_{decision_id}",
"event_type": "exposure",
"decision_id": decision_id,
"subject_id": subject_id,
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
def _send_error(self, event_id, decision_id, subject_id):
process_events_batch(
[
{
"event_id": event_id,
"event_type": "error_occurred",
"decision_id": decision_id,
"subject_id": subject_id,
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
def test_no_trigger_when_below_threshold(self) -> None:
for i in range(20):
self._create_decision_and_exposure(
f"dec_ok_{i}",
f"u{i}",
self.v_treatment,
)
self._send_error("err_ok_0", "dec_ok_0", "u0")
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 0)
self.experiment.refresh_from_db()
self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING)
def test_trigger_pause_when_above_threshold(self) -> None:
for i in range(10):
self._create_decision_and_exposure(
f"dec_err_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(10):
self._send_error(f"err_{i}", f"dec_err_{i}", f"u{i}")
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 1)
self.experiment.refresh_from_db()
self.assertEqual(self.experiment.status, ExperimentStatus.PAUSED)
self.assertEqual(triggers[0].action, GuardrailAction.PAUSE)
self.assertEqual(triggers[0].metric_key, "gcp_error_rate")
def test_trigger_audit_log_created(self) -> None:
for i in range(5):
self._create_decision_and_exposure(
f"dec_al_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(5):
self._send_error(f"err_al_{i}", f"dec_al_{i}", f"u{i}")
check_experiment_guardrails(self.experiment)
log = ExperimentLog.objects.filter(
experiment=self.experiment,
log_type=LogType.GUARDRAIL_TRIGGERED,
).first()
self.assertIsNotNone(log)
self.assertIn("gcp_error_rate", log.comment)
self.assertIn("threshold", log.metadata)
self.assertIn("actual_value", log.metadata)
def test_trigger_record_created(self) -> None:
for i in range(5):
self._create_decision_and_exposure(
f"dec_tr_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(5):
self._send_error(f"err_tr_{i}", f"dec_tr_{i}", f"u{i}")
check_experiment_guardrails(self.experiment)
trigger = GuardrailTrigger.objects.filter(
experiment=self.experiment,
).first()
self.assertIsNotNone(trigger)
self.assertEqual(trigger.metric_key, "gcp_error_rate")
self.assertEqual(trigger.threshold, Decimal("0.05"))
self.assertGreater(trigger.actual_value, Decimal("0.05"))
self.assertEqual(trigger.action, GuardrailAction.PAUSE)
self.assertIsNotNone(trigger.triggered_at)
def test_no_trigger_for_non_running_experiment(self) -> None:
self.experiment.status = ExperimentStatus.PAUSED
self.experiment.save(update_fields=["status"])
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 0)
def test_no_trigger_when_no_data(self) -> None:
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 0)
def test_inactive_guardrail_skipped(self) -> None:
Guardrail.objects.filter(experiment=self.experiment).update(
is_active=False,
)
for i in range(5):
self._create_decision_and_exposure(
f"dec_ia_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(5):
self._send_error(f"err_ia_{i}", f"dec_ia_{i}", f"u{i}")
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 0)
self.experiment.refresh_from_db()
self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING)
class GuardrailCheckRollbackTest(TestCase):
def setUp(self) -> None:
review_settings_update(
default_min_approvals=1,
allow_any_approver=True,
)
self.approver = make_approver("_grb")
self.exposure_type = make_exposure_type()
self.error_type = make_event_type(
name="rb_error",
display_name="Error",
requires_exposure=False,
)
self.experiment = make_experiment(suffix="_grb")
self.v_control, self.v_treatment = add_two_variants(self.experiment)
self.error_rate_metric = metric_definition_create(
key="grb_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "rb_error",
"denominator_event": "exposure",
},
)
guardrail_create(
experiment=self.experiment,
metric=self.error_rate_metric,
threshold=Decimal("0.10"),
observation_window_minutes=60,
action=GuardrailAction.ROLLBACK,
)
self.experiment = _start_experiment(self.experiment, self.approver)
self.now = timezone.now()
def _create_decision_and_exposure(self, decision_id, subject_id, variant):
decision_create(
decision_id=decision_id,
flag_key="flag_grb",
subject_id=subject_id,
experiment_id=str(self.experiment.pk),
variant_id=str(variant.pk),
value=variant.value,
reason="experiment",
)
process_events_batch(
[
{
"event_id": f"exp_{decision_id}",
"event_type": "exposure",
"decision_id": decision_id,
"subject_id": subject_id,
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
def test_rollback_completes_experiment(self) -> None:
for i in range(5):
self._create_decision_and_exposure(
f"dec_rb_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(5):
process_events_batch(
[
{
"event_id": f"err_rb_{i}",
"event_type": "rb_error",
"decision_id": f"dec_rb_{i}",
"subject_id": f"u{i}",
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 1)
self.experiment.refresh_from_db()
self.assertEqual(self.experiment.status, ExperimentStatus.COMPLETED)
self.assertEqual(triggers[0].action, GuardrailAction.ROLLBACK)
def test_rollback_creates_outcome(self) -> None:
for i in range(5):
self._create_decision_and_exposure(
f"dec_rbo_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(5):
process_events_batch(
[
{
"event_id": f"err_rbo_{i}",
"event_type": "rb_error",
"decision_id": f"dec_rbo_{i}",
"subject_id": f"u{i}",
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
check_experiment_guardrails(self.experiment)
outcome = ExperimentOutcome.objects.filter(
experiment=self.experiment,
).first()
self.assertIsNotNone(outcome)
self.assertEqual(outcome.outcome, OutcomeType.ROLLBACK)
self.assertEqual(outcome.winning_variant, self.v_control)
self.assertIsNone(outcome.decided_by)
self.assertIn("guardrail", outcome.rationale.lower())
def test_rollback_audit_log(self) -> None:
for i in range(5):
self._create_decision_and_exposure(
f"dec_rba_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(5):
process_events_batch(
[
{
"event_id": f"err_rba_{i}",
"event_type": "rb_error",
"decision_id": f"dec_rba_{i}",
"subject_id": f"u{i}",
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
check_experiment_guardrails(self.experiment)
log = ExperimentLog.objects.filter(
experiment=self.experiment,
log_type=LogType.GUARDRAIL_TRIGGERED,
).first()
self.assertIsNotNone(log)
self.assertEqual(log.metadata["action"], GuardrailAction.ROLLBACK)
self.assertEqual(
log.metadata["to_status"],
ExperimentStatus.COMPLETED,
)
class CheckAllRunningTest(TestCase):
def setUp(self) -> None:
review_settings_update(
default_min_approvals=1,
allow_any_approver=True,
)
self.approver = make_approver("_all")
self.exposure_type = make_exposure_type()
self.error_type = make_event_type(
name="all_error",
display_name="Error",
requires_exposure=False,
)
self.metric = metric_definition_create(
key="all_error_rate",
name="Error Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.LOWER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "all_error",
"denominator_event": "exposure",
},
)
def test_check_all_running(self) -> None:
exp1 = make_experiment(suffix="_all1")
add_two_variants(exp1)
guardrail_create(
experiment=exp1,
metric=self.metric,
threshold=Decimal("0.05"),
action=GuardrailAction.PAUSE,
)
_start_experiment(exp1, self.approver)
exp2 = make_experiment(suffix="_all2")
add_two_variants(exp2)
_start_experiment(exp2, self.approver)
results = check_all_running_experiments()
self.assertEqual(results["checked"], 2)
def test_check_all_with_trigger(self) -> None:
exp = make_experiment(suffix="_allt")
_v_ctrl, v_treat = add_two_variants(exp)
guardrail_create(
experiment=exp,
metric=self.metric,
threshold=Decimal("0.05"),
action=GuardrailAction.PAUSE,
)
exp = _start_experiment(exp, self.approver)
now = timezone.now()
for i in range(5):
decision_create(
decision_id=f"dec_allt_{i}",
flag_key="flag_allt",
subject_id=f"u{i}",
experiment_id=str(exp.pk),
variant_id=str(v_treat.pk),
value=v_treat.value,
reason="experiment",
)
process_events_batch(
[
{
"event_id": f"exp_allt_{i}",
"event_type": "exposure",
"decision_id": f"dec_allt_{i}",
"subject_id": f"u{i}",
"timestamp": now.isoformat(),
"properties": {},
}
]
)
process_events_batch(
[
{
"event_id": f"err_allt_{i}",
"event_type": "all_error",
"decision_id": f"dec_allt_{i}",
"subject_id": f"u{i}",
"timestamp": now.isoformat(),
"properties": {},
}
]
)
results = check_all_running_experiments()
self.assertEqual(results["triggered"], 1)
self.assertGreater(len(results["triggers"]), 0)