From a1bc15bdac6068f46ba5ea6281fc51c56f3387ec Mon Sep 17 00:00:00 2001 From: ITQ Date: Tue, 24 Feb 2026 18:02:38 +0300 Subject: [PATCH] chore(notifications): fixed validation for notification providers --- src/backend/apps/notifications/services.py | 107 ++++++++++++++++-- .../notifications/tests/test_notifications.py | 104 ++++++++++++++++- 2 files changed, 197 insertions(+), 14 deletions(-) diff --git a/src/backend/apps/notifications/services.py b/src/backend/apps/notifications/services.py index f33ba1c..a71dd84 100644 --- a/src/backend/apps/notifications/services.py +++ b/src/backend/apps/notifications/services.py @@ -3,7 +3,8 @@ from dataclasses import dataclass, field from typing import Any import requests -from django.core.mail import send_mail +from django.core.exceptions import ValidationError +from django.core.mail import EmailMessage, get_connection from django.db import transaction from django.db.models import Q, QuerySet from django.utils import timezone @@ -18,6 +19,21 @@ from apps.notifications.models import ( logger = logging.getLogger("lotty") +REQUIRED_CHANNEL_CONFIG_FIELDS: dict[str, tuple[str, ...]] = { + ChannelType.TELEGRAM: ("bot_token", "chat_id"), + ChannelType.SMTP: ("recipient",), +} +SMTP_CONNECTION_OPTION_KEYS: tuple[str, ...] = ( + "host", + "port", + "username", + "password", + "use_tls", + "use_ssl", + "timeout", +) +DEFAULT_SMTP_FROM_EMAIL = "lotty@lotty.local" + @dataclass(frozen=True) class NotificationPayload: @@ -29,6 +45,38 @@ class NotificationPayload: extra: dict[str, Any] = field(default_factory=dict) +def _required_str_config( + config: dict[str, Any], + field: str, + channel_type: str, +) -> str: + value = config.get(field) + if not isinstance(value, str) or not value.strip(): + raise ValidationError( + { + "config": ( + f"Channel '{channel_type}' requires non-empty '{field}'." + ) + } + ) + return value.strip() + + +def _validate_channel_config( + *, + channel_type: str, + config: dict[str, Any], +) -> None: + if not isinstance(config, dict): + raise ValidationError( + {"config": "Channel config must be a JSON object."} + ) + for required_field in REQUIRED_CHANNEL_CONFIG_FIELDS.get( + channel_type, () + ): + _required_str_config(config, required_field, channel_type) + + @transaction.atomic def channel_create( *, @@ -36,10 +84,15 @@ def channel_create( name: str, config: dict[str, Any] | None = None, ) -> NotificationChannel: + normalized_config = config or {} + _validate_channel_config( + channel_type=channel_type, + config=normalized_config, + ) channel = NotificationChannel( channel_type=channel_type, name=name, - config=config or {}, + config=normalized_config, ) channel.save() return channel @@ -54,6 +107,15 @@ def channel_update( for key in fields: if key not in allowed: raise ValueError(f"Field '{key}' cannot be updated.") + next_config = ( + fields["config"] + if "config" in fields and fields["config"] is not None + else None + ) + _validate_channel_config( + channel_type=channel.channel_type, + config=channel.config if next_config is None else next_config, + ) for key, value in fields.items(): if value is not None: setattr(channel, key, value) @@ -237,28 +299,51 @@ def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None: def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None: recipient = config.get("recipient", "") - from_email = config.get("from_email", "lotty@lotty.local") - if not recipient: + from_email = config.get("from_email", DEFAULT_SMTP_FROM_EMAIL) + if not isinstance(recipient, str) or not recipient.strip(): raise ValueError("SMTP config requires 'recipient'.") + recipient = recipient.strip() + if not isinstance(from_email, str) or not from_email.strip(): + from_email = DEFAULT_SMTP_FROM_EMAIL + + connection_options = { + key: config.get(key) for key in SMTP_CONNECTION_OPTION_KEYS + } + runtime_options: dict[str, Any] = {} + for key, value in connection_options.items(): + if value is None: + continue + if isinstance(value, str) and not value: + continue + runtime_options[key] = value + connection = get_connection( + fail_silently=False, + **runtime_options, + ) subject = payload.get("title", "Lotty Notification") body = payload.get("body", "") if payload.get("experiment_name"): body += f"\n\nExperiment: {payload['experiment_name']}" - send_mail( + email = EmailMessage( subject=subject, - message=body, + body=body, from_email=from_email, - recipient_list=[recipient], - fail_silently=False, + to=[recipient], + connection=connection, ) + email.send(fail_silently=False) def flush_pending_notifications() -> dict[str, int]: - pending = NotificationLog.objects.filter( - status=NotificationStatus.PENDING, - ).select_related("channel").order_by("created_at") + pending = ( + NotificationLog.objects.filter( + status=NotificationStatus.PENDING, + ) + .select_related("channel") + .order_by("created_at") + ) senders = { ChannelType.TELEGRAM: _send_telegram, diff --git a/src/backend/apps/notifications/tests/test_notifications.py b/src/backend/apps/notifications/tests/test_notifications.py index 34bb284..c388f72 100644 --- a/src/backend/apps/notifications/tests/test_notifications.py +++ b/src/backend/apps/notifications/tests/test_notifications.py @@ -1,6 +1,7 @@ from typing import Any, override from unittest.mock import patch +from django.core.exceptions import ValidationError from django.test import TestCase from apps.experiments.tests.helpers import make_experiment @@ -42,6 +43,7 @@ class ChannelCRUDTest(TestCase): ch = channel_create( channel_type=ChannelType.SMTP, name="Old Name", + config={"recipient": "ops@lotty.local"}, ) ch = channel_update(channel=ch, name="New Name") self.assertEqual(ch.name, "New Name") @@ -50,6 +52,7 @@ class ChannelCRUDTest(TestCase): ch = channel_create( channel_type=ChannelType.SMTP, name="X", + config={"recipient": "ops@lotty.local"}, ) with self.assertRaises(ValueError): channel_update(channel=ch, channel_type="telegram") @@ -58,16 +61,53 @@ class ChannelCRUDTest(TestCase): ch = channel_create( channel_type=ChannelType.TELEGRAM, name="Delete Me", + config={"bot_token": "tok", "chat_id": "123"}, ) pk = ch.pk channel_delete(channel=ch) self.assertIsNone(channel_get(pk)) def test_list_channels(self) -> None: - channel_create(channel_type=ChannelType.TELEGRAM, name="A") - channel_create(channel_type=ChannelType.SMTP, name="B") + channel_create( + channel_type=ChannelType.TELEGRAM, + name="A", + config={"bot_token": "tok", "chat_id": "123"}, + ) + channel_create( + channel_type=ChannelType.SMTP, + name="B", + config={"recipient": "team@lotty.local"}, + ) self.assertEqual(channel_list().count(), 2) + def test_create_telegram_channel_requires_config(self) -> None: + with self.assertRaises(ValidationError): + channel_create( + channel_type=ChannelType.TELEGRAM, + name="Broken", + config={}, + ) + + def test_create_smtp_channel_requires_recipient(self) -> None: + with self.assertRaises(ValidationError): + channel_create( + channel_type=ChannelType.SMTP, + name="Broken", + config={}, + ) + + def test_update_channel_rejects_invalid_provider_config(self) -> None: + channel = channel_create( + channel_type=ChannelType.TELEGRAM, + name="Broken Update", + config={"bot_token": "tok", "chat_id": "123"}, + ) + with self.assertRaises(ValidationError): + channel_update( + channel=channel, + config={"bot_token": "tok"}, + ) + class RuleCRUDTest(TestCase): @override @@ -247,7 +287,10 @@ class NotificationEnqueueTest(TestCase): NotificationEventType.EXPERIMENT_RESUMED, NotificationPayload( title="Experiment Resumed", - body=f"Experiment '{self.experiment.name}' - experiment resumed.", + body=( + f"Experiment '{self.experiment.name}' -" + " experiment resumed." + ), event_type=NotificationEventType.EXPERIMENT_RESUMED, experiment_id=str(self.experiment.pk), experiment_name=self.experiment.name, @@ -352,3 +395,58 @@ class FlushNotificationsTest(TestCase): results = flush_pending_notifications() self.assertEqual(results["sent"], 1) mock_send.assert_called_once() + + +class SMTPDeliveryTest(TestCase): + @patch("apps.notifications.services.EmailMessage") + @patch("apps.notifications.services.get_connection") + def test_send_smtp_uses_runtime_connection_options( + self, + mock_get_connection: Any, + mock_email_cls: Any, + ) -> None: + channel = channel_create( + channel_type=ChannelType.SMTP, + name="SMTP Runtime", + config={ + "recipient": "team@lotty.local", + "from_email": "alerts@lotty.local", + "host": "smtp.lotty.local", + "port": 2525, + "username": "mailer", + "password": "secret", + "use_tls": True, + "timeout": 7, + }, + ) + rule = rule_create( + event_type=NotificationEventType.EXPERIMENT_COMPLETED, + channel=channel, + ) + email_instance = mock_email_cls.return_value + NotificationLog.objects.create( + rule=rule, + channel=channel, + event_type=NotificationEventType.EXPERIMENT_COMPLETED, + event_key="smtp:override:1", + payload={"title": "Done", "body": "Completed"}, + status=NotificationStatus.PENDING, + ) + + results = flush_pending_notifications() + + self.assertEqual(results["sent"], 1) + mock_get_connection.assert_called_once_with( + fail_silently=False, + host="smtp.lotty.local", + port=2525, + username="mailer", + password="secret", + use_tls=True, + timeout=7, + ) + mock_email_cls.assert_called_once() + _, kwargs = mock_email_cls.call_args + self.assertEqual(kwargs["from_email"], "alerts@lotty.local") + self.assertEqual(kwargs["to"], ["team@lotty.local"]) + email_instance.send.assert_called_once_with(fail_silently=False)