392 lines
11 KiB
Python
392 lines
11 KiB
Python
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import requests
|
|
from django.core.exceptions import ValidationError
|
|
from django.core.mail import EmailMessage, get_connection
|
|
from django.db import transaction
|
|
from django.db.models import Q, QuerySet
|
|
from django.utils import timezone
|
|
|
|
from apps.notifications.models import (
|
|
ChannelType,
|
|
NotificationChannel,
|
|
NotificationLog,
|
|
NotificationRule,
|
|
NotificationStatus,
|
|
)
|
|
|
|
logger = logging.getLogger("lotty")
|
|
|
|
REQUIRED_CHANNEL_CONFIG_FIELDS: dict[str, tuple[str, ...]] = {
|
|
ChannelType.TELEGRAM: ("bot_token", "chat_id"),
|
|
ChannelType.SMTP: ("recipient",),
|
|
}
|
|
SMTP_CONNECTION_OPTION_KEYS: tuple[str, ...] = (
|
|
"host",
|
|
"port",
|
|
"username",
|
|
"password",
|
|
"use_tls",
|
|
"use_ssl",
|
|
"timeout",
|
|
)
|
|
DEFAULT_SMTP_FROM_EMAIL = "lotty@lotty.local"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class NotificationPayload:
|
|
title: str
|
|
body: str
|
|
event_type: str
|
|
experiment_id: str = ""
|
|
experiment_name: str = ""
|
|
extra: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
def _required_str_config(
|
|
config: dict[str, Any],
|
|
field: str,
|
|
channel_type: str,
|
|
) -> str:
|
|
value = config.get(field)
|
|
if not isinstance(value, str) or not value.strip():
|
|
raise ValidationError(
|
|
{
|
|
"config": (
|
|
f"Channel '{channel_type}' requires non-empty '{field}'."
|
|
)
|
|
}
|
|
)
|
|
return value.strip()
|
|
|
|
|
|
def _validate_channel_config(
|
|
*,
|
|
channel_type: str,
|
|
config: dict[str, Any],
|
|
) -> None:
|
|
if not isinstance(config, dict):
|
|
raise ValidationError(
|
|
{"config": "Channel config must be a JSON object."}
|
|
)
|
|
for required_field in REQUIRED_CHANNEL_CONFIG_FIELDS.get(channel_type, ()):
|
|
_required_str_config(config, required_field, channel_type)
|
|
|
|
|
|
@transaction.atomic
|
|
def channel_create(
|
|
*,
|
|
channel_type: str,
|
|
name: str,
|
|
config: dict[str, Any] | None = None,
|
|
) -> NotificationChannel:
|
|
normalized_config = config or {}
|
|
_validate_channel_config(
|
|
channel_type=channel_type,
|
|
config=normalized_config,
|
|
)
|
|
channel = NotificationChannel(
|
|
channel_type=channel_type,
|
|
name=name,
|
|
config=normalized_config,
|
|
)
|
|
channel.save()
|
|
return channel
|
|
|
|
|
|
def channel_update(
|
|
*,
|
|
channel: NotificationChannel,
|
|
**fields: Any,
|
|
) -> NotificationChannel:
|
|
allowed = {"name", "config", "is_active"}
|
|
for key in fields:
|
|
if key not in allowed:
|
|
raise ValueError(f"Field '{key}' cannot be updated.")
|
|
next_config = (
|
|
fields["config"]
|
|
if "config" in fields and fields["config"] is not None
|
|
else None
|
|
)
|
|
_validate_channel_config(
|
|
channel_type=channel.channel_type,
|
|
config=channel.config if next_config is None else next_config,
|
|
)
|
|
for key, value in fields.items():
|
|
if value is not None:
|
|
setattr(channel, key, value)
|
|
channel.save()
|
|
return channel
|
|
|
|
|
|
def channel_delete(*, channel: NotificationChannel) -> None:
|
|
channel.delete()
|
|
|
|
|
|
def channel_list() -> QuerySet[NotificationChannel]:
|
|
return NotificationChannel.objects.all()
|
|
|
|
|
|
def channel_get(channel_id: Any) -> NotificationChannel | None:
|
|
try:
|
|
return NotificationChannel.objects.get(pk=channel_id)
|
|
except NotificationChannel.DoesNotExist:
|
|
return None
|
|
|
|
|
|
@transaction.atomic
|
|
def rule_create(
|
|
*,
|
|
event_type: str,
|
|
channel: NotificationChannel,
|
|
experiment: Any | None = None,
|
|
rate_limit_window_seconds: int = 60,
|
|
rate_limit_max_notifications: int = 1,
|
|
) -> NotificationRule:
|
|
rule = NotificationRule(
|
|
event_type=event_type,
|
|
channel=channel,
|
|
experiment=experiment,
|
|
rate_limit_window_seconds=rate_limit_window_seconds,
|
|
rate_limit_max_notifications=rate_limit_max_notifications,
|
|
)
|
|
rule.save()
|
|
return rule
|
|
|
|
|
|
def rule_update(
|
|
*,
|
|
rule: NotificationRule,
|
|
**fields: Any,
|
|
) -> NotificationRule:
|
|
allowed = {
|
|
"event_type",
|
|
"is_active",
|
|
"rate_limit_window_seconds",
|
|
"rate_limit_max_notifications",
|
|
}
|
|
for key in fields:
|
|
if key not in allowed:
|
|
raise ValueError(f"Field '{key}' cannot be updated.")
|
|
for key, value in fields.items():
|
|
if value is not None:
|
|
setattr(rule, key, value)
|
|
rule.save()
|
|
return rule
|
|
|
|
|
|
def rule_delete(*, rule: NotificationRule) -> None:
|
|
rule.delete()
|
|
|
|
|
|
def rule_list(channel_id: Any | None = None) -> QuerySet[NotificationRule]:
|
|
qs = NotificationRule.objects.select_related("channel", "experiment").all()
|
|
if channel_id is not None:
|
|
qs = qs.filter(channel_id=channel_id)
|
|
return qs
|
|
|
|
|
|
def log_list(
|
|
*,
|
|
status: str | None = None,
|
|
limit: int = 100,
|
|
) -> QuerySet[NotificationLog]:
|
|
qs = NotificationLog.objects.select_related("channel", "rule").all()
|
|
if status:
|
|
qs = qs.filter(status=status)
|
|
return qs[:limit]
|
|
|
|
|
|
def notification_enqueue(
|
|
event_type: str,
|
|
payload: NotificationPayload,
|
|
) -> list[NotificationLog]:
|
|
rules = NotificationRule.objects.filter(
|
|
event_type=event_type,
|
|
is_active=True,
|
|
channel__is_active=True,
|
|
).select_related("channel")
|
|
|
|
if payload.experiment_id:
|
|
rules = rules.filter(
|
|
Q(experiment__isnull=True) | Q(experiment_id=payload.experiment_id)
|
|
)
|
|
else:
|
|
rules = rules.filter(experiment__isnull=True)
|
|
|
|
logs: list[NotificationLog] = []
|
|
for rule in rules:
|
|
event_key = _build_event_key(
|
|
event_type,
|
|
payload,
|
|
rule.rate_limit_window_seconds,
|
|
)
|
|
sent_or_pending = NotificationLog.objects.filter(
|
|
event_key=event_key,
|
|
channel=rule.channel,
|
|
status__in=[NotificationStatus.PENDING, NotificationStatus.SENT],
|
|
).count()
|
|
if sent_or_pending >= rule.rate_limit_max_notifications:
|
|
continue
|
|
|
|
log = NotificationLog.objects.create(
|
|
rule=rule,
|
|
channel=rule.channel,
|
|
event_type=event_type,
|
|
event_key=event_key,
|
|
payload={
|
|
"title": payload.title,
|
|
"body": payload.body,
|
|
"experiment_id": payload.experiment_id,
|
|
"experiment_name": payload.experiment_name,
|
|
"extra": payload.extra,
|
|
},
|
|
status=NotificationStatus.PENDING,
|
|
)
|
|
logs.append(log)
|
|
|
|
return logs
|
|
|
|
|
|
def _build_event_key(
|
|
event_type: str,
|
|
payload: NotificationPayload,
|
|
window_seconds: int,
|
|
) -> str:
|
|
normalized_window = max(window_seconds, 1)
|
|
bucket = int(timezone.now().timestamp()) // normalized_window
|
|
return f"{event_type}:{payload.experiment_id}:{bucket}"
|
|
|
|
|
|
def _escape_markdown(text: str) -> str:
|
|
for ch in r"\_*[]()~`>#+-=|{}.!":
|
|
text = text.replace(ch, f"\\{ch}")
|
|
return text
|
|
|
|
|
|
def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
|
|
bot_token = config.get("bot_token", "")
|
|
chat_id = config.get("chat_id", "")
|
|
if not bot_token or not chat_id:
|
|
raise ValueError("Telegram config requires 'bot_token' and 'chat_id'.")
|
|
|
|
title = _escape_markdown(payload["title"])
|
|
body = _escape_markdown(payload["body"])
|
|
text = f"*{title}*\n\n{body}"
|
|
if payload.get("experiment_name"):
|
|
name = _escape_markdown(payload["experiment_name"])
|
|
text += f"\n\nExperiment: {name}"
|
|
|
|
api_url = config.get(
|
|
"api_url",
|
|
f"https://api.telegram.org/bot{bot_token}",
|
|
)
|
|
response = requests.post(
|
|
f"{api_url}/sendMessage",
|
|
json={
|
|
"chat_id": chat_id,
|
|
"text": text,
|
|
"parse_mode": "MarkdownV2",
|
|
},
|
|
timeout=10,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
|
|
def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
|
|
recipient = config.get("recipient", "")
|
|
from_email = config.get("from_email", DEFAULT_SMTP_FROM_EMAIL)
|
|
if not isinstance(recipient, str) or not recipient.strip():
|
|
raise ValueError("SMTP config requires 'recipient'.")
|
|
recipient = recipient.strip()
|
|
if not isinstance(from_email, str) or not from_email.strip():
|
|
from_email = DEFAULT_SMTP_FROM_EMAIL
|
|
|
|
connection_options = {
|
|
key: config.get(key) for key in SMTP_CONNECTION_OPTION_KEYS
|
|
}
|
|
runtime_options: dict[str, Any] = {}
|
|
for key, value in connection_options.items():
|
|
if value is None:
|
|
continue
|
|
if isinstance(value, str) and not value:
|
|
continue
|
|
runtime_options[key] = value
|
|
connection = get_connection(
|
|
fail_silently=False,
|
|
**runtime_options,
|
|
)
|
|
|
|
subject = payload.get("title", "Lotty Notification")
|
|
body = payload.get("body", "")
|
|
if payload.get("experiment_name"):
|
|
body += f"\n\nExperiment: {payload['experiment_name']}"
|
|
|
|
email = EmailMessage(
|
|
subject=subject,
|
|
body=body,
|
|
from_email=from_email,
|
|
to=[recipient],
|
|
connection=connection,
|
|
)
|
|
email.send(fail_silently=False)
|
|
|
|
|
|
def flush_pending_notifications() -> dict[str, int]:
|
|
pending = (
|
|
NotificationLog.objects.filter(
|
|
status=NotificationStatus.PENDING,
|
|
)
|
|
.select_related("channel")
|
|
.order_by("created_at")
|
|
)
|
|
|
|
senders = {
|
|
ChannelType.TELEGRAM: _send_telegram,
|
|
ChannelType.SMTP: _send_smtp,
|
|
}
|
|
|
|
results = {"sent": 0, "failed": 0}
|
|
|
|
for log in pending:
|
|
if not log.channel or not log.channel.is_active:
|
|
log.status = NotificationStatus.FAILED
|
|
log.error = "Channel is inactive or missing."
|
|
log.save(update_fields=["status", "error"])
|
|
results["failed"] += 1
|
|
continue
|
|
|
|
sender = senders.get(log.channel.channel_type)
|
|
if not sender:
|
|
log.status = NotificationStatus.FAILED
|
|
log.error = (
|
|
f"No sender for channel type '{log.channel.channel_type}'."
|
|
)
|
|
log.save(update_fields=["status", "error"])
|
|
results["failed"] += 1
|
|
continue
|
|
|
|
try:
|
|
sender(log.channel.config, log.payload)
|
|
log.status = NotificationStatus.SENT
|
|
log.sent_at = timezone.now()
|
|
log.save(update_fields=["status", "sent_at"])
|
|
results["sent"] += 1
|
|
except Exception as exc:
|
|
logger.exception(
|
|
"notification_send_failed",
|
|
extra={
|
|
"log_id": str(log.pk),
|
|
"channel": log.channel.name,
|
|
"error": str(exc),
|
|
},
|
|
)
|
|
log.status = NotificationStatus.FAILED
|
|
log.error = str(exc)[:1000]
|
|
log.save(update_fields=["status", "error"])
|
|
results["failed"] += 1
|
|
|
|
return results
|