chore(notifications): fixed validation for notification providers
This commit is contained in:
@@ -3,7 +3,8 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from django.core.mail import send_mail
|
from django.core.exceptions import ValidationError
|
||||||
|
from django.core.mail import EmailMessage, get_connection
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Q, QuerySet
|
from django.db.models import Q, QuerySet
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
@@ -18,6 +19,21 @@ from apps.notifications.models import (
|
|||||||
|
|
||||||
logger = logging.getLogger("lotty")
|
logger = logging.getLogger("lotty")
|
||||||
|
|
||||||
|
REQUIRED_CHANNEL_CONFIG_FIELDS: dict[str, tuple[str, ...]] = {
|
||||||
|
ChannelType.TELEGRAM: ("bot_token", "chat_id"),
|
||||||
|
ChannelType.SMTP: ("recipient",),
|
||||||
|
}
|
||||||
|
SMTP_CONNECTION_OPTION_KEYS: tuple[str, ...] = (
|
||||||
|
"host",
|
||||||
|
"port",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"use_tls",
|
||||||
|
"use_ssl",
|
||||||
|
"timeout",
|
||||||
|
)
|
||||||
|
DEFAULT_SMTP_FROM_EMAIL = "lotty@lotty.local"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NotificationPayload:
|
class NotificationPayload:
|
||||||
@@ -29,6 +45,38 @@ class NotificationPayload:
|
|||||||
extra: dict[str, Any] = field(default_factory=dict)
|
extra: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _required_str_config(
|
||||||
|
config: dict[str, Any],
|
||||||
|
field: str,
|
||||||
|
channel_type: str,
|
||||||
|
) -> str:
|
||||||
|
value = config.get(field)
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
raise ValidationError(
|
||||||
|
{
|
||||||
|
"config": (
|
||||||
|
f"Channel '{channel_type}' requires non-empty '{field}'."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_channel_config(
|
||||||
|
*,
|
||||||
|
channel_type: str,
|
||||||
|
config: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
raise ValidationError(
|
||||||
|
{"config": "Channel config must be a JSON object."}
|
||||||
|
)
|
||||||
|
for required_field in REQUIRED_CHANNEL_CONFIG_FIELDS.get(
|
||||||
|
channel_type, ()
|
||||||
|
):
|
||||||
|
_required_str_config(config, required_field, channel_type)
|
||||||
|
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def channel_create(
|
def channel_create(
|
||||||
*,
|
*,
|
||||||
@@ -36,10 +84,15 @@ def channel_create(
|
|||||||
name: str,
|
name: str,
|
||||||
config: dict[str, Any] | None = None,
|
config: dict[str, Any] | None = None,
|
||||||
) -> NotificationChannel:
|
) -> NotificationChannel:
|
||||||
|
normalized_config = config or {}
|
||||||
|
_validate_channel_config(
|
||||||
|
channel_type=channel_type,
|
||||||
|
config=normalized_config,
|
||||||
|
)
|
||||||
channel = NotificationChannel(
|
channel = NotificationChannel(
|
||||||
channel_type=channel_type,
|
channel_type=channel_type,
|
||||||
name=name,
|
name=name,
|
||||||
config=config or {},
|
config=normalized_config,
|
||||||
)
|
)
|
||||||
channel.save()
|
channel.save()
|
||||||
return channel
|
return channel
|
||||||
@@ -54,6 +107,15 @@ def channel_update(
|
|||||||
for key in fields:
|
for key in fields:
|
||||||
if key not in allowed:
|
if key not in allowed:
|
||||||
raise ValueError(f"Field '{key}' cannot be updated.")
|
raise ValueError(f"Field '{key}' cannot be updated.")
|
||||||
|
next_config = (
|
||||||
|
fields["config"]
|
||||||
|
if "config" in fields and fields["config"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
_validate_channel_config(
|
||||||
|
channel_type=channel.channel_type,
|
||||||
|
config=channel.config if next_config is None else next_config,
|
||||||
|
)
|
||||||
for key, value in fields.items():
|
for key, value in fields.items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(channel, key, value)
|
setattr(channel, key, value)
|
||||||
@@ -237,28 +299,51 @@ def _send_telegram(config: dict[str, Any], payload: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
|
def _send_smtp(config: dict[str, Any], payload: dict[str, Any]) -> None:
|
||||||
recipient = config.get("recipient", "")
|
recipient = config.get("recipient", "")
|
||||||
from_email = config.get("from_email", "lotty@lotty.local")
|
from_email = config.get("from_email", DEFAULT_SMTP_FROM_EMAIL)
|
||||||
if not recipient:
|
if not isinstance(recipient, str) or not recipient.strip():
|
||||||
raise ValueError("SMTP config requires 'recipient'.")
|
raise ValueError("SMTP config requires 'recipient'.")
|
||||||
|
recipient = recipient.strip()
|
||||||
|
if not isinstance(from_email, str) or not from_email.strip():
|
||||||
|
from_email = DEFAULT_SMTP_FROM_EMAIL
|
||||||
|
|
||||||
|
connection_options = {
|
||||||
|
key: config.get(key) for key in SMTP_CONNECTION_OPTION_KEYS
|
||||||
|
}
|
||||||
|
runtime_options: dict[str, Any] = {}
|
||||||
|
for key, value in connection_options.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if isinstance(value, str) and not value:
|
||||||
|
continue
|
||||||
|
runtime_options[key] = value
|
||||||
|
connection = get_connection(
|
||||||
|
fail_silently=False,
|
||||||
|
**runtime_options,
|
||||||
|
)
|
||||||
|
|
||||||
subject = payload.get("title", "Lotty Notification")
|
subject = payload.get("title", "Lotty Notification")
|
||||||
body = payload.get("body", "")
|
body = payload.get("body", "")
|
||||||
if payload.get("experiment_name"):
|
if payload.get("experiment_name"):
|
||||||
body += f"\n\nExperiment: {payload['experiment_name']}"
|
body += f"\n\nExperiment: {payload['experiment_name']}"
|
||||||
|
|
||||||
send_mail(
|
email = EmailMessage(
|
||||||
subject=subject,
|
subject=subject,
|
||||||
message=body,
|
body=body,
|
||||||
from_email=from_email,
|
from_email=from_email,
|
||||||
recipient_list=[recipient],
|
to=[recipient],
|
||||||
fail_silently=False,
|
connection=connection,
|
||||||
)
|
)
|
||||||
|
email.send(fail_silently=False)
|
||||||
|
|
||||||
|
|
||||||
def flush_pending_notifications() -> dict[str, int]:
|
def flush_pending_notifications() -> dict[str, int]:
|
||||||
pending = NotificationLog.objects.filter(
|
pending = (
|
||||||
status=NotificationStatus.PENDING,
|
NotificationLog.objects.filter(
|
||||||
).select_related("channel").order_by("created_at")
|
status=NotificationStatus.PENDING,
|
||||||
|
)
|
||||||
|
.select_related("channel")
|
||||||
|
.order_by("created_at")
|
||||||
|
)
|
||||||
|
|
||||||
senders = {
|
senders = {
|
||||||
ChannelType.TELEGRAM: _send_telegram,
|
ChannelType.TELEGRAM: _send_telegram,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Any, override
|
from typing import Any, override
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from apps.experiments.tests.helpers import make_experiment
|
from apps.experiments.tests.helpers import make_experiment
|
||||||
@@ -42,6 +43,7 @@ class ChannelCRUDTest(TestCase):
|
|||||||
ch = channel_create(
|
ch = channel_create(
|
||||||
channel_type=ChannelType.SMTP,
|
channel_type=ChannelType.SMTP,
|
||||||
name="Old Name",
|
name="Old Name",
|
||||||
|
config={"recipient": "ops@lotty.local"},
|
||||||
)
|
)
|
||||||
ch = channel_update(channel=ch, name="New Name")
|
ch = channel_update(channel=ch, name="New Name")
|
||||||
self.assertEqual(ch.name, "New Name")
|
self.assertEqual(ch.name, "New Name")
|
||||||
@@ -50,6 +52,7 @@ class ChannelCRUDTest(TestCase):
|
|||||||
ch = channel_create(
|
ch = channel_create(
|
||||||
channel_type=ChannelType.SMTP,
|
channel_type=ChannelType.SMTP,
|
||||||
name="X",
|
name="X",
|
||||||
|
config={"recipient": "ops@lotty.local"},
|
||||||
)
|
)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
channel_update(channel=ch, channel_type="telegram")
|
channel_update(channel=ch, channel_type="telegram")
|
||||||
@@ -58,16 +61,53 @@ class ChannelCRUDTest(TestCase):
|
|||||||
ch = channel_create(
|
ch = channel_create(
|
||||||
channel_type=ChannelType.TELEGRAM,
|
channel_type=ChannelType.TELEGRAM,
|
||||||
name="Delete Me",
|
name="Delete Me",
|
||||||
|
config={"bot_token": "tok", "chat_id": "123"},
|
||||||
)
|
)
|
||||||
pk = ch.pk
|
pk = ch.pk
|
||||||
channel_delete(channel=ch)
|
channel_delete(channel=ch)
|
||||||
self.assertIsNone(channel_get(pk))
|
self.assertIsNone(channel_get(pk))
|
||||||
|
|
||||||
def test_list_channels(self) -> None:
|
def test_list_channels(self) -> None:
|
||||||
channel_create(channel_type=ChannelType.TELEGRAM, name="A")
|
channel_create(
|
||||||
channel_create(channel_type=ChannelType.SMTP, name="B")
|
channel_type=ChannelType.TELEGRAM,
|
||||||
|
name="A",
|
||||||
|
config={"bot_token": "tok", "chat_id": "123"},
|
||||||
|
)
|
||||||
|
channel_create(
|
||||||
|
channel_type=ChannelType.SMTP,
|
||||||
|
name="B",
|
||||||
|
config={"recipient": "team@lotty.local"},
|
||||||
|
)
|
||||||
self.assertEqual(channel_list().count(), 2)
|
self.assertEqual(channel_list().count(), 2)
|
||||||
|
|
||||||
|
def test_create_telegram_channel_requires_config(self) -> None:
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
channel_create(
|
||||||
|
channel_type=ChannelType.TELEGRAM,
|
||||||
|
name="Broken",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_create_smtp_channel_requires_recipient(self) -> None:
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
channel_create(
|
||||||
|
channel_type=ChannelType.SMTP,
|
||||||
|
name="Broken",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_update_channel_rejects_invalid_provider_config(self) -> None:
|
||||||
|
channel = channel_create(
|
||||||
|
channel_type=ChannelType.TELEGRAM,
|
||||||
|
name="Broken Update",
|
||||||
|
config={"bot_token": "tok", "chat_id": "123"},
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
channel_update(
|
||||||
|
channel=channel,
|
||||||
|
config={"bot_token": "tok"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RuleCRUDTest(TestCase):
|
class RuleCRUDTest(TestCase):
|
||||||
@override
|
@override
|
||||||
@@ -247,7 +287,10 @@ class NotificationEnqueueTest(TestCase):
|
|||||||
NotificationEventType.EXPERIMENT_RESUMED,
|
NotificationEventType.EXPERIMENT_RESUMED,
|
||||||
NotificationPayload(
|
NotificationPayload(
|
||||||
title="Experiment Resumed",
|
title="Experiment Resumed",
|
||||||
body=f"Experiment '{self.experiment.name}' - experiment resumed.",
|
body=(
|
||||||
|
f"Experiment '{self.experiment.name}' -"
|
||||||
|
" experiment resumed."
|
||||||
|
),
|
||||||
event_type=NotificationEventType.EXPERIMENT_RESUMED,
|
event_type=NotificationEventType.EXPERIMENT_RESUMED,
|
||||||
experiment_id=str(self.experiment.pk),
|
experiment_id=str(self.experiment.pk),
|
||||||
experiment_name=self.experiment.name,
|
experiment_name=self.experiment.name,
|
||||||
@@ -352,3 +395,58 @@ class FlushNotificationsTest(TestCase):
|
|||||||
results = flush_pending_notifications()
|
results = flush_pending_notifications()
|
||||||
self.assertEqual(results["sent"], 1)
|
self.assertEqual(results["sent"], 1)
|
||||||
mock_send.assert_called_once()
|
mock_send.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class SMTPDeliveryTest(TestCase):
|
||||||
|
@patch("apps.notifications.services.EmailMessage")
|
||||||
|
@patch("apps.notifications.services.get_connection")
|
||||||
|
def test_send_smtp_uses_runtime_connection_options(
|
||||||
|
self,
|
||||||
|
mock_get_connection: Any,
|
||||||
|
mock_email_cls: Any,
|
||||||
|
) -> None:
|
||||||
|
channel = channel_create(
|
||||||
|
channel_type=ChannelType.SMTP,
|
||||||
|
name="SMTP Runtime",
|
||||||
|
config={
|
||||||
|
"recipient": "team@lotty.local",
|
||||||
|
"from_email": "alerts@lotty.local",
|
||||||
|
"host": "smtp.lotty.local",
|
||||||
|
"port": 2525,
|
||||||
|
"username": "mailer",
|
||||||
|
"password": "secret",
|
||||||
|
"use_tls": True,
|
||||||
|
"timeout": 7,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rule = rule_create(
|
||||||
|
event_type=NotificationEventType.EXPERIMENT_COMPLETED,
|
||||||
|
channel=channel,
|
||||||
|
)
|
||||||
|
email_instance = mock_email_cls.return_value
|
||||||
|
NotificationLog.objects.create(
|
||||||
|
rule=rule,
|
||||||
|
channel=channel,
|
||||||
|
event_type=NotificationEventType.EXPERIMENT_COMPLETED,
|
||||||
|
event_key="smtp:override:1",
|
||||||
|
payload={"title": "Done", "body": "Completed"},
|
||||||
|
status=NotificationStatus.PENDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = flush_pending_notifications()
|
||||||
|
|
||||||
|
self.assertEqual(results["sent"], 1)
|
||||||
|
mock_get_connection.assert_called_once_with(
|
||||||
|
fail_silently=False,
|
||||||
|
host="smtp.lotty.local",
|
||||||
|
port=2525,
|
||||||
|
username="mailer",
|
||||||
|
password="secret",
|
||||||
|
use_tls=True,
|
||||||
|
timeout=7,
|
||||||
|
)
|
||||||
|
mock_email_cls.assert_called_once()
|
||||||
|
_, kwargs = mock_email_cls.call_args
|
||||||
|
self.assertEqual(kwargs["from_email"], "alerts@lotty.local")
|
||||||
|
self.assertEqual(kwargs["to"], ["team@lotty.local"])
|
||||||
|
email_instance.send.assert_called_once_with(fail_silently=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user