from datetime import datetime from decimal import Decimal from uuid import UUID from django.db import connection from django.db.models import ( Aggregate, Avg, Case, Count, F, FloatField, QuerySet, Subquery, Value, When, ) from django.db.models.fields.json import KeyTextTransform from django.db.models.functions import Cast from apps.events.models import Event, Exposure from apps.experiments.models import Experiment from apps.metrics.models import ( ExperimentMetric, MetricDefinition, MetricType, ) class PercentileCont(Aggregate): function = "PERCENTILE_CONT" template = ( "%(function)s(%(percentile)s) WITHIN GROUP (ORDER BY %(expressions)s)" ) allow_distinct = False output_field = FloatField() def __init__(self, expression, percentile, **extra): super().__init__( expression, percentile=percentile, **extra, ) def _exposure_queryset( experiment_id: UUID, variant_id: UUID, start_date: datetime | None = None, end_date: datetime | None = None, ) -> QuerySet[Exposure]: qs = Exposure.objects.filter( experiment_id=experiment_id, variant_id=variant_id, ) if start_date: qs = qs.filter(timestamp__gte=start_date) if end_date: qs = qs.filter(timestamp__lt=end_date) return qs def _exposure_decision_ids_subquery( exposure_qs: QuerySet[Exposure], ): return Subquery(exposure_qs.values("decision_id")) def _events_queryset( *, exposure_qs: QuerySet[Exposure], event_type_name: str, start_date: datetime | None = None, end_date: datetime | None = None, ) -> QuerySet[Event]: qs = Event.objects.filter( decision_id__in=_exposure_decision_ids_subquery(exposure_qs), event_type__name=event_type_name, is_attributed=True, ) if start_date: qs = qs.filter(timestamp__gte=start_date) if end_date: qs = qs.filter(timestamp__lt=end_date) return qs def _numeric_property_expression(property_field: str): if connection.vendor == "postgresql": key_text = KeyTextTransform(property_field, "properties") pattern = r"^-?(?:\d+(?:\.\d+)?|\.\d+)$" return Case( When( **{f"properties__{property_field}__regex": pattern}, then=Cast(key_text, FloatField()), ), default=Value(None), output_field=FloatField(), ) return Cast(F(f"properties__{property_field}"), FloatField()) def _count_events( *, exposure_qs: QuerySet[Exposure], event_type_name: str, start_date: datetime | None = None, end_date: datetime | None = None, ) -> int: qs = _events_queryset( exposure_qs=exposure_qs, event_type_name=event_type_name, start_date=start_date, end_date=end_date, ) return qs.count() def _average_property( *, exposure_qs: QuerySet[Exposure], event_type_name: str, property_field: str, start_date: datetime | None = None, end_date: datetime | None = None, ) -> Decimal | None: qs = _events_queryset( exposure_qs=exposure_qs, event_type_name=event_type_name, start_date=start_date, end_date=end_date, ).annotate( numeric_value=_numeric_property_expression(property_field), ) value = qs.aggregate(value=Avg("numeric_value"))["value"] if value is None: return None return Decimal(str(value)) def _percentile_property( *, exposure_qs: QuerySet[Exposure], event_type_name: str, property_field: str, percentile: int, start_date: datetime | None = None, end_date: datetime | None = None, ) -> Decimal | None: qs = ( _events_queryset( exposure_qs=exposure_qs, event_type_name=event_type_name, start_date=start_date, end_date=end_date, ) .annotate( numeric_value=_numeric_property_expression(property_field), ) .exclude(numeric_value__isnull=True) ) if connection.vendor == "postgresql": value = qs.aggregate( value=PercentileCont( "numeric_value", Decimal(percentile) / Decimal(100), ) )["value"] if value is None: return None return Decimal(str(value)) total = qs.aggregate(total=Count("pk"))["total"] if not total: return None idx = min(int(total * percentile / 100), total - 1) value = qs.order_by("numeric_value").values_list( "numeric_value", flat=True, )[idx] return Decimal(str(value)) def calculate_metric_value( metric: MetricDefinition, experiment_id: UUID, variant_id: UUID, start_date: datetime | None = None, end_date: datetime | None = None, event_start_date: datetime | None = None, event_end_date: datetime | None = None, ) -> Decimal | None: rule = metric.calculation_rule exposure_qs = _exposure_queryset( experiment_id, variant_id, start_date, end_date, ) if not exposure_qs.exists(): return None ev_start = event_start_date or start_date ev_end = event_end_date or end_date metric_type = metric.metric_type if metric_type == MetricType.RATIO: numerator = _count_events( exposure_qs=exposure_qs, event_type_name=rule["numerator_event"], start_date=ev_start, end_date=ev_end, ) denominator = _count_events( exposure_qs=exposure_qs, event_type_name=rule["denominator_event"], start_date=ev_start, end_date=ev_end, ) if denominator == 0: return None return Decimal(str(round(numerator / denominator, 6))) if metric_type == MetricType.COUNT: count = _count_events( exposure_qs=exposure_qs, event_type_name=rule["event"], start_date=ev_start, end_date=ev_end, ) return Decimal(str(count)) if metric_type == MetricType.AVERAGE: return _average_property( exposure_qs=exposure_qs, event_type_name=rule["event"], property_field=rule["property"], start_date=ev_start, end_date=ev_end, ) if metric_type == MetricType.PERCENTILE: return _percentile_property( exposure_qs=exposure_qs, event_type_name=rule["event"], property_field=rule["property"], percentile=rule.get("percentile", 95), start_date=ev_start, end_date=ev_end, ) return None def _exposure_count_for_variant( experiment_id: UUID, variant_id: UUID, start_date: datetime | None = None, end_date: datetime | None = None, ) -> int: qs = Exposure.objects.filter( experiment_id=experiment_id, variant_id=variant_id, ) if start_date: qs = qs.filter(timestamp__gte=start_date) if end_date: qs = qs.filter(timestamp__lt=end_date) return qs.count() def _unique_subjects_for_variant( experiment_id: UUID, variant_id: UUID, start_date: datetime | None = None, end_date: datetime | None = None, ) -> int: qs = Exposure.objects.filter( experiment_id=experiment_id, variant_id=variant_id, ) if start_date: qs = qs.filter(timestamp__gte=start_date) if end_date: qs = qs.filter(timestamp__lt=end_date) return qs.values("subject_id").distinct().count() def build_experiment_report( experiment: Experiment, start_date: datetime | None = None, end_date: datetime | None = None, ) -> dict: experiment_metrics = ( ExperimentMetric.objects.filter(experiment=experiment) .select_related("metric") .order_by("-is_primary", "metric__key") ) variants = experiment.variants.all() variant_reports = [] for variant in variants: metric_results = [] for em in experiment_metrics: value = calculate_metric_value( metric=em.metric, experiment_id=experiment.pk, variant_id=variant.pk, start_date=start_date, end_date=end_date, ) metric_results.append( { "metric_key": em.metric.key, "metric_name": em.metric.name, "metric_type": em.metric.metric_type, "direction": em.metric.direction, "is_primary": em.is_primary, "value": value, } ) variant_reports.append( { "variant_id": variant.pk, "variant_name": variant.name, "is_control": variant.is_control, "weight": variant.weight, "exposures": _exposure_count_for_variant( experiment.pk, variant.pk, start_date, end_date, ), "unique_subjects": _unique_subjects_for_variant( experiment.pk, variant.pk, start_date, end_date, ), "metrics": metric_results, } ) return { "experiment_id": experiment.pk, "experiment_name": experiment.name, "status": experiment.status, "period": { "start": start_date.isoformat() if start_date else None, "end": end_date.isoformat() if end_date else None, }, "variants": variant_reports, }