refactor(); project refactor

This commit is contained in:
ITQ
2026-02-23 11:46:52 +03:00
parent 85923f11fc
commit ca0c456862
16 changed files with 198 additions and 194 deletions
+9 -3
View File
@@ -4,7 +4,11 @@ from uuid import UUID
from ninja import Field, ModelSchema, Schema from ninja import Field, ModelSchema, Schema
from apps.conflicts.models import ConflictDomain, ConflictPolicy, ExperimentConflictDomain from apps.conflicts.models import (
ConflictDomain,
ConflictPolicy,
ExperimentConflictDomain,
)
class ConflictDomainOut(ModelSchema): class ConflictDomainOut(ModelSchema):
@@ -54,7 +58,8 @@ class MembershipOut(Schema):
@classmethod @classmethod
def from_membership( def from_membership(
cls, m: ExperimentConflictDomain, cls,
m: ExperimentConflictDomain,
) -> "MembershipOut": ) -> "MembershipOut":
return cls( return cls(
id=m.pk, id=m.pk,
@@ -76,7 +81,8 @@ class ExperimentDomainOut(Schema):
@classmethod @classmethod
def from_membership( def from_membership(
cls, m: ExperimentConflictDomain, cls,
m: ExperimentConflictDomain,
) -> "ExperimentDomainOut": ) -> "ExperimentDomainOut":
return cls( return cls(
id=m.pk, id=m.pk,
@@ -1,13 +1,18 @@
import json import json
import uuid
from typing import override from typing import override
from django.test import Client, TestCase from django.test import Client, TestCase
from django.urls import reverse from django.urls import reverse
from apps.conflicts.tests.helpers import make_domain from apps.conflicts.tests.helpers import make_domain
from apps.experiments.tests.helpers import add_two_variants, make_flag
from apps.experiments.services import experiment_create from apps.experiments.services import experiment_create
from apps.reviews.tests.helpers import make_admin, make_experimenter, make_viewer from apps.experiments.tests.helpers import add_two_variants, make_flag
from apps.reviews.tests.helpers import (
make_admin,
make_experimenter,
make_viewer,
)
from apps.users.tests.helpers import auth_header from apps.users.tests.helpers import auth_header
@@ -23,12 +28,14 @@ class ConflictDomainAPITest(TestCase):
def test_create_domain(self) -> None: def test_create_domain(self) -> None:
resp = self.client.post( resp = self.client.post(
reverse("api-1:create_domain"), reverse("api-1:create_domain"),
data=json.dumps({ data=json.dumps(
"name": "checkout", {
"description": "Checkout zone", "name": "checkout",
"policy": "mutual_exclusion", "description": "Checkout zone",
"max_concurrent": 1, "policy": "mutual_exclusion",
}), "max_concurrent": 1,
}
),
content_type="application/json", content_type="application/json",
HTTP_AUTHORIZATION=self.admin_auth, HTTP_AUTHORIZATION=self.admin_auth,
) )
@@ -86,8 +93,6 @@ class ConflictDomainAPITest(TestCase):
self.assertEqual(resp.status_code, 204) self.assertEqual(resp.status_code, 204)
def test_get_nonexistent_domain(self) -> None: def test_get_nonexistent_domain(self) -> None:
import uuid
resp = self.client.get( resp = self.client.get(
reverse("api-1:get_domain", args=[uuid.uuid4()]), reverse("api-1:get_domain", args=[uuid.uuid4()]),
HTTP_AUTHORIZATION=self.admin_auth, HTTP_AUTHORIZATION=self.admin_auth,
@@ -114,10 +119,12 @@ class DomainExperimentAPITest(TestCase):
def test_add_experiment_to_domain(self) -> None: def test_add_experiment_to_domain(self) -> None:
resp = self.client.post( resp = self.client.post(
reverse("api-1:add_experiment_to_domain", args=[self.domain.pk]), reverse("api-1:add_experiment_to_domain", args=[self.domain.pk]),
data=json.dumps({ data=json.dumps(
"experiment_id": str(self.exp.pk), {
"priority": 5, "experiment_id": str(self.exp.pk),
}), "priority": 5,
}
),
content_type="application/json", content_type="application/json",
HTTP_AUTHORIZATION=self.admin_auth, HTTP_AUTHORIZATION=self.admin_auth,
) )
@@ -134,9 +141,7 @@ class DomainExperimentAPITest(TestCase):
HTTP_AUTHORIZATION=self.admin_auth, HTTP_AUTHORIZATION=self.admin_auth,
) )
resp = self.client.get( resp = self.client.get(
reverse( reverse("api-1:list_experiment_domains", args=[self.exp.pk]),
"api-1:list_experiment_domains", args=[self.exp.pk]
),
HTTP_AUTHORIZATION=self.admin_auth, HTTP_AUTHORIZATION=self.admin_auth,
) )
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
+3 -1
View File
@@ -214,7 +214,9 @@ def update_variant(
) -> tuple[HTTPStatus, VariantOut]: ) -> tuple[HTTPStatus, VariantOut]:
user = _get_user(request) user = _get_user(request)
v = get_object_or_404( v = get_object_or_404(
Variant.objects.select_related("experiment__flag", "experiment__owner"), Variant.objects.select_related(
"experiment__flag", "experiment__owner"
),
pk=variant_id, pk=variant_id,
experiment_id=experiment_id, experiment_id=experiment_id,
) )
+23 -19
View File
@@ -36,11 +36,13 @@ def create_learning(
payload: LearningCreateIn, payload: LearningCreateIn,
) -> tuple[int, LearningOut]: ) -> tuple[int, LearningOut]:
try: try:
experiment = Experiment.objects.select_related("flag").get(pk=payload.experiment_id) experiment = Experiment.objects.select_related("flag").get(
pk=payload.experiment_id
)
except Experiment.DoesNotExist: except Experiment.DoesNotExist:
raise Http404 from Experiment.DoesNotExist raise Http404 from Experiment.DoesNotExist
l = learning_create( learning = learning_create(
experiment=experiment, experiment=experiment,
hypothesis=payload.hypothesis, hypothesis=payload.hypothesis,
findings=payload.findings, findings=payload.findings,
@@ -48,8 +50,8 @@ def create_learning(
context_summary=payload.context_summary, context_summary=payload.context_summary,
user=request.auth, user=request.auth,
) )
l = learning_get(l.pk) learning = learning_get(learning.pk)
return HTTPStatus.CREATED, LearningOut.from_learning(l) return HTTPStatus.CREATED, LearningOut.from_learning(learning)
@router.get( @router.get(
@@ -72,7 +74,9 @@ def list_learnings(
tag=tag, tag=tag,
search=search, search=search,
) )
return HTTPStatus.OK, [LearningOut.from_learning(l) for l in learnings] return HTTPStatus.OK, [
LearningOut.from_learning(learning) for learning in learnings
]
@router.get( @router.get(
@@ -84,10 +88,10 @@ def get_learning(
request: HttpRequest, request: HttpRequest,
learning_id: UUID, learning_id: UUID,
) -> tuple[int, LearningOut]: ) -> tuple[int, LearningOut]:
l = learning_get(learning_id) learning = learning_get(learning_id)
if not l: if not learning:
raise Http404 raise Http404
return HTTPStatus.OK, LearningOut.from_learning(l) return HTTPStatus.OK, LearningOut.from_learning(learning)
@router.patch( @router.patch(
@@ -100,16 +104,16 @@ def update_learning(
learning_id: UUID, learning_id: UUID,
payload: LearningUpdateIn, payload: LearningUpdateIn,
) -> tuple[int, LearningOut]: ) -> tuple[int, LearningOut]:
l = learning_get(learning_id) learning = learning_get(learning_id)
if not l: if not learning:
raise Http404 raise Http404
l = learning_update( learning = learning_update(
learning=l, learning=learning,
user=request.auth, user=request.auth,
**payload.dict(exclude_unset=True), **payload.dict(exclude_unset=True),
) )
l = learning_get(l.pk) learning = learning_get(learning.pk)
return HTTPStatus.OK, LearningOut.from_learning(l) return HTTPStatus.OK, LearningOut.from_learning(learning)
@router.delete( @router.delete(
@@ -121,10 +125,10 @@ def delete_learning(
request: HttpRequest, request: HttpRequest,
learning_id: UUID, learning_id: UUID,
) -> tuple[int, None]: ) -> tuple[int, None]:
l = learning_get(learning_id) learning = learning_get(learning_id)
if not l: if not learning:
raise Http404 raise Http404
learning_delete(learning=l) learning_delete(learning=learning)
return HTTPStatus.NO_CONTENT, None return HTTPStatus.NO_CONTENT, None
@@ -137,8 +141,8 @@ def list_learning_edits(
request: HttpRequest, request: HttpRequest,
learning_id: UUID, learning_id: UUID,
) -> tuple[int, list[EditOut]]: ) -> tuple[int, list[EditOut]]:
l = learning_get(learning_id) learning = learning_get(learning_id)
if not l: if not learning:
raise Http404 raise Http404
edits = learning_edit_list(learning_id) edits = learning_edit_list(learning_id)
return HTTPStatus.OK, [EditOut.from_edit(e) for e in edits] return HTTPStatus.OK, [EditOut.from_edit(e) for e in edits]
+15 -16
View File
@@ -1,4 +1,3 @@
from datetime import datetime
from typing import Any, ClassVar from typing import Any, ClassVar
from uuid import UUID from uuid import UUID
@@ -51,26 +50,26 @@ class LearningOut(ModelSchema):
) )
@classmethod @classmethod
def from_learning(cls, l: Learning) -> "LearningOut": def from_learning(cls, learning: Learning) -> "LearningOut":
created_by_out = None created_by_out = None
if l.created_by: if learning.created_by:
created_by_out = CreatedByOut( created_by_out = CreatedByOut(
id=l.created_by.pk, id=learning.created_by.pk,
username=l.created_by.username, username=learning.created_by.username,
) )
return cls( return cls(
id=l.pk, id=learning.pk,
hypothesis=l.hypothesis, hypothesis=learning.hypothesis,
findings=l.findings, findings=learning.findings,
tags=l.tags, tags=learning.tags,
context_summary=l.context_summary, context_summary=learning.context_summary,
created_at=l.created_at, created_at=learning.created_at,
updated_at=l.updated_at, updated_at=learning.updated_at,
experiment=ExperimentBriefOut( experiment=ExperimentBriefOut(
id=l.experiment.pk, id=learning.experiment.pk,
name=l.experiment.name, name=learning.experiment.name,
status=l.experiment.status, status=learning.experiment.status,
flag_key=l.experiment.flag.key, flag_key=learning.experiment.flag.key,
), ),
created_by=created_by_out, created_by=created_by_out,
) )
+6 -4
View File
@@ -133,9 +133,7 @@ def validate_domain_conflicts(experiment: Experiment) -> None:
active_count = active.count() active_count = active.count()
if active_count >= domain.max_concurrent: if active_count >= domain.max_concurrent:
active_names = ", ".join( active_names = ", ".join(m.experiment.name for m in active[:3])
m.experiment.name for m in active[:3]
)
raise ValidationError( raise ValidationError(
{ {
"conflict_domain": ( "conflict_domain": (
@@ -174,7 +172,11 @@ def resolve_domain_conflict(
if domain.policy == ConflictPolicy.PRIORITY: if domain.policy == ConflictPolicy.PRIORITY:
current = next( current = next(
(m for m in active_memberships if str(m.experiment_id) == str(experiment_id)), (
m
for m in active_memberships
if str(m.experiment_id) == str(experiment_id)
),
None, None,
) )
if not current: if not current:
@@ -15,14 +15,12 @@ from apps.conflicts.selectors import (
experiment_conflict_domains, experiment_conflict_domains,
) )
from apps.conflicts.services import ( from apps.conflicts.services import (
conflict_domain_create,
conflict_domain_delete, conflict_domain_delete,
conflict_domain_update, conflict_domain_update,
experiment_add_to_domain, experiment_add_to_domain,
experiment_remove_from_domain, experiment_remove_from_domain,
experiment_update_domain_priority, experiment_update_domain_priority,
resolve_domain_conflict, resolve_domain_conflict,
validate_domain_conflicts,
) )
from apps.conflicts.tests.helpers import make_domain from apps.conflicts.tests.helpers import make_domain
from apps.experiments.models import ExperimentStatus from apps.experiments.models import ExperimentStatus
@@ -35,6 +33,7 @@ from apps.experiments.services import (
from apps.experiments.tests.helpers import add_two_variants, make_flag from apps.experiments.tests.helpers import add_two_variants, make_flag
from apps.reviews.services import review_settings_update from apps.reviews.services import review_settings_update
from apps.reviews.tests.helpers import make_approver, make_experimenter from apps.reviews.tests.helpers import make_approver, make_experimenter
from config.errors import ConflictError
class ConflictDomainCRUDTest(TestCase): class ConflictDomainCRUDTest(TestCase):
@@ -55,7 +54,7 @@ class ConflictDomainCRUDTest(TestCase):
def test_duplicate_name_fails(self) -> None: def test_duplicate_name_fails(self) -> None:
make_domain(suffix="_dup") make_domain(suffix="_dup")
with self.assertRaises(Exception): with self.assertRaises(ConflictError):
make_domain(suffix="_dup") make_domain(suffix="_dup")
def test_updates_domain(self) -> None: def test_updates_domain(self) -> None:
@@ -128,7 +127,7 @@ class ExperimentDomainMembershipTest(TestCase):
def test_duplicate_membership_fails(self) -> None: def test_duplicate_membership_fails(self) -> None:
exp = self._make_ready_experiment("_dup2") exp = self._make_ready_experiment("_dup2")
experiment_add_to_domain(experiment=exp, domain=self.domain) experiment_add_to_domain(experiment=exp, domain=self.domain)
with self.assertRaises(Exception): with self.assertRaises(ConflictError):
experiment_add_to_domain(experiment=exp, domain=self.domain) experiment_add_to_domain(experiment=exp, domain=self.domain)
def test_remove_experiment_from_domain(self) -> None: def test_remove_experiment_from_domain(self) -> None:
@@ -307,12 +306,8 @@ class ResolveDomainConflictTest(TestCase):
) )
exp_low = self._make_and_start("_pr1", domain, priority=1) exp_low = self._make_and_start("_pr1", domain, priority=1)
exp_high = self._make_and_start("_pr2", domain, priority=10) exp_high = self._make_and_start("_pr2", domain, priority=10)
self.assertTrue( self.assertTrue(resolve_domain_conflict(exp_high.pk, domain.pk, "u1"))
resolve_domain_conflict(exp_high.pk, domain.pk, "u1") self.assertFalse(resolve_domain_conflict(exp_low.pk, domain.pk, "u1"))
)
self.assertFalse(
resolve_domain_conflict(exp_low.pk, domain.pk, "u1")
)
def test_priority_tie_first_created_wins(self) -> None: def test_priority_tie_first_created_wins(self) -> None:
domain = make_domain( domain = make_domain(
@@ -322,16 +317,10 @@ class ResolveDomainConflictTest(TestCase):
) )
exp1 = self._make_and_start("_tie1", domain, priority=5) exp1 = self._make_and_start("_tie1", domain, priority=5)
exp2 = self._make_and_start("_tie2", domain, priority=5) exp2 = self._make_and_start("_tie2", domain, priority=5)
self.assertTrue( self.assertTrue(resolve_domain_conflict(exp1.pk, domain.pk, "u1"))
resolve_domain_conflict(exp1.pk, domain.pk, "u1") self.assertFalse(resolve_domain_conflict(exp2.pk, domain.pk, "u1"))
)
self.assertFalse(
resolve_domain_conflict(exp2.pk, domain.pk, "u1")
)
def test_single_experiment_always_wins(self) -> None: def test_single_experiment_always_wins(self) -> None:
domain = make_domain(suffix="_single") domain = make_domain(suffix="_single")
exp = self._make_and_start("_s", domain) exp = self._make_and_start("_s", domain)
self.assertTrue( self.assertTrue(resolve_domain_conflict(exp.pk, domain.pk, "u1"))
resolve_domain_conflict(exp.pk, domain.pk, "u1")
)
+2 -3
View File
@@ -8,6 +8,7 @@ from django.core.cache import cache
from django.utils import timezone from django.utils import timezone
from prometheus_client import Counter from prometheus_client import Counter
from apps.conflicts.models import ExperimentConflictDomain
from apps.conflicts.services import resolve_domain_conflict from apps.conflicts.services import resolve_domain_conflict
from apps.events.models import Decision from apps.events.models import Decision
from apps.events.services import decision_create from apps.events.services import decision_create
@@ -145,8 +146,6 @@ def _check_participation_limits(
def _check_domain_conflicts(experiment: Experiment) -> bool: def _check_domain_conflicts(experiment: Experiment) -> bool:
from apps.conflicts.models import ExperimentConflictDomain
memberships = ExperimentConflictDomain.objects.filter( memberships = ExperimentConflictDomain.objects.filter(
experiment=experiment, experiment=experiment,
).select_related("conflict_domain") ).select_related("conflict_domain")
@@ -271,7 +270,7 @@ def decide_for_flag(
"variant", "variant",
) )
total_weight = sum(v.weight for v in variants) total_weight = sum(v.weight for v in variants)
normalized_hash = variant_hash * total_weight / Decimal("100") normalized_hash = variant_hash * total_weight / Decimal(100)
selected = _select_variant(variants, normalized_hash) selected = _select_variant(variants, normalized_hash)
DECIDE_REQUESTS.labels(reason="experiment_assigned").inc() DECIDE_REQUESTS.labels(reason="experiment_assigned").inc()
+10 -3
View File
@@ -4,6 +4,7 @@ from typing import Any
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import transaction from django.db import transaction
from apps.conflicts.services import validate_domain_conflicts
from apps.experiments.models import ( from apps.experiments.models import (
ACTIVE_STATUSES, ACTIVE_STATUSES,
ALLOWED_TRANSITIONS, ALLOWED_TRANSITIONS,
@@ -16,7 +17,6 @@ from apps.experiments.models import (
OutcomeType, OutcomeType,
Variant, Variant,
) )
from apps.conflicts.services import validate_domain_conflicts
from apps.flags.models import FeatureFlag, validate_value_for_type from apps.flags.models import FeatureFlag, validate_value_for_type
from apps.notifications.services import ( from apps.notifications.services import (
NotificationPayload, NotificationPayload,
@@ -30,12 +30,19 @@ from apps.users.models import User
from config.errors import ForbiddenError from config.errors import ForbiddenError
def _notify(event_type: str, experiment: Experiment, extra: dict[str, Any] | None = None) -> None: def _notify(
event_type: str,
experiment: Experiment,
extra: dict[str, Any] | None = None,
) -> None:
notification_enqueue( notification_enqueue(
event_type, event_type,
NotificationPayload( NotificationPayload(
title=f"{event_type.replace('_', ' ').title()}", title=f"{event_type.replace('_', ' ').title()}",
body=f"Experiment '{experiment.name}'{event_type.replace('_', ' ')}.", body=(
f"Experiment '{experiment.name}'"
f"{event_type.replace('_', ' ')}."
),
event_type=event_type, event_type=event_type,
experiment_id=str(experiment.pk), experiment_id=str(experiment.pk),
experiment_name=experiment.name, experiment_name=experiment.name,
@@ -14,7 +14,6 @@ from apps.experiments.models import (
OutcomeType, OutcomeType,
) )
from apps.experiments.services import ( from apps.experiments.services import (
ensure_owner_or_admin,
experiment_approve, experiment_approve,
experiment_archive, experiment_archive,
experiment_complete, experiment_complete,
@@ -530,9 +529,7 @@ class OwnershipPermissionTest(TestCase):
def test_other_experimenter_cannot_submit_for_review(self) -> None: def test_other_experimenter_cannot_submit_for_review(self) -> None:
with self.assertRaises(ForbiddenError): with self.assertRaises(ForbiddenError):
experiment_submit_for_review( experiment_submit_for_review(experiment=self.exp, user=self.other)
experiment=self.exp, user=self.other
)
def test_admin_can_submit_for_review(self) -> None: def test_admin_can_submit_for_review(self) -> None:
exp = experiment_submit_for_review( exp = experiment_submit_for_review(
+3 -1
View File
@@ -59,7 +59,9 @@ class Learning(BaseModel):
ordering = ["-created_at"] ordering = ["-created_at"]
indexes = [ indexes = [
GinIndex(fields=["search_vector"], name="idx_learning_search"), GinIndex(fields=["search_vector"], name="idx_learning_search"),
models.Index(fields=["experiment"], name="idx_learning_experiment"), models.Index(
fields=["experiment"], name="idx_learning_experiment"
),
] ]
@override @override
+49 -29
View File
@@ -1,11 +1,20 @@
import logging import logging
import operator
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
from django.contrib.postgres.search import (
SearchQuery,
SearchRank,
SearchVector,
)
from django.db import connection, transaction from django.db import connection, transaction
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from apps.experiments.models import Experiment, ExperimentOutcome, ExperimentStatus from apps.experiments.models import (
Experiment,
ExperimentOutcome,
)
from apps.guardrails.models import GuardrailTrigger from apps.guardrails.models import GuardrailTrigger
from apps.learnings.models import Learning, LearningEdit from apps.learnings.models import Learning, LearningEdit
from apps.metrics.models import ExperimentMetric from apps.metrics.models import ExperimentMetric
@@ -111,11 +120,12 @@ def learning_list(
qs = qs.filter(tags__icontains=tag) qs = qs.filter(tags__icontains=tag)
if search is not None: if search is not None:
if _is_postgres(): if _is_postgres():
from django.contrib.postgres.search import SearchQuery, SearchRank
query = SearchQuery(search, config="english") query = SearchQuery(search, config="english")
qs = qs.filter(search_vector=query).annotate( qs = (
rank=SearchRank("search_vector", query) qs.filter(search_vector=query)
).order_by("-rank") .annotate(rank=SearchRank("search_vector", query))
.order_by("-rank")
)
else: else:
qs = qs.filter( qs = qs.filter(
Q(hypothesis__icontains=search) Q(hypothesis__icontains=search)
@@ -138,7 +148,9 @@ def find_similar_learnings(
limit: int = 5, limit: int = 5,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
try: try:
experiment = Experiment.objects.select_related("flag").get(pk=experiment_id) experiment = Experiment.objects.select_related("flag").get(
pk=experiment_id
)
except Experiment.DoesNotExist: except Experiment.DoesNotExist:
return [] return []
@@ -159,13 +171,17 @@ def find_similar_learnings(
experiment_id=experiment_id experiment_id=experiment_id
).exists() ).exists()
candidates = Learning.objects.select_related( candidates = (
"experiment__flag", Learning.objects.select_related(
"experiment__owner", "experiment__flag",
"experiment__outcome", "experiment__owner",
).exclude( "experiment__outcome",
experiment_id=experiment_id, )
).order_by("-created_at")[:100] .exclude(
experiment_id=experiment_id,
)
.order_by("-created_at")[:100]
)
scored: list[tuple[float, Learning]] = [] scored: list[tuple[float, Learning]] = []
for candidate in candidates: for candidate in candidates:
@@ -179,7 +195,7 @@ def find_similar_learnings(
if score > 0: if score > 0:
scored.append((score, candidate)) scored.append((score, candidate))
scored.sort(key=lambda x: x[0], reverse=True) scored.sort(key=operator.itemgetter(0), reverse=True)
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
for score, candidate in scored[:limit]: for score, candidate in scored[:limit]:
outcome_data = None outcome_data = None
@@ -189,7 +205,9 @@ def find_similar_learnings(
outcome_data = { outcome_data = {
"outcome": exp_outcome.outcome, "outcome": exp_outcome.outcome,
"rationale": exp_outcome.rationale, "rationale": exp_outcome.rationale,
"winning_variant": str(exp_outcome.winning_variant_id) if exp_outcome.winning_variant_id else None, "winning_variant": str(exp_outcome.winning_variant_id)
if exp_outcome.winning_variant_id
else None,
} }
except ExperimentOutcome.DoesNotExist: except ExperimentOutcome.DoesNotExist:
pass pass
@@ -198,19 +216,21 @@ def find_similar_learnings(
experiment_id=candidate.experiment_id experiment_id=candidate.experiment_id
).count() ).count()
results.append({ results.append(
"learning_id": str(candidate.pk), {
"experiment_id": str(candidate.experiment_id), "learning_id": str(candidate.pk),
"experiment_name": candidate.experiment.name, "experiment_id": str(candidate.experiment_id),
"flag_key": candidate.experiment.flag.key, "experiment_name": candidate.experiment.name,
"hypothesis": candidate.hypothesis, "flag_key": candidate.experiment.flag.key,
"findings": candidate.findings, "hypothesis": candidate.hypothesis,
"tags": candidate.tags, "findings": candidate.findings,
"outcome": outcome_data, "tags": candidate.tags,
"had_guardrail_triggers": trigger_count > 0, "outcome": outcome_data,
"guardrail_trigger_count": trigger_count, "had_guardrail_triggers": trigger_count > 0,
"similarity_score": round(score, 3), "guardrail_trigger_count": trigger_count,
}) "similarity_score": round(score, 3),
}
)
return results return results
@@ -258,7 +278,7 @@ def _jaccard_score(
def _update_search_vector(learning: Learning) -> None: def _update_search_vector(learning: Learning) -> None:
if not _is_postgres(): if not _is_postgres():
return return
from django.contrib.postgres.search import SearchVector
Learning.objects.filter(pk=learning.pk).update( Learning.objects.filter(pk=learning.pk).update(
search_vector=( search_vector=(
SearchVector("hypothesis", weight="A", config="english") SearchVector("hypothesis", weight="A", config="english")
@@ -1,12 +1,12 @@
import uuid
from typing import override from typing import override
from django.test import TestCase from django.test import TestCase
from apps.experiments.models import ExperimentOutcome, ExperimentStatus, OutcomeType from apps.experiments.tests.helpers import (
from apps.experiments.services import experiment_complete make_experiment,
from apps.experiments.tests.helpers import add_two_variants, make_experiment, make_flag make_flag,
from apps.guardrails.models import Guardrail, GuardrailTrigger )
from apps.learnings.models import Learning, LearningEdit
from apps.learnings.services import ( from apps.learnings.services import (
find_similar_learnings, find_similar_learnings,
learning_create, learning_create,
@@ -16,9 +16,9 @@ from apps.learnings.services import (
learning_list, learning_list,
learning_update, learning_update,
) )
from apps.metrics.models import ExperimentMetric, MetricDefinition, MetricType
from apps.reviews.tests.helpers import make_admin from apps.reviews.tests.helpers import make_admin
from apps.users.tests.helpers import make_user from apps.users.tests.helpers import make_user
from config.errors import ConflictError
class LearningCRUDTest(TestCase): class LearningCRUDTest(TestCase):
@@ -28,7 +28,7 @@ class LearningCRUDTest(TestCase):
self.exp = make_experiment(suffix="_lcrud", owner=self.user) self.exp = make_experiment(suffix="_lcrud", owner=self.user)
def test_create_learning(self) -> None: def test_create_learning(self) -> None:
l = learning_create( learning = learning_create(
experiment=self.exp, experiment=self.exp,
hypothesis="Changing button color increases CTR", hypothesis="Changing button color increases CTR",
findings="Blue variant showed +5% CTR improvement", findings="Blue variant showed +5% CTR improvement",
@@ -36,21 +36,25 @@ class LearningCRUDTest(TestCase):
context_summary="Checkout page button color test", context_summary="Checkout page button color test",
user=self.user, user=self.user,
) )
self.assertEqual(l.experiment, self.exp) self.assertEqual(learning.experiment, self.exp)
self.assertEqual(l.hypothesis, "Changing button color increases CTR") self.assertEqual(
self.assertEqual(l.findings, "Blue variant showed +5% CTR improvement") learning.hypothesis, "Changing button color increases CTR"
self.assertEqual(l.tags, ["ui", "ctr", "button"]) )
self.assertEqual(l.created_by, self.user) self.assertEqual(
learning.findings, "Blue variant showed +5% CTR improvement"
)
self.assertEqual(learning.tags, ["ui", "ctr", "button"])
self.assertEqual(learning.created_by, self.user)
def test_create_learning_minimal(self) -> None: def test_create_learning_minimal(self) -> None:
l = learning_create( learning = learning_create(
experiment=self.exp, experiment=self.exp,
hypothesis="Test hypothesis", hypothesis="Test hypothesis",
findings="Test findings", findings="Test findings",
) )
self.assertEqual(l.tags, []) self.assertEqual(learning.tags, [])
self.assertEqual(l.context_summary, "") self.assertEqual(learning.context_summary, "")
self.assertIsNone(l.created_by) self.assertIsNone(learning.created_by)
def test_create_duplicate_learning_fails(self) -> None: def test_create_duplicate_learning_fails(self) -> None:
learning_create( learning_create(
@@ -58,7 +62,7 @@ class LearningCRUDTest(TestCase):
hypothesis="First", hypothesis="First",
findings="First findings", findings="First findings",
) )
with self.assertRaises(Exception): with self.assertRaises(ConflictError):
learning_create( learning_create(
experiment=self.exp, experiment=self.exp,
hypothesis="Second", hypothesis="Second",
@@ -66,28 +70,27 @@ class LearningCRUDTest(TestCase):
) )
def test_get_learning(self) -> None: def test_get_learning(self) -> None:
l = learning_create( learning = learning_create(
experiment=self.exp, experiment=self.exp,
hypothesis="Test", hypothesis="Test",
findings="Test findings", findings="Test findings",
) )
fetched = learning_get(l.pk) fetched = learning_get(learning.pk)
self.assertIsNotNone(fetched) self.assertIsNotNone(fetched)
self.assertEqual(fetched.pk, l.pk) self.assertEqual(fetched.pk, learning.pk)
def test_get_nonexistent_learning(self) -> None: def test_get_nonexistent_learning(self) -> None:
import uuid
result = learning_get(uuid.uuid4()) result = learning_get(uuid.uuid4())
self.assertIsNone(result) self.assertIsNone(result)
def test_delete_learning(self) -> None: def test_delete_learning(self) -> None:
l = learning_create( learning = learning_create(
experiment=self.exp, experiment=self.exp,
hypothesis="Test", hypothesis="Test",
findings="Test findings", findings="Test findings",
) )
learning_delete(learning=l) learning_delete(learning=learning)
self.assertIsNone(learning_get(l.pk)) self.assertIsNone(learning_get(learning.pk))
class LearningUpdateTest(TestCase): class LearningUpdateTest(TestCase):
@@ -108,12 +111,12 @@ class LearningUpdateTest(TestCase):
username="editor_lupd", username="editor_lupd",
email="editor_lupd@lotty.local", email="editor_lupd@lotty.local",
) )
l = learning_update( learning = learning_update(
learning=self.learning, learning=self.learning,
user=editor, user=editor,
hypothesis="Updated hypothesis", hypothesis="Updated hypothesis",
) )
self.assertEqual(l.hypothesis, "Updated hypothesis") self.assertEqual(learning.hypothesis, "Updated hypothesis")
def test_update_creates_audit_trail(self) -> None: def test_update_creates_audit_trail(self) -> None:
editor = make_user( editor = make_user(
@@ -193,7 +196,7 @@ class LearningListTest(TestCase):
def test_filter_by_tag(self) -> None: def test_filter_by_tag(self) -> None:
learnings = learning_list(tag="ui") learnings = learning_list(tag="ui")
self.assertTrue(all("ui" in l.tags for l in learnings)) self.assertTrue(all("ui" in learning.tags for learning in learnings))
def test_filter_by_tag_no_match(self) -> None: def test_filter_by_tag_no_match(self) -> None:
learnings = learning_list(tag="nonexistent") learnings = learning_list(tag="nonexistent")
@@ -214,7 +217,9 @@ class SimilarLearningsTest(TestCase):
self.user = make_admin("_sim") self.user = make_admin("_sim")
self.flag = make_flag(suffix="_sim1") self.flag = make_flag(suffix="_sim1")
self.exp1 = make_experiment(flag=self.flag, owner=self.user, suffix="_sim1") self.exp1 = make_experiment(
flag=self.flag, owner=self.user, suffix="_sim1"
)
self.learning1 = learning_create( self.learning1 = learning_create(
experiment=self.exp1, experiment=self.exp1,
hypothesis="Test button color", hypothesis="Test button color",
@@ -223,7 +228,9 @@ class SimilarLearningsTest(TestCase):
) )
flag2 = make_flag(suffix="_sim2") flag2 = make_flag(suffix="_sim2")
self.exp2 = make_experiment(flag=flag2, owner=self.user, suffix="_sim2") self.exp2 = make_experiment(
flag=flag2, owner=self.user, suffix="_sim2"
)
self.learning2 = learning_create( self.learning2 = learning_create(
experiment=self.exp2, experiment=self.exp2,
hypothesis="Test font size", hypothesis="Test font size",
@@ -262,7 +269,6 @@ class SimilarLearningsTest(TestCase):
self.assertNotIn(str(self.exp1.pk), result_exp_ids) self.assertNotIn(str(self.exp1.pk), result_exp_ids)
def test_find_similar_nonexistent_experiment(self) -> None: def test_find_similar_nonexistent_experiment(self) -> None:
import uuid
results = find_similar_learnings(experiment_id=uuid.uuid4()) results = find_similar_learnings(experiment_id=uuid.uuid4())
self.assertEqual(results, []) self.assertEqual(results, [])
+3 -1
View File
@@ -41,7 +41,9 @@ class NotificationChannel(BaseModel):
default=dict, default=dict,
blank=True, blank=True,
verbose_name=_("configuration"), verbose_name=_("configuration"),
help_text=_("Provider-specific settings (tokens, chat IDs, SMTP host, etc.)"), help_text=_(
"Provider-specific settings (tokens, chat IDs, SMTP host, etc.)"
),
) )
is_active = models.BooleanField( is_active = models.BooleanField(
default=True, default=True,
+5 -41
View File
@@ -5,13 +5,12 @@ from typing import Any
import requests import requests
from django.core.mail import send_mail from django.core.mail import send_mail
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import Q, QuerySet
from django.utils import timezone from django.utils import timezone
from apps.notifications.models import ( from apps.notifications.models import (
ChannelType, ChannelType,
NotificationChannel, NotificationChannel,
NotificationEventType,
NotificationLog, NotificationLog,
NotificationRule, NotificationRule,
NotificationStatus, NotificationStatus,
@@ -30,11 +29,6 @@ class NotificationPayload:
extra: dict[str, Any] = field(default_factory=dict) extra: dict[str, Any] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# Channel CRUD
# ---------------------------------------------------------------------------
@transaction.atomic @transaction.atomic
def channel_create( def channel_create(
*, *,
@@ -82,11 +76,6 @@ def channel_get(channel_id: Any) -> NotificationChannel | None:
return None return None
# ---------------------------------------------------------------------------
# Rule CRUD
# ---------------------------------------------------------------------------
@transaction.atomic @transaction.atomic
def rule_create( def rule_create(
*, *,
@@ -130,11 +119,6 @@ def rule_list(channel_id: Any | None = None) -> QuerySet[NotificationRule]:
return qs return qs
# ---------------------------------------------------------------------------
# Log selectors
# ---------------------------------------------------------------------------
def log_list( def log_list(
*, *,
status: str | None = None, status: str | None = None,
@@ -146,11 +130,6 @@ def log_list(
return qs[:limit] return qs[:limit]
# ---------------------------------------------------------------------------
# Notification enqueue (called from integration points)
# ---------------------------------------------------------------------------
def notification_enqueue( def notification_enqueue(
event_type: str, event_type: str,
payload: NotificationPayload, payload: NotificationPayload,
@@ -163,8 +142,7 @@ def notification_enqueue(
if payload.experiment_id: if payload.experiment_id:
rules = rules.filter( rules = rules.filter(
models_Q(experiment__isnull=True) Q(experiment__isnull=True) | Q(experiment_id=payload.experiment_id)
| models_Q(experiment_id=payload.experiment_id)
) )
else: else:
rules = rules.filter(experiment__isnull=True) rules = rules.filter(experiment__isnull=True)
@@ -203,17 +181,6 @@ def _build_event_key(event_type: str, payload: NotificationPayload) -> str:
return f"{event_type}:{payload.experiment_id}:{bucket}" return f"{event_type}:{payload.experiment_id}:{bucket}"
def models_Q(**kwargs):
from django.db.models import Q
return Q(**kwargs)
# ---------------------------------------------------------------------------
# Senders
# ---------------------------------------------------------------------------
def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None: def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
bot_token = config.get("bot_token", "") bot_token = config.get("bot_token", "")
chat_id = config.get("chat_id", "") chat_id = config.get("chat_id", "")
@@ -260,11 +227,6 @@ def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
) )
# ---------------------------------------------------------------------------
# Flush pending (called from Celery task)
# ---------------------------------------------------------------------------
def flush_pending_notifications() -> dict[str, int]: def flush_pending_notifications() -> dict[str, int]:
pending = NotificationLog.objects.filter( pending = NotificationLog.objects.filter(
status=NotificationStatus.PENDING, status=NotificationStatus.PENDING,
@@ -288,7 +250,9 @@ def flush_pending_notifications() -> dict[str, int]:
sender = senders.get(log.channel.channel_type) sender = senders.get(log.channel.channel_type)
if not sender: if not sender:
log.status = NotificationStatus.FAILED log.status = NotificationStatus.FAILED
log.error = f"No sender for channel type '{log.channel.channel_type}'." log.error = (
f"No sender for channel type '{log.channel.channel_type}'."
)
log.save(update_fields=["status", "error"]) log.save(update_fields=["status", "error"])
results["failed"] += 1 results["failed"] += 1
continue continue
@@ -6,10 +6,8 @@ from django.test import TestCase
from apps.experiments.tests.helpers import make_experiment from apps.experiments.tests.helpers import make_experiment
from apps.notifications.models import ( from apps.notifications.models import (
ChannelType, ChannelType,
NotificationChannel,
NotificationEventType, NotificationEventType,
NotificationLog, NotificationLog,
NotificationRule,
NotificationStatus, NotificationStatus,
) )
from apps.notifications.services import ( from apps.notifications.services import (
@@ -85,7 +83,9 @@ class RuleCRUDTest(TestCase):
event_type=NotificationEventType.GUARDRAIL_TRIGGERED, event_type=NotificationEventType.GUARDRAIL_TRIGGERED,
channel=self.channel, channel=self.channel,
) )
self.assertEqual(r.event_type, NotificationEventType.GUARDRAIL_TRIGGERED) self.assertEqual(
r.event_type, NotificationEventType.GUARDRAIL_TRIGGERED
)
self.assertIsNone(r.experiment) self.assertIsNone(r.experiment)
self.assertTrue(r.is_active) self.assertTrue(r.is_active)