From ffd887dc39e98a15c66858781ac885747cc4f9e6 Mon Sep 17 00:00:00 2001 From: ITQ Date: Mon, 23 Feb 2026 23:29:03 +0300 Subject: [PATCH] chore(): small improvements --- src/backend/api/v1/notifications/endpoints.py | 2 ++ src/backend/api/v1/notifications/schemas.py | 8 +++++ .../tests/test_notifications_api.py | 22 +++++++++++- src/backend/apps/conflicts/services.py | 36 +++++++++++++++++-- .../apps/conflicts/tests/test_conflicts.py | 32 ++++++++++++----- src/backend/apps/decision/services.py | 9 +++-- .../migrations/0002_alter_featureflag_key.py | 19 ++++++++++ ...e_rate_limit_max_notifications_and_more.py | 23 ++++++++++++ src/backend/apps/notifications/models.py | 8 +++++ src/backend/apps/notifications/services.py | 31 ++++++++++++---- .../notifications/tests/test_notifications.py | 27 ++++++++++++++ 11 files changed, 197 insertions(+), 20 deletions(-) create mode 100644 src/backend/apps/flags/migrations/0002_alter_featureflag_key.py create mode 100644 src/backend/apps/notifications/migrations/0003_notificationrule_rate_limit_max_notifications_and_more.py diff --git a/src/backend/api/v1/notifications/endpoints.py b/src/backend/api/v1/notifications/endpoints.py index 92fc498..c775138 100644 --- a/src/backend/api/v1/notifications/endpoints.py +++ b/src/backend/api/v1/notifications/endpoints.py @@ -138,6 +138,8 @@ def create_rule( event_type=payload.event_type, channel=ch, experiment=experiment, + rate_limit_window_seconds=payload.rate_limit_window_seconds, + rate_limit_max_notifications=payload.rate_limit_max_notifications, ) r = NotificationRule.objects.select_related("channel", "experiment").get( pk=r.pk diff --git a/src/backend/api/v1/notifications/schemas.py b/src/backend/api/v1/notifications/schemas.py index 4ea4471..874e850 100644 --- a/src/backend/api/v1/notifications/schemas.py +++ b/src/backend/api/v1/notifications/schemas.py @@ -42,11 +42,15 @@ class RuleCreateIn(Schema): event_type: NotificationEventType channel_id: UUID experiment_id: UUID | None = None + rate_limit_window_seconds: int = 60 + rate_limit_max_notifications: int = 1 class RuleUpdateIn(Schema): event_type: NotificationEventType | None = None is_active: bool | None = None + rate_limit_window_seconds: int | None = None + rate_limit_max_notifications: int | None = None class ChannelBriefOut(Schema): @@ -70,6 +74,8 @@ class RuleOut(ModelSchema): NotificationRule.id.field.name, NotificationRule.event_type.field.name, NotificationRule.is_active.field.name, + NotificationRule.rate_limit_window_seconds.field.name, + NotificationRule.rate_limit_max_notifications.field.name, NotificationRule.created_at.field.name, NotificationRule.updated_at.field.name, ) @@ -92,6 +98,8 @@ class RuleOut(ModelSchema): ), experiment=experiment_brief, is_active=r.is_active, + rate_limit_window_seconds=r.rate_limit_window_seconds, + rate_limit_max_notifications=r.rate_limit_max_notifications, created_at=r.created_at, updated_at=r.updated_at, ) diff --git a/src/backend/api/v1/notifications/tests/test_notifications_api.py b/src/backend/api/v1/notifications/tests/test_notifications_api.py index 6179ad8..de83a03 100644 --- a/src/backend/api/v1/notifications/tests/test_notifications_api.py +++ b/src/backend/api/v1/notifications/tests/test_notifications_api.py @@ -148,6 +148,18 @@ class RuleAPITest(TestCase): ) self.assertEqual(data["channel"]["id"], str(self.channel.pk)) self.assertIsNone(data["experiment"]) + self.assertEqual(data["rate_limit_window_seconds"], 60) + self.assertEqual(data["rate_limit_max_notifications"], 1) + + def test_create_rule_with_custom_rate_limit(self) -> None: + resp = self._create_rule( + rate_limit_window_seconds=120, + rate_limit_max_notifications=3, + ) + self.assertEqual(resp.status_code, 201) + data = resp.json() + self.assertEqual(data["rate_limit_window_seconds"], 120) + self.assertEqual(data["rate_limit_max_notifications"], 3) def test_create_rule_nonexistent_channel(self) -> None: resp = self._create_rule(channel_id=str(uuid.uuid4())) @@ -167,12 +179,20 @@ class RuleAPITest(TestCase): rule_id = create_resp.json()["id"] resp = self.client.patch( reverse("api-1:update_rule", args=[rule_id]), - data=json.dumps({"is_active": False}), + data=json.dumps( + { + "is_active": False, + "rate_limit_window_seconds": 300, + "rate_limit_max_notifications": 5, + } + ), content_type="application/json", HTTP_AUTHORIZATION=self.auth, ) self.assertEqual(resp.status_code, 200) self.assertFalse(resp.json()["is_active"]) + self.assertEqual(resp.json()["rate_limit_window_seconds"], 300) + self.assertEqual(resp.json()["rate_limit_max_notifications"], 5) def test_update_rule_not_found(self) -> None: resp = self.client.patch( diff --git a/src/backend/apps/conflicts/services.py b/src/backend/apps/conflicts/services.py index 83643e1..2b77747 100644 --- a/src/backend/apps/conflicts/services.py +++ b/src/backend/apps/conflicts/services.py @@ -1,3 +1,4 @@ +import hashlib from uuid import UUID from django.core.exceptions import ValidationError @@ -12,6 +13,17 @@ from apps.conflicts.selectors import domain_active_experiments from apps.experiments.models import ACTIVE_STATUSES, Experiment +def _subject_winner_index( + *, + subject_id: str, + domain_id: UUID, + size: int, +) -> int: + seed = f"{domain_id}:{subject_id}".encode() + digest = hashlib.sha256(seed).digest() + return int.from_bytes(digest[:8], byteorder="big") % size + + @transaction.atomic def conflict_domain_create( *, @@ -167,7 +179,15 @@ def resolve_domain_conflict( return True if domain.policy == ConflictPolicy.MUTUAL_EXCLUSION: - winner = active_memberships[0] + if subject_id: + winner_idx = _subject_winner_index( + subject_id=subject_id, + domain_id=domain_id, + size=len(active_memberships), + ) + winner = active_memberships[winner_idx] + else: + winner = active_memberships[0] return str(winner.experiment_id) == str(experiment_id) if domain.policy == ConflictPolicy.PRIORITY: @@ -187,7 +207,19 @@ def resolve_domain_conflict( tied = [m for m in active_memberships if m.priority == top_priority] if len(tied) <= 1: return True - winner = min(tied, key=lambda m: m.experiment.created_at) + if subject_id: + ordered_tied = sorted( + tied, + key=lambda m: str(m.experiment_id), + ) + winner_idx = _subject_winner_index( + subject_id=subject_id, + domain_id=domain_id, + size=len(ordered_tied), + ) + winner = ordered_tied[winner_idx] + else: + winner = min(tied, key=lambda m: m.experiment.created_at) return str(winner.experiment_id) == str(experiment_id) return True diff --git a/src/backend/apps/conflicts/tests/test_conflicts.py b/src/backend/apps/conflicts/tests/test_conflicts.py index 21c0227..fa88632 100644 --- a/src/backend/apps/conflicts/tests/test_conflicts.py +++ b/src/backend/apps/conflicts/tests/test_conflicts.py @@ -285,7 +285,7 @@ class ResolveDomainConflictTest(TestCase): exp = experiment_approve(experiment=exp, approver=self.approver) return experiment_start(experiment=exp, user=self.experimenter) - def test_mutual_exclusion_winner_is_first(self) -> None: + def test_mutual_exclusion_deterministic_per_subject(self) -> None: domain = make_domain( suffix="_me", policy=ConflictPolicy.MUTUAL_EXCLUSION, @@ -293,10 +293,17 @@ class ResolveDomainConflictTest(TestCase): ) exp1 = self._make_and_start("_me1", domain) exp2 = self._make_and_start("_me2", domain) - winner = resolve_domain_conflict(exp1.pk, domain.pk, "u1") - self.assertTrue(winner) - loser = resolve_domain_conflict(exp2.pk, domain.pk, "u1") - self.assertFalse(loser) + exp1_u1 = resolve_domain_conflict(exp1.pk, domain.pk, "u1") + exp2_u1 = resolve_domain_conflict(exp2.pk, domain.pk, "u1") + self.assertNotEqual(exp1_u1, exp2_u1) + self.assertEqual( + exp1_u1, + resolve_domain_conflict(exp1.pk, domain.pk, "u1"), + ) + self.assertEqual( + exp2_u1, + resolve_domain_conflict(exp2.pk, domain.pk, "u1"), + ) def test_priority_higher_wins(self) -> None: domain = make_domain( @@ -309,7 +316,7 @@ class ResolveDomainConflictTest(TestCase): self.assertTrue(resolve_domain_conflict(exp_high.pk, domain.pk, "u1")) self.assertFalse(resolve_domain_conflict(exp_low.pk, domain.pk, "u1")) - def test_priority_tie_first_created_wins(self) -> None: + def test_priority_tie_deterministic_per_subject(self) -> None: domain = make_domain( suffix="_tie", policy=ConflictPolicy.PRIORITY, @@ -317,8 +324,17 @@ class ResolveDomainConflictTest(TestCase): ) exp1 = self._make_and_start("_tie1", domain, priority=5) exp2 = self._make_and_start("_tie2", domain, priority=5) - self.assertTrue(resolve_domain_conflict(exp1.pk, domain.pk, "u1")) - self.assertFalse(resolve_domain_conflict(exp2.pk, domain.pk, "u1")) + exp1_u1 = resolve_domain_conflict(exp1.pk, domain.pk, "u1") + exp2_u1 = resolve_domain_conflict(exp2.pk, domain.pk, "u1") + self.assertNotEqual(exp1_u1, exp2_u1) + self.assertEqual( + exp1_u1, + resolve_domain_conflict(exp1.pk, domain.pk, "u1"), + ) + self.assertEqual( + exp2_u1, + resolve_domain_conflict(exp2.pk, domain.pk, "u1"), + ) def test_single_experiment_always_wins(self) -> None: domain = make_domain(suffix="_single") diff --git a/src/backend/apps/decision/services.py b/src/backend/apps/decision/services.py index 1563ebf..bf19ba5 100644 --- a/src/backend/apps/decision/services.py +++ b/src/backend/apps/decision/services.py @@ -145,7 +145,10 @@ def _check_participation_limits( return not recent_completed -def _check_domain_conflicts(experiment: Experiment) -> bool: +def _check_domain_conflicts( + experiment: Experiment, + subject_id: str, +) -> bool: memberships = ExperimentConflictDomain.objects.filter( experiment=experiment, ).select_related("conflict_domain") @@ -154,7 +157,7 @@ def _check_domain_conflicts(experiment: Experiment) -> bool: if not resolve_domain_conflict( experiment_id=experiment.pk, domain_id=membership.conflict_domain_id, - subject_id="", + subject_id=subject_id, ): return False return True @@ -219,7 +222,7 @@ def decide_for_flag( _persist_decision(result, subject_id) return result - if not _check_domain_conflicts(experiment): + if not _check_domain_conflicts(experiment, subject_id): DECIDE_REQUESTS.labels(reason="domain_conflict").inc() result = { "flag": flag_key, diff --git a/src/backend/apps/flags/migrations/0002_alter_featureflag_key.py b/src/backend/apps/flags/migrations/0002_alter_featureflag_key.py new file mode 100644 index 0000000..642c80f --- /dev/null +++ b/src/backend/apps/flags/migrations/0002_alter_featureflag_key.py @@ -0,0 +1,19 @@ +# Generated by Django 5.2.11 on 2026-02-23 14:56 + +import django.core.validators +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('flags', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='featureflag', + name='key', + field=models.CharField(help_text='Unique identifier for the feature flag', max_length=100, unique=True, validators=[django.core.validators.RegexValidator(message='Event type name must follow snake_case, camelCase, or PascalCase.', regex='^[A-Za-z][A-Za-z0-9_]*$')], verbose_name='key'), + ), + ] diff --git a/src/backend/apps/notifications/migrations/0003_notificationrule_rate_limit_max_notifications_and_more.py b/src/backend/apps/notifications/migrations/0003_notificationrule_rate_limit_max_notifications_and_more.py new file mode 100644 index 0000000..ecc9b31 --- /dev/null +++ b/src/backend/apps/notifications/migrations/0003_notificationrule_rate_limit_max_notifications_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.11 on 2026-02-23 14:56 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('notifications', '0002_alter_notificationchannel_config'), + ] + + operations = [ + migrations.AddField( + model_name='notificationrule', + name='rate_limit_max_notifications', + field=models.PositiveIntegerField(default=1, verbose_name='rate limit max notifications'), + ), + migrations.AddField( + model_name='notificationrule', + name='rate_limit_window_seconds', + field=models.PositiveIntegerField(default=60, verbose_name='rate limit window seconds'), + ), + ] diff --git a/src/backend/apps/notifications/models.py b/src/backend/apps/notifications/models.py index 5f44f0d..bae46e3 100644 --- a/src/backend/apps/notifications/models.py +++ b/src/backend/apps/notifications/models.py @@ -93,6 +93,14 @@ class NotificationRule(BaseModel): default=True, verbose_name=_("is active"), ) + rate_limit_window_seconds = models.PositiveIntegerField( + default=60, + verbose_name=_("rate limit window seconds"), + ) + rate_limit_max_notifications = models.PositiveIntegerField( + default=1, + verbose_name=_("rate limit max notifications"), + ) created_at = models.DateTimeField( auto_now_add=True, verbose_name=_("created at"), diff --git a/src/backend/apps/notifications/services.py b/src/backend/apps/notifications/services.py index 760becd..d91f59d 100644 --- a/src/backend/apps/notifications/services.py +++ b/src/backend/apps/notifications/services.py @@ -82,11 +82,15 @@ def rule_create( event_type: str, channel: NotificationChannel, experiment: Any | None = None, + rate_limit_window_seconds: int = 60, + rate_limit_max_notifications: int = 1, ) -> NotificationRule: rule = NotificationRule( event_type=event_type, channel=channel, experiment=experiment, + rate_limit_window_seconds=rate_limit_window_seconds, + rate_limit_max_notifications=rate_limit_max_notifications, ) rule.save() return rule @@ -97,7 +101,12 @@ def rule_update( rule: NotificationRule, **fields: Any, ) -> NotificationRule: - allowed = {"event_type", "is_active"} + allowed = { + "event_type", + "is_active", + "rate_limit_window_seconds", + "rate_limit_max_notifications", + } for key in fields: if key not in allowed: raise ValueError(f"Field '{key}' cannot be updated.") @@ -149,12 +158,17 @@ def notification_enqueue( logs: list[NotificationLog] = [] for rule in rules: - event_key = _build_event_key(event_type, payload) - if NotificationLog.objects.filter( + event_key = _build_event_key( + event_type, + payload, + rule.rate_limit_window_seconds, + ) + sent_or_pending = NotificationLog.objects.filter( event_key=event_key, channel=rule.channel, status__in=[NotificationStatus.PENDING, NotificationStatus.SENT], - ).exists(): + ).count() + if sent_or_pending >= rule.rate_limit_max_notifications: continue log = NotificationLog.objects.create( @@ -176,8 +190,13 @@ def notification_enqueue( return logs -def _build_event_key(event_type: str, payload: NotificationPayload) -> str: - bucket = int(timezone.now().timestamp()) // 60 +def _build_event_key( + event_type: str, + payload: NotificationPayload, + window_seconds: int, +) -> str: + normalized_window = max(window_seconds, 1) + bucket = int(timezone.now().timestamp()) // normalized_window return f"{event_type}:{payload.experiment_id}:{bucket}" diff --git a/src/backend/apps/notifications/tests/test_notifications.py b/src/backend/apps/notifications/tests/test_notifications.py index c61146a..b0093ad 100644 --- a/src/backend/apps/notifications/tests/test_notifications.py +++ b/src/backend/apps/notifications/tests/test_notifications.py @@ -164,6 +164,33 @@ class NotificationEnqueueTest(TestCase): self.assertEqual(len(logs_1), 1) self.assertEqual(len(logs_2), 0) + def test_enqueue_respects_rule_rate_limit(self) -> None: + rule_create( + event_type=NotificationEventType.GUARDRAIL_TRIGGERED, + channel=self.channel, + rate_limit_window_seconds=60, + rate_limit_max_notifications=2, + ) + payload = NotificationPayload( + title="Alert", + body="Error rate exceeded", + event_type=NotificationEventType.GUARDRAIL_TRIGGERED, + experiment_id=str(self.experiment.pk), + experiment_name=self.experiment.name, + ) + logs_1 = notification_enqueue( + NotificationEventType.GUARDRAIL_TRIGGERED, payload + ) + logs_2 = notification_enqueue( + NotificationEventType.GUARDRAIL_TRIGGERED, payload + ) + logs_3 = notification_enqueue( + NotificationEventType.GUARDRAIL_TRIGGERED, payload + ) + self.assertEqual(len(logs_1), 1) + self.assertEqual(len(logs_2), 1) + self.assertEqual(len(logs_3), 0) + def test_enqueue_no_matching_rules(self) -> None: logs = notification_enqueue( NotificationEventType.EXPERIMENT_STARTED,