chore(notifications): fixed validation for notification providers
This commit is contained in:
@@ -3,7 +3,8 @@ from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from django.core.mail import send_mail
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.mail import EmailMessage, get_connection
|
||||
from django.db import transaction
|
||||
from django.db.models import Q, QuerySet
|
||||
from django.utils import timezone
|
||||
@@ -18,6 +19,21 @@ from apps.notifications.models import (
|
||||
|
||||
logger = logging.getLogger("lotty")
|
||||
|
||||
REQUIRED_CHANNEL_CONFIG_FIELDS: dict[str, tuple[str, ...]] = {
|
||||
ChannelType.TELEGRAM: ("bot_token", "chat_id"),
|
||||
ChannelType.SMTP: ("recipient",),
|
||||
}
|
||||
SMTP_CONNECTION_OPTION_KEYS: tuple[str, ...] = (
|
||||
"host",
|
||||
"port",
|
||||
"username",
|
||||
"password",
|
||||
"use_tls",
|
||||
"use_ssl",
|
||||
"timeout",
|
||||
)
|
||||
DEFAULT_SMTP_FROM_EMAIL = "lotty@lotty.local"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NotificationPayload:
|
||||
@@ -29,6 +45,38 @@ class NotificationPayload:
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _required_str_config(
|
||||
config: dict[str, Any],
|
||||
field: str,
|
||||
channel_type: str,
|
||||
) -> str:
|
||||
value = config.get(field)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
raise ValidationError(
|
||||
{
|
||||
"config": (
|
||||
f"Channel '{channel_type}' requires non-empty '{field}'."
|
||||
)
|
||||
}
|
||||
)
|
||||
return value.strip()
|
||||
|
||||
|
||||
def _validate_channel_config(
|
||||
*,
|
||||
channel_type: str,
|
||||
config: dict[str, Any],
|
||||
) -> None:
|
||||
if not isinstance(config, dict):
|
||||
raise ValidationError(
|
||||
{"config": "Channel config must be a JSON object."}
|
||||
)
|
||||
for required_field in REQUIRED_CHANNEL_CONFIG_FIELDS.get(
|
||||
channel_type, ()
|
||||
):
|
||||
_required_str_config(config, required_field, channel_type)
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
def channel_create(
|
||||
*,
|
||||
@@ -36,10 +84,15 @@ def channel_create(
|
||||
name: str,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> NotificationChannel:
|
||||
normalized_config = config or {}
|
||||
_validate_channel_config(
|
||||
channel_type=channel_type,
|
||||
config=normalized_config,
|
||||
)
|
||||
channel = NotificationChannel(
|
||||
channel_type=channel_type,
|
||||
name=name,
|
||||
config=config or {},
|
||||
config=normalized_config,
|
||||
)
|
||||
channel.save()
|
||||
return channel
|
||||
@@ -54,6 +107,15 @@ def channel_update(
|
||||
for key in fields:
|
||||
if key not in allowed:
|
||||
raise ValueError(f"Field '{key}' cannot be updated.")
|
||||
next_config = (
|
||||
fields["config"]
|
||||
if "config" in fields and fields["config"] is not None
|
||||
else None
|
||||
)
|
||||
_validate_channel_config(
|
||||
channel_type=channel.channel_type,
|
||||
config=channel.config if next_config is None else next_config,
|
||||
)
|
||||
for key, value in fields.items():
|
||||
if value is not None:
|
||||
setattr(channel, key, value)
|
||||
@@ -237,28 +299,51 @@ def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
|
||||
|
||||
def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
|
||||
recipient = config.get("recipient", "")
|
||||
from_email = config.get("from_email", "lotty@lotty.local")
|
||||
if not recipient:
|
||||
from_email = config.get("from_email", DEFAULT_SMTP_FROM_EMAIL)
|
||||
if not isinstance(recipient, str) or not recipient.strip():
|
||||
raise ValueError("SMTP config requires 'recipient'.")
|
||||
recipient = recipient.strip()
|
||||
if not isinstance(from_email, str) or not from_email.strip():
|
||||
from_email = DEFAULT_SMTP_FROM_EMAIL
|
||||
|
||||
connection_options = {
|
||||
key: config.get(key) for key in SMTP_CONNECTION_OPTION_KEYS
|
||||
}
|
||||
runtime_options: dict[str, Any] = {}
|
||||
for key, value in connection_options.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str) and not value:
|
||||
continue
|
||||
runtime_options[key] = value
|
||||
connection = get_connection(
|
||||
fail_silently=False,
|
||||
**runtime_options,
|
||||
)
|
||||
|
||||
subject = payload.get("title", "Lotty Notification")
|
||||
body = payload.get("body", "")
|
||||
if payload.get("experiment_name"):
|
||||
body += f"\n\nExperiment: {payload['experiment_name']}"
|
||||
|
||||
send_mail(
|
||||
email = EmailMessage(
|
||||
subject=subject,
|
||||
message=body,
|
||||
body=body,
|
||||
from_email=from_email,
|
||||
recipient_list=[recipient],
|
||||
fail_silently=False,
|
||||
to=[recipient],
|
||||
connection=connection,
|
||||
)
|
||||
email.send(fail_silently=False)
|
||||
|
||||
|
||||
def flush_pending_notifications() -> dict[str, int]:
|
||||
pending = NotificationLog.objects.filter(
|
||||
status=NotificationStatus.PENDING,
|
||||
).select_related("channel").order_by("created_at")
|
||||
pending = (
|
||||
NotificationLog.objects.filter(
|
||||
status=NotificationStatus.PENDING,
|
||||
)
|
||||
.select_related("channel")
|
||||
.order_by("created_at")
|
||||
)
|
||||
|
||||
senders = {
|
||||
ChannelType.TELEGRAM: _send_telegram,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, override
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.test import TestCase
|
||||
|
||||
from apps.experiments.tests.helpers import make_experiment
|
||||
@@ -42,6 +43,7 @@ class ChannelCRUDTest(TestCase):
|
||||
ch = channel_create(
|
||||
channel_type=ChannelType.SMTP,
|
||||
name="Old Name",
|
||||
config={"recipient": "ops@lotty.local"},
|
||||
)
|
||||
ch = channel_update(channel=ch, name="New Name")
|
||||
self.assertEqual(ch.name, "New Name")
|
||||
@@ -50,6 +52,7 @@ class ChannelCRUDTest(TestCase):
|
||||
ch = channel_create(
|
||||
channel_type=ChannelType.SMTP,
|
||||
name="X",
|
||||
config={"recipient": "ops@lotty.local"},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
channel_update(channel=ch, channel_type="telegram")
|
||||
@@ -58,16 +61,53 @@ class ChannelCRUDTest(TestCase):
|
||||
ch = channel_create(
|
||||
channel_type=ChannelType.TELEGRAM,
|
||||
name="Delete Me",
|
||||
config={"bot_token": "tok", "chat_id": "123"},
|
||||
)
|
||||
pk = ch.pk
|
||||
channel_delete(channel=ch)
|
||||
self.assertIsNone(channel_get(pk))
|
||||
|
||||
def test_list_channels(self) -> None:
|
||||
channel_create(channel_type=ChannelType.TELEGRAM, name="A")
|
||||
channel_create(channel_type=ChannelType.SMTP, name="B")
|
||||
channel_create(
|
||||
channel_type=ChannelType.TELEGRAM,
|
||||
name="A",
|
||||
config={"bot_token": "tok", "chat_id": "123"},
|
||||
)
|
||||
channel_create(
|
||||
channel_type=ChannelType.SMTP,
|
||||
name="B",
|
||||
config={"recipient": "team@lotty.local"},
|
||||
)
|
||||
self.assertEqual(channel_list().count(), 2)
|
||||
|
||||
def test_create_telegram_channel_requires_config(self) -> None:
|
||||
with self.assertRaises(ValidationError):
|
||||
channel_create(
|
||||
channel_type=ChannelType.TELEGRAM,
|
||||
name="Broken",
|
||||
config={},
|
||||
)
|
||||
|
||||
def test_create_smtp_channel_requires_recipient(self) -> None:
|
||||
with self.assertRaises(ValidationError):
|
||||
channel_create(
|
||||
channel_type=ChannelType.SMTP,
|
||||
name="Broken",
|
||||
config={},
|
||||
)
|
||||
|
||||
def test_update_channel_rejects_invalid_provider_config(self) -> None:
|
||||
channel = channel_create(
|
||||
channel_type=ChannelType.TELEGRAM,
|
||||
name="Broken Update",
|
||||
config={"bot_token": "tok", "chat_id": "123"},
|
||||
)
|
||||
with self.assertRaises(ValidationError):
|
||||
channel_update(
|
||||
channel=channel,
|
||||
config={"bot_token": "tok"},
|
||||
)
|
||||
|
||||
|
||||
class RuleCRUDTest(TestCase):
|
||||
@override
|
||||
@@ -247,7 +287,10 @@ class NotificationEnqueueTest(TestCase):
|
||||
NotificationEventType.EXPERIMENT_RESUMED,
|
||||
NotificationPayload(
|
||||
title="Experiment Resumed",
|
||||
body=f"Experiment '{self.experiment.name}' - experiment resumed.",
|
||||
body=(
|
||||
f"Experiment '{self.experiment.name}' -"
|
||||
" experiment resumed."
|
||||
),
|
||||
event_type=NotificationEventType.EXPERIMENT_RESUMED,
|
||||
experiment_id=str(self.experiment.pk),
|
||||
experiment_name=self.experiment.name,
|
||||
@@ -352,3 +395,58 @@ class FlushNotificationsTest(TestCase):
|
||||
results = flush_pending_notifications()
|
||||
self.assertEqual(results["sent"], 1)
|
||||
mock_send.assert_called_once()
|
||||
|
||||
|
||||
class SMTPDeliveryTest(TestCase):
|
||||
@patch("apps.notifications.services.EmailMessage")
|
||||
@patch("apps.notifications.services.get_connection")
|
||||
def test_send_smtp_uses_runtime_connection_options(
|
||||
self,
|
||||
mock_get_connection: Any,
|
||||
mock_email_cls: Any,
|
||||
) -> None:
|
||||
channel = channel_create(
|
||||
channel_type=ChannelType.SMTP,
|
||||
name="SMTP Runtime",
|
||||
config={
|
||||
"recipient": "team@lotty.local",
|
||||
"from_email": "alerts@lotty.local",
|
||||
"host": "smtp.lotty.local",
|
||||
"port": 2525,
|
||||
"username": "mailer",
|
||||
"password": "secret",
|
||||
"use_tls": True,
|
||||
"timeout": 7,
|
||||
},
|
||||
)
|
||||
rule = rule_create(
|
||||
event_type=NotificationEventType.EXPERIMENT_COMPLETED,
|
||||
channel=channel,
|
||||
)
|
||||
email_instance = mock_email_cls.return_value
|
||||
NotificationLog.objects.create(
|
||||
rule=rule,
|
||||
channel=channel,
|
||||
event_type=NotificationEventType.EXPERIMENT_COMPLETED,
|
||||
event_key="smtp:override:1",
|
||||
payload={"title": "Done", "body": "Completed"},
|
||||
status=NotificationStatus.PENDING,
|
||||
)
|
||||
|
||||
results = flush_pending_notifications()
|
||||
|
||||
self.assertEqual(results["sent"], 1)
|
||||
mock_get_connection.assert_called_once_with(
|
||||
fail_silently=False,
|
||||
host="smtp.lotty.local",
|
||||
port=2525,
|
||||
username="mailer",
|
||||
password="secret",
|
||||
use_tls=True,
|
||||
timeout=7,
|
||||
)
|
||||
mock_email_cls.assert_called_once()
|
||||
_, kwargs = mock_email_cls.call_args
|
||||
self.assertEqual(kwargs["from_email"], "alerts@lotty.local")
|
||||
self.assertEqual(kwargs["to"], ["team@lotty.local"])
|
||||
email_instance.send.assert_called_once_with(fail_silently=False)
|
||||
|
||||
Reference in New Issue
Block a user