diff --git a/solution/api/v1/business/schemas.py b/solution/api/v1/business/schemas.py index 25ad8f3..ad40d52 100644 --- a/solution/api/v1/business/schemas.py +++ b/solution/api/v1/business/schemas.py @@ -1,9 +1,9 @@ import datetime import uuid -from typing import ClassVar, Literal +from typing import Any, ClassVar, Literal from ninja import ModelSchema, Schema -from pydantic import Field, StrictInt +from pydantic import Field, StrictInt, field_validator from pydantic_extra_types.country import CountryAlpha2 from apps.business.models import Business @@ -64,6 +64,16 @@ class CreatePromocodeIn(ModelSchema): Promocode.promo_common.field.name, ] + @field_validator("target", mode="before") + def validate_target(cls, value: Any) -> Any: + if not isinstance(value, dict) and not isinstance( + value, + PromocodeTarget, + ): + err = "The 'target' field must be a valid object." + raise ValueError(err) + return value + class CreatePromocodeOut(Schema): id: uuid.UUID @@ -113,6 +123,17 @@ class PatchPromocodeIn(Schema): active_from: datetime.date | None = None active_until: datetime.date | None = None + @staticmethod + @field_validator("target", mode="before") + def validate_target(value: Any) -> Any: + if not isinstance(value, dict) and not isinstance( + value, + PromocodeTarget, + ): + err = "The 'target' field must be a valid object." + raise TypeError(err) + return value + class PromocodeStatsForCountry(Schema): country: str diff --git a/solution/api/v1/business/utils.py b/solution/api/v1/business/utils.py index 8555ada..2e0d55f 100644 --- a/solution/api/v1/business/utils.py +++ b/solution/api/v1/business/utils.py @@ -12,9 +12,7 @@ def map_promocode_to_schema(promocode: Promocode) -> schemas.PromocodeViewOut: target=schemas.PromocodeTargetViewOut( age_from=promocode.target.age_from, age_until=promocode.target.age_until, - country=promocode.target.country_raw - if promocode.target.country_raw - else None, + country=promocode.target.country_raw or None, categories=promocode.target.categories, ), max_count=promocode.max_count, diff --git a/solution/api/v1/user/views.py b/solution/api/v1/user/views.py index 8c106e8..1020b07 100644 --- a/solution/api/v1/user/views.py +++ b/solution/api/v1/user/views.py @@ -170,9 +170,7 @@ def feed( category_lower = filters.category.lower() def matches_category(promocode: Promocode) -> bool: - categories = ( - promocode.target.categories or [] - ) + categories = promocode.target.categories or [] return any( category.lower() == category_lower for category in categories )