diff --git a/src/backend/apps/decision/__init__.py b/src/backend/apps/decision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/backend/apps/decision/apps.py b/src/backend/apps/decision/apps.py new file mode 100644 index 0000000..44011fa --- /dev/null +++ b/src/backend/apps/decision/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class DecisionConfig(AppConfig): + name = "apps.decision" diff --git a/src/backend/apps/decision/services.py b/src/backend/apps/decision/services.py new file mode 100644 index 0000000..d85cdd3 --- /dev/null +++ b/src/backend/apps/decision/services.py @@ -0,0 +1,256 @@ +import hashlib +import logging +import uuid +from datetime import timedelta +from decimal import Decimal + +from django.core.cache import cache +from django.utils import timezone +from prometheus_client import Counter + +from apps.events.models import Decision +from apps.events.services import decision_create +from apps.experiments.models import Experiment, ExperimentStatus, Variant +from apps.experiments.selectors import active_experiment_for_flag +from apps.flags.models import FeatureFlag +from apps.flags.selectors import feature_flag_get_by_key +from libs.dsl import evaluate +from libs.dsl.exceptions import EvaluationError, LexerError, ParserError + +logger = logging.getLogger("lotty") + +DECIDE_REQUESTS = Counter( + "lotty_decide_requests_total", + "Total number of flag decision requests", + ["reason"], +) + +FLAG_CACHE_TTL = 300 +EXPERIMENT_CACHE_TTL = 60 +MAX_CONCURRENT_EXPERIMENTS = 3 +COOLDOWN_DAYS = 7 + + +def _hash_subject(subject_id: str, experiment_id: str, salt: str) -> float: + hash_input = f"{subject_id}:{experiment_id}:{salt}".encode() + hash_bytes = hashlib.sha256(hash_input).digest() + hash_int = int.from_bytes(hash_bytes[:8], byteorder="big") + return (hash_int % 10000) / Decimal(100) + + +def _select_variant( + variants: list[Variant], hash_value: float +) -> Variant | None: + cumulative = Decimal(0) + for variant in sorted(variants, key=lambda v: v.name): + cumulative += variant.weight + if hash_value < cumulative: + return variant + return variants[-1] if variants else None + + +def _persist_decision(result: dict, subject_id: str) -> None: + decision_create( + decision_id=result["decision_id"], + flag_key=result["flag"], + subject_id=subject_id, + experiment_id=result.get("experiment_id"), + variant_id=result.get("variant_id"), + value=str(result["value"]) if result["value"] is not None else "", + reason=result["reason"], + ) + + +def _cached_flag_get(flag_key: str) -> FeatureFlag | None: + cache_key = f"flag:{flag_key}" + cached = cache.get(cache_key) + if cached is not None: + return cached if cached != "__none__" else None + flag = feature_flag_get_by_key(flag_key) + cache.set(cache_key, flag or "__none__", FLAG_CACHE_TTL) + return flag + + +def _cached_active_experiment(flag_pk): + cache_key = f"active_exp:{flag_pk}" + cached = cache.get(cache_key) + if cached is not None: + return cached if cached != "__none__" else None + experiment = active_experiment_for_flag(flag_pk) + cache.set( + cache_key, + experiment or "__none__", + EXPERIMENT_CACHE_TTL, + ) + return experiment + + +def _check_targeting( + targeting_rules: str, + subject_attributes: dict, +) -> bool: + if not targeting_rules or not targeting_rules.strip(): + return True + try: + return evaluate(targeting_rules, subject_attributes) + except (EvaluationError, LexerError, ParserError): + logger.warning( + "targeting_rules_evaluation_error", + extra={"rules": targeting_rules}, + ) + return False + + +def _check_participation_limits( + subject_id: str, + experiment_pk: object, +) -> bool: + active_count = ( + Decision.objects.filter( + subject_id=subject_id, + reason="experiment_assigned", + experiment_id__isnull=False, + ) + .exclude(experiment_id=experiment_pk) + .values("experiment_id") + .distinct() + .count() + ) + if active_count >= MAX_CONCURRENT_EXPERIMENTS: + return False + + cutoff = timezone.now() - timedelta(days=COOLDOWN_DAYS) + recent_completed = ( + Decision.objects.filter( + subject_id=subject_id, + reason="experiment_assigned", + experiment_id__isnull=False, + created_at__gte=cutoff, + ) + .filter( + experiment_id__in=Experiment.objects.filter( + status__in=( + ExperimentStatus.COMPLETED, + ExperimentStatus.ARCHIVED, + ), + ).values("pk"), + ) + .exclude(experiment_id=experiment_pk) + .values("experiment_id") + .distinct() + .exists() + ) + return not recent_completed + + +def decide_for_flag( + flag_key: str, + subject_id: str, + subject_attributes: dict, +) -> dict: + flag = _cached_flag_get(flag_key) + if not flag: + DECIDE_REQUESTS.labels(reason="flag_not_found").inc() + result = { + "flag": flag_key, + "value": None, + "decision_id": str(uuid.uuid4()), + "experiment_id": None, + "variant_id": None, + "reason": "flag_not_found", + } + _persist_decision(result, subject_id) + return result + + experiment = _cached_active_experiment(flag.pk) + if not experiment or experiment.status != ExperimentStatus.RUNNING: + DECIDE_REQUESTS.labels(reason="no_active_experiment").inc() + result = { + "flag": flag_key, + "value": flag.default_value, + "decision_id": str(uuid.uuid4()), + "experiment_id": None, + "variant_id": None, + "reason": "no_active_experiment", + } + _persist_decision(result, subject_id) + return result + + if not _check_targeting(experiment.targeting_rules, subject_attributes): + DECIDE_REQUESTS.labels(reason="targeting_mismatch").inc() + result = { + "flag": flag_key, + "value": flag.default_value, + "decision_id": str(uuid.uuid4()), + "experiment_id": str(experiment.pk), + "variant_id": None, + "reason": "targeting_mismatch", + } + _persist_decision(result, subject_id) + return result + + if not _check_participation_limits(subject_id, experiment.pk): + DECIDE_REQUESTS.labels(reason="participation_limit").inc() + result = { + "flag": flag_key, + "value": flag.default_value, + "decision_id": str(uuid.uuid4()), + "experiment_id": str(experiment.pk), + "variant_id": None, + "reason": "participation_limit", + } + _persist_decision(result, subject_id) + return result + + allocation_hash = _hash_subject( + subject_id, + str(experiment.pk), + "allocation", + ) + if allocation_hash >= experiment.traffic_allocation: + DECIDE_REQUESTS.labels(reason="outside_traffic_allocation").inc() + result = { + "flag": flag_key, + "value": flag.default_value, + "decision_id": str(uuid.uuid4()), + "experiment_id": str(experiment.pk), + "variant_id": None, + "reason": "outside_traffic_allocation", + } + _persist_decision(result, subject_id) + return result + + variants = list(experiment.variants.all()) + if not variants: + DECIDE_REQUESTS.labels(reason="no_variants").inc() + result = { + "flag": flag_key, + "value": flag.default_value, + "decision_id": str(uuid.uuid4()), + "experiment_id": str(experiment.pk), + "variant_id": None, + "reason": "no_variants", + } + _persist_decision(result, subject_id) + return result + + variant_hash = _hash_subject( + subject_id, + str(experiment.pk), + "variant", + ) + total_weight = sum(v.weight for v in variants) + normalized_hash = variant_hash * total_weight / Decimal("100") + selected = _select_variant(variants, normalized_hash) + + DECIDE_REQUESTS.labels(reason="experiment_assigned").inc() + result = { + "flag": flag_key, + "value": selected.value if selected else flag.default_value, + "decision_id": str(uuid.uuid4()), + "experiment_id": str(experiment.pk), + "variant_id": str(selected.pk) if selected else None, + "reason": "experiment_assigned", + } + _persist_decision(result, subject_id) + return result diff --git a/src/backend/apps/decision/tests/__init__.py b/src/backend/apps/decision/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/backend/apps/decision/tests/test_decide.py b/src/backend/apps/decision/tests/test_decide.py new file mode 100644 index 0000000..0272d5c --- /dev/null +++ b/src/backend/apps/decision/tests/test_decide.py @@ -0,0 +1,362 @@ +from decimal import Decimal +from typing import override + +from django.core.cache import cache +from django.test import TestCase + +from apps.decision.services import ( + MAX_CONCURRENT_EXPERIMENTS, + _hash_subject, + _select_variant, + decide_for_flag, +) +from apps.experiments.models import Experiment, ExperimentStatus +from apps.experiments.services import variant_create +from apps.experiments.tests.helpers import make_experiment, make_flag +from apps.users.tests.helpers import make_user + + +class HashSubjectTest(TestCase): + def test_deterministic(self) -> None: + h1 = _hash_subject("user1", "exp1", "salt") + h2 = _hash_subject("user1", "exp1", "salt") + self.assertEqual(h1, h2) + + def test_different_inputs_differ(self) -> None: + h1 = _hash_subject("user1", "exp1", "salt") + h2 = _hash_subject("user2", "exp1", "salt") + self.assertNotEqual(h1, h2) + + def test_range_0_to_100(self) -> None: + for i in range(100): + h = _hash_subject(f"u{i}", "exp", "s") + self.assertGreaterEqual(h, 0) + self.assertLess(h, 100) + + +class SelectVariantTest(TestCase): + def test_selects_by_weight(self) -> None: + class FV: + def __init__(self, n, w): + self.name = n + self.weight = Decimal(str(w)) + + variants = [FV("a", 50), FV("b", 50)] + self.assertEqual(_select_variant(variants, Decimal(10)).name, "a") + self.assertEqual(_select_variant(variants, Decimal(60)).name, "b") + + def test_empty_list_returns_none(self) -> None: + self.assertIsNone(_select_variant([], Decimal(50))) + + +class DecideForFlagTest(TestCase): + @override + def setUp(self) -> None: + cache.clear() + self.flag = make_flag(suffix="_dec", default="default_val") + self.owner = make_user( + username="dec_owner", email="dec_owner@lotty.local" + ) + + def test_flag_not_found(self) -> None: + result = decide_for_flag("nonexistent", "u1", {}) + self.assertEqual(result["reason"], "flag_not_found") + self.assertIsNone(result["value"]) + + def test_no_active_experiment(self) -> None: + result = decide_for_flag(self.flag.key, "u1", {}) + self.assertEqual(result["reason"], "no_active_experiment") + self.assertEqual(result["value"], "default_val") + + def test_running_experiment_assigns_variant(self) -> None: + exp = make_experiment( + flag=self.flag, + owner=self.owner, + suffix="_da", + traffic_allocation=Decimal("100.00"), + ) + variant_create( + experiment=exp, + name="control", + value="ctrl", + weight=Decimal("50.00"), + is_control=True, + ) + variant_create( + experiment=exp, + name="treatment", + value="treat", + weight=Decimal("50.00"), + ) + Experiment.objects.filter(pk=exp.pk).update( + status=ExperimentStatus.RUNNING + ) + result = decide_for_flag(self.flag.key, "user123", {}) + self.assertEqual(result["reason"], "experiment_assigned") + self.assertIn(result["value"], ("ctrl", "treat")) + self.assertIsNotNone(result["variant_id"]) + self.assertEqual(result["experiment_id"], str(exp.pk)) + + def test_paused_experiment_returns_default(self) -> None: + exp = make_experiment( + flag=self.flag, + owner=self.owner, + suffix="_dp", + ) + Experiment.objects.filter(pk=exp.pk).update( + status=ExperimentStatus.PAUSED + ) + result = decide_for_flag(self.flag.key, "u1", {}) + self.assertEqual(result["reason"], "no_active_experiment") + + def test_deterministic_assignment(self) -> None: + exp = make_experiment( + flag=self.flag, + owner=self.owner, + suffix="_dd", + traffic_allocation=Decimal("100.00"), + ) + variant_create( + experiment=exp, + name="control", + value="ctrl", + weight=Decimal("50.00"), + is_control=True, + ) + variant_create( + experiment=exp, + name="treatment", + value="treat", + weight=Decimal("50.00"), + ) + Experiment.objects.filter(pk=exp.pk).update( + status=ExperimentStatus.RUNNING + ) + r1 = decide_for_flag(self.flag.key, "stable_user", {}) + r2 = decide_for_flag(self.flag.key, "stable_user", {}) + self.assertEqual(r1["value"], r2["value"]) + self.assertEqual(r1["variant_id"], r2["variant_id"]) + + +class TargetingRulesTest(TestCase): + @override + def setUp(self) -> None: + cache.clear() + self.flag = make_flag(suffix="_tgt", default="default_val") + self.owner = make_user( + username="tgt_owner", email="tgt_owner@lotty.local" + ) + + def _make_running_experiment(self, suffix, targeting_rules=""): + exp = make_experiment( + flag=self.flag, + owner=self.owner, + suffix=suffix, + traffic_allocation=Decimal("100.00"), + targeting_rules=targeting_rules, + ) + variant_create( + experiment=exp, + name="control", + value="ctrl", + weight=Decimal("50.00"), + is_control=True, + ) + variant_create( + experiment=exp, + name="treatment", + value="treat", + weight=Decimal("50.00"), + ) + Experiment.objects.filter(pk=exp.pk).update( + status=ExperimentStatus.RUNNING, + ) + return exp + + def test_targeting_pass_assigns_variant(self) -> None: + self._make_running_experiment( + "_tp", + targeting_rules='country == "US"', + ) + result = decide_for_flag( + self.flag.key, + "user1", + {"country": "US"}, + ) + self.assertEqual(result["reason"], "experiment_assigned") + self.assertIn(result["value"], ("ctrl", "treat")) + + def test_targeting_fail_returns_default(self) -> None: + self._make_running_experiment( + "_tf", + targeting_rules='country == "US"', + ) + result = decide_for_flag( + self.flag.key, + "user1", + {"country": "DE"}, + ) + self.assertEqual(result["reason"], "targeting_mismatch") + self.assertEqual(result["value"], "default_val") + + def test_empty_targeting_rules_passes(self) -> None: + self._make_running_experiment("_te", targeting_rules="") + result = decide_for_flag(self.flag.key, "user1", {}) + self.assertNotEqual(result["reason"], "targeting_mismatch") + + def test_complex_targeting_and_condition(self) -> None: + self._make_running_experiment( + "_tc", + targeting_rules='country == "US" AND age >= 18', + ) + result_pass = decide_for_flag( + self.flag.key, + "user2", + {"country": "US", "age": 25}, + ) + self.assertEqual(result_pass["reason"], "experiment_assigned") + + cache.clear() + result_fail = decide_for_flag( + self.flag.key, + "user3", + {"country": "US", "age": 15}, + ) + self.assertEqual(result_fail["reason"], "targeting_mismatch") + + def test_invalid_targeting_rules_returns_default(self) -> None: + self._make_running_experiment( + "_ti", + targeting_rules="??? invalid syntax", + ) + result = decide_for_flag(self.flag.key, "user1", {}) + self.assertEqual(result["reason"], "targeting_mismatch") + self.assertEqual(result["value"], "default_val") + + def test_missing_attribute_fails_targeting(self) -> None: + self._make_running_experiment( + "_tm", + targeting_rules='plan == "premium"', + ) + result = decide_for_flag(self.flag.key, "user1", {}) + self.assertEqual(result["reason"], "targeting_mismatch") + + +class ParticipationLimitsTest(TestCase): + @override + def setUp(self) -> None: + cache.clear() + self.owner = make_user( + username="part_owner", email="part_owner@lotty.local" + ) + + def _make_running_experiment_with_variants(self, suffix): + flag = make_flag(suffix=suffix, default="default_val") + exp = make_experiment( + flag=flag, + owner=self.owner, + suffix=suffix, + traffic_allocation=Decimal("100.00"), + ) + variant_create( + experiment=exp, + name="control", + value="ctrl", + weight=Decimal("50.00"), + is_control=True, + ) + variant_create( + experiment=exp, + name="treatment", + value="treat", + weight=Decimal("50.00"), + ) + Experiment.objects.filter(pk=exp.pk).update( + status=ExperimentStatus.RUNNING, + ) + return flag, exp + + def test_within_limits_assigns_variant(self) -> None: + flag, _exp = self._make_running_experiment_with_variants("_pl1") + result = decide_for_flag(flag.key, "user_new", {}) + self.assertEqual(result["reason"], "experiment_assigned") + + def test_exceeds_concurrent_limit_returns_default(self) -> None: + subject_id = "user_busy" + for i in range(MAX_CONCURRENT_EXPERIMENTS): + flag_i, _exp_i = self._make_running_experiment_with_variants( + f"_plx{i}", + ) + cache.clear() + decide_for_flag(flag_i.key, subject_id, {}) + + cache.clear() + flag_extra, _exp_extra = self._make_running_experiment_with_variants( + "_plxtra", + ) + result = decide_for_flag(flag_extra.key, subject_id, {}) + self.assertEqual(result["reason"], "participation_limit") + self.assertEqual(result["value"], "default_val") + + def test_cooldown_after_completed_experiment(self) -> None: + flag1, exp1 = self._make_running_experiment_with_variants("_plc1") + subject_id = "user_cool" + decide_for_flag(flag1.key, subject_id, {}) + Experiment.objects.filter(pk=exp1.pk).update( + status=ExperimentStatus.COMPLETED, + ) + + cache.clear() + flag2, _exp2 = self._make_running_experiment_with_variants("_plc2") + result = decide_for_flag(flag2.key, subject_id, {}) + self.assertEqual(result["reason"], "participation_limit") + self.assertEqual(result["value"], "default_val") + + +class PartialTrafficVariantDistributionTest(TestCase): + @override + def setUp(self) -> None: + cache.clear() + self.owner = make_user( + username="pt_owner", email="pt_owner@lotty.local" + ) + + def test_variant_distribution_with_partial_traffic(self) -> None: + flag = make_flag(suffix="_pt", default="default_val") + exp = make_experiment( + flag=flag, + owner=self.owner, + suffix="_pt", + traffic_allocation=Decimal("20.00"), + ) + variant_create( + experiment=exp, + name="control", + value="ctrl", + weight=Decimal("10.00"), + is_control=True, + ) + variant_create( + experiment=exp, + name="treatment", + value="treat", + weight=Decimal("10.00"), + ) + Experiment.objects.filter(pk=exp.pk).update( + status=ExperimentStatus.RUNNING, + ) + + counts: dict[str, int] = {"ctrl": 0, "treat": 0} + assigned = 0 + for i in range(2000): + cache.clear() + result = decide_for_flag(flag.key, f"user_{i}", {}) + if result["reason"] == "experiment_assigned": + assigned += 1 + counts[result["value"]] += 1 + + self.assertGreater(assigned, 0) + total = counts["ctrl"] + counts["treat"] + ratio = counts["ctrl"] / total + self.assertGreater(ratio, 0.3) + self.assertLess(ratio, 0.7)