fix(): fixed bugs with cache invalidation, notifications and guardrails

This commit is contained in:
ITQ
2026-02-24 13:16:29 +03:00
parent b27254e2fb
commit 7bf3ccee5c
14 changed files with 290 additions and 82 deletions
+10 -2
View File
@@ -86,6 +86,14 @@ def _extract_django_rejected_value(
return None return None
def _extract_django_issue(
error: django.core.exceptions.ValidationError,
) -> str:
if error.messages:
return str(error.messages[0])
return str(error.message)
def handle_django_validation_error( def handle_django_validation_error(
request: HttpRequest, request: HttpRequest,
exc: django.core.exceptions.ValidationError, exc: django.core.exceptions.ValidationError,
@@ -99,7 +107,7 @@ def handle_django_validation_error(
field_errors_data.extend( field_errors_data.extend(
{ {
"field": field, "field": field,
"issue": str(error.message), "issue": _extract_django_issue(error),
"rejected_value": _extract_django_rejected_value(error), "rejected_value": _extract_django_rejected_value(error),
} }
for error in errors for error in errors
@@ -108,7 +116,7 @@ def handle_django_validation_error(
field_errors_data.extend( field_errors_data.extend(
{ {
"field": "non_field_error", "field": "non_field_error",
"issue": str(error.message), "issue": _extract_django_issue(error),
"rejected_value": _extract_django_rejected_value(error), "rejected_value": _extract_django_rejected_value(error),
} }
for error in exc.error_list for error in exc.error_list
@@ -159,3 +159,14 @@ class ExperimentReportAPITest(TestCase):
HTTP_AUTHORIZATION=self.auth, HTTP_AUTHORIZATION=self.auth,
) )
self.assertEqual(resp.status_code, 404) self.assertEqual(resp.status_code, 404)
def test_report_invalid_start_date_returns_422(self) -> None:
resp = self.client.get(
reverse(
"api-1:get_experiment_report",
args=[self.experiment.pk],
),
{"start_date": "not-a-date"},
HTTP_AUTHORIZATION=self.auth,
)
self.assertEqual(resp.status_code, 422)
+5 -1
View File
@@ -1,6 +1,7 @@
from decimal import Decimal from decimal import Decimal
from typing import Any from typing import Any
from django.core.cache import cache
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import transaction from django.db import transaction
@@ -305,6 +306,7 @@ def _transition(
) )
experiment.status = new_status experiment.status = new_status
experiment.save(update_fields=["status", "updated_at"]) experiment.save(update_fields=["status", "updated_at"])
cache.delete(f"active_exp:{experiment.flag_id}")
ExperimentLog.objects.create( ExperimentLog.objects.create(
experiment=experiment, experiment=experiment,
log_type=log_type, log_type=log_type,
@@ -464,11 +466,13 @@ def experiment_resume(*, experiment: Experiment, user: User) -> Experiment:
ensure_owner_or_admin(experiment, user) ensure_owner_or_admin(experiment, user)
_validate_no_active_flag_conflict(experiment) _validate_no_active_flag_conflict(experiment)
validate_domain_conflicts(experiment) validate_domain_conflicts(experiment)
return _transition( experiment = _transition(
experiment, experiment,
ExperimentStatus.RUNNING, ExperimentStatus.RUNNING,
user, user,
) )
_notify("experiment_resumed", experiment)
return experiment
@transaction.atomic @transaction.atomic
@@ -349,6 +349,26 @@ class RejectAndReqChangesTest(TestCase):
exp = experiment_reopen(experiment=exp, user=self.experimenter) exp = experiment_reopen(experiment=exp, user=self.experimenter)
self.assertEqual(exp.status, ExperimentStatus.DRAFT) self.assertEqual(exp.status, ExperimentStatus.DRAFT)
def test_unauthorized_reject_raises(self) -> None:
outsider = make_approver("_rr_out")
with self.assertRaises(ValidationError) as ctx:
experiment_reject(
experiment=self.exp,
user=outsider,
comment="nope",
)
self.assertIn("user", ctx.exception.message_dict)
def test_unauthorized_request_changes_raises(self) -> None:
outsider = make_approver("_rr_rc")
with self.assertRaises(ValidationError) as ctx:
experiment_request_changes(
experiment=self.exp,
user=outsider,
comment="nope",
)
self.assertIn("user", ctx.exception.message_dict)
class LifecycleFlowTest(TestCase): class LifecycleFlowTest(TestCase):
@override @override
+5 -3
View File
@@ -5,6 +5,8 @@ from django.utils.translation import gettext_lazy as _
from apps.core.models import BaseModel from apps.core.models import BaseModel
METRIC_DECIMAL_PLACES = 4
class GuardrailAction(models.TextChoices): class GuardrailAction(models.TextChoices):
PAUSE = "pause", _("Pause experiment") PAUSE = "pause", _("Pause experiment")
@@ -26,7 +28,7 @@ class Guardrail(BaseModel):
) )
threshold = models.DecimalField( threshold = models.DecimalField(
max_digits=10, max_digits=10,
decimal_places=4, decimal_places=METRIC_DECIMAL_PLACES,
verbose_name=_("threshold"), verbose_name=_("threshold"),
) )
observation_window_minutes = models.PositiveIntegerField( observation_window_minutes = models.PositiveIntegerField(
@@ -89,12 +91,12 @@ class GuardrailTrigger(BaseModel):
) )
threshold = models.DecimalField( threshold = models.DecimalField(
max_digits=10, max_digits=10,
decimal_places=4, decimal_places=METRIC_DECIMAL_PLACES,
verbose_name=_("threshold"), verbose_name=_("threshold"),
) )
actual_value = models.DecimalField( actual_value = models.DecimalField(
max_digits=10, max_digits=10,
decimal_places=4, decimal_places=METRIC_DECIMAL_PLACES,
verbose_name=_("actual value"), verbose_name=_("actual value"),
) )
observation_window_minutes = models.PositiveIntegerField( observation_window_minutes = models.PositiveIntegerField(
+7 -34
View File
@@ -2,13 +2,13 @@ from datetime import timedelta
from decimal import Decimal from decimal import Decimal
from typing import Any from typing import Any
from django.core.cache import cache
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import transaction from django.db import transaction
from django.db.models import Case, QuerySet, Value, When from django.db.models import Case, QuerySet, Value, When
from django.utils import timezone from django.utils import timezone
from apps.experiments.models import ( from apps.experiments.models import (
STARTED_STATUSES,
Experiment, Experiment,
ExperimentLog, ExperimentLog,
ExperimentOutcome, ExperimentOutcome,
@@ -17,6 +17,7 @@ from apps.experiments.models import (
OutcomeType, OutcomeType,
) )
from apps.guardrails.models import ( from apps.guardrails.models import (
METRIC_DECIMAL_PLACES,
Guardrail, Guardrail,
GuardrailAction, GuardrailAction,
GuardrailTrigger, GuardrailTrigger,
@@ -38,16 +39,6 @@ def guardrail_create(
observation_window_minutes: int = 60, observation_window_minutes: int = 60,
action: str = GuardrailAction.PAUSE, action: str = GuardrailAction.PAUSE,
) -> Guardrail: ) -> Guardrail:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Guardrails cannot be added after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
guardrail = Guardrail( guardrail = Guardrail(
experiment=experiment, experiment=experiment,
metric=metric, metric=metric,
@@ -65,16 +56,6 @@ def guardrail_update(
**fields: Any, **fields: Any,
) -> Guardrail: ) -> Guardrail:
guardrail.experiment.refresh_from_db(fields=["status"]) guardrail.experiment.refresh_from_db(fields=["status"])
if guardrail.experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Guardrails cannot be modified after the experiment "
"has been started "
f"(status: '{guardrail.experiment.status}')."
)
}
)
allowed = { allowed = {
"threshold", "threshold",
"observation_window_minutes", "observation_window_minutes",
@@ -93,16 +74,6 @@ def guardrail_update(
def guardrail_delete(*, guardrail: Guardrail) -> None: def guardrail_delete(*, guardrail: Guardrail) -> None:
guardrail.experiment.refresh_from_db(fields=["status"]) guardrail.experiment.refresh_from_db(fields=["status"])
if guardrail.experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Guardrails cannot be deleted after the experiment "
"has been started "
f"(status: '{guardrail.experiment.status}')."
)
}
)
guardrail.delete() guardrail.delete()
@@ -138,8 +109,8 @@ def _calculate_guardrail_metric(
metric=guardrail.metric, metric=guardrail.metric,
experiment_id=experiment.pk, experiment_id=experiment.pk,
variant_id=variant.pk, variant_id=variant.pk,
start_date=window_start, event_start_date=window_start,
end_date=now, event_end_date=now,
) )
if val is not None: if val is not None:
values.append(val) values.append(val)
@@ -182,7 +153,7 @@ def _execute_guardrail_action(
experiment=experiment, experiment=experiment,
metric_key=guardrail.metric.key, metric_key=guardrail.metric.key,
threshold=guardrail.threshold, threshold=guardrail.threshold,
actual_value=actual_value, actual_value=round(actual_value, METRIC_DECIMAL_PLACES),
observation_window_minutes=guardrail.observation_window_minutes, observation_window_minutes=guardrail.observation_window_minutes,
action=guardrail.action, action=guardrail.action,
triggered_at=now, triggered_at=now,
@@ -191,6 +162,7 @@ def _execute_guardrail_action(
if guardrail.action == GuardrailAction.PAUSE: if guardrail.action == GuardrailAction.PAUSE:
experiment.status = ExperimentStatus.PAUSED experiment.status = ExperimentStatus.PAUSED
experiment.save(update_fields=["status", "updated_at"]) experiment.save(update_fields=["status", "updated_at"])
cache.delete(f"active_exp:{experiment.flag_id}")
ExperimentLog.objects.create( ExperimentLog.objects.create(
experiment=experiment, experiment=experiment,
@@ -218,6 +190,7 @@ def _execute_guardrail_action(
experiment.status = ExperimentStatus.COMPLETED experiment.status = ExperimentStatus.COMPLETED
experiment.save(update_fields=["status", "updated_at"]) experiment.save(update_fields=["status", "updated_at"])
cache.delete(f"active_exp:{experiment.flag_id}")
ExperimentOutcome.objects.create( ExperimentOutcome.objects.create(
experiment=experiment, experiment=experiment,
@@ -1,6 +1,5 @@
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import ValidationError
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
@@ -488,6 +487,119 @@ class CheckAllRunningTest(TestCase):
self.assertGreater(len(results["triggers"]), 0) self.assertGreater(len(results["triggers"]), 0)
class GuardrailHigherIsBetterTest(TestCase):
def setUp(self) -> None:
review_settings_update(
default_min_approvals=1,
allow_any_approver=True,
)
self.approver = make_approver("_hib")
self.exposure_type = make_exposure_type()
self.purchase_type = make_event_type(
name="purchase",
display_name="Purchase",
requires_exposure=False,
)
self.experiment = make_experiment(suffix="_hib")
self.v_control, self.v_treatment = add_two_variants(self.experiment)
self.conversion_metric = metric_definition_create(
key="hib_conversion_rate",
name="Conversion Rate",
metric_type=MetricType.RATIO,
direction=MetricDirection.HIGHER_IS_BETTER,
calculation_rule={
"type": "ratio",
"numerator_event": "purchase",
"denominator_event": "exposure",
},
)
guardrail_create(
experiment=self.experiment,
metric=self.conversion_metric,
threshold=Decimal("0.50"),
observation_window_minutes=60,
action=GuardrailAction.PAUSE,
)
self.experiment = _start_experiment(self.experiment, self.approver)
self.now = timezone.now()
def _create_decision_and_exposure(self, decision_id, subject_id, variant):
decision_create(
decision_id=decision_id,
flag_key="flag_hib",
subject_id=subject_id,
experiment_id=str(self.experiment.pk),
variant_id=str(variant.pk),
value=variant.value,
reason="experiment",
)
process_events_batch(
[
{
"event_id": f"exp_{decision_id}",
"event_type": "exposure",
"decision_id": decision_id,
"subject_id": subject_id,
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
def _send_purchase(self, event_id, decision_id, subject_id):
process_events_batch(
[
{
"event_id": event_id,
"event_type": "purchase",
"decision_id": decision_id,
"subject_id": subject_id,
"timestamp": self.now.isoformat(),
"properties": {},
}
]
)
def test_trigger_when_conversion_below_threshold(self) -> None:
for i in range(3):
self._create_decision_and_exposure(
f"dec_hib_{i}",
f"u{i}",
self.v_treatment,
)
self._send_purchase("pur_hib_0", "dec_hib_0", "u0")
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 1)
self.experiment.refresh_from_db()
self.assertEqual(self.experiment.status, ExperimentStatus.PAUSED)
self.assertLess(triggers[0].actual_value, Decimal("0.50"))
def test_no_trigger_when_conversion_above_threshold(self) -> None:
for i in range(3):
self._create_decision_and_exposure(
f"dec_hib_ok_{i}",
f"u{i}",
self.v_treatment,
)
for i in range(3):
self._send_purchase(
f"pur_hib_ok_{i}", f"dec_hib_ok_{i}", f"u{i}"
)
triggers = check_experiment_guardrails(self.experiment)
self.assertEqual(len(triggers), 0)
self.experiment.refresh_from_db()
self.assertEqual(self.experiment.status, ExperimentStatus.RUNNING)
class GuardrailServiceTest(TestCase): class GuardrailServiceTest(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.experiment = make_experiment(suffix="_gr") self.experiment = make_experiment(suffix="_gr")
@@ -535,7 +647,7 @@ class GuardrailServiceTest(TestCase):
) )
self.assertEqual(updated.threshold, Decimal("0.10")) self.assertEqual(updated.threshold, Decimal("0.10"))
def test_reject_update_after_start(self) -> None: def test_update_guardrail_after_start(self) -> None:
review_settings_update( review_settings_update(
default_min_approvals=1, allow_any_approver=True default_min_approvals=1, allow_any_approver=True
) )
@@ -552,8 +664,8 @@ class GuardrailServiceTest(TestCase):
) )
exp = experiment_approve(experiment=exp, approver=approver) exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner) experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError): updated = guardrail_update(guardrail=g, threshold=Decimal("0.10"))
guardrail_update(guardrail=g, threshold=Decimal("0.10")) self.assertEqual(updated.threshold, Decimal("0.10"))
def test_delete_guardrail_in_draft(self) -> None: def test_delete_guardrail_in_draft(self) -> None:
g = guardrail_create( g = guardrail_create(
@@ -564,7 +676,7 @@ class GuardrailServiceTest(TestCase):
guardrail_delete(guardrail=g) guardrail_delete(guardrail=g)
self.assertEqual(Guardrail.objects.count(), 0) self.assertEqual(Guardrail.objects.count(), 0)
def test_reject_delete_after_start(self) -> None: def test_delete_guardrail_after_start(self) -> None:
review_settings_update( review_settings_update(
default_min_approvals=1, allow_any_approver=True default_min_approvals=1, allow_any_approver=True
) )
@@ -581,5 +693,5 @@ class GuardrailServiceTest(TestCase):
) )
exp = experiment_approve(experiment=exp, approver=approver) exp = experiment_approve(experiment=exp, approver=approver)
experiment_start(experiment=exp, user=self.experiment.owner) experiment_start(experiment=exp, user=self.experiment.owner)
with self.assertRaises(ValidationError): guardrail_delete(guardrail=g)
guardrail_delete(guardrail=g) self.assertEqual(Guardrail.objects.count(), 0)
-21
View File
@@ -5,7 +5,6 @@ from django.core.exceptions import ValidationError
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from apps.experiments.models import STARTED_STATUSES
from apps.metrics.models import ( from apps.metrics.models import (
ExperimentMetric, ExperimentMetric,
MetricDefinition, MetricDefinition,
@@ -131,16 +130,6 @@ def experiment_metric_add(
metric: MetricDefinition, metric: MetricDefinition,
is_primary: bool = False, is_primary: bool = False,
) -> ExperimentMetric: ) -> ExperimentMetric:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Metrics cannot be modified after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
if is_primary: if is_primary:
experiment.experiment_metrics.filter(is_primary=True).update( experiment.experiment_metrics.filter(is_primary=True).update(
is_primary=False, is_primary=False,
@@ -160,16 +149,6 @@ def experiment_metric_remove(
experiment: Any, experiment: Any,
metric: MetricDefinition, metric: MetricDefinition,
) -> None: ) -> None:
if experiment.status in STARTED_STATUSES:
raise ValidationError(
{
"experiment": (
"Metrics cannot be modified after the experiment "
"has been started "
f"(status: '{experiment.status}')."
)
}
)
deleted, _ = ExperimentMetric.objects.filter( deleted, _ = ExperimentMetric.objects.filter(
experiment=experiment, experiment=experiment,
metric=metric, metric=metric,
@@ -103,6 +103,46 @@ class MetricDefinitionCreateTest(TestCase):
calculation_rule={"type": "count", "event": "x"}, calculation_rule={"type": "count", "event": "x"},
) )
def test_reject_percentile_out_of_range(self) -> None:
with self.assertRaises(ValidationError):
metric_definition_create(
key="bad_pct",
name="Bad Percentile",
metric_type=MetricType.PERCENTILE,
calculation_rule={
"type": "percentile",
"event": "page_loaded",
"property": "latency_ms",
"percentile": 200,
},
)
def test_reject_rule_type_mismatch(self) -> None:
with self.assertRaises(ValidationError):
metric_definition_create(
key="mismatch",
name="Mismatch",
metric_type=MetricType.RATIO,
calculation_rule={
"type": "count",
"numerator_event": "click",
"denominator_event": "exposure",
},
)
def test_reject_extra_rule_fields(self) -> None:
with self.assertRaises(ValidationError):
metric_definition_create(
key="extra_field",
name="Extra",
metric_type=MetricType.COUNT,
calculation_rule={
"type": "count",
"event": "click",
"unknown": "value",
},
)
class MetricDefinitionUpdateTest(TestCase): class MetricDefinitionUpdateTest(TestCase):
def test_update_name_and_description(self) -> None: def test_update_name_and_description(self) -> None:
@@ -0,0 +1,23 @@
# Generated by Django 5.2.11 on 2026-02-24 09:40
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('notifications', '0003_notificationrule_rate_limit_max_notifications_and_more'),
]
operations = [
migrations.AlterField(
model_name='notificationlog',
name='event_type',
field=models.CharField(choices=[('experiment_started', 'Experiment started'), ('experiment_paused', 'Experiment paused'), ('experiment_resumed', 'Experiment resumed'), ('experiment_completed', 'Experiment completed'), ('guardrail_triggered', 'Guardrail triggered'), ('review_requested', 'Review requested'), ('review_approved', 'Review approved'), ('review_rejected', 'Review rejected')], max_length=30, verbose_name='event type'),
),
migrations.AlterField(
model_name='notificationrule',
name='event_type',
field=models.CharField(choices=[('experiment_started', 'Experiment started'), ('experiment_paused', 'Experiment paused'), ('experiment_resumed', 'Experiment resumed'), ('experiment_completed', 'Experiment completed'), ('guardrail_triggered', 'Guardrail triggered'), ('review_requested', 'Review requested'), ('review_approved', 'Review approved'), ('review_rejected', 'Review rejected')], max_length=30, verbose_name='event type'),
),
]
+1
View File
@@ -14,6 +14,7 @@ class ChannelType(models.TextChoices):
class NotificationEventType(models.TextChoices): class NotificationEventType(models.TextChoices):
EXPERIMENT_STARTED = "experiment_started", _("Experiment started") EXPERIMENT_STARTED = "experiment_started", _("Experiment started")
EXPERIMENT_PAUSED = "experiment_paused", _("Experiment paused") EXPERIMENT_PAUSED = "experiment_paused", _("Experiment paused")
EXPERIMENT_RESUMED = "experiment_resumed", _("Experiment resumed")
EXPERIMENT_COMPLETED = "experiment_completed", _("Experiment completed") EXPERIMENT_COMPLETED = "experiment_completed", _("Experiment completed")
GUARDRAIL_TRIGGERED = "guardrail_triggered", _("Guardrail triggered") GUARDRAIL_TRIGGERED = "guardrail_triggered", _("Guardrail triggered")
REVIEW_REQUESTED = "review_requested", _("Review requested") REVIEW_REQUESTED = "review_requested", _("Review requested")
+13 -4
View File
@@ -200,15 +200,24 @@ def _build_event_key(
return f"{event_type}:{payload.experiment_id}:{bucket}" return f"{event_type}:{payload.experiment_id}:{bucket}"
def _escape_markdown(text: str) -> str:
for ch in r"\_*[]()~`>#+-=|{}.!":
text = text.replace(ch, f"\\{ch}")
return text
def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None: def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
bot_token = config.get("bot_token", "") bot_token = config.get("bot_token", "")
chat_id = config.get("chat_id", "") chat_id = config.get("chat_id", "")
if not bot_token or not chat_id: if not bot_token or not chat_id:
raise ValueError("Telegram config requires 'bot_token' and 'chat_id'.") raise ValueError("Telegram config requires 'bot_token' and 'chat_id'.")
text = f"*{payload['title']}*\n\n{payload['body']}" title = _escape_markdown(payload["title"])
body = _escape_markdown(payload["body"])
text = f"*{title}*\n\n{body}"
if payload.get("experiment_name"): if payload.get("experiment_name"):
text += f"\n\nExperiment: {payload['experiment_name']}" name = _escape_markdown(payload["experiment_name"])
text += f"\n\nExperiment: {name}"
api_url = config.get( api_url = config.get(
"api_url", "api_url",
@@ -219,7 +228,7 @@ def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
json={ json={
"chat_id": chat_id, "chat_id": chat_id,
"text": text, "text": text,
"parse_mode": "Markdown", "parse_mode": "MarkdownV2",
}, },
timeout=10, timeout=10,
) )
@@ -249,7 +258,7 @@ def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
def flush_pending_notifications() -> dict[str, int]: def flush_pending_notifications() -> dict[str, int]:
pending = NotificationLog.objects.filter( pending = NotificationLog.objects.filter(
status=NotificationStatus.PENDING, status=NotificationStatus.PENDING,
).select_related("channel") ).select_related("channel").order_by("created_at")
senders = { senders = {
ChannelType.TELEGRAM: _send_telegram, ChannelType.TELEGRAM: _send_telegram,
@@ -238,6 +238,27 @@ class NotificationEnqueueTest(TestCase):
) )
self.assertEqual(len(logs), 0) self.assertEqual(len(logs), 0)
def test_enqueue_experiment_resumed(self) -> None:
rule_create(
event_type=NotificationEventType.EXPERIMENT_RESUMED,
channel=self.channel,
)
logs = notification_enqueue(
NotificationEventType.EXPERIMENT_RESUMED,
NotificationPayload(
title="Experiment Resumed",
body=f"Experiment '{self.experiment.name}' - experiment resumed.",
event_type=NotificationEventType.EXPERIMENT_RESUMED,
experiment_id=str(self.experiment.pk),
experiment_name=self.experiment.name,
),
)
self.assertEqual(len(logs), 1)
self.assertEqual(
logs[0].event_type, NotificationEventType.EXPERIMENT_RESUMED
)
self.assertEqual(logs[0].status, NotificationStatus.PENDING)
class FlushNotificationsTest(TestCase): class FlushNotificationsTest(TestCase):
@override @override
+15 -10
View File
@@ -117,6 +117,8 @@ def calculate_metric_value(
variant_id: UUID, variant_id: UUID,
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
event_start_date: datetime | None = None,
event_end_date: datetime | None = None,
) -> Decimal | None: ) -> Decimal | None:
rule = metric.calculation_rule rule = metric.calculation_rule
decision_ids = _exposure_decision_ids( decision_ids = _exposure_decision_ids(
@@ -129,20 +131,23 @@ def calculate_metric_value(
if not decision_ids: if not decision_ids:
return None return None
ev_start = event_start_date or start_date
ev_end = event_end_date or end_date
metric_type = metric.metric_type metric_type = metric.metric_type
if metric_type == MetricType.RATIO: if metric_type == MetricType.RATIO:
numerator = _count_events( numerator = _count_events(
decision_ids, decision_ids,
rule["numerator_event"], rule["numerator_event"],
start_date, ev_start,
end_date, ev_end,
) )
denominator = _count_events( denominator = _count_events(
decision_ids, decision_ids,
rule["denominator_event"], rule["denominator_event"],
start_date, ev_start,
end_date, ev_end,
) )
if denominator == 0: if denominator == 0:
return None return None
@@ -152,8 +157,8 @@ def calculate_metric_value(
count = _count_events( count = _count_events(
decision_ids, decision_ids,
rule["event"], rule["event"],
start_date, ev_start,
end_date, ev_end,
) )
return Decimal(str(count)) return Decimal(str(count))
@@ -162,8 +167,8 @@ def calculate_metric_value(
decision_ids, decision_ids,
rule["event"], rule["event"],
rule["property"], rule["property"],
start_date, ev_start,
end_date, ev_end,
) )
if metric_type == MetricType.PERCENTILE: if metric_type == MetricType.PERCENTILE:
@@ -172,8 +177,8 @@ def calculate_metric_value(
rule["event"], rule["event"],
rule["property"], rule["property"],
rule.get("percentile", 95), rule.get("percentile", 95),
start_date, ev_start,
end_date, ev_end,
) )
return None return None