from datetime import timedelta from decimal import Decimal from typing import Any from django.core.exceptions import ValidationError from django.db import transaction from django.db.models import QuerySet from django.utils import timezone from apps.experiments.models import ( STARTED_STATUSES, Experiment, ExperimentLog, ExperimentOutcome, ExperimentStatus, LogType, OutcomeType, ) from apps.guardrails.models import ( Guardrail, GuardrailAction, GuardrailTrigger, ) from apps.metrics.models import MetricDefinition from apps.notifications.services import ( NotificationPayload, notification_enqueue, ) from apps.reports.services import calculate_metric_value @transaction.atomic def guardrail_create( *, experiment: Experiment, metric: MetricDefinition, threshold: Any, observation_window_minutes: int = 60, action: str = GuardrailAction.PAUSE, ) -> Guardrail: guardrail = Guardrail( experiment=experiment, metric=metric, threshold=threshold, observation_window_minutes=observation_window_minutes, action=action, ) guardrail.save() return guardrail def guardrail_update( *, guardrail: Guardrail, **fields: Any, ) -> Guardrail: if guardrail.experiment.status in STARTED_STATUSES: raise ValidationError( { "experiment": ( "Guardrails cannot be modified after the experiment " "has been started " f"(status: '{guardrail.experiment.status}')." ) } ) allowed = { "threshold", "observation_window_minutes", "action", "is_active", } for key in fields: if key not in allowed: raise ValidationError({key: f"Field '{key}' cannot be updated."}) for key, value in fields.items(): if value is not None: setattr(guardrail, key, value) guardrail.save() return guardrail def guardrail_delete(*, guardrail: Guardrail) -> None: if guardrail.experiment.status in STARTED_STATUSES: raise ValidationError( { "experiment": ( "Guardrails cannot be deleted after the experiment " "has been started " f"(status: '{guardrail.experiment.status}')." ) } ) guardrail.delete() def guardrail_list( experiment: Experiment, ) -> QuerySet[Guardrail]: return Guardrail.objects.filter(experiment=experiment).select_related( "metric" ) def guardrail_trigger_list( experiment: Experiment, ) -> QuerySet[GuardrailTrigger]: return GuardrailTrigger.objects.filter( experiment=experiment ).select_related("guardrail", "guardrail__metric") def _calculate_guardrail_metric( guardrail: Guardrail, experiment: Experiment, ) -> Decimal | None: now = timezone.now() window_start = now - timedelta( minutes=guardrail.observation_window_minutes, ) variants = experiment.variants.all() values: list[Decimal] = [] for variant in variants: val = calculate_metric_value( metric=guardrail.metric, experiment_id=experiment.pk, variant_id=variant.pk, start_date=window_start, end_date=now, ) if val is not None: values.append(val) if not values: return None return max(values) @transaction.atomic def _execute_guardrail_action( guardrail: Guardrail, experiment: Experiment, actual_value: Decimal, ) -> GuardrailTrigger: now = timezone.now() trigger = GuardrailTrigger.objects.create( guardrail=guardrail, experiment=experiment, metric_key=guardrail.metric.key, threshold=guardrail.threshold, actual_value=actual_value, observation_window_minutes=guardrail.observation_window_minutes, action=guardrail.action, triggered_at=now, ) if guardrail.action == GuardrailAction.PAUSE: experiment.status = ExperimentStatus.PAUSED experiment.save(update_fields=["status", "updated_at"]) ExperimentLog.objects.create( experiment=experiment, log_type=LogType.GUARDRAIL_TRIGGERED, comment=( f"Guardrail triggered: {guardrail.metric.key} " f"= {actual_value} (threshold: {guardrail.threshold}). " f"Experiment paused." ), metadata={ "guardrail_id": str(guardrail.pk), "metric_key": guardrail.metric.key, "threshold": str(guardrail.threshold), "actual_value": str(actual_value), "action": guardrail.action, "from_status": ExperimentStatus.RUNNING, "to_status": ExperimentStatus.PAUSED, }, ) elif guardrail.action == GuardrailAction.ROLLBACK: control_variant = experiment.variants.filter( is_control=True, ).first() experiment.status = ExperimentStatus.COMPLETED experiment.save(update_fields=["status", "updated_at"]) ExperimentOutcome.objects.create( experiment=experiment, outcome=OutcomeType.ROLLBACK, winning_variant=control_variant, rationale=( f"Automatic rollback: guardrail {guardrail.metric.key} " f"= {actual_value} exceeded threshold {guardrail.threshold}." ), decided_by=None, ) ExperimentLog.objects.create( experiment=experiment, log_type=LogType.GUARDRAIL_TRIGGERED, comment=( f"Guardrail triggered: {guardrail.metric.key} " f"= {actual_value} (threshold: {guardrail.threshold}). " f"Experiment rolled back to control." ), metadata={ "guardrail_id": str(guardrail.pk), "metric_key": guardrail.metric.key, "threshold": str(guardrail.threshold), "actual_value": str(actual_value), "action": guardrail.action, "from_status": ExperimentStatus.RUNNING, "to_status": ExperimentStatus.COMPLETED, "control_variant_id": ( str(control_variant.pk) if control_variant else None ), }, ) notification_enqueue( "guardrail_triggered", NotificationPayload( title="Guardrail Triggered", body=( f"Guardrail triggered on experiment '{experiment.name}': " f"{guardrail.metric.key} = {actual_value} " f"(threshold: {guardrail.threshold}). " f"Action: {guardrail.action}." ), event_type="guardrail_triggered", experiment_id=str(experiment.pk), experiment_name=experiment.name, extra={ "metric_key": guardrail.metric.key, "threshold": str(guardrail.threshold), "actual_value": str(actual_value), "action": guardrail.action, }, ), ) return trigger def check_experiment_guardrails( experiment: Experiment, ) -> list[GuardrailTrigger]: if experiment.status != ExperimentStatus.RUNNING: return [] guardrails = Guardrail.objects.filter( experiment=experiment, is_active=True, ).select_related("metric") triggers: list[GuardrailTrigger] = [] for guardrail in guardrails: actual_value = _calculate_guardrail_metric(guardrail, experiment) if actual_value is None: continue if actual_value > guardrail.threshold: experiment.refresh_from_db() if experiment.status != ExperimentStatus.RUNNING: break trigger = _execute_guardrail_action( guardrail, experiment, actual_value, ) triggers.append(trigger) if guardrail.action in { GuardrailAction.PAUSE, GuardrailAction.ROLLBACK, }: break return triggers def check_all_running_experiments() -> dict[str, Any]: running = ( Experiment.objects.filter(status=ExperimentStatus.RUNNING) .select_related("flag") .prefetch_related("variants", "guardrails__metric") ) results: dict[str, Any] = { "checked": 0, "triggered": 0, "triggers": [], } for experiment in running: results["checked"] += 1 triggers = check_experiment_guardrails(experiment) if triggers: results["triggered"] += 1 for t in triggers: results["triggers"].append( { "experiment_id": str(experiment.pk), "experiment_name": experiment.name, "metric_key": t.metric_key, "threshold": str(t.threshold), "actual_value": str(t.actual_value), "action": t.action, } ) return results