337 lines
9.7 KiB
Python
337 lines
9.7 KiB
Python
from datetime import timedelta
|
|
from decimal import Decimal
|
|
from typing import Any
|
|
|
|
from django.core.cache import cache
|
|
from django.core.exceptions import ValidationError
|
|
from django.db import transaction
|
|
from django.db.models import Case, QuerySet, Value, When
|
|
from django.utils import timezone
|
|
|
|
from apps.experiments.models import (
|
|
Experiment,
|
|
ExperimentLog,
|
|
ExperimentOutcome,
|
|
ExperimentStatus,
|
|
LogType,
|
|
OutcomeType,
|
|
)
|
|
from apps.guardrails.models import (
|
|
METRIC_DECIMAL_PLACES,
|
|
Guardrail,
|
|
GuardrailAction,
|
|
GuardrailTrigger,
|
|
)
|
|
from apps.metrics.models import MetricDefinition, MetricDirection
|
|
from apps.notifications.services import (
|
|
NotificationPayload,
|
|
notification_enqueue,
|
|
)
|
|
from apps.reports.services import calculate_metric_value
|
|
|
|
|
|
@transaction.atomic
|
|
def guardrail_create(
|
|
*,
|
|
experiment: Experiment,
|
|
metric: MetricDefinition,
|
|
threshold: Decimal,
|
|
observation_window_minutes: int = 60,
|
|
action: str = GuardrailAction.PAUSE,
|
|
) -> Guardrail:
|
|
guardrail = Guardrail(
|
|
experiment=experiment,
|
|
metric=metric,
|
|
threshold=threshold,
|
|
observation_window_minutes=observation_window_minutes,
|
|
action=action,
|
|
)
|
|
guardrail.save()
|
|
return guardrail
|
|
|
|
|
|
def guardrail_update(
|
|
*,
|
|
guardrail: Guardrail,
|
|
**fields: Any,
|
|
) -> Guardrail:
|
|
guardrail.experiment.refresh_from_db(fields=["status"])
|
|
allowed = {
|
|
"threshold",
|
|
"observation_window_minutes",
|
|
"action",
|
|
"is_active",
|
|
}
|
|
for key in fields:
|
|
if key not in allowed:
|
|
raise ValidationError({key: f"Field '{key}' cannot be updated."})
|
|
for key, value in fields.items():
|
|
if value is not None:
|
|
setattr(guardrail, key, value)
|
|
guardrail.save()
|
|
return guardrail
|
|
|
|
|
|
def guardrail_delete(*, guardrail: Guardrail) -> None:
|
|
guardrail.experiment.refresh_from_db(fields=["status"])
|
|
guardrail.delete()
|
|
|
|
|
|
def guardrail_list(
|
|
experiment: Experiment,
|
|
) -> QuerySet[Guardrail]:
|
|
return Guardrail.objects.filter(experiment=experiment).select_related(
|
|
"metric"
|
|
)
|
|
|
|
|
|
def guardrail_trigger_list(
|
|
experiment: Experiment,
|
|
) -> QuerySet[GuardrailTrigger]:
|
|
return GuardrailTrigger.objects.filter(
|
|
experiment=experiment
|
|
).select_related("guardrail", "guardrail__metric")
|
|
|
|
|
|
def _calculate_guardrail_metric(
|
|
guardrail: Guardrail,
|
|
experiment: Experiment,
|
|
) -> Decimal | None:
|
|
now = timezone.now()
|
|
window_start = now - timedelta(
|
|
minutes=guardrail.observation_window_minutes,
|
|
)
|
|
|
|
variants = experiment.variants.all()
|
|
values: list[Decimal] = []
|
|
for variant in variants:
|
|
val = calculate_metric_value(
|
|
metric=guardrail.metric,
|
|
experiment_id=experiment.pk,
|
|
variant_id=variant.pk,
|
|
event_start_date=window_start,
|
|
event_end_date=now,
|
|
)
|
|
if val is not None:
|
|
values.append(val)
|
|
|
|
if not values:
|
|
return None
|
|
|
|
direction = guardrail.metric.direction
|
|
if direction == MetricDirection.HIGHER_IS_BETTER:
|
|
return min(values)
|
|
return max(values)
|
|
|
|
|
|
def _is_threshold_breached(
|
|
actual_value: Decimal,
|
|
threshold: Decimal,
|
|
direction: str,
|
|
) -> bool:
|
|
if direction == MetricDirection.HIGHER_IS_BETTER:
|
|
return actual_value < threshold
|
|
return actual_value > threshold
|
|
|
|
|
|
@transaction.atomic
|
|
def _execute_guardrail_action(
|
|
guardrail: Guardrail,
|
|
experiment: Experiment,
|
|
actual_value: Decimal,
|
|
) -> GuardrailTrigger | None:
|
|
now = timezone.now()
|
|
|
|
experiment = Experiment.objects.select_for_update().get(pk=experiment.pk)
|
|
if experiment.status != ExperimentStatus.RUNNING:
|
|
return None
|
|
|
|
from_status = experiment.status
|
|
|
|
trigger = GuardrailTrigger.objects.create(
|
|
guardrail=guardrail,
|
|
experiment=experiment,
|
|
metric_key=guardrail.metric.key,
|
|
threshold=guardrail.threshold,
|
|
actual_value=round(actual_value, METRIC_DECIMAL_PLACES),
|
|
observation_window_minutes=guardrail.observation_window_minutes,
|
|
action=guardrail.action,
|
|
triggered_at=now,
|
|
)
|
|
|
|
if guardrail.action == GuardrailAction.PAUSE:
|
|
experiment.status = ExperimentStatus.PAUSED
|
|
experiment.save(update_fields=["status", "updated_at"])
|
|
cache.delete(f"active_exp:{experiment.flag_id}")
|
|
|
|
ExperimentLog.objects.create(
|
|
experiment=experiment,
|
|
log_type=LogType.GUARDRAIL_TRIGGERED,
|
|
comment=(
|
|
f"Guardrail triggered: {guardrail.metric.key} "
|
|
f"= {actual_value} (threshold: {guardrail.threshold}). "
|
|
f"Experiment paused."
|
|
),
|
|
metadata={
|
|
"guardrail_id": str(guardrail.pk),
|
|
"metric_key": guardrail.metric.key,
|
|
"threshold": str(guardrail.threshold),
|
|
"actual_value": str(actual_value),
|
|
"action": guardrail.action,
|
|
"from_status": from_status,
|
|
"to_status": ExperimentStatus.PAUSED,
|
|
},
|
|
)
|
|
|
|
elif guardrail.action == GuardrailAction.ROLLBACK:
|
|
control_variant = experiment.variants.filter(
|
|
is_control=True,
|
|
).first()
|
|
|
|
experiment.status = ExperimentStatus.COMPLETED
|
|
experiment.save(update_fields=["status", "updated_at"])
|
|
cache.delete(f"active_exp:{experiment.flag_id}")
|
|
|
|
ExperimentOutcome.objects.create(
|
|
experiment=experiment,
|
|
outcome=OutcomeType.ROLLBACK,
|
|
winning_variant=control_variant,
|
|
rationale=(
|
|
f"Automatic rollback: guardrail {guardrail.metric.key} "
|
|
f"= {actual_value} exceeded threshold {guardrail.threshold}."
|
|
),
|
|
decided_by=None,
|
|
)
|
|
|
|
ExperimentLog.objects.create(
|
|
experiment=experiment,
|
|
log_type=LogType.GUARDRAIL_TRIGGERED,
|
|
comment=(
|
|
f"Guardrail triggered: {guardrail.metric.key} "
|
|
f"= {actual_value} (threshold: {guardrail.threshold}). "
|
|
f"Experiment rolled back to control."
|
|
),
|
|
metadata={
|
|
"guardrail_id": str(guardrail.pk),
|
|
"metric_key": guardrail.metric.key,
|
|
"threshold": str(guardrail.threshold),
|
|
"actual_value": str(actual_value),
|
|
"action": guardrail.action,
|
|
"from_status": from_status,
|
|
"to_status": ExperimentStatus.COMPLETED,
|
|
"control_variant_id": (
|
|
str(control_variant.pk) if control_variant else None
|
|
),
|
|
},
|
|
)
|
|
|
|
notification_enqueue(
|
|
"guardrail_triggered",
|
|
NotificationPayload(
|
|
title="Guardrail Triggered",
|
|
body=(
|
|
f"Guardrail triggered on experiment '{experiment.name}': "
|
|
f"{guardrail.metric.key} = {actual_value} "
|
|
f"(threshold: {guardrail.threshold}). "
|
|
f"Action: {guardrail.action}."
|
|
),
|
|
event_type="guardrail_triggered",
|
|
experiment_id=str(experiment.pk),
|
|
experiment_name=experiment.name,
|
|
extra={
|
|
"metric_key": guardrail.metric.key,
|
|
"threshold": str(guardrail.threshold),
|
|
"actual_value": str(actual_value),
|
|
"action": guardrail.action,
|
|
},
|
|
),
|
|
)
|
|
|
|
return trigger
|
|
|
|
|
|
def check_experiment_guardrails(
|
|
experiment: Experiment,
|
|
) -> list[GuardrailTrigger]:
|
|
if experiment.status != ExperimentStatus.RUNNING:
|
|
return []
|
|
|
|
guardrails = (
|
|
Guardrail.objects.filter(
|
|
experiment=experiment,
|
|
is_active=True,
|
|
)
|
|
.select_related("metric")
|
|
.annotate(
|
|
action_order=Case(
|
|
When(action=GuardrailAction.ROLLBACK, then=Value(0)),
|
|
When(action=GuardrailAction.PAUSE, then=Value(1)),
|
|
default=Value(2),
|
|
),
|
|
)
|
|
.order_by("action_order", "-created_at")
|
|
)
|
|
|
|
triggers: list[GuardrailTrigger] = []
|
|
|
|
for guardrail in guardrails:
|
|
actual_value = _calculate_guardrail_metric(guardrail, experiment)
|
|
if actual_value is None:
|
|
continue
|
|
|
|
if _is_threshold_breached(
|
|
actual_value,
|
|
guardrail.threshold,
|
|
guardrail.metric.direction,
|
|
):
|
|
trigger = _execute_guardrail_action(
|
|
guardrail,
|
|
experiment,
|
|
actual_value,
|
|
)
|
|
if trigger is None:
|
|
break
|
|
|
|
triggers.append(trigger)
|
|
|
|
if guardrail.action in {
|
|
GuardrailAction.PAUSE,
|
|
GuardrailAction.ROLLBACK,
|
|
}:
|
|
break
|
|
|
|
return triggers
|
|
|
|
|
|
def check_all_running_experiments() -> dict[str, Any]:
|
|
running = (
|
|
Experiment.objects.filter(status=ExperimentStatus.RUNNING)
|
|
.select_related("flag")
|
|
.prefetch_related("variants", "guardrails__metric")
|
|
)
|
|
|
|
results: dict[str, Any] = {
|
|
"checked": 0,
|
|
"triggered": 0,
|
|
"triggers": [],
|
|
}
|
|
|
|
for experiment in running:
|
|
results["checked"] += 1
|
|
triggers = check_experiment_guardrails(experiment)
|
|
if triggers:
|
|
results["triggered"] += 1
|
|
for t in triggers:
|
|
results["triggers"].append(
|
|
{
|
|
"experiment_id": str(experiment.pk),
|
|
"experiment_name": experiment.name,
|
|
"metric_key": t.metric_key,
|
|
"threshold": str(t.threshold),
|
|
"actual_value": str(t.actual_value),
|
|
"action": t.action,
|
|
}
|
|
)
|
|
|
|
return results
|