chore(notifications): fixed validation for notification providers

This commit is contained in:
ITQ
2026-02-24 18:02:38 +03:00
parent 2e974e6148
commit a1bc15bdac
2 changed files with 197 additions and 14 deletions
+96 -11
View File
@@ -3,7 +3,8 @@ from dataclasses import dataclass, field
from typing import Any from typing import Any
import requests import requests
from django.core.mail import send_mail from django.core.exceptions import ValidationError
from django.core.mail import EmailMessage, get_connection
from django.db import transaction from django.db import transaction
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from django.utils import timezone from django.utils import timezone
@@ -18,6 +19,21 @@ from apps.notifications.models import (
logger = logging.getLogger("lotty") logger = logging.getLogger("lotty")
REQUIRED_CHANNEL_CONFIG_FIELDS: dict[str, tuple[str, ...]] = {
ChannelType.TELEGRAM: ("bot_token", "chat_id"),
ChannelType.SMTP: ("recipient",),
}
SMTP_CONNECTION_OPTION_KEYS: tuple[str, ...] = (
"host",
"port",
"username",
"password",
"use_tls",
"use_ssl",
"timeout",
)
DEFAULT_SMTP_FROM_EMAIL = "lotty@lotty.local"
@dataclass(frozen=True) @dataclass(frozen=True)
class NotificationPayload: class NotificationPayload:
@@ -29,6 +45,38 @@ class NotificationPayload:
extra: dict[str, Any] = field(default_factory=dict) extra: dict[str, Any] = field(default_factory=dict)
def _required_str_config(
config: dict[str, Any],
field: str,
channel_type: str,
) -> str:
value = config.get(field)
if not isinstance(value, str) or not value.strip():
raise ValidationError(
{
"config": (
f"Channel '{channel_type}' requires non-empty '{field}'."
)
}
)
return value.strip()
def _validate_channel_config(
*,
channel_type: str,
config: dict[str, Any],
) -> None:
if not isinstance(config, dict):
raise ValidationError(
{"config": "Channel config must be a JSON object."}
)
for required_field in REQUIRED_CHANNEL_CONFIG_FIELDS.get(
channel_type, ()
):
_required_str_config(config, required_field, channel_type)
@transaction.atomic @transaction.atomic
def channel_create( def channel_create(
*, *,
@@ -36,10 +84,15 @@ def channel_create(
name: str, name: str,
config: dict[str, Any] | None = None, config: dict[str, Any] | None = None,
) -> NotificationChannel: ) -> NotificationChannel:
normalized_config = config or {}
_validate_channel_config(
channel_type=channel_type,
config=normalized_config,
)
channel = NotificationChannel( channel = NotificationChannel(
channel_type=channel_type, channel_type=channel_type,
name=name, name=name,
config=config or {}, config=normalized_config,
) )
channel.save() channel.save()
return channel return channel
@@ -54,6 +107,15 @@ def channel_update(
for key in fields: for key in fields:
if key not in allowed: if key not in allowed:
raise ValueError(f"Field '{key}' cannot be updated.") raise ValueError(f"Field '{key}' cannot be updated.")
next_config = (
fields["config"]
if "config" in fields and fields["config"] is not None
else None
)
_validate_channel_config(
channel_type=channel.channel_type,
config=channel.config if next_config is None else next_config,
)
for key, value in fields.items(): for key, value in fields.items():
if value is not None: if value is not None:
setattr(channel, key, value) setattr(channel, key, value)
@@ -237,28 +299,51 @@ def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None: def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
recipient = config.get("recipient", "") recipient = config.get("recipient", "")
from_email = config.get("from_email", "lotty@lotty.local") from_email = config.get("from_email", DEFAULT_SMTP_FROM_EMAIL)
if not recipient: if not isinstance(recipient, str) or not recipient.strip():
raise ValueError("SMTP config requires 'recipient'.") raise ValueError("SMTP config requires 'recipient'.")
recipient = recipient.strip()
if not isinstance(from_email, str) or not from_email.strip():
from_email = DEFAULT_SMTP_FROM_EMAIL
connection_options = {
key: config.get(key) for key in SMTP_CONNECTION_OPTION_KEYS
}
runtime_options: dict[str, Any] = {}
for key, value in connection_options.items():
if value is None:
continue
if isinstance(value, str) and not value:
continue
runtime_options[key] = value
connection = get_connection(
fail_silently=False,
**runtime_options,
)
subject = payload.get("title", "Lotty Notification") subject = payload.get("title", "Lotty Notification")
body = payload.get("body", "") body = payload.get("body", "")
if payload.get("experiment_name"): if payload.get("experiment_name"):
body += f"\n\nExperiment: {payload['experiment_name']}" body += f"\n\nExperiment: {payload['experiment_name']}"
send_mail( email = EmailMessage(
subject=subject, subject=subject,
message=body, body=body,
from_email=from_email, from_email=from_email,
recipient_list=[recipient], to=[recipient],
fail_silently=False, connection=connection,
) )
email.send(fail_silently=False)
def flush_pending_notifications() -> dict[str, int]: def flush_pending_notifications() -> dict[str, int]:
pending = NotificationLog.objects.filter( pending = (
status=NotificationStatus.PENDING, NotificationLog.objects.filter(
).select_related("channel").order_by("created_at") status=NotificationStatus.PENDING,
)
.select_related("channel")
.order_by("created_at")
)
senders = { senders = {
ChannelType.TELEGRAM: _send_telegram, ChannelType.TELEGRAM: _send_telegram,
@@ -1,6 +1,7 @@
from typing import Any, override from typing import Any, override
from unittest.mock import patch from unittest.mock import patch
from django.core.exceptions import ValidationError
from django.test import TestCase from django.test import TestCase
from apps.experiments.tests.helpers import make_experiment from apps.experiments.tests.helpers import make_experiment
@@ -42,6 +43,7 @@ class ChannelCRUDTest(TestCase):
ch = channel_create( ch = channel_create(
channel_type=ChannelType.SMTP, channel_type=ChannelType.SMTP,
name="Old Name", name="Old Name",
config={"recipient": "ops@lotty.local"},
) )
ch = channel_update(channel=ch, name="New Name") ch = channel_update(channel=ch, name="New Name")
self.assertEqual(ch.name, "New Name") self.assertEqual(ch.name, "New Name")
@@ -50,6 +52,7 @@ class ChannelCRUDTest(TestCase):
ch = channel_create( ch = channel_create(
channel_type=ChannelType.SMTP, channel_type=ChannelType.SMTP,
name="X", name="X",
config={"recipient": "ops@lotty.local"},
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
channel_update(channel=ch, channel_type="telegram") channel_update(channel=ch, channel_type="telegram")
@@ -58,16 +61,53 @@ class ChannelCRUDTest(TestCase):
ch = channel_create( ch = channel_create(
channel_type=ChannelType.TELEGRAM, channel_type=ChannelType.TELEGRAM,
name="Delete Me", name="Delete Me",
config={"bot_token": "tok", "chat_id": "123"},
) )
pk = ch.pk pk = ch.pk
channel_delete(channel=ch) channel_delete(channel=ch)
self.assertIsNone(channel_get(pk)) self.assertIsNone(channel_get(pk))
def test_list_channels(self) -> None: def test_list_channels(self) -> None:
channel_create(channel_type=ChannelType.TELEGRAM, name="A") channel_create(
channel_create(channel_type=ChannelType.SMTP, name="B") channel_type=ChannelType.TELEGRAM,
name="A",
config={"bot_token": "tok", "chat_id": "123"},
)
channel_create(
channel_type=ChannelType.SMTP,
name="B",
config={"recipient": "team@lotty.local"},
)
self.assertEqual(channel_list().count(), 2) self.assertEqual(channel_list().count(), 2)
def test_create_telegram_channel_requires_config(self) -> None:
with self.assertRaises(ValidationError):
channel_create(
channel_type=ChannelType.TELEGRAM,
name="Broken",
config={},
)
def test_create_smtp_channel_requires_recipient(self) -> None:
with self.assertRaises(ValidationError):
channel_create(
channel_type=ChannelType.SMTP,
name="Broken",
config={},
)
def test_update_channel_rejects_invalid_provider_config(self) -> None:
channel = channel_create(
channel_type=ChannelType.TELEGRAM,
name="Broken Update",
config={"bot_token": "tok", "chat_id": "123"},
)
with self.assertRaises(ValidationError):
channel_update(
channel=channel,
config={"bot_token": "tok"},
)
class RuleCRUDTest(TestCase): class RuleCRUDTest(TestCase):
@override @override
@@ -247,7 +287,10 @@ class NotificationEnqueueTest(TestCase):
NotificationEventType.EXPERIMENT_RESUMED, NotificationEventType.EXPERIMENT_RESUMED,
NotificationPayload( NotificationPayload(
title="Experiment Resumed", title="Experiment Resumed",
body=f"Experiment '{self.experiment.name}' - experiment resumed.", body=(
f"Experiment '{self.experiment.name}' -"
" experiment resumed."
),
event_type=NotificationEventType.EXPERIMENT_RESUMED, event_type=NotificationEventType.EXPERIMENT_RESUMED,
experiment_id=str(self.experiment.pk), experiment_id=str(self.experiment.pk),
experiment_name=self.experiment.name, experiment_name=self.experiment.name,
@@ -352,3 +395,58 @@ class FlushNotificationsTest(TestCase):
results = flush_pending_notifications() results = flush_pending_notifications()
self.assertEqual(results["sent"], 1) self.assertEqual(results["sent"], 1)
mock_send.assert_called_once() mock_send.assert_called_once()
class SMTPDeliveryTest(TestCase):
@patch("apps.notifications.services.EmailMessage")
@patch("apps.notifications.services.get_connection")
def test_send_smtp_uses_runtime_connection_options(
self,
mock_get_connection: Any,
mock_email_cls: Any,
) -> None:
channel = channel_create(
channel_type=ChannelType.SMTP,
name="SMTP Runtime",
config={
"recipient": "team@lotty.local",
"from_email": "alerts@lotty.local",
"host": "smtp.lotty.local",
"port": 2525,
"username": "mailer",
"password": "secret",
"use_tls": True,
"timeout": 7,
},
)
rule = rule_create(
event_type=NotificationEventType.EXPERIMENT_COMPLETED,
channel=channel,
)
email_instance = mock_email_cls.return_value
NotificationLog.objects.create(
rule=rule,
channel=channel,
event_type=NotificationEventType.EXPERIMENT_COMPLETED,
event_key="smtp:override:1",
payload={"title": "Done", "body": "Completed"},
status=NotificationStatus.PENDING,
)
results = flush_pending_notifications()
self.assertEqual(results["sent"], 1)
mock_get_connection.assert_called_once_with(
fail_silently=False,
host="smtp.lotty.local",
port=2525,
username="mailer",
password="secret",
use_tls=True,
timeout=7,
)
mock_email_cls.assert_called_once()
_, kwargs = mock_email_cls.call_args
self.assertEqual(kwargs["from_email"], "alerts@lotty.local")
self.assertEqual(kwargs["to"], ["team@lotty.local"])
email_instance.send.assert_called_once_with(fail_silently=False)