Files
Lotty/src/backend/apps/guardrails/services.py
T

337 lines
9.7 KiB
Python

from datetime import timedelta
from decimal import Decimal
from typing import Any
from django.core.cache import cache
from django.core.exceptions import ValidationError
from django.db import transaction
from django.db.models import Case, QuerySet, Value, When
from django.utils import timezone
from apps.experiments.models import (
Experiment,
ExperimentLog,
ExperimentOutcome,
ExperimentStatus,
LogType,
OutcomeType,
)
from apps.guardrails.models import (
METRIC_DECIMAL_PLACES,
Guardrail,
GuardrailAction,
GuardrailTrigger,
)
from apps.metrics.models import MetricDefinition, MetricDirection
from apps.notifications.services import (
NotificationPayload,
notification_enqueue,
)
from apps.reports.services import calculate_metric_value
@transaction.atomic
def guardrail_create(
*,
experiment: Experiment,
metric: MetricDefinition,
threshold: Decimal,
observation_window_minutes: int = 60,
action: str = GuardrailAction.PAUSE,
) -> Guardrail:
guardrail = Guardrail(
experiment=experiment,
metric=metric,
threshold=threshold,
observation_window_minutes=observation_window_minutes,
action=action,
)
guardrail.save()
return guardrail
def guardrail_update(
*,
guardrail: Guardrail,
**fields: Any,
) -> Guardrail:
guardrail.experiment.refresh_from_db(fields=["status"])
allowed = {
"threshold",
"observation_window_minutes",
"action",
"is_active",
}
for key in fields:
if key not in allowed:
raise ValidationError({key: f"Field '{key}' cannot be updated."})
for key, value in fields.items():
if value is not None:
setattr(guardrail, key, value)
guardrail.save()
return guardrail
def guardrail_delete(*, guardrail: Guardrail) -> None:
guardrail.experiment.refresh_from_db(fields=["status"])
guardrail.delete()
def guardrail_list(
experiment: Experiment,
) -> QuerySet[Guardrail]:
return Guardrail.objects.filter(experiment=experiment).select_related(
"metric"
)
def guardrail_trigger_list(
experiment: Experiment,
) -> QuerySet[GuardrailTrigger]:
return GuardrailTrigger.objects.filter(
experiment=experiment
).select_related("guardrail", "guardrail__metric")
def _calculate_guardrail_metric(
guardrail: Guardrail,
experiment: Experiment,
) -> Decimal | None:
now = timezone.now()
window_start = now - timedelta(
minutes=guardrail.observation_window_minutes,
)
variants = experiment.variants.all()
values: list[Decimal] = []
for variant in variants:
val = calculate_metric_value(
metric=guardrail.metric,
experiment_id=experiment.pk,
variant_id=variant.pk,
event_start_date=window_start,
event_end_date=now,
)
if val is not None:
values.append(val)
if not values:
return None
direction = guardrail.metric.direction
if direction == MetricDirection.HIGHER_IS_BETTER:
return min(values)
return max(values)
def _is_threshold_breached(
actual_value: Decimal,
threshold: Decimal,
direction: str,
) -> bool:
if direction == MetricDirection.HIGHER_IS_BETTER:
return actual_value < threshold
return actual_value > threshold
@transaction.atomic
def _execute_guardrail_action(
guardrail: Guardrail,
experiment: Experiment,
actual_value: Decimal,
) -> GuardrailTrigger | None:
now = timezone.now()
experiment = Experiment.objects.select_for_update().get(pk=experiment.pk)
if experiment.status != ExperimentStatus.RUNNING:
return None
from_status = experiment.status
trigger = GuardrailTrigger.objects.create(
guardrail=guardrail,
experiment=experiment,
metric_key=guardrail.metric.key,
threshold=guardrail.threshold,
actual_value=round(actual_value, METRIC_DECIMAL_PLACES),
observation_window_minutes=guardrail.observation_window_minutes,
action=guardrail.action,
triggered_at=now,
)
if guardrail.action == GuardrailAction.PAUSE:
experiment.status = ExperimentStatus.PAUSED
experiment.save(update_fields=["status", "updated_at"])
cache.delete(f"active_exp:{experiment.flag_id}")
ExperimentLog.objects.create(
experiment=experiment,
log_type=LogType.GUARDRAIL_TRIGGERED,
comment=(
f"Guardrail triggered: {guardrail.metric.key} "
f"= {actual_value} (threshold: {guardrail.threshold}). "
f"Experiment paused."
),
metadata={
"guardrail_id": str(guardrail.pk),
"metric_key": guardrail.metric.key,
"threshold": str(guardrail.threshold),
"actual_value": str(actual_value),
"action": guardrail.action,
"from_status": from_status,
"to_status": ExperimentStatus.PAUSED,
},
)
elif guardrail.action == GuardrailAction.ROLLBACK:
control_variant = experiment.variants.filter(
is_control=True,
).first()
experiment.status = ExperimentStatus.COMPLETED
experiment.save(update_fields=["status", "updated_at"])
cache.delete(f"active_exp:{experiment.flag_id}")
ExperimentOutcome.objects.create(
experiment=experiment,
outcome=OutcomeType.ROLLBACK,
winning_variant=control_variant,
rationale=(
f"Automatic rollback: guardrail {guardrail.metric.key} "
f"= {actual_value} exceeded threshold {guardrail.threshold}."
),
decided_by=None,
)
ExperimentLog.objects.create(
experiment=experiment,
log_type=LogType.GUARDRAIL_TRIGGERED,
comment=(
f"Guardrail triggered: {guardrail.metric.key} "
f"= {actual_value} (threshold: {guardrail.threshold}). "
f"Experiment rolled back to control."
),
metadata={
"guardrail_id": str(guardrail.pk),
"metric_key": guardrail.metric.key,
"threshold": str(guardrail.threshold),
"actual_value": str(actual_value),
"action": guardrail.action,
"from_status": from_status,
"to_status": ExperimentStatus.COMPLETED,
"control_variant_id": (
str(control_variant.pk) if control_variant else None
),
},
)
notification_enqueue(
"guardrail_triggered",
NotificationPayload(
title="Guardrail Triggered",
body=(
f"Guardrail triggered on experiment '{experiment.name}': "
f"{guardrail.metric.key} = {actual_value} "
f"(threshold: {guardrail.threshold}). "
f"Action: {guardrail.action}."
),
event_type="guardrail_triggered",
experiment_id=str(experiment.pk),
experiment_name=experiment.name,
extra={
"metric_key": guardrail.metric.key,
"threshold": str(guardrail.threshold),
"actual_value": str(actual_value),
"action": guardrail.action,
},
),
)
return trigger
def check_experiment_guardrails(
experiment: Experiment,
) -> list[GuardrailTrigger]:
if experiment.status != ExperimentStatus.RUNNING:
return []
guardrails = (
Guardrail.objects.filter(
experiment=experiment,
is_active=True,
)
.select_related("metric")
.annotate(
action_order=Case(
When(action=GuardrailAction.ROLLBACK, then=Value(0)),
When(action=GuardrailAction.PAUSE, then=Value(1)),
default=Value(2),
),
)
.order_by("action_order", "-created_at")
)
triggers: list[GuardrailTrigger] = []
for guardrail in guardrails:
actual_value = _calculate_guardrail_metric(guardrail, experiment)
if actual_value is None:
continue
if _is_threshold_breached(
actual_value,
guardrail.threshold,
guardrail.metric.direction,
):
trigger = _execute_guardrail_action(
guardrail,
experiment,
actual_value,
)
if trigger is None:
break
triggers.append(trigger)
if guardrail.action in {
GuardrailAction.PAUSE,
GuardrailAction.ROLLBACK,
}:
break
return triggers
def check_all_running_experiments() -> dict[str, Any]:
running = (
Experiment.objects.filter(status=ExperimentStatus.RUNNING)
.select_related("flag")
.prefetch_related("variants", "guardrails__metric")
)
results: dict[str, Any] = {
"checked": 0,
"triggered": 0,
"triggers": [],
}
for experiment in running:
results["checked"] += 1
triggers = check_experiment_guardrails(experiment)
if triggers:
results["triggered"] += 1
for t in triggers:
results["triggers"].append(
{
"experiment_id": str(experiment.pk),
"experiment_name": experiment.name,
"metric_key": t.metric_key,
"threshold": str(t.threshold),
"actual_value": str(t.actual_value),
"action": t.action,
}
)
return results