diff --git a/src/backend/api/v1/guardrails/__init__.py b/src/backend/api/v1/guardrails/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/backend/api/v1/guardrails/apps.py b/src/backend/api/v1/guardrails/apps.py new file mode 100644 index 0000000..27782f1 --- /dev/null +++ b/src/backend/api/v1/guardrails/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class GuardrailsApiConfig(AppConfig): + name = "api.v1.guardrails" + label = "api_v1_guardrails" diff --git a/src/backend/api/v1/guardrails/endpoints.py b/src/backend/api/v1/guardrails/endpoints.py new file mode 100644 index 0000000..1c93229 --- /dev/null +++ b/src/backend/api/v1/guardrails/endpoints.py @@ -0,0 +1,180 @@ +from http import HTTPStatus +from uuid import UUID + +from django.http import Http404, HttpRequest +from ninja import Router + +from api.v1.guardrails.schemas import ( + ExperimentGuardrailCheckResultOut, + GuardrailCheckResultOut, + GuardrailCreateIn, + GuardrailOut, + GuardrailTriggerOut, + GuardrailUpdateIn, +) +from apps.experiments.models import Experiment +from apps.guardrails.models import Guardrail +from apps.guardrails.services import ( + check_all_running_experiments, + check_experiment_guardrails, + guardrail_create, + guardrail_delete, + guardrail_list, + guardrail_trigger_list, + guardrail_update, +) +from apps.metrics.services import metric_definition_get +from apps.users.auth.bearer import jwt_bearer + +router = Router(tags=["guardrails"], auth=jwt_bearer) + + +@router.post( + "/guardrails/check-all", + response={HTTPStatus.OK: GuardrailCheckResultOut}, + summary="Check guardrails across all running experiments", +) +def check_all_guardrails( + request: HttpRequest, +) -> tuple[int, GuardrailCheckResultOut]: + results = check_all_running_experiments() + return HTTPStatus.OK, GuardrailCheckResultOut(**results) + + +@router.post( + "/experiments/{experiment_id}/guardrails/check", + response={HTTPStatus.OK: ExperimentGuardrailCheckResultOut}, + summary="Manually trigger guardrail check for experiment", +) +def check_guardrails_for_experiment( + request: HttpRequest, + experiment_id: UUID, +) -> tuple[int, ExperimentGuardrailCheckResultOut]: + try: + experiment = ( + Experiment.objects.select_related("flag") + .prefetch_related( + "variants", + "guardrails__metric", + ) + .get(pk=experiment_id) + ) + except Experiment.DoesNotExist: + raise Http404 from Experiment.DoesNotExist + triggers = check_experiment_guardrails(experiment) + return HTTPStatus.OK, ExperimentGuardrailCheckResultOut( + experiment_id=str(experiment.pk), + triggered=len(triggers), + triggers=list(triggers), + ) + + +@router.get( + "/experiments/{experiment_id}/guardrails", + response={HTTPStatus.OK: list[GuardrailOut]}, + summary="List experiment guardrails", +) +def list_guardrails( + request: HttpRequest, + experiment_id: UUID, +) -> tuple[int, list[GuardrailOut]]: + try: + experiment = Experiment.objects.get(pk=experiment_id) + except Experiment.DoesNotExist: + raise Http404 from Experiment.DoesNotExist + grs = guardrail_list(experiment) + return HTTPStatus.OK, [GuardrailOut.from_guardrail(g) for g in grs] + + +@router.post( + "/experiments/{experiment_id}/guardrails", + response={HTTPStatus.CREATED: GuardrailOut}, + summary="Create a guardrail for experiment", +) +def create_guardrail( + request: HttpRequest, + experiment_id: UUID, + payload: GuardrailCreateIn, +) -> tuple[int, GuardrailOut]: + try: + experiment = Experiment.objects.get(pk=experiment_id) + except Experiment.DoesNotExist: + raise Http404 from Experiment.DoesNotExist + metric = metric_definition_get(payload.metric_id) + if not metric: + raise Http404 + g = guardrail_create( + experiment=experiment, + metric=metric, + threshold=payload.threshold, + observation_window_minutes=payload.observation_window_minutes, + action=payload.action, + ) + g = Guardrail.objects.select_related("metric").get(pk=g.pk) + return HTTPStatus.CREATED, GuardrailOut.from_guardrail(g) + + +@router.get( + "/experiments/{experiment_id}/guardrail-triggers", + response={HTTPStatus.OK: list[GuardrailTriggerOut]}, + summary="List guardrail triggers (audit)", +) +def list_guardrail_triggers( + request: HttpRequest, + experiment_id: UUID, +) -> tuple[int, list[GuardrailTriggerOut]]: + try: + experiment = Experiment.objects.get(pk=experiment_id) + except Experiment.DoesNotExist: + raise Http404 from Experiment.DoesNotExist + triggers = guardrail_trigger_list(experiment) + return HTTPStatus.OK, [ + GuardrailTriggerOut.model_validate(trigger) for trigger in triggers + ] + + +@router.patch( + "/experiments/{experiment_id}/guardrails/{guardrail_id}", + response={HTTPStatus.OK: GuardrailOut}, + summary="Update a guardrail", +) +def update_guardrail( + request: HttpRequest, + experiment_id: UUID, + guardrail_id: UUID, + payload: GuardrailUpdateIn, +) -> tuple[int, GuardrailOut]: + try: + g = Guardrail.objects.select_related("metric", "experiment").get( + pk=guardrail_id, + experiment_id=experiment_id, + ) + except Guardrail.DoesNotExist: + raise Http404 from Guardrail.DoesNotExist + g = guardrail_update( + guardrail=g, + **payload.dict(exclude_unset=True), + ) + g = Guardrail.objects.select_related("metric").get(pk=g.pk) + return HTTPStatus.OK, GuardrailOut.from_guardrail(g) + + +@router.delete( + "/experiments/{experiment_id}/guardrails/{guardrail_id}", + response={HTTPStatus.NO_CONTENT: None}, + summary="Delete a guardrail", +) +def delete_guardrail_endpoint( + request: HttpRequest, + experiment_id: UUID, + guardrail_id: UUID, +) -> tuple[int, None]: + try: + g = Guardrail.objects.select_related("experiment").get( + pk=guardrail_id, + experiment_id=experiment_id, + ) + except Guardrail.DoesNotExist: + raise Http404 from Guardrail.DoesNotExist + guardrail_delete(guardrail=g) + return HTTPStatus.NO_CONTENT, None diff --git a/src/backend/api/v1/guardrails/schemas.py b/src/backend/api/v1/guardrails/schemas.py new file mode 100644 index 0000000..c4626f4 --- /dev/null +++ b/src/backend/api/v1/guardrails/schemas.py @@ -0,0 +1,100 @@ +from decimal import Decimal +from typing import ClassVar +from uuid import UUID + +from ninja import Field, ModelSchema, Schema + +from apps.guardrails.models import ( + Guardrail, + GuardrailAction, + GuardrailTrigger, +) + + +class GuardrailCreateIn(Schema): + metric_id: UUID + threshold: Decimal + observation_window_minutes: int = Field(60, gt=0) + action: GuardrailAction = GuardrailAction.PAUSE + + +class GuardrailUpdateIn(Schema): + threshold: Decimal | None = None + observation_window_minutes: int | None = Field(None, gt=0) + action: GuardrailAction | None = None + is_active: bool | None = None + + +class MetricBriefOut(Schema): + id: UUID + key: str + name: str + + +class GuardrailOut(ModelSchema): + metric: MetricBriefOut + + class Meta: + model = Guardrail + fields: ClassVar[tuple[str, ...]] = ( + Guardrail.id.field.name, + Guardrail.threshold.field.name, + Guardrail.observation_window_minutes.field.name, + Guardrail.action.field.name, + Guardrail.is_active.field.name, + Guardrail.created_at.field.name, + Guardrail.updated_at.field.name, + ) + + @classmethod + def from_guardrail(cls, g: Guardrail) -> "GuardrailOut": + return cls( + id=g.pk, + metric=MetricBriefOut( + id=g.metric.pk, + key=g.metric.key, + name=g.metric.name, + ), + threshold=g.threshold, + observation_window_minutes=g.observation_window_minutes, + action=g.action, + is_active=g.is_active, + created_at=g.created_at, + updated_at=g.updated_at, + ) + + +class GuardrailTriggerOut(ModelSchema): + class Meta: + model = GuardrailTrigger + fields: ClassVar[tuple[str, ...]] = ( + GuardrailTrigger.id.field.name, + GuardrailTrigger.metric_key.field.name, + GuardrailTrigger.threshold.field.name, + GuardrailTrigger.actual_value.field.name, + GuardrailTrigger.observation_window_minutes.field.name, + GuardrailTrigger.action.field.name, + GuardrailTrigger.triggered_at.field.name, + GuardrailTrigger.created_at.field.name, + ) + + +class GuardrailCheckTriggerOut(Schema): + experiment_id: str + experiment_name: str + metric_key: str + threshold: str + actual_value: str + action: str + + +class GuardrailCheckResultOut(Schema): + checked: int + triggered: int + triggers: list[GuardrailCheckTriggerOut] + + +class ExperimentGuardrailCheckResultOut(Schema): + experiment_id: str + triggered: int + triggers: list[GuardrailTriggerOut] diff --git a/src/backend/api/v1/guardrails/tests/__init__.py b/src/backend/api/v1/guardrails/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/backend/api/v1/guardrails/tests/test_guardrails_api.py b/src/backend/api/v1/guardrails/tests/test_guardrails_api.py new file mode 100644 index 0000000..7a811b3 --- /dev/null +++ b/src/backend/api/v1/guardrails/tests/test_guardrails_api.py @@ -0,0 +1,280 @@ +import json +import uuid +from typing import override + +from django.test import Client, TestCase +from django.urls import reverse + +from apps.experiments.tests.helpers import make_experiment +from apps.guardrails.models import GuardrailAction +from apps.metrics.models import MetricType +from apps.metrics.services import metric_definition_create +from apps.reviews.tests.helpers import make_admin +from apps.users.tests.helpers import auth_header + + +class GuardrailCRUDAPITest(TestCase): + @override + def setUp(self) -> None: + self.client = Client() + self.admin = make_admin("_gcrud") + self.auth = auth_header(self.admin) + self.experiment = make_experiment(suffix="_gcrud") + self.metric = metric_definition_create( + key="gcrud_errors", + name="Error Rate", + metric_type=MetricType.RATIO, + calculation_rule={ + "type": "ratio", + "numerator_event": "error", + "denominator_event": "request", + }, + ) + + def _create_guardrail(self, **overrides): + payload = { + "metric_id": str(self.metric.pk), + "threshold": "0.05", + "observation_window_minutes": 30, + "action": GuardrailAction.PAUSE, + } + payload.update(overrides) + return self.client.post( + reverse( + "api-1:create_guardrail", + args=[self.experiment.pk], + ), + data=json.dumps(payload), + content_type="application/json", + HTTP_AUTHORIZATION=self.auth, + ) + + def test_create_guardrail(self) -> None: + resp = self._create_guardrail() + self.assertEqual(resp.status_code, 201) + data = resp.json() + self.assertEqual(data["metric"]["key"], "gcrud_errors") + self.assertEqual(float(data["threshold"]), 0.05) + self.assertEqual(data["observation_window_minutes"], 30) + self.assertEqual(data["action"], GuardrailAction.PAUSE) + self.assertTrue(data["is_active"]) + + def test_create_guardrail_with_rollback_action(self) -> None: + resp = self._create_guardrail(action=GuardrailAction.ROLLBACK) + self.assertEqual(resp.status_code, 201) + self.assertEqual(resp.json()["action"], GuardrailAction.ROLLBACK) + + def test_create_guardrail_nonexistent_experiment(self) -> None: + resp = self.client.post( + reverse( + "api-1:create_guardrail", + args=[uuid.uuid4()], + ), + data=json.dumps( + { + "metric_id": str(self.metric.pk), + "threshold": "0.05", + } + ), + content_type="application/json", + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 404) + + def test_create_guardrail_nonexistent_metric(self) -> None: + resp = self.client.post( + reverse( + "api-1:create_guardrail", + args=[self.experiment.pk], + ), + data=json.dumps( + { + "metric_id": str(uuid.uuid4()), + "threshold": "0.05", + } + ), + content_type="application/json", + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 404) + + def test_list_guardrails(self) -> None: + self._create_guardrail() + resp = self.client.get( + reverse( + "api-1:list_guardrails", + args=[self.experiment.pk], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(len(data), 1) + self.assertEqual(data[0]["metric"]["key"], "gcrud_errors") + + def test_list_guardrails_empty(self) -> None: + resp = self.client.get( + reverse( + "api-1:list_guardrails", + args=[self.experiment.pk], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), []) + + def test_list_guardrails_nonexistent_experiment(self) -> None: + resp = self.client.get( + reverse( + "api-1:list_guardrails", + args=[uuid.uuid4()], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 404) + + def test_update_guardrail(self) -> None: + create_resp = self._create_guardrail() + guardrail_id = create_resp.json()["id"] + resp = self.client.patch( + reverse( + "api-1:update_guardrail", + args=[self.experiment.pk, guardrail_id], + ), + data=json.dumps({"threshold": "0.10"}), + content_type="application/json", + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 200) + self.assertEqual(float(resp.json()["threshold"]), 0.10) + + def test_update_guardrail_action(self) -> None: + create_resp = self._create_guardrail() + guardrail_id = create_resp.json()["id"] + resp = self.client.patch( + reverse( + "api-1:update_guardrail", + args=[self.experiment.pk, guardrail_id], + ), + data=json.dumps({"action": GuardrailAction.ROLLBACK}), + content_type="application/json", + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json()["action"], GuardrailAction.ROLLBACK) + + def test_update_guardrail_nonexistent(self) -> None: + resp = self.client.patch( + reverse( + "api-1:update_guardrail", + args=[self.experiment.pk, uuid.uuid4()], + ), + data=json.dumps({"threshold": "0.10"}), + content_type="application/json", + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 404) + + def test_delete_guardrail(self) -> None: + create_resp = self._create_guardrail() + guardrail_id = create_resp.json()["id"] + resp = self.client.delete( + reverse( + "api-1:delete_guardrail_endpoint", + args=[self.experiment.pk, guardrail_id], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 204) + list_resp = self.client.get( + reverse( + "api-1:list_guardrails", + args=[self.experiment.pk], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(list_resp.json(), []) + + def test_delete_guardrail_nonexistent(self) -> None: + resp = self.client.delete( + reverse( + "api-1:delete_guardrail_endpoint", + args=[self.experiment.pk, uuid.uuid4()], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 404) + + +class GuardrailTriggerAPITest(TestCase): + @override + def setUp(self) -> None: + self.client = Client() + self.admin = make_admin("_gtrig") + self.auth = auth_header(self.admin) + self.experiment = make_experiment(suffix="_gtrig") + + def test_list_triggers_empty(self) -> None: + resp = self.client.get( + reverse( + "api-1:list_guardrail_triggers", + args=[self.experiment.pk], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), []) + + def test_list_triggers_nonexistent_experiment(self) -> None: + resp = self.client.get( + reverse( + "api-1:list_guardrail_triggers", + args=[uuid.uuid4()], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 404) + + +class GuardrailCheckAPITest(TestCase): + @override + def setUp(self) -> None: + self.client = Client() + self.admin = make_admin("_gchk") + self.auth = auth_header(self.admin) + self.experiment = make_experiment(suffix="_gchk") + + def test_check_experiment_guardrails_not_running(self) -> None: + resp = self.client.post( + reverse( + "api-1:check_guardrails_for_experiment", + args=[self.experiment.pk], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["experiment_id"], str(self.experiment.pk)) + self.assertEqual(data["triggered"], 0) + self.assertEqual(data["triggers"], []) + + def test_check_experiment_guardrails_nonexistent(self) -> None: + resp = self.client.post( + reverse( + "api-1:check_guardrails_for_experiment", + args=[uuid.uuid4()], + ), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 404) + + def test_check_all_guardrails(self) -> None: + resp = self.client.post( + reverse("api-1:check_all_guardrails"), + HTTP_AUTHORIZATION=self.auth, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["checked"], 0) + self.assertEqual(data["triggered"], 0) + self.assertEqual(data["triggers"], [])