chore(): small improvements
This commit is contained in:
@@ -138,6 +138,8 @@ def create_rule(
|
||||
event_type=payload.event_type,
|
||||
channel=ch,
|
||||
experiment=experiment,
|
||||
rate_limit_window_seconds=payload.rate_limit_window_seconds,
|
||||
rate_limit_max_notifications=payload.rate_limit_max_notifications,
|
||||
)
|
||||
r = NotificationRule.objects.select_related("channel", "experiment").get(
|
||||
pk=r.pk
|
||||
|
||||
@@ -42,11 +42,15 @@ class RuleCreateIn(Schema):
|
||||
event_type: NotificationEventType
|
||||
channel_id: UUID
|
||||
experiment_id: UUID | None = None
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_max_notifications: int = 1
|
||||
|
||||
|
||||
class RuleUpdateIn(Schema):
|
||||
event_type: NotificationEventType | None = None
|
||||
is_active: bool | None = None
|
||||
rate_limit_window_seconds: int | None = None
|
||||
rate_limit_max_notifications: int | None = None
|
||||
|
||||
|
||||
class ChannelBriefOut(Schema):
|
||||
@@ -70,6 +74,8 @@ class RuleOut(ModelSchema):
|
||||
NotificationRule.id.field.name,
|
||||
NotificationRule.event_type.field.name,
|
||||
NotificationRule.is_active.field.name,
|
||||
NotificationRule.rate_limit_window_seconds.field.name,
|
||||
NotificationRule.rate_limit_max_notifications.field.name,
|
||||
NotificationRule.created_at.field.name,
|
||||
NotificationRule.updated_at.field.name,
|
||||
)
|
||||
@@ -92,6 +98,8 @@ class RuleOut(ModelSchema):
|
||||
),
|
||||
experiment=experiment_brief,
|
||||
is_active=r.is_active,
|
||||
rate_limit_window_seconds=r.rate_limit_window_seconds,
|
||||
rate_limit_max_notifications=r.rate_limit_max_notifications,
|
||||
created_at=r.created_at,
|
||||
updated_at=r.updated_at,
|
||||
)
|
||||
|
||||
@@ -148,6 +148,18 @@ class RuleAPITest(TestCase):
|
||||
)
|
||||
self.assertEqual(data["channel"]["id"], str(self.channel.pk))
|
||||
self.assertIsNone(data["experiment"])
|
||||
self.assertEqual(data["rate_limit_window_seconds"], 60)
|
||||
self.assertEqual(data["rate_limit_max_notifications"], 1)
|
||||
|
||||
def test_create_rule_with_custom_rate_limit(self) -> None:
|
||||
resp = self._create_rule(
|
||||
rate_limit_window_seconds=120,
|
||||
rate_limit_max_notifications=3,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 201)
|
||||
data = resp.json()
|
||||
self.assertEqual(data["rate_limit_window_seconds"], 120)
|
||||
self.assertEqual(data["rate_limit_max_notifications"], 3)
|
||||
|
||||
def test_create_rule_nonexistent_channel(self) -> None:
|
||||
resp = self._create_rule(channel_id=str(uuid.uuid4()))
|
||||
@@ -167,12 +179,20 @@ class RuleAPITest(TestCase):
|
||||
rule_id = create_resp.json()["id"]
|
||||
resp = self.client.patch(
|
||||
reverse("api-1:update_rule", args=[rule_id]),
|
||||
data=json.dumps({"is_active": False}),
|
||||
data=json.dumps(
|
||||
{
|
||||
"is_active": False,
|
||||
"rate_limit_window_seconds": 300,
|
||||
"rate_limit_max_notifications": 5,
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION=self.auth,
|
||||
)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertFalse(resp.json()["is_active"])
|
||||
self.assertEqual(resp.json()["rate_limit_window_seconds"], 300)
|
||||
self.assertEqual(resp.json()["rate_limit_max_notifications"], 5)
|
||||
|
||||
def test_update_rule_not_found(self) -> None:
|
||||
resp = self.client.patch(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
from uuid import UUID
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
@@ -12,6 +13,17 @@ from apps.conflicts.selectors import domain_active_experiments
|
||||
from apps.experiments.models import ACTIVE_STATUSES, Experiment
|
||||
|
||||
|
||||
def _subject_winner_index(
|
||||
*,
|
||||
subject_id: str,
|
||||
domain_id: UUID,
|
||||
size: int,
|
||||
) -> int:
|
||||
seed = f"{domain_id}:{subject_id}".encode()
|
||||
digest = hashlib.sha256(seed).digest()
|
||||
return int.from_bytes(digest[:8], byteorder="big") % size
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
def conflict_domain_create(
|
||||
*,
|
||||
@@ -167,7 +179,15 @@ def resolve_domain_conflict(
|
||||
return True
|
||||
|
||||
if domain.policy == ConflictPolicy.MUTUAL_EXCLUSION:
|
||||
winner = active_memberships[0]
|
||||
if subject_id:
|
||||
winner_idx = _subject_winner_index(
|
||||
subject_id=subject_id,
|
||||
domain_id=domain_id,
|
||||
size=len(active_memberships),
|
||||
)
|
||||
winner = active_memberships[winner_idx]
|
||||
else:
|
||||
winner = active_memberships[0]
|
||||
return str(winner.experiment_id) == str(experiment_id)
|
||||
|
||||
if domain.policy == ConflictPolicy.PRIORITY:
|
||||
@@ -187,7 +207,19 @@ def resolve_domain_conflict(
|
||||
tied = [m for m in active_memberships if m.priority == top_priority]
|
||||
if len(tied) <= 1:
|
||||
return True
|
||||
winner = min(tied, key=lambda m: m.experiment.created_at)
|
||||
if subject_id:
|
||||
ordered_tied = sorted(
|
||||
tied,
|
||||
key=lambda m: str(m.experiment_id),
|
||||
)
|
||||
winner_idx = _subject_winner_index(
|
||||
subject_id=subject_id,
|
||||
domain_id=domain_id,
|
||||
size=len(ordered_tied),
|
||||
)
|
||||
winner = ordered_tied[winner_idx]
|
||||
else:
|
||||
winner = min(tied, key=lambda m: m.experiment.created_at)
|
||||
return str(winner.experiment_id) == str(experiment_id)
|
||||
|
||||
return True
|
||||
|
||||
@@ -285,7 +285,7 @@ class ResolveDomainConflictTest(TestCase):
|
||||
exp = experiment_approve(experiment=exp, approver=self.approver)
|
||||
return experiment_start(experiment=exp, user=self.experimenter)
|
||||
|
||||
def test_mutual_exclusion_winner_is_first(self) -> None:
|
||||
def test_mutual_exclusion_deterministic_per_subject(self) -> None:
|
||||
domain = make_domain(
|
||||
suffix="_me",
|
||||
policy=ConflictPolicy.MUTUAL_EXCLUSION,
|
||||
@@ -293,10 +293,17 @@ class ResolveDomainConflictTest(TestCase):
|
||||
)
|
||||
exp1 = self._make_and_start("_me1", domain)
|
||||
exp2 = self._make_and_start("_me2", domain)
|
||||
winner = resolve_domain_conflict(exp1.pk, domain.pk, "u1")
|
||||
self.assertTrue(winner)
|
||||
loser = resolve_domain_conflict(exp2.pk, domain.pk, "u1")
|
||||
self.assertFalse(loser)
|
||||
exp1_u1 = resolve_domain_conflict(exp1.pk, domain.pk, "u1")
|
||||
exp2_u1 = resolve_domain_conflict(exp2.pk, domain.pk, "u1")
|
||||
self.assertNotEqual(exp1_u1, exp2_u1)
|
||||
self.assertEqual(
|
||||
exp1_u1,
|
||||
resolve_domain_conflict(exp1.pk, domain.pk, "u1"),
|
||||
)
|
||||
self.assertEqual(
|
||||
exp2_u1,
|
||||
resolve_domain_conflict(exp2.pk, domain.pk, "u1"),
|
||||
)
|
||||
|
||||
def test_priority_higher_wins(self) -> None:
|
||||
domain = make_domain(
|
||||
@@ -309,7 +316,7 @@ class ResolveDomainConflictTest(TestCase):
|
||||
self.assertTrue(resolve_domain_conflict(exp_high.pk, domain.pk, "u1"))
|
||||
self.assertFalse(resolve_domain_conflict(exp_low.pk, domain.pk, "u1"))
|
||||
|
||||
def test_priority_tie_first_created_wins(self) -> None:
|
||||
def test_priority_tie_deterministic_per_subject(self) -> None:
|
||||
domain = make_domain(
|
||||
suffix="_tie",
|
||||
policy=ConflictPolicy.PRIORITY,
|
||||
@@ -317,8 +324,17 @@ class ResolveDomainConflictTest(TestCase):
|
||||
)
|
||||
exp1 = self._make_and_start("_tie1", domain, priority=5)
|
||||
exp2 = self._make_and_start("_tie2", domain, priority=5)
|
||||
self.assertTrue(resolve_domain_conflict(exp1.pk, domain.pk, "u1"))
|
||||
self.assertFalse(resolve_domain_conflict(exp2.pk, domain.pk, "u1"))
|
||||
exp1_u1 = resolve_domain_conflict(exp1.pk, domain.pk, "u1")
|
||||
exp2_u1 = resolve_domain_conflict(exp2.pk, domain.pk, "u1")
|
||||
self.assertNotEqual(exp1_u1, exp2_u1)
|
||||
self.assertEqual(
|
||||
exp1_u1,
|
||||
resolve_domain_conflict(exp1.pk, domain.pk, "u1"),
|
||||
)
|
||||
self.assertEqual(
|
||||
exp2_u1,
|
||||
resolve_domain_conflict(exp2.pk, domain.pk, "u1"),
|
||||
)
|
||||
|
||||
def test_single_experiment_always_wins(self) -> None:
|
||||
domain = make_domain(suffix="_single")
|
||||
|
||||
@@ -145,7 +145,10 @@ def _check_participation_limits(
|
||||
return not recent_completed
|
||||
|
||||
|
||||
def _check_domain_conflicts(experiment: Experiment) -> bool:
|
||||
def _check_domain_conflicts(
|
||||
experiment: Experiment,
|
||||
subject_id: str,
|
||||
) -> bool:
|
||||
memberships = ExperimentConflictDomain.objects.filter(
|
||||
experiment=experiment,
|
||||
).select_related("conflict_domain")
|
||||
@@ -154,7 +157,7 @@ def _check_domain_conflicts(experiment: Experiment) -> bool:
|
||||
if not resolve_domain_conflict(
|
||||
experiment_id=experiment.pk,
|
||||
domain_id=membership.conflict_domain_id,
|
||||
subject_id="",
|
||||
subject_id=subject_id,
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -219,7 +222,7 @@ def decide_for_flag(
|
||||
_persist_decision(result, subject_id)
|
||||
return result
|
||||
|
||||
if not _check_domain_conflicts(experiment):
|
||||
if not _check_domain_conflicts(experiment, subject_id):
|
||||
DECIDE_REQUESTS.labels(reason="domain_conflict").inc()
|
||||
result = {
|
||||
"flag": flag_key,
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-23 14:56
|
||||
|
||||
import django.core.validators
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('flags', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='featureflag',
|
||||
name='key',
|
||||
field=models.CharField(help_text='Unique identifier for the feature flag', max_length=100, unique=True, validators=[django.core.validators.RegexValidator(message='Event type name must follow snake_case, camelCase, or PascalCase.', regex='^[A-Za-z][A-Za-z0-9_]*$')], verbose_name='key'),
|
||||
),
|
||||
]
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.2.11 on 2026-02-23 14:56
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('notifications', '0002_alter_notificationchannel_config'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='notificationrule',
|
||||
name='rate_limit_max_notifications',
|
||||
field=models.PositiveIntegerField(default=1, verbose_name='rate limit max notifications'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='notificationrule',
|
||||
name='rate_limit_window_seconds',
|
||||
field=models.PositiveIntegerField(default=60, verbose_name='rate limit window seconds'),
|
||||
),
|
||||
]
|
||||
@@ -93,6 +93,14 @@ class NotificationRule(BaseModel):
|
||||
default=True,
|
||||
verbose_name=_("is active"),
|
||||
)
|
||||
rate_limit_window_seconds = models.PositiveIntegerField(
|
||||
default=60,
|
||||
verbose_name=_("rate limit window seconds"),
|
||||
)
|
||||
rate_limit_max_notifications = models.PositiveIntegerField(
|
||||
default=1,
|
||||
verbose_name=_("rate limit max notifications"),
|
||||
)
|
||||
created_at = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
verbose_name=_("created at"),
|
||||
|
||||
@@ -82,11 +82,15 @@ def rule_create(
|
||||
event_type: str,
|
||||
channel: NotificationChannel,
|
||||
experiment: Any | None = None,
|
||||
rate_limit_window_seconds: int = 60,
|
||||
rate_limit_max_notifications: int = 1,
|
||||
) -> NotificationRule:
|
||||
rule = NotificationRule(
|
||||
event_type=event_type,
|
||||
channel=channel,
|
||||
experiment=experiment,
|
||||
rate_limit_window_seconds=rate_limit_window_seconds,
|
||||
rate_limit_max_notifications=rate_limit_max_notifications,
|
||||
)
|
||||
rule.save()
|
||||
return rule
|
||||
@@ -97,7 +101,12 @@ def rule_update(
|
||||
rule: NotificationRule,
|
||||
**fields: Any,
|
||||
) -> NotificationRule:
|
||||
allowed = {"event_type", "is_active"}
|
||||
allowed = {
|
||||
"event_type",
|
||||
"is_active",
|
||||
"rate_limit_window_seconds",
|
||||
"rate_limit_max_notifications",
|
||||
}
|
||||
for key in fields:
|
||||
if key not in allowed:
|
||||
raise ValueError(f"Field '{key}' cannot be updated.")
|
||||
@@ -149,12 +158,17 @@ def notification_enqueue(
|
||||
|
||||
logs: list[NotificationLog] = []
|
||||
for rule in rules:
|
||||
event_key = _build_event_key(event_type, payload)
|
||||
if NotificationLog.objects.filter(
|
||||
event_key = _build_event_key(
|
||||
event_type,
|
||||
payload,
|
||||
rule.rate_limit_window_seconds,
|
||||
)
|
||||
sent_or_pending = NotificationLog.objects.filter(
|
||||
event_key=event_key,
|
||||
channel=rule.channel,
|
||||
status__in=[NotificationStatus.PENDING, NotificationStatus.SENT],
|
||||
).exists():
|
||||
).count()
|
||||
if sent_or_pending >= rule.rate_limit_max_notifications:
|
||||
continue
|
||||
|
||||
log = NotificationLog.objects.create(
|
||||
@@ -176,8 +190,13 @@ def notification_enqueue(
|
||||
return logs
|
||||
|
||||
|
||||
def _build_event_key(event_type: str, payload: NotificationPayload) -> str:
|
||||
bucket = int(timezone.now().timestamp()) // 60
|
||||
def _build_event_key(
|
||||
event_type: str,
|
||||
payload: NotificationPayload,
|
||||
window_seconds: int,
|
||||
) -> str:
|
||||
normalized_window = max(window_seconds, 1)
|
||||
bucket = int(timezone.now().timestamp()) // normalized_window
|
||||
return f"{event_type}:{payload.experiment_id}:{bucket}"
|
||||
|
||||
|
||||
|
||||
@@ -164,6 +164,33 @@ class NotificationEnqueueTest(TestCase):
|
||||
self.assertEqual(len(logs_1), 1)
|
||||
self.assertEqual(len(logs_2), 0)
|
||||
|
||||
def test_enqueue_respects_rule_rate_limit(self) -> None:
|
||||
rule_create(
|
||||
event_type=NotificationEventType.GUARDRAIL_TRIGGERED,
|
||||
channel=self.channel,
|
||||
rate_limit_window_seconds=60,
|
||||
rate_limit_max_notifications=2,
|
||||
)
|
||||
payload = NotificationPayload(
|
||||
title="Alert",
|
||||
body="Error rate exceeded",
|
||||
event_type=NotificationEventType.GUARDRAIL_TRIGGERED,
|
||||
experiment_id=str(self.experiment.pk),
|
||||
experiment_name=self.experiment.name,
|
||||
)
|
||||
logs_1 = notification_enqueue(
|
||||
NotificationEventType.GUARDRAIL_TRIGGERED, payload
|
||||
)
|
||||
logs_2 = notification_enqueue(
|
||||
NotificationEventType.GUARDRAIL_TRIGGERED, payload
|
||||
)
|
||||
logs_3 = notification_enqueue(
|
||||
NotificationEventType.GUARDRAIL_TRIGGERED, payload
|
||||
)
|
||||
self.assertEqual(len(logs_1), 1)
|
||||
self.assertEqual(len(logs_2), 1)
|
||||
self.assertEqual(len(logs_3), 0)
|
||||
|
||||
def test_enqueue_no_matching_rules(self) -> None:
|
||||
logs = notification_enqueue(
|
||||
NotificationEventType.EXPERIMENT_STARTED,
|
||||
|
||||
Reference in New Issue
Block a user