feat(guardrails): added guardrails API
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class GuardrailsApiConfig(AppConfig):
|
||||
name = "api.v1.guardrails"
|
||||
label = "api_v1_guardrails"
|
||||
@@ -0,0 +1,180 @@
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from django.http import Http404, HttpRequest
|
||||
from ninja import Router
|
||||
|
||||
from api.v1.guardrails.schemas import (
|
||||
ExperimentGuardrailCheckResultOut,
|
||||
GuardrailCheckResultOut,
|
||||
GuardrailCreateIn,
|
||||
GuardrailOut,
|
||||
GuardrailTriggerOut,
|
||||
GuardrailUpdateIn,
|
||||
)
|
||||
from apps.experiments.models import Experiment
|
||||
from apps.guardrails.models import Guardrail
|
||||
from apps.guardrails.services import (
|
||||
check_all_running_experiments,
|
||||
check_experiment_guardrails,
|
||||
guardrail_create,
|
||||
guardrail_delete,
|
||||
guardrail_list,
|
||||
guardrail_trigger_list,
|
||||
guardrail_update,
|
||||
)
|
||||
from apps.metrics.services import metric_definition_get
|
||||
from apps.users.auth.bearer import jwt_bearer
|
||||
|
||||
router = Router(tags=["guardrails"], auth=jwt_bearer)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/guardrails/check-all",
|
||||
response={HTTPStatus.OK: GuardrailCheckResultOut},
|
||||
summary="Check guardrails across all running experiments",
|
||||
)
|
||||
def check_all_guardrails(
|
||||
request: HttpRequest,
|
||||
) -> tuple[int, GuardrailCheckResultOut]:
|
||||
results = check_all_running_experiments()
|
||||
return HTTPStatus.OK, GuardrailCheckResultOut(**results)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/experiments/{experiment_id}/guardrails/check",
|
||||
response={HTTPStatus.OK: ExperimentGuardrailCheckResultOut},
|
||||
summary="Manually trigger guardrail check for experiment",
|
||||
)
|
||||
def check_guardrails_for_experiment(
|
||||
request: HttpRequest,
|
||||
experiment_id: UUID,
|
||||
) -> tuple[int, ExperimentGuardrailCheckResultOut]:
|
||||
try:
|
||||
experiment = (
|
||||
Experiment.objects.select_related("flag")
|
||||
.prefetch_related(
|
||||
"variants",
|
||||
"guardrails__metric",
|
||||
)
|
||||
.get(pk=experiment_id)
|
||||
)
|
||||
except Experiment.DoesNotExist:
|
||||
raise Http404 from Experiment.DoesNotExist
|
||||
triggers = check_experiment_guardrails(experiment)
|
||||
return HTTPStatus.OK, ExperimentGuardrailCheckResultOut(
|
||||
experiment_id=str(experiment.pk),
|
||||
triggered=len(triggers),
|
||||
triggers=list(triggers),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/experiments/{experiment_id}/guardrails",
|
||||
response={HTTPStatus.OK: list[GuardrailOut]},
|
||||
summary="List experiment guardrails",
|
||||
)
|
||||
def list_guardrails(
|
||||
request: HttpRequest,
|
||||
experiment_id: UUID,
|
||||
) -> tuple[int, list[GuardrailOut]]:
|
||||
try:
|
||||
experiment = Experiment.objects.get(pk=experiment_id)
|
||||
except Experiment.DoesNotExist:
|
||||
raise Http404 from Experiment.DoesNotExist
|
||||
grs = guardrail_list(experiment)
|
||||
return HTTPStatus.OK, [GuardrailOut.from_guardrail(g) for g in grs]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/experiments/{experiment_id}/guardrails",
|
||||
response={HTTPStatus.CREATED: GuardrailOut},
|
||||
summary="Create a guardrail for experiment",
|
||||
)
|
||||
def create_guardrail(
|
||||
request: HttpRequest,
|
||||
experiment_id: UUID,
|
||||
payload: GuardrailCreateIn,
|
||||
) -> tuple[int, GuardrailOut]:
|
||||
try:
|
||||
experiment = Experiment.objects.get(pk=experiment_id)
|
||||
except Experiment.DoesNotExist:
|
||||
raise Http404 from Experiment.DoesNotExist
|
||||
metric = metric_definition_get(payload.metric_id)
|
||||
if not metric:
|
||||
raise Http404
|
||||
g = guardrail_create(
|
||||
experiment=experiment,
|
||||
metric=metric,
|
||||
threshold=payload.threshold,
|
||||
observation_window_minutes=payload.observation_window_minutes,
|
||||
action=payload.action,
|
||||
)
|
||||
g = Guardrail.objects.select_related("metric").get(pk=g.pk)
|
||||
return HTTPStatus.CREATED, GuardrailOut.from_guardrail(g)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/experiments/{experiment_id}/guardrail-triggers",
|
||||
response={HTTPStatus.OK: list[GuardrailTriggerOut]},
|
||||
summary="List guardrail triggers (audit)",
|
||||
)
|
||||
def list_guardrail_triggers(
|
||||
request: HttpRequest,
|
||||
experiment_id: UUID,
|
||||
) -> tuple[int, list[GuardrailTriggerOut]]:
|
||||
try:
|
||||
experiment = Experiment.objects.get(pk=experiment_id)
|
||||
except Experiment.DoesNotExist:
|
||||
raise Http404 from Experiment.DoesNotExist
|
||||
triggers = guardrail_trigger_list(experiment)
|
||||
return HTTPStatus.OK, [
|
||||
GuardrailTriggerOut.model_validate(trigger) for trigger in triggers
|
||||
]
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/experiments/{experiment_id}/guardrails/{guardrail_id}",
|
||||
response={HTTPStatus.OK: GuardrailOut},
|
||||
summary="Update a guardrail",
|
||||
)
|
||||
def update_guardrail(
|
||||
request: HttpRequest,
|
||||
experiment_id: UUID,
|
||||
guardrail_id: UUID,
|
||||
payload: GuardrailUpdateIn,
|
||||
) -> tuple[int, GuardrailOut]:
|
||||
try:
|
||||
g = Guardrail.objects.select_related("metric", "experiment").get(
|
||||
pk=guardrail_id,
|
||||
experiment_id=experiment_id,
|
||||
)
|
||||
except Guardrail.DoesNotExist:
|
||||
raise Http404 from Guardrail.DoesNotExist
|
||||
g = guardrail_update(
|
||||
guardrail=g,
|
||||
**payload.dict(exclude_unset=True),
|
||||
)
|
||||
g = Guardrail.objects.select_related("metric").get(pk=g.pk)
|
||||
return HTTPStatus.OK, GuardrailOut.from_guardrail(g)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/experiments/{experiment_id}/guardrails/{guardrail_id}",
|
||||
response={HTTPStatus.NO_CONTENT: None},
|
||||
summary="Delete a guardrail",
|
||||
)
|
||||
def delete_guardrail_endpoint(
|
||||
request: HttpRequest,
|
||||
experiment_id: UUID,
|
||||
guardrail_id: UUID,
|
||||
) -> tuple[int, None]:
|
||||
try:
|
||||
g = Guardrail.objects.select_related("experiment").get(
|
||||
pk=guardrail_id,
|
||||
experiment_id=experiment_id,
|
||||
)
|
||||
except Guardrail.DoesNotExist:
|
||||
raise Http404 from Guardrail.DoesNotExist
|
||||
guardrail_delete(guardrail=g)
|
||||
return HTTPStatus.NO_CONTENT, None
|
||||
@@ -0,0 +1,100 @@
|
||||
from decimal import Decimal
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from ninja import Field, ModelSchema, Schema
|
||||
|
||||
from apps.guardrails.models import (
|
||||
Guardrail,
|
||||
GuardrailAction,
|
||||
GuardrailTrigger,
|
||||
)
|
||||
|
||||
|
||||
class GuardrailCreateIn(Schema):
|
||||
metric_id: UUID
|
||||
threshold: Decimal
|
||||
observation_window_minutes: int = Field(60, gt=0)
|
||||
action: GuardrailAction = GuardrailAction.PAUSE
|
||||
|
||||
|
||||
class GuardrailUpdateIn(Schema):
|
||||
threshold: Decimal | None = None
|
||||
observation_window_minutes: int | None = Field(None, gt=0)
|
||||
action: GuardrailAction | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
class MetricBriefOut(Schema):
|
||||
id: UUID
|
||||
key: str
|
||||
name: str
|
||||
|
||||
|
||||
class GuardrailOut(ModelSchema):
|
||||
metric: MetricBriefOut
|
||||
|
||||
class Meta:
|
||||
model = Guardrail
|
||||
fields: ClassVar[tuple[str, ...]] = (
|
||||
Guardrail.id.field.name,
|
||||
Guardrail.threshold.field.name,
|
||||
Guardrail.observation_window_minutes.field.name,
|
||||
Guardrail.action.field.name,
|
||||
Guardrail.is_active.field.name,
|
||||
Guardrail.created_at.field.name,
|
||||
Guardrail.updated_at.field.name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_guardrail(cls, g: Guardrail) -> "GuardrailOut":
|
||||
return cls(
|
||||
id=g.pk,
|
||||
metric=MetricBriefOut(
|
||||
id=g.metric.pk,
|
||||
key=g.metric.key,
|
||||
name=g.metric.name,
|
||||
),
|
||||
threshold=g.threshold,
|
||||
observation_window_minutes=g.observation_window_minutes,
|
||||
action=g.action,
|
||||
is_active=g.is_active,
|
||||
created_at=g.created_at,
|
||||
updated_at=g.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class GuardrailTriggerOut(ModelSchema):
|
||||
class Meta:
|
||||
model = GuardrailTrigger
|
||||
fields: ClassVar[tuple[str, ...]] = (
|
||||
GuardrailTrigger.id.field.name,
|
||||
GuardrailTrigger.metric_key.field.name,
|
||||
GuardrailTrigger.threshold.field.name,
|
||||
GuardrailTrigger.actual_value.field.name,
|
||||
GuardrailTrigger.observation_window_minutes.field.name,
|
||||
GuardrailTrigger.action.field.name,
|
||||
GuardrailTrigger.triggered_at.field.name,
|
||||
GuardrailTrigger.created_at.field.name,
|
||||
)
|
||||
|
||||
|
||||
class GuardrailCheckTriggerOut(Schema):
|
||||
experiment_id: str
|
||||
experiment_name: str
|
||||
metric_key: str
|
||||
threshold: str
|
||||
actual_value: str
|
||||
action: str
|
||||
|
||||
|
||||
class GuardrailCheckResultOut(Schema):
|
||||
checked: int
|
||||
triggered: int
|
||||
triggers: list[GuardrailCheckTriggerOut]
|
||||
|
||||
|
||||
class ExperimentGuardrailCheckResultOut(Schema):
|
||||
experiment_id: str
|
||||
triggered: int
|
||||
triggers: list[GuardrailTriggerOut]
|
||||
@@ -0,0 +1,280 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import override
|
||||
|
||||
from django.test import Client, TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
from apps.experiments.tests.helpers import make_experiment
|
||||
from apps.guardrails.models import GuardrailAction
|
||||
from apps.metrics.models import MetricType
|
||||
from apps.metrics.services import metric_definition_create
|
||||
from apps.reviews.tests.helpers import make_admin
|
||||
from apps.users.tests.helpers import auth_header
|
||||
|
||||
|
||||
class GuardrailCRUDAPITest(TestCase):
|
||||
@override
|
||||
def setUp(self) -> None:
|
||||
self.client = Client()
|
||||
self.admin = make_admin("_gcrud")
|
||||
self.auth = auth_header(self.admin)
|
||||
self.experiment = make_experiment(suffix="_gcrud")
|
||||
self.metric = metric_definition_create(
|
||||
key="gcrud_errors",
|
||||
name="Error Rate",
|
||||
metric_type=MetricType.RATIO,
|
||||
calculation_rule={
|
||||
"type": "ratio",
|
||||
"numerator_event": "error",
|
||||
"denominator_event": "request",
|
||||
},
|
||||
)
|
||||
|
||||
def _create_guardrail(self, **overrides):
|
||||
payload = {
|
||||
"metric_id": str(self.metric.pk),
|
||||
"threshold": "0.05",
|
||||
"observation_window_minutes": 30,
|
||||
"action": GuardrailAction.PAUSE,
|
||||
}
|
||||
payload.update(overrides)
|
||||
return self.client.post(
|
||||
reverse(
|
||||
"api-1:create_guardrail",
|
||||
args=[self.experiment.pk],
|
||||
),
|
||||
data=json.dumps(payload),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
|
||||
def test_create_guardrail(self) -> None:
|
||||
resp = self._create_guardrail()
|
||||
self.assertEqual(resp.status_code, 201)
|
||||
data = resp.json()
|
||||
self.assertEqual(data["metric"]["key"], "gcrud_errors")
|
||||
self.assertEqual(float(data["threshold"]), 0.05)
|
||||
self.assertEqual(data["observation_window_minutes"], 30)
|
||||
self.assertEqual(data["action"], GuardrailAction.PAUSE)
|
||||
self.assertTrue(data["is_active"])
|
||||
|
||||
def test_create_guardrail_with_rollback_action(self) -> None:
|
||||
resp = self._create_guardrail(action=GuardrailAction.ROLLBACK)
|
||||
self.assertEqual(resp.status_code, 201)
|
||||
self.assertEqual(resp.json()["action"], GuardrailAction.ROLLBACK)
|
||||
|
||||
def test_create_guardrail_nonexistent_experiment(self) -> None:
|
||||
resp = self.client.post(
|
||||
reverse(
|
||||
"api-1:create_guardrail",
|
||||
args=[uuid.uuid4()],
|
||||
),
|
||||
data=json.dumps(
|
||||
{
|
||||
"metric_id": str(self.metric.pk),
|
||||
"threshold": "0.05",
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
|
||||
def test_create_guardrail_nonexistent_metric(self) -> None:
|
||||
resp = self.client.post(
|
||||
reverse(
|
||||
"api-1:create_guardrail",
|
||||
args=[self.experiment.pk],
|
||||
),
|
||||
data=json.dumps(
|
||||
{
|
||||
"metric_id": str(uuid.uuid4()),
|
||||
"threshold": "0.05",
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
|
||||
def test_list_guardrails(self) -> None:
|
||||
self._create_guardrail()
|
||||
resp = self.client.get(
|
||||
reverse(
|
||||
"api-1:list_guardrails",
|
||||
args=[self.experiment.pk],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
data = resp.json()
|
||||
self.assertEqual(len(data), 1)
|
||||
self.assertEqual(data[0]["metric"]["key"], "gcrud_errors")
|
||||
|
||||
def test_list_guardrails_empty(self) -> None:
|
||||
resp = self.client.get(
|
||||
reverse(
|
||||
"api-1:list_guardrails",
|
||||
args=[self.experiment.pk],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(resp.json(), [])
|
||||
|
||||
def test_list_guardrails_nonexistent_experiment(self) -> None:
|
||||
resp = self.client.get(
|
||||
reverse(
|
||||
"api-1:list_guardrails",
|
||||
args=[uuid.uuid4()],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
|
||||
def test_update_guardrail(self) -> None:
|
||||
create_resp = self._create_guardrail()
|
||||
guardrail_id = create_resp.json()["id"]
|
||||
resp = self.client.patch(
|
||||
reverse(
|
||||
"api-1:update_guardrail",
|
||||
args=[self.experiment.pk, guardrail_id],
|
||||
),
|
||||
data=json.dumps({"threshold": "0.10"}),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(float(resp.json()["threshold"]), 0.10)
|
||||
|
||||
def test_update_guardrail_action(self) -> None:
|
||||
create_resp = self._create_guardrail()
|
||||
guardrail_id = create_resp.json()["id"]
|
||||
resp = self.client.patch(
|
||||
reverse(
|
||||
"api-1:update_guardrail",
|
||||
args=[self.experiment.pk, guardrail_id],
|
||||
),
|
||||
data=json.dumps({"action": GuardrailAction.ROLLBACK}),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(resp.json()["action"], GuardrailAction.ROLLBACK)
|
||||
|
||||
def test_update_guardrail_nonexistent(self) -> None:
|
||||
resp = self.client.patch(
|
||||
reverse(
|
||||
"api-1:update_guardrail",
|
||||
args=[self.experiment.pk, uuid.uuid4()],
|
||||
),
|
||||
data=json.dumps({"threshold": "0.10"}),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
|
||||
def test_delete_guardrail(self) -> None:
|
||||
create_resp = self._create_guardrail()
|
||||
guardrail_id = create_resp.json()["id"]
|
||||
resp = self.client.delete(
|
||||
reverse(
|
||||
"api-1:delete_guardrail_endpoint",
|
||||
args=[self.experiment.pk, guardrail_id],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 204)
|
||||
list_resp = self.client.get(
|
||||
reverse(
|
||||
"api-1:list_guardrails",
|
||||
args=[self.experiment.pk],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(list_resp.json(), [])
|
||||
|
||||
def test_delete_guardrail_nonexistent(self) -> None:
|
||||
resp = self.client.delete(
|
||||
reverse(
|
||||
"api-1:delete_guardrail_endpoint",
|
||||
args=[self.experiment.pk, uuid.uuid4()],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
|
||||
|
||||
class GuardrailTriggerAPITest(TestCase):
|
||||
@override
|
||||
def setUp(self) -> None:
|
||||
self.client = Client()
|
||||
self.admin = make_admin("_gtrig")
|
||||
self.auth = auth_header(self.admin)
|
||||
self.experiment = make_experiment(suffix="_gtrig")
|
||||
|
||||
def test_list_triggers_empty(self) -> None:
|
||||
resp = self.client.get(
|
||||
reverse(
|
||||
"api-1:list_guardrail_triggers",
|
||||
args=[self.experiment.pk],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(resp.json(), [])
|
||||
|
||||
def test_list_triggers_nonexistent_experiment(self) -> None:
|
||||
resp = self.client.get(
|
||||
reverse(
|
||||
"api-1:list_guardrail_triggers",
|
||||
args=[uuid.uuid4()],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
|
||||
|
||||
class GuardrailCheckAPITest(TestCase):
|
||||
@override
|
||||
def setUp(self) -> None:
|
||||
self.client = Client()
|
||||
self.admin = make_admin("_gchk")
|
||||
self.auth = auth_header(self.admin)
|
||||
self.experiment = make_experiment(suffix="_gchk")
|
||||
|
||||
def test_check_experiment_guardrails_not_running(self) -> None:
|
||||
resp = self.client.post(
|
||||
reverse(
|
||||
"api-1:check_guardrails_for_experiment",
|
||||
args=[self.experiment.pk],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
data = resp.json()
|
||||
self.assertEqual(data["experiment_id"], str(self.experiment.pk))
|
||||
self.assertEqual(data["triggered"], 0)
|
||||
self.assertEqual(data["triggers"], [])
|
||||
|
||||
def test_check_experiment_guardrails_nonexistent(self) -> None:
|
||||
resp = self.client.post(
|
||||
reverse(
|
||||
"api-1:check_guardrails_for_experiment",
|
||||
args=[uuid.uuid4()],
|
||||
),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
|
||||
def test_check_all_guardrails(self) -> None:
|
||||
resp = self.client.post(
|
||||
reverse("api-1:check_all_guardrails"),
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
data = resp.json()
|
||||
self.assertEqual(data["checked"], 0)
|
||||
self.assertEqual(data["triggered"], 0)
|
||||
self.assertEqual(data["triggers"], [])
|
||||
Reference in New Issue
Block a user