Files
Lotty/src/backend/apps/notifications/services.py
T
2026-02-24 18:30:01 +03:00

392 lines
11 KiB
Python

import logging
from dataclasses import dataclass, field
from typing import Any
import requests
from django.core.exceptions import ValidationError
from django.core.mail import EmailMessage, get_connection
from django.db import transaction
from django.db.models import Q, QuerySet
from django.utils import timezone
from apps.notifications.models import (
ChannelType,
NotificationChannel,
NotificationLog,
NotificationRule,
NotificationStatus,
)
logger = logging.getLogger("lotty")
REQUIRED_CHANNEL_CONFIG_FIELDS: dict[str, tuple[str, ...]] = {
ChannelType.TELEGRAM: ("bot_token", "chat_id"),
ChannelType.SMTP: ("recipient",),
}
SMTP_CONNECTION_OPTION_KEYS: tuple[str, ...] = (
"host",
"port",
"username",
"password",
"use_tls",
"use_ssl",
"timeout",
)
DEFAULT_SMTP_FROM_EMAIL = "lotty@lotty.local"
@dataclass(frozen=True)
class NotificationPayload:
title: str
body: str
event_type: str
experiment_id: str = ""
experiment_name: str = ""
extra: dict[str, Any] = field(default_factory=dict)
def _required_str_config(
config: dict[str, Any],
field: str,
channel_type: str,
) -> str:
value = config.get(field)
if not isinstance(value, str) or not value.strip():
raise ValidationError(
{
"config": (
f"Channel '{channel_type}' requires non-empty '{field}'."
)
}
)
return value.strip()
def _validate_channel_config(
*,
channel_type: str,
config: dict[str, Any],
) -> None:
if not isinstance(config, dict):
raise ValidationError(
{"config": "Channel config must be a JSON object."}
)
for required_field in REQUIRED_CHANNEL_CONFIG_FIELDS.get(channel_type, ()):
_required_str_config(config, required_field, channel_type)
@transaction.atomic
def channel_create(
*,
channel_type: str,
name: str,
config: dict[str, Any] | None = None,
) -> NotificationChannel:
normalized_config = config or {}
_validate_channel_config(
channel_type=channel_type,
config=normalized_config,
)
channel = NotificationChannel(
channel_type=channel_type,
name=name,
config=normalized_config,
)
channel.save()
return channel
def channel_update(
*,
channel: NotificationChannel,
**fields: Any,
) -> NotificationChannel:
allowed = {"name", "config", "is_active"}
for key in fields:
if key not in allowed:
raise ValueError(f"Field '{key}' cannot be updated.")
next_config = (
fields["config"]
if "config" in fields and fields["config"] is not None
else None
)
_validate_channel_config(
channel_type=channel.channel_type,
config=channel.config if next_config is None else next_config,
)
for key, value in fields.items():
if value is not None:
setattr(channel, key, value)
channel.save()
return channel
def channel_delete(*, channel: NotificationChannel) -> None:
channel.delete()
def channel_list() -> QuerySet[NotificationChannel]:
return NotificationChannel.objects.all()
def channel_get(channel_id: Any) -> NotificationChannel | None:
try:
return NotificationChannel.objects.get(pk=channel_id)
except NotificationChannel.DoesNotExist:
return None
@transaction.atomic
def rule_create(
*,
event_type: str,
channel: NotificationChannel,
experiment: Any | None = None,
rate_limit_window_seconds: int = 60,
rate_limit_max_notifications: int = 1,
) -> NotificationRule:
rule = NotificationRule(
event_type=event_type,
channel=channel,
experiment=experiment,
rate_limit_window_seconds=rate_limit_window_seconds,
rate_limit_max_notifications=rate_limit_max_notifications,
)
rule.save()
return rule
def rule_update(
*,
rule: NotificationRule,
**fields: Any,
) -> NotificationRule:
allowed = {
"event_type",
"is_active",
"rate_limit_window_seconds",
"rate_limit_max_notifications",
}
for key in fields:
if key not in allowed:
raise ValueError(f"Field '{key}' cannot be updated.")
for key, value in fields.items():
if value is not None:
setattr(rule, key, value)
rule.save()
return rule
def rule_delete(*, rule: NotificationRule) -> None:
rule.delete()
def rule_list(channel_id: Any | None = None) -> QuerySet[NotificationRule]:
qs = NotificationRule.objects.select_related("channel", "experiment").all()
if channel_id is not None:
qs = qs.filter(channel_id=channel_id)
return qs
def log_list(
*,
status: str | None = None,
limit: int = 100,
) -> QuerySet[NotificationLog]:
qs = NotificationLog.objects.select_related("channel", "rule").all()
if status:
qs = qs.filter(status=status)
return qs[:limit]
def notification_enqueue(
event_type: str,
payload: NotificationPayload,
) -> list[NotificationLog]:
rules = NotificationRule.objects.filter(
event_type=event_type,
is_active=True,
channel__is_active=True,
).select_related("channel")
if payload.experiment_id:
rules = rules.filter(
Q(experiment__isnull=True) | Q(experiment_id=payload.experiment_id)
)
else:
rules = rules.filter(experiment__isnull=True)
logs: list[NotificationLog] = []
for rule in rules:
event_key = _build_event_key(
event_type,
payload,
rule.rate_limit_window_seconds,
)
sent_or_pending = NotificationLog.objects.filter(
event_key=event_key,
channel=rule.channel,
status__in=[NotificationStatus.PENDING, NotificationStatus.SENT],
).count()
if sent_or_pending >= rule.rate_limit_max_notifications:
continue
log = NotificationLog.objects.create(
rule=rule,
channel=rule.channel,
event_type=event_type,
event_key=event_key,
payload={
"title": payload.title,
"body": payload.body,
"experiment_id": payload.experiment_id,
"experiment_name": payload.experiment_name,
"extra": payload.extra,
},
status=NotificationStatus.PENDING,
)
logs.append(log)
return logs
def _build_event_key(
event_type: str,
payload: NotificationPayload,
window_seconds: int,
) -> str:
normalized_window = max(window_seconds, 1)
bucket = int(timezone.now().timestamp()) // normalized_window
return f"{event_type}:{payload.experiment_id}:{bucket}"
def _escape_markdown(text: str) -> str:
for ch in r"\_*[]()~`>#+-=|{}.!":
text = text.replace(ch, f"\\{ch}")
return text
def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
bot_token = config.get("bot_token", "")
chat_id = config.get("chat_id", "")
if not bot_token or not chat_id:
raise ValueError("Telegram config requires 'bot_token' and 'chat_id'.")
title = _escape_markdown(payload["title"])
body = _escape_markdown(payload["body"])
text = f"*{title}*\n\n{body}"
if payload.get("experiment_name"):
name = _escape_markdown(payload["experiment_name"])
text += f"\n\nExperiment: {name}"
api_url = config.get(
"api_url",
f"https://api.telegram.org/bot{bot_token}",
)
response = requests.post(
f"{api_url}/sendMessage",
json={
"chat_id": chat_id,
"text": text,
"parse_mode": "MarkdownV2",
},
timeout=10,
)
response.raise_for_status()
def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
recipient = config.get("recipient", "")
from_email = config.get("from_email", DEFAULT_SMTP_FROM_EMAIL)
if not isinstance(recipient, str) or not recipient.strip():
raise ValueError("SMTP config requires 'recipient'.")
recipient = recipient.strip()
if not isinstance(from_email, str) or not from_email.strip():
from_email = DEFAULT_SMTP_FROM_EMAIL
connection_options = {
key: config.get(key) for key in SMTP_CONNECTION_OPTION_KEYS
}
runtime_options: dict[str, Any] = {}
for key, value in connection_options.items():
if value is None:
continue
if isinstance(value, str) and not value:
continue
runtime_options[key] = value
connection = get_connection(
fail_silently=False,
**runtime_options,
)
subject = payload.get("title", "Lotty Notification")
body = payload.get("body", "")
if payload.get("experiment_name"):
body += f"\n\nExperiment: {payload['experiment_name']}"
email = EmailMessage(
subject=subject,
body=body,
from_email=from_email,
to=[recipient],
connection=connection,
)
email.send(fail_silently=False)
def flush_pending_notifications() -> dict[str, int]:
pending = (
NotificationLog.objects.filter(
status=NotificationStatus.PENDING,
)
.select_related("channel")
.order_by("created_at")
)
senders = {
ChannelType.TELEGRAM: _send_telegram,
ChannelType.SMTP: _send_smtp,
}
results = {"sent": 0, "failed": 0}
for log in pending:
if not log.channel or not log.channel.is_active:
log.status = NotificationStatus.FAILED
log.error = "Channel is inactive or missing."
log.save(update_fields=["status", "error"])
results["failed"] += 1
continue
sender = senders.get(log.channel.channel_type)
if not sender:
log.status = NotificationStatus.FAILED
log.error = (
f"No sender for channel type '{log.channel.channel_type}'."
)
log.save(update_fields=["status", "error"])
results["failed"] += 1
continue
try:
sender(log.channel.config, log.payload)
log.status = NotificationStatus.SENT
log.sent_at = timezone.now()
log.save(update_fields=["status", "sent_at"])
results["sent"] += 1
except Exception as exc:
logger.exception(
"notification_send_failed",
extra={
"log_id": str(log.pk),
"channel": log.channel.name,
"error": str(exc),
},
)
log.status = NotificationStatus.FAILED
log.error = str(exc)[:1000]
log.save(update_fields=["status", "error"])
results["failed"] += 1
return results