feat(guardrails): added guardrails business logic
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class GuardrailsConfig(AppConfig):
|
||||
name = "apps.guardrails"
|
||||
@@ -0,0 +1,65 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-14 09:55
|
||||
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('experiments', '0001_initial'),
|
||||
('metrics', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Guardrail',
|
||||
fields=[
|
||||
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
|
||||
('threshold', models.DecimalField(decimal_places=4, max_digits=10, verbose_name='threshold')),
|
||||
('observation_window_minutes', models.PositiveIntegerField(default=60, verbose_name='observation window (minutes)')),
|
||||
('action', models.CharField(choices=[('pause', 'Pause experiment'), ('rollback', 'Rollback to control')], max_length=20, verbose_name='action on trigger')),
|
||||
('is_active', models.BooleanField(default=True, verbose_name='is active')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='created at')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, verbose_name='updated at')),
|
||||
('experiment', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='guardrails', to='experiments.experiment', verbose_name='experiment')),
|
||||
('metric', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='guardrail_usages', to='metrics.metricdefinition', verbose_name='metric')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'guardrail',
|
||||
'verbose_name_plural': 'guardrails',
|
||||
'ordering': ['-created_at'],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='GuardrailTrigger',
|
||||
fields=[
|
||||
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
|
||||
('metric_key', models.CharField(max_length=100, verbose_name='metric key')),
|
||||
('threshold', models.DecimalField(decimal_places=4, max_digits=10, verbose_name='threshold')),
|
||||
('actual_value', models.DecimalField(decimal_places=4, max_digits=10, verbose_name='actual value')),
|
||||
('observation_window_minutes', models.PositiveIntegerField(verbose_name='observation window (minutes)')),
|
||||
('action', models.CharField(max_length=20, verbose_name='action taken')),
|
||||
('triggered_at', models.DateTimeField(verbose_name='triggered at')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='created at')),
|
||||
('experiment', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='guardrail_triggers', to='experiments.experiment', verbose_name='experiment')),
|
||||
('guardrail', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='triggers', to='guardrails.guardrail', verbose_name='guardrail')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'guardrail trigger',
|
||||
'verbose_name_plural': 'guardrail triggers',
|
||||
'ordering': ['-triggered_at'],
|
||||
},
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='guardrail',
|
||||
index=models.Index(fields=['experiment', 'is_active'], name='idx_guardrail_exp_active'),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='guardrailtrigger',
|
||||
index=models.Index(fields=['experiment', '-triggered_at'], name='idx_trigger_exp_time'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,131 @@
|
||||
from typing import override
|
||||
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.core.models import BaseModel
|
||||
|
||||
|
||||
class GuardrailAction(models.TextChoices):
|
||||
PAUSE = "pause", _("Pause experiment")
|
||||
ROLLBACK = "rollback", _("Rollback to control")
|
||||
|
||||
|
||||
class Guardrail(BaseModel):
|
||||
experiment = models.ForeignKey(
|
||||
"experiments.Experiment",
|
||||
on_delete=models.CASCADE,
|
||||
related_name="guardrails",
|
||||
verbose_name=_("experiment"),
|
||||
)
|
||||
metric = models.ForeignKey(
|
||||
"metrics.MetricDefinition",
|
||||
on_delete=models.PROTECT,
|
||||
related_name="guardrail_usages",
|
||||
verbose_name=_("metric"),
|
||||
)
|
||||
threshold = models.DecimalField(
|
||||
max_digits=10,
|
||||
decimal_places=4,
|
||||
verbose_name=_("threshold"),
|
||||
)
|
||||
observation_window_minutes = models.PositiveIntegerField(
|
||||
default=60,
|
||||
verbose_name=_("observation window (minutes)"),
|
||||
)
|
||||
action = models.CharField(
|
||||
max_length=20,
|
||||
choices=GuardrailAction.choices,
|
||||
verbose_name=_("action on trigger"),
|
||||
)
|
||||
is_active = models.BooleanField(
|
||||
default=True,
|
||||
verbose_name=_("is active"),
|
||||
)
|
||||
created_at = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
verbose_name=_("created at"),
|
||||
)
|
||||
updated_at = models.DateTimeField(
|
||||
auto_now=True,
|
||||
verbose_name=_("updated at"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("guardrail")
|
||||
verbose_name_plural = _("guardrails")
|
||||
ordering = ["-created_at"]
|
||||
indexes = [
|
||||
models.Index(
|
||||
fields=["experiment", "is_active"],
|
||||
name="idx_guardrail_exp_active",
|
||||
),
|
||||
]
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Guardrail({self.metric.key} > {self.threshold}, "
|
||||
f"action={self.action})"
|
||||
)
|
||||
|
||||
|
||||
class GuardrailTrigger(BaseModel):
|
||||
guardrail = models.ForeignKey(
|
||||
Guardrail,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="triggers",
|
||||
verbose_name=_("guardrail"),
|
||||
)
|
||||
experiment = models.ForeignKey(
|
||||
"experiments.Experiment",
|
||||
on_delete=models.CASCADE,
|
||||
related_name="guardrail_triggers",
|
||||
verbose_name=_("experiment"),
|
||||
)
|
||||
metric_key = models.CharField(
|
||||
max_length=100,
|
||||
verbose_name=_("metric key"),
|
||||
)
|
||||
threshold = models.DecimalField(
|
||||
max_digits=10,
|
||||
decimal_places=4,
|
||||
verbose_name=_("threshold"),
|
||||
)
|
||||
actual_value = models.DecimalField(
|
||||
max_digits=10,
|
||||
decimal_places=4,
|
||||
verbose_name=_("actual value"),
|
||||
)
|
||||
observation_window_minutes = models.PositiveIntegerField(
|
||||
verbose_name=_("observation window (minutes)"),
|
||||
)
|
||||
action = models.CharField(
|
||||
max_length=20,
|
||||
verbose_name=_("action taken"),
|
||||
)
|
||||
triggered_at = models.DateTimeField(
|
||||
verbose_name=_("triggered at"),
|
||||
)
|
||||
created_at = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
verbose_name=_("created at"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("guardrail trigger")
|
||||
verbose_name_plural = _("guardrail triggers")
|
||||
ordering = ["-triggered_at"]
|
||||
indexes = [
|
||||
models.Index(
|
||||
fields=["experiment", "-triggered_at"],
|
||||
name="idx_trigger_exp_time",
|
||||
),
|
||||
]
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Trigger({self.metric_key}: "
|
||||
f"{self.actual_value} > {self.threshold})"
|
||||
)
|
||||
@@ -0,0 +1,318 @@
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.experiments.models import (
|
||||
STARTED_STATUSES,
|
||||
Experiment,
|
||||
ExperimentLog,
|
||||
ExperimentOutcome,
|
||||
ExperimentStatus,
|
||||
LogType,
|
||||
OutcomeType,
|
||||
)
|
||||
from apps.guardrails.models import (
|
||||
Guardrail,
|
||||
GuardrailAction,
|
||||
GuardrailTrigger,
|
||||
)
|
||||
from apps.metrics.models import MetricDefinition
|
||||
from apps.notifications.services import (
|
||||
NotificationPayload,
|
||||
notification_enqueue,
|
||||
)
|
||||
from apps.reports.services import calculate_metric_value
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
def guardrail_create(
|
||||
*,
|
||||
experiment: Experiment,
|
||||
metric: MetricDefinition,
|
||||
threshold: Any,
|
||||
observation_window_minutes: int = 60,
|
||||
action: str = GuardrailAction.PAUSE,
|
||||
) -> Guardrail:
|
||||
guardrail = Guardrail(
|
||||
experiment=experiment,
|
||||
metric=metric,
|
||||
threshold=threshold,
|
||||
observation_window_minutes=observation_window_minutes,
|
||||
action=action,
|
||||
)
|
||||
guardrail.save()
|
||||
return guardrail
|
||||
|
||||
|
||||
def guardrail_update(
|
||||
*,
|
||||
guardrail: Guardrail,
|
||||
**fields: Any,
|
||||
) -> Guardrail:
|
||||
if guardrail.experiment.status in STARTED_STATUSES:
|
||||
raise ValidationError(
|
||||
{
|
||||
"experiment": (
|
||||
"Guardrails cannot be modified after the experiment "
|
||||
"has been started "
|
||||
f"(status: '{guardrail.experiment.status}')."
|
||||
)
|
||||
}
|
||||
)
|
||||
allowed = {
|
||||
"threshold",
|
||||
"observation_window_minutes",
|
||||
"action",
|
||||
"is_active",
|
||||
}
|
||||
for key in fields:
|
||||
if key not in allowed:
|
||||
raise ValidationError({key: f"Field '{key}' cannot be updated."})
|
||||
for key, value in fields.items():
|
||||
if value is not None:
|
||||
setattr(guardrail, key, value)
|
||||
guardrail.save()
|
||||
return guardrail
|
||||
|
||||
|
||||
def guardrail_delete(*, guardrail: Guardrail) -> None:
|
||||
if guardrail.experiment.status in STARTED_STATUSES:
|
||||
raise ValidationError(
|
||||
{
|
||||
"experiment": (
|
||||
"Guardrails cannot be deleted after the experiment "
|
||||
"has been started "
|
||||
f"(status: '{guardrail.experiment.status}')."
|
||||
)
|
||||
}
|
||||
)
|
||||
guardrail.delete()
|
||||
|
||||
|
||||
def guardrail_list(
|
||||
experiment: Experiment,
|
||||
) -> QuerySet[Guardrail]:
|
||||
return Guardrail.objects.filter(experiment=experiment).select_related(
|
||||
"metric"
|
||||
)
|
||||
|
||||
|
||||
def guardrail_trigger_list(
|
||||
experiment: Experiment,
|
||||
) -> QuerySet[GuardrailTrigger]:
|
||||
return GuardrailTrigger.objects.filter(
|
||||
experiment=experiment
|
||||
).select_related("guardrail", "guardrail__metric")
|
||||
|
||||
|
||||
def _calculate_guardrail_metric(
|
||||
guardrail: Guardrail,
|
||||
experiment: Experiment,
|
||||
) -> Decimal | None:
|
||||
now = timezone.now()
|
||||
window_start = now - timedelta(
|
||||
minutes=guardrail.observation_window_minutes,
|
||||
)
|
||||
|
||||
variants = experiment.variants.all()
|
||||
values: list[Decimal] = []
|
||||
for variant in variants:
|
||||
val = calculate_metric_value(
|
||||
metric=guardrail.metric,
|
||||
experiment_id=experiment.pk,
|
||||
variant_id=variant.pk,
|
||||
start_date=window_start,
|
||||
end_date=now,
|
||||
)
|
||||
if val is not None:
|
||||
values.append(val)
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
return max(values)
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
def _execute_guardrail_action(
|
||||
guardrail: Guardrail,
|
||||
experiment: Experiment,
|
||||
actual_value: Decimal,
|
||||
) -> GuardrailTrigger:
|
||||
now = timezone.now()
|
||||
|
||||
trigger = GuardrailTrigger.objects.create(
|
||||
guardrail=guardrail,
|
||||
experiment=experiment,
|
||||
metric_key=guardrail.metric.key,
|
||||
threshold=guardrail.threshold,
|
||||
actual_value=actual_value,
|
||||
observation_window_minutes=guardrail.observation_window_minutes,
|
||||
action=guardrail.action,
|
||||
triggered_at=now,
|
||||
)
|
||||
|
||||
if guardrail.action == GuardrailAction.PAUSE:
|
||||
experiment.status = ExperimentStatus.PAUSED
|
||||
experiment.save(update_fields=["status", "updated_at"])
|
||||
|
||||
ExperimentLog.objects.create(
|
||||
experiment=experiment,
|
||||
log_type=LogType.GUARDRAIL_TRIGGERED,
|
||||
comment=(
|
||||
f"Guardrail triggered: {guardrail.metric.key} "
|
||||
f"= {actual_value} (threshold: {guardrail.threshold}). "
|
||||
f"Experiment paused."
|
||||
),
|
||||
metadata={
|
||||
"guardrail_id": str(guardrail.pk),
|
||||
"metric_key": guardrail.metric.key,
|
||||
"threshold": str(guardrail.threshold),
|
||||
"actual_value": str(actual_value),
|
||||
"action": guardrail.action,
|
||||
"from_status": ExperimentStatus.RUNNING,
|
||||
"to_status": ExperimentStatus.PAUSED,
|
||||
},
|
||||
)
|
||||
|
||||
elif guardrail.action == GuardrailAction.ROLLBACK:
|
||||
control_variant = experiment.variants.filter(
|
||||
is_control=True,
|
||||
).first()
|
||||
|
||||
experiment.status = ExperimentStatus.COMPLETED
|
||||
experiment.save(update_fields=["status", "updated_at"])
|
||||
|
||||
ExperimentOutcome.objects.create(
|
||||
experiment=experiment,
|
||||
outcome=OutcomeType.ROLLBACK,
|
||||
winning_variant=control_variant,
|
||||
rationale=(
|
||||
f"Automatic rollback: guardrail {guardrail.metric.key} "
|
||||
f"= {actual_value} exceeded threshold {guardrail.threshold}."
|
||||
),
|
||||
decided_by=None,
|
||||
)
|
||||
|
||||
ExperimentLog.objects.create(
|
||||
experiment=experiment,
|
||||
log_type=LogType.GUARDRAIL_TRIGGERED,
|
||||
comment=(
|
||||
f"Guardrail triggered: {guardrail.metric.key} "
|
||||
f"= {actual_value} (threshold: {guardrail.threshold}). "
|
||||
f"Experiment rolled back to control."
|
||||
),
|
||||
metadata={
|
||||
"guardrail_id": str(guardrail.pk),
|
||||
"metric_key": guardrail.metric.key,
|
||||
"threshold": str(guardrail.threshold),
|
||||
"actual_value": str(actual_value),
|
||||
"action": guardrail.action,
|
||||
"from_status": ExperimentStatus.RUNNING,
|
||||
"to_status": ExperimentStatus.COMPLETED,
|
||||
"control_variant_id": (
|
||||
str(control_variant.pk) if control_variant else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
notification_enqueue(
|
||||
"guardrail_triggered",
|
||||
NotificationPayload(
|
||||
title="Guardrail Triggered",
|
||||
body=(
|
||||
f"Guardrail triggered on experiment '{experiment.name}': "
|
||||
f"{guardrail.metric.key} = {actual_value} "
|
||||
f"(threshold: {guardrail.threshold}). "
|
||||
f"Action: {guardrail.action}."
|
||||
),
|
||||
event_type="guardrail_triggered",
|
||||
experiment_id=str(experiment.pk),
|
||||
experiment_name=experiment.name,
|
||||
extra={
|
||||
"metric_key": guardrail.metric.key,
|
||||
"threshold": str(guardrail.threshold),
|
||||
"actual_value": str(actual_value),
|
||||
"action": guardrail.action,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
def check_experiment_guardrails(
|
||||
experiment: Experiment,
|
||||
) -> list[GuardrailTrigger]:
|
||||
if experiment.status != ExperimentStatus.RUNNING:
|
||||
return []
|
||||
|
||||
guardrails = Guardrail.objects.filter(
|
||||
experiment=experiment,
|
||||
is_active=True,
|
||||
).select_related("metric")
|
||||
|
||||
triggers: list[GuardrailTrigger] = []
|
||||
|
||||
for guardrail in guardrails:
|
||||
actual_value = _calculate_guardrail_metric(guardrail, experiment)
|
||||
if actual_value is None:
|
||||
continue
|
||||
|
||||
if actual_value > guardrail.threshold:
|
||||
experiment.refresh_from_db()
|
||||
if experiment.status != ExperimentStatus.RUNNING:
|
||||
break
|
||||
|
||||
trigger = _execute_guardrail_action(
|
||||
guardrail,
|
||||
experiment,
|
||||
actual_value,
|
||||
)
|
||||
triggers.append(trigger)
|
||||
|
||||
if guardrail.action in {
|
||||
GuardrailAction.PAUSE,
|
||||
GuardrailAction.ROLLBACK,
|
||||
}:
|
||||
break
|
||||
|
||||
return triggers
|
||||
|
||||
|
||||
def check_all_running_experiments() -> dict[str, Any]:
|
||||
running = (
|
||||
Experiment.objects.filter(status=ExperimentStatus.RUNNING)
|
||||
.select_related("flag")
|
||||
.prefetch_related("variants", "guardrails__metric")
|
||||
)
|
||||
|
||||
results: dict[str, Any] = {
|
||||
"checked": 0,
|
||||
"triggered": 0,
|
||||
"triggers": [],
|
||||
}
|
||||
|
||||
for experiment in running:
|
||||
results["checked"] += 1
|
||||
triggers = check_experiment_guardrails(experiment)
|
||||
if triggers:
|
||||
results["triggered"] += 1
|
||||
for t in triggers:
|
||||
results["triggers"].append(
|
||||
{
|
||||
"experiment_id": str(experiment.pk),
|
||||
"experiment_name": experiment.name,
|
||||
"metric_key": t.metric_key,
|
||||
"threshold": str(t.threshold),
|
||||
"actual_value": str(t.actual_value),
|
||||
"action": t.action,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,20 @@
|
||||
import logging
|
||||
|
||||
from apps.guardrails.services import check_all_running_experiments
|
||||
from config.celery import app
|
||||
|
||||
logger = logging.getLogger("lotty")
|
||||
|
||||
|
||||
@app.task(bind=True, name="guardrails.check_all")
|
||||
def check_all_experiment_guardrails_task(self):
|
||||
results = check_all_running_experiments()
|
||||
logger.info(
|
||||
"guardrail_check_completed",
|
||||
extra={
|
||||
"checked": results["checked"],
|
||||
"triggered": results["triggered"],
|
||||
"triggers_count": len(results["triggers"]),
|
||||
},
|
||||
)
|
||||
return results
|
||||
@@ -0,0 +1,484 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.events.services import decision_create, process_events_batch
|
||||
from apps.events.tests.helpers import make_event_type, make_exposure_type
|
||||
from apps.experiments.models import (
|
||||
ExperimentLog,
|
||||
ExperimentOutcome,
|
||||
ExperimentStatus,
|
||||
LogType,
|
||||
OutcomeType,
|
||||
)
|
||||
from apps.experiments.services import (
|
||||
experiment_approve,
|
||||
experiment_start,
|
||||
experiment_submit_for_review,
|
||||
)
|
||||
from apps.experiments.tests.helpers import add_two_variants, make_experiment
|
||||
from apps.guardrails.models import (
|
||||
Guardrail,
|
||||
GuardrailAction,
|
||||
GuardrailTrigger,
|
||||
)
|
||||
from apps.guardrails.services import (
|
||||
check_all_running_experiments,
|
||||
check_experiment_guardrails,
|
||||
guardrail_create,
|
||||
)
|
||||
from apps.metrics.models import MetricDirection, MetricType
|
||||
from apps.metrics.services import metric_definition_create
|
||||
from apps.reviews.services import review_settings_update
|
||||
from apps.reviews.tests.helpers import make_approver
|
||||
|
||||
|
||||
def _start_experiment(experiment, approver):
|
||||
exp = experiment_submit_for_review(
|
||||
experiment=experiment,
|
||||
user=experiment.owner,
|
||||
)
|
||||
exp = experiment_approve(experiment=exp, approver=approver)
|
||||
return experiment_start(experiment=exp, user=experiment.owner)
|
||||
|
||||
|
||||
class GuardrailCheckPauseTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
review_settings_update(
|
||||
default_min_approvals=1,
|
||||
allow_any_approver=True,
|
||||
)
|
||||
self.approver = make_approver("_gp")
|
||||
|
||||
self.exposure_type = make_exposure_type()
|
||||
self.error_type = make_event_type(
|
||||
name="error_occurred",
|
||||
display_name="Error",
|
||||
requires_exposure=False,
|
||||
)
|
||||
|
||||
self.experiment = make_experiment(suffix="_gcp")
|
||||
self.v_control, self.v_treatment = add_two_variants(self.experiment)
|
||||
|
||||
self.error_rate_metric = metric_definition_create(
|
||||
key="gcp_error_rate",
|
||||
name="Error Rate",
|
||||
metric_type=MetricType.RATIO,
|
||||
direction=MetricDirection.LOWER_IS_BETTER,
|
||||
calculation_rule={
|
||||
"type": "ratio",
|
||||
"numerator_event": "error_occurred",
|
||||
"denominator_event": "exposure",
|
||||
},
|
||||
)
|
||||
|
||||
guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.error_rate_metric,
|
||||
threshold=Decimal("0.05"),
|
||||
observation_window_minutes=60,
|
||||
action=GuardrailAction.PAUSE,
|
||||
)
|
||||
|
||||
self.experiment = _start_experiment(self.experiment, self.approver)
|
||||
self.now = timezone.now()
|
||||
|
||||
def _create_decision_and_exposure(self, decision_id, subject_id, variant):
|
||||
decision_create(
|
||||
decision_id=decision_id,
|
||||
flag_key="flag_gcp",
|
||||
subject_id=subject_id,
|
||||
experiment_id=str(self.experiment.pk),
|
||||
variant_id=str(variant.pk),
|
||||
value=variant.value,
|
||||
reason="experiment",
|
||||
)
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": f"exp_{decision_id}",
|
||||
"event_type": "exposure",
|
||||
"decision_id": decision_id,
|
||||
"subject_id": subject_id,
|
||||
"timestamp": self.now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def _send_error(self, event_id, decision_id, subject_id):
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": event_id,
|
||||
"event_type": "error_occurred",
|
||||
"decision_id": decision_id,
|
||||
"subject_id": subject_id,
|
||||
"timestamp": self.now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_no_trigger_when_below_threshold(self) -> None:
|
||||
for i in range(20):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_ok_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
self._send_error("err_ok_0", "dec_ok_0", "u0")
|
||||
|
||||
triggers = check_experiment_guardrails(self.experiment)
|
||||
|
||||
self.assertEqual(len(triggers), 0)
|
||||
self.experiment.refresh_from_db()
|
||||
self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING)
|
||||
|
||||
def test_trigger_pause_when_above_threshold(self) -> None:
|
||||
for i in range(10):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_err_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
for i in range(10):
|
||||
self._send_error(f"err_{i}", f"dec_err_{i}", f"u{i}")
|
||||
|
||||
triggers = check_experiment_guardrails(self.experiment)
|
||||
|
||||
self.assertEqual(len(triggers), 1)
|
||||
self.experiment.refresh_from_db()
|
||||
self.assertEqual(self.experiment.status, ExperimentStatus.PAUSED)
|
||||
self.assertEqual(triggers[0].action, GuardrailAction.PAUSE)
|
||||
self.assertEqual(triggers[0].metric_key, "gcp_error_rate")
|
||||
|
||||
def test_trigger_audit_log_created(self) -> None:
|
||||
for i in range(5):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_al_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
for i in range(5):
|
||||
self._send_error(f"err_al_{i}", f"dec_al_{i}", f"u{i}")
|
||||
|
||||
check_experiment_guardrails(self.experiment)
|
||||
|
||||
log = ExperimentLog.objects.filter(
|
||||
experiment=self.experiment,
|
||||
log_type=LogType.GUARDRAIL_TRIGGERED,
|
||||
).first()
|
||||
self.assertIsNotNone(log)
|
||||
self.assertIn("gcp_error_rate", log.comment)
|
||||
self.assertIn("threshold", log.metadata)
|
||||
self.assertIn("actual_value", log.metadata)
|
||||
|
||||
def test_trigger_record_created(self) -> None:
|
||||
for i in range(5):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_tr_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
for i in range(5):
|
||||
self._send_error(f"err_tr_{i}", f"dec_tr_{i}", f"u{i}")
|
||||
|
||||
check_experiment_guardrails(self.experiment)
|
||||
|
||||
trigger = GuardrailTrigger.objects.filter(
|
||||
experiment=self.experiment,
|
||||
).first()
|
||||
self.assertIsNotNone(trigger)
|
||||
self.assertEqual(trigger.metric_key, "gcp_error_rate")
|
||||
self.assertEqual(trigger.threshold, Decimal("0.05"))
|
||||
self.assertGreater(trigger.actual_value, Decimal("0.05"))
|
||||
self.assertEqual(trigger.action, GuardrailAction.PAUSE)
|
||||
self.assertIsNotNone(trigger.triggered_at)
|
||||
|
||||
def test_no_trigger_for_non_running_experiment(self) -> None:
|
||||
self.experiment.status = ExperimentStatus.PAUSED
|
||||
self.experiment.save(update_fields=["status"])
|
||||
|
||||
triggers = check_experiment_guardrails(self.experiment)
|
||||
self.assertEqual(len(triggers), 0)
|
||||
|
||||
def test_no_trigger_when_no_data(self) -> None:
|
||||
triggers = check_experiment_guardrails(self.experiment)
|
||||
self.assertEqual(len(triggers), 0)
|
||||
|
||||
def test_inactive_guardrail_skipped(self) -> None:
|
||||
Guardrail.objects.filter(experiment=self.experiment).update(
|
||||
is_active=False,
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_ia_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
for i in range(5):
|
||||
self._send_error(f"err_ia_{i}", f"dec_ia_{i}", f"u{i}")
|
||||
|
||||
triggers = check_experiment_guardrails(self.experiment)
|
||||
self.assertEqual(len(triggers), 0)
|
||||
self.experiment.refresh_from_db()
|
||||
self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING)
|
||||
|
||||
|
||||
class GuardrailCheckRollbackTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
review_settings_update(
|
||||
default_min_approvals=1,
|
||||
allow_any_approver=True,
|
||||
)
|
||||
self.approver = make_approver("_grb")
|
||||
|
||||
self.exposure_type = make_exposure_type()
|
||||
self.error_type = make_event_type(
|
||||
name="rb_error",
|
||||
display_name="Error",
|
||||
requires_exposure=False,
|
||||
)
|
||||
|
||||
self.experiment = make_experiment(suffix="_grb")
|
||||
self.v_control, self.v_treatment = add_two_variants(self.experiment)
|
||||
|
||||
self.error_rate_metric = metric_definition_create(
|
||||
key="grb_error_rate",
|
||||
name="Error Rate",
|
||||
metric_type=MetricType.RATIO,
|
||||
direction=MetricDirection.LOWER_IS_BETTER,
|
||||
calculation_rule={
|
||||
"type": "ratio",
|
||||
"numerator_event": "rb_error",
|
||||
"denominator_event": "exposure",
|
||||
},
|
||||
)
|
||||
|
||||
guardrail_create(
|
||||
experiment=self.experiment,
|
||||
metric=self.error_rate_metric,
|
||||
threshold=Decimal("0.10"),
|
||||
observation_window_minutes=60,
|
||||
action=GuardrailAction.ROLLBACK,
|
||||
)
|
||||
|
||||
self.experiment = _start_experiment(self.experiment, self.approver)
|
||||
self.now = timezone.now()
|
||||
|
||||
def _create_decision_and_exposure(self, decision_id, subject_id, variant):
|
||||
decision_create(
|
||||
decision_id=decision_id,
|
||||
flag_key="flag_grb",
|
||||
subject_id=subject_id,
|
||||
experiment_id=str(self.experiment.pk),
|
||||
variant_id=str(variant.pk),
|
||||
value=variant.value,
|
||||
reason="experiment",
|
||||
)
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": f"exp_{decision_id}",
|
||||
"event_type": "exposure",
|
||||
"decision_id": decision_id,
|
||||
"subject_id": subject_id,
|
||||
"timestamp": self.now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_rollback_completes_experiment(self) -> None:
|
||||
for i in range(5):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_rb_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
for i in range(5):
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": f"err_rb_{i}",
|
||||
"event_type": "rb_error",
|
||||
"decision_id": f"dec_rb_{i}",
|
||||
"subject_id": f"u{i}",
|
||||
"timestamp": self.now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
triggers = check_experiment_guardrails(self.experiment)
|
||||
|
||||
self.assertEqual(len(triggers), 1)
|
||||
self.experiment.refresh_from_db()
|
||||
self.assertEqual(self.experiment.status, ExperimentStatus.COMPLETED)
|
||||
self.assertEqual(triggers[0].action, GuardrailAction.ROLLBACK)
|
||||
|
||||
def test_rollback_creates_outcome(self) -> None:
|
||||
for i in range(5):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_rbo_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
for i in range(5):
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": f"err_rbo_{i}",
|
||||
"event_type": "rb_error",
|
||||
"decision_id": f"dec_rbo_{i}",
|
||||
"subject_id": f"u{i}",
|
||||
"timestamp": self.now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
check_experiment_guardrails(self.experiment)
|
||||
|
||||
outcome = ExperimentOutcome.objects.filter(
|
||||
experiment=self.experiment,
|
||||
).first()
|
||||
self.assertIsNotNone(outcome)
|
||||
self.assertEqual(outcome.outcome, OutcomeType.ROLLBACK)
|
||||
self.assertEqual(outcome.winning_variant, self.v_control)
|
||||
self.assertIsNone(outcome.decided_by)
|
||||
self.assertIn("guardrail", outcome.rationale.lower())
|
||||
|
||||
def test_rollback_audit_log(self) -> None:
|
||||
for i in range(5):
|
||||
self._create_decision_and_exposure(
|
||||
f"dec_rba_{i}",
|
||||
f"u{i}",
|
||||
self.v_treatment,
|
||||
)
|
||||
for i in range(5):
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": f"err_rba_{i}",
|
||||
"event_type": "rb_error",
|
||||
"decision_id": f"dec_rba_{i}",
|
||||
"subject_id": f"u{i}",
|
||||
"timestamp": self.now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
check_experiment_guardrails(self.experiment)
|
||||
|
||||
log = ExperimentLog.objects.filter(
|
||||
experiment=self.experiment,
|
||||
log_type=LogType.GUARDRAIL_TRIGGERED,
|
||||
).first()
|
||||
self.assertIsNotNone(log)
|
||||
self.assertEqual(log.metadata["action"], GuardrailAction.ROLLBACK)
|
||||
self.assertEqual(
|
||||
log.metadata["to_status"],
|
||||
ExperimentStatus.COMPLETED,
|
||||
)
|
||||
|
||||
|
||||
class CheckAllRunningTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
review_settings_update(
|
||||
default_min_approvals=1,
|
||||
allow_any_approver=True,
|
||||
)
|
||||
self.approver = make_approver("_all")
|
||||
|
||||
self.exposure_type = make_exposure_type()
|
||||
self.error_type = make_event_type(
|
||||
name="all_error",
|
||||
display_name="Error",
|
||||
requires_exposure=False,
|
||||
)
|
||||
|
||||
self.metric = metric_definition_create(
|
||||
key="all_error_rate",
|
||||
name="Error Rate",
|
||||
metric_type=MetricType.RATIO,
|
||||
direction=MetricDirection.LOWER_IS_BETTER,
|
||||
calculation_rule={
|
||||
"type": "ratio",
|
||||
"numerator_event": "all_error",
|
||||
"denominator_event": "exposure",
|
||||
},
|
||||
)
|
||||
|
||||
def test_check_all_running(self) -> None:
|
||||
exp1 = make_experiment(suffix="_all1")
|
||||
add_two_variants(exp1)
|
||||
guardrail_create(
|
||||
experiment=exp1,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
action=GuardrailAction.PAUSE,
|
||||
)
|
||||
_start_experiment(exp1, self.approver)
|
||||
|
||||
exp2 = make_experiment(suffix="_all2")
|
||||
add_two_variants(exp2)
|
||||
_start_experiment(exp2, self.approver)
|
||||
|
||||
results = check_all_running_experiments()
|
||||
self.assertEqual(results["checked"], 2)
|
||||
|
||||
def test_check_all_with_trigger(self) -> None:
|
||||
exp = make_experiment(suffix="_allt")
|
||||
_v_ctrl, v_treat = add_two_variants(exp)
|
||||
guardrail_create(
|
||||
experiment=exp,
|
||||
metric=self.metric,
|
||||
threshold=Decimal("0.05"),
|
||||
action=GuardrailAction.PAUSE,
|
||||
)
|
||||
exp = _start_experiment(exp, self.approver)
|
||||
|
||||
now = timezone.now()
|
||||
for i in range(5):
|
||||
decision_create(
|
||||
decision_id=f"dec_allt_{i}",
|
||||
flag_key="flag_allt",
|
||||
subject_id=f"u{i}",
|
||||
experiment_id=str(exp.pk),
|
||||
variant_id=str(v_treat.pk),
|
||||
value=v_treat.value,
|
||||
reason="experiment",
|
||||
)
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": f"exp_allt_{i}",
|
||||
"event_type": "exposure",
|
||||
"decision_id": f"dec_allt_{i}",
|
||||
"subject_id": f"u{i}",
|
||||
"timestamp": now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
process_events_batch(
|
||||
[
|
||||
{
|
||||
"event_id": f"err_allt_{i}",
|
||||
"event_type": "all_error",
|
||||
"decision_id": f"dec_allt_{i}",
|
||||
"subject_id": f"u{i}",
|
||||
"timestamp": now.isoformat(),
|
||||
"properties": {},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
results = check_all_running_experiments()
|
||||
self.assertEqual(results["triggered"], 1)
|
||||
self.assertGreater(len(results["triggers"]), 0)
|
||||
Reference in New Issue
Block a user