from datetime import timedelta from decimal import Decimal from typing import Any from django.core.exceptions import ValidationError from django.db import transaction from django.db.models import Case, QuerySet, Value, When from django.utils import timezone from apps.experiments.models import ( STARTED_STATUSES, Experiment, ExperimentLog, ExperimentOutcome, ExperimentStatus, LogType, OutcomeType, ) from apps.guardrails.models import ( Guardrail, GuardrailAction, GuardrailTrigger, ) from apps.metrics.models import MetricDefinition, MetricDirection from apps.notifications.services import ( NotificationPayload, notification_enqueue, ) from apps.reports.services import calculate_metric_value @transaction.atomic def guardrail_create( *, experiment: Experiment, metric: MetricDefinition, threshold: Decimal, observation_window_minutes: int = 60, action: str = GuardrailAction.PAUSE, ) -> Guardrail: if experiment.status in STARTED_STATUSES: raise ValidationError( { "experiment": ( "Guardrails cannot be added after the experiment " "has been started " f"(status: '{experiment.status}')." ) } ) guardrail = Guardrail( experiment=experiment, metric=metric, threshold=threshold, observation_window_minutes=observation_window_minutes, action=action, ) guardrail.save() return guardrail def guardrail_update( *, guardrail: Guardrail, **fields: Any, ) -> Guardrail: guardrail.experiment.refresh_from_db(fields=["status"]) if guardrail.experiment.status in STARTED_STATUSES: raise ValidationError( { "experiment": ( "Guardrails cannot be modified after the experiment " "has been started " f"(status: '{guardrail.experiment.status}')." ) } ) allowed = { "threshold", "observation_window_minutes", "action", "is_active", } for key in fields: if key not in allowed: raise ValidationError({key: f"Field '{key}' cannot be updated."}) for key, value in fields.items(): if value is not None: setattr(guardrail, key, value) guardrail.save() return guardrail def guardrail_delete(*, guardrail: Guardrail) -> None: guardrail.experiment.refresh_from_db(fields=["status"]) if guardrail.experiment.status in STARTED_STATUSES: raise ValidationError( { "experiment": ( "Guardrails cannot be deleted after the experiment " "has been started " f"(status: '{guardrail.experiment.status}')." ) } ) guardrail.delete() def guardrail_list( experiment: Experiment, ) -> QuerySet[Guardrail]: return Guardrail.objects.filter(experiment=experiment).select_related( "metric" ) def guardrail_trigger_list( experiment: Experiment, ) -> QuerySet[GuardrailTrigger]: return GuardrailTrigger.objects.filter( experiment=experiment ).select_related("guardrail", "guardrail__metric") def _calculate_guardrail_metric( guardrail: Guardrail, experiment: Experiment, ) -> Decimal | None: now = timezone.now() window_start = now - timedelta( minutes=guardrail.observation_window_minutes, ) variants = experiment.variants.all() values: list[Decimal] = [] for variant in variants: val = calculate_metric_value( metric=guardrail.metric, experiment_id=experiment.pk, variant_id=variant.pk, start_date=window_start, end_date=now, ) if val is not None: values.append(val) if not values: return None direction = guardrail.metric.direction if direction == MetricDirection.HIGHER_IS_BETTER: return min(values) return max(values) def _is_threshold_breached( actual_value: Decimal, threshold: Decimal, direction: str, ) -> bool: if direction == MetricDirection.HIGHER_IS_BETTER: return actual_value < threshold return actual_value > threshold @transaction.atomic def _execute_guardrail_action( guardrail: Guardrail, experiment: Experiment, actual_value: Decimal, ) -> GuardrailTrigger | None: now = timezone.now() experiment = Experiment.objects.select_for_update().get(pk=experiment.pk) if experiment.status != ExperimentStatus.RUNNING: return None from_status = experiment.status trigger = GuardrailTrigger.objects.create( guardrail=guardrail, experiment=experiment, metric_key=guardrail.metric.key, threshold=guardrail.threshold, actual_value=actual_value, observation_window_minutes=guardrail.observation_window_minutes, action=guardrail.action, triggered_at=now, ) if guardrail.action == GuardrailAction.PAUSE: experiment.status = ExperimentStatus.PAUSED experiment.save(update_fields=["status", "updated_at"]) ExperimentLog.objects.create( experiment=experiment, log_type=LogType.GUARDRAIL_TRIGGERED, comment=( f"Guardrail triggered: {guardrail.metric.key} " f"= {actual_value} (threshold: {guardrail.threshold}). " f"Experiment paused." ), metadata={ "guardrail_id": str(guardrail.pk), "metric_key": guardrail.metric.key, "threshold": str(guardrail.threshold), "actual_value": str(actual_value), "action": guardrail.action, "from_status": from_status, "to_status": ExperimentStatus.PAUSED, }, ) elif guardrail.action == GuardrailAction.ROLLBACK: control_variant = experiment.variants.filter( is_control=True, ).first() experiment.status = ExperimentStatus.COMPLETED experiment.save(update_fields=["status", "updated_at"]) ExperimentOutcome.objects.create( experiment=experiment, outcome=OutcomeType.ROLLBACK, winning_variant=control_variant, rationale=( f"Automatic rollback: guardrail {guardrail.metric.key} " f"= {actual_value} exceeded threshold {guardrail.threshold}." ), decided_by=None, ) ExperimentLog.objects.create( experiment=experiment, log_type=LogType.GUARDRAIL_TRIGGERED, comment=( f"Guardrail triggered: {guardrail.metric.key} " f"= {actual_value} (threshold: {guardrail.threshold}). " f"Experiment rolled back to control." ), metadata={ "guardrail_id": str(guardrail.pk), "metric_key": guardrail.metric.key, "threshold": str(guardrail.threshold), "actual_value": str(actual_value), "action": guardrail.action, "from_status": from_status, "to_status": ExperimentStatus.COMPLETED, "control_variant_id": ( str(control_variant.pk) if control_variant else None ), }, ) notification_enqueue( "guardrail_triggered", NotificationPayload( title="Guardrail Triggered", body=( f"Guardrail triggered on experiment '{experiment.name}': " f"{guardrail.metric.key} = {actual_value} " f"(threshold: {guardrail.threshold}). " f"Action: {guardrail.action}." ), event_type="guardrail_triggered", experiment_id=str(experiment.pk), experiment_name=experiment.name, extra={ "metric_key": guardrail.metric.key, "threshold": str(guardrail.threshold), "actual_value": str(actual_value), "action": guardrail.action, }, ), ) return trigger def check_experiment_guardrails( experiment: Experiment, ) -> list[GuardrailTrigger]: if experiment.status != ExperimentStatus.RUNNING: return [] guardrails = ( Guardrail.objects.filter( experiment=experiment, is_active=True, ) .select_related("metric") .annotate( action_order=Case( When(action=GuardrailAction.ROLLBACK, then=Value(0)), When(action=GuardrailAction.PAUSE, then=Value(1)), default=Value(2), ), ) .order_by("action_order", "-created_at") ) triggers: list[GuardrailTrigger] = [] for guardrail in guardrails: actual_value = _calculate_guardrail_metric(guardrail, experiment) if actual_value is None: continue if _is_threshold_breached( actual_value, guardrail.threshold, guardrail.metric.direction, ): trigger = _execute_guardrail_action( guardrail, experiment, actual_value, ) if trigger is None: break triggers.append(trigger) if guardrail.action in { GuardrailAction.PAUSE, GuardrailAction.ROLLBACK, }: break return triggers def check_all_running_experiments() -> dict[str, Any]: running = ( Experiment.objects.filter(status=ExperimentStatus.RUNNING) .select_related("flag") .prefetch_related("variants", "guardrails__metric") ) results: dict[str, Any] = { "checked": 0, "triggered": 0, "triggers": [], } for experiment in running: results["checked"] += 1 triggers = check_experiment_guardrails(experiment) if triggers: results["triggered"] += 1 for t in triggers: results["triggers"].append( { "experiment_id": str(experiment.pk), "experiment_name": experiment.name, "metric_key": t.metric_key, "threshold": str(t.threshold), "actual_value": str(t.actual_value), "action": t.action, } ) return results