diff --git a/app/__init__.py b/app/__init__.py index e69de29..70ccd92 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -0,0 +1,10 @@ +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from app.config import Config + +engine: Engine = create_engine(Config.SQLALCHEMY_DATABASE_URI) + +Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) +session = Session() diff --git a/app/bot.py b/app/bot.py index 48e0e9a..3f63623 100644 --- a/app/bot.py +++ b/app/bot.py @@ -2,34 +2,29 @@ __all__ = ("main",) from typing import Optional -from aiogram import Bot, Dispatcher, types +from aiogram import Bot, Dispatcher from aiogram.enums import ParseMode -from aiogram.filters import CommandStart -from aiogram.types import Message -from aiogram.utils.markdown import hbold + from app.config import Config - - -dp = Dispatcher() - - -@dp.message(CommandStart()) -async def command_start_handler(message: Message) -> None: - await message.answer(f"Hello, {hbold(message.from_user.full_name)}!") - - -@dp.message() -async def echo_handler(message: types.Message) -> None: - try: - await message.send_copy(chat_id=message.chat.id) - except TypeError: - await message.answer("Nice try!") +from app.handlers import help_command, profile_command, start_command +from app.middlewares.throttling import ThrottlingMiddleware async def main() -> None: + dp = Dispatcher() + bot_token: Optional[str] = Config.BOT_TOKEN if bot_token is not None: bot = Bot(bot_token, parse_mode=ParseMode.HTML) + + dp.message.middleware(ThrottlingMiddleware(0.5)) + dp.include_routers( + start_command.router, + profile_command.router, + help_command.router, + ) + + await bot.delete_webhook(drop_pending_updates=True) await dp.start_polling(bot) else: exit("BOT_TOKEN is not set") diff --git a/app/config.py b/app/config.py index 36691bb..034d344 100644 --- a/app/config.py +++ b/app/config.py @@ -8,5 +8,8 @@ load_dotenv() class Config: - BOT_TOKEN = os.getenv("BOT_TOKEN") - SQLALCHEMY_DATABASE_URI = os.getenv("SQLALCHEMY_DATABASE_URI") + BOT_TOKEN = os.getenv("BOT_TOKEN", "") + SQLALCHEMY_DATABASE_URI = os.getenv( + "SQLALCHEMY_DATABASE_URI", + "sqlite:///database.db", + ) diff --git a/app/filters/user_filter.py b/app/filters/user_filter.py new file mode 100644 index 0000000..af01c86 --- /dev/null +++ b/app/filters/user_filter.py @@ -0,0 +1,27 @@ +__all__ = ("Unregistered", "Registered", "RegisteredCallback") + +from aiogram.filters import Filter +from aiogram.types import CallbackQuery, Message + +from app.models.user import User + + +class Unregistered(Filter): + async def __call__(self, message: Message) -> bool: + if message.from_user is None: + return False + + return not User.user_by_telegram_id_exist(message.from_user.id) + + +class Registered(Filter): + async def __call__(self, message: Message) -> bool: + if message.from_user is None: + return False + + return User.user_by_telegram_id_exist(message.from_user.id) + + +class RegisteredCallback(Filter): + async def __call__(self, callback: CallbackQuery) -> bool: + return User.user_by_telegram_id_exist(callback.from_user.id) diff --git a/app/handlers/help_command.py b/app/handlers/help_command.py new file mode 100644 index 0000000..ab6dff9 --- /dev/null +++ b/app/handlers/help_command.py @@ -0,0 +1,16 @@ +__all__ = () + +from aiogram import Router +from aiogram.filters import Command +from aiogram.types import Message + +from app import messages +from app.filters.user_filter import Registered + + +router = Router(name="help_command") + + +@router.message(Command("help"), Registered()) +async def command_help_handler(message: Message) -> None: + await message.answer(messages.HELP_MESSAGE) diff --git a/app/handlers/profile_command.py b/app/handlers/profile_command.py new file mode 100644 index 0000000..80fec4e --- /dev/null +++ b/app/handlers/profile_command.py @@ -0,0 +1,31 @@ +# type: ignore +__all__ = () + +from aiogram import Router +from aiogram.filters import Command +from aiogram.types import Message + +from app import messages +from app.filters.user_filter import Registered +from app.keyboards.profile import get +from app.models.user import User + + +router = Router(name="profile_command") + + +@router.message(Command("profile"), Registered()) +async def command_profile_handler(message: Message) -> None: + user = User().get_user_by_telegram_id(message.from_user.id) + + await message.answer( + messages.PROFILE.format( + username=user.username, + age=user.age, + bio=user.bio if user.bio else messages.NOT_SET, + sex=user.sex.capitalize(), + country=user.country, + city=user.city, + ), + reply_markup=get(), + ) diff --git a/app/handlers/start_command.py b/app/handlers/start_command.py new file mode 100644 index 0000000..be04bc2 --- /dev/null +++ b/app/handlers/start_command.py @@ -0,0 +1,179 @@ +# type: ignore +__all__ = () + +from aiogram import F, Router +from aiogram.filters import CommandStart +from aiogram.fsm.context import FSMContext +from aiogram.types import Message, ReplyKeyboardRemove + +from app import messages, session +from app.keyboards.builders import profile +from app.models.user import User +from app.utils.states import RegistrationForm + + +router = Router(name="start_command") + + +@router.message(CommandStart()) +async def command_start_handler(message: Message, state: FSMContext) -> None: + if ( + User.get_user_by_telegram_id( + telegram_id=message.from_user.id, + ) + is not None + ): + await message.answer( + messages.WELCOME_AGAIN_MESSAGE.format( + name=message.from_user.full_name, + ), + ) + else: + await message.answer( + messages.WELCOME_MESSAGE.format( + name=message.from_user.full_name, + ), + ) + + await state.set_state(RegistrationForm.username) + await message.answer(messages.INPUT_USERNAME) + + +@router.message(RegistrationForm.username, F.text) +async def username_handler(message: Message, state: FSMContext) -> None: + username = message.text.strip() + + try: + validated_username = User().validate_username( + key="username", + value=username, + ) + except AssertionError as e: + await message.answer(str(e)) + return + + await state.update_data(username=validated_username) + await state.set_state(RegistrationForm.age) + + await message.answer( + messages.INPUT_CALLBACK.format( + key="username", + value=validated_username, + ), + ) + await message.answer(messages.INPUT_AGE) + + +@router.message(RegistrationForm.age, F.text) +async def age_handler(message: Message, state: FSMContext) -> None: + age = message.text.strip() + + try: + validated_age = User().validate_age(key="age", value=age) + except AssertionError as e: + await message.answer(str(e)) + return + + await state.update_data(age=validated_age) + await state.set_state(RegistrationForm.sex) + + await message.answer( + messages.INPUT_CALLBACK.format(key="age", value=validated_age), + ) + await message.answer( + messages.INPUT_SEX, + reply_markup=profile(["Male", "Female"]), + ) + + +@router.message(RegistrationForm.sex, F.text) +async def sex_handler(message: Message, state: FSMContext) -> None: + sex = message.text.strip().lower() + + if sex not in ["male", "female"]: + await message.answer(messages.VALIDATION_ERROR_MESSAGE) + return + + await state.update_data(sex=sex) + await state.set_state(RegistrationForm.bio) + + await message.answer( + messages.INPUT_CALLBACK.format(key="sex", value=sex), + reply_markup=ReplyKeyboardRemove(), + ) + await message.answer(messages.INPUT_BIO) + + +@router.message(RegistrationForm.bio, F.text) +async def bio_handler(message: Message, state: FSMContext) -> None: + bio = message.text.strip() + + if bio == "/skip": + await state.update_data(bio=None) + await state.set_state(RegistrationForm.location) + + await message.answer(messages.INPUT_BIO_SKIPPED) + await message.answer(messages.INPUT_LOCATION) + else: + try: + validated_bio = User().validate_bio(key="bio", value=bio) + except AssertionError as e: + await message.answer(str(e)) + return + + await state.update_data(bio=validated_bio) + await state.set_state(RegistrationForm.location) + + await message.answer( + messages.INPUT_CALLBACK.format(key="bio", value=validated_bio), + ) + await message.answer(messages.INPUT_LOCATION) + + +@router.message(RegistrationForm.location, F.text) +async def location_handler(message: Message, state: FSMContext) -> None: + location = message.text.strip().split(", ") + + if len(location) != 2: + await message.answer(messages.VALIDATION_ERROR_MESSAGE) + return + + country, city = location + + try: + validated_country = User().validate_country( + key="country", + value=country, + ) + except AssertionError as e: + await message.answer(str(e)) + return + + try: + validated_city = User().validate_city( + city=city, + country=validated_country, + ) + except AssertionError as e: + await message.answer(str(e)) + return + + await state.update_data(location=[validated_country, validated_city]) + data = await state.get_data() + await state.clear() + + await message.answer( + messages.INPUT_CALLBACK.format( + key="location", + value=", ".join([validated_country, validated_city]), + ), + ) + + data["telegram_id"] = message.from_user.id + data["country"] = data["location"][0] + data["city"] = data["location"][1] + del data["location"] + session.add(User(**data)) + session.commit() + + await message.answer(messages.REGISTERED_MESSAGE) diff --git a/app/keyboards/builders.py b/app/keyboards/builders.py new file mode 100644 index 0000000..b7538e5 --- /dev/null +++ b/app/keyboards/builders.py @@ -0,0 +1,13 @@ +__all__ = ("profile",) + +from aiogram.utils.keyboard import ReplyKeyboardBuilder + + +def profile(text: str | list): + builder = ReplyKeyboardBuilder() + + if isinstance(text, str): + text = [text] + + [builder.button(text=txt) for txt in text] + return builder.as_markup(resize_keyboard=True) diff --git a/app/keyboards/profile.py b/app/keyboards/profile.py new file mode 100644 index 0000000..3728912 --- /dev/null +++ b/app/keyboards/profile.py @@ -0,0 +1,37 @@ +__all__ = ("get",) + +from aiogram import types +from aiogram.utils.keyboard import InlineKeyboardBuilder + + +def get(): + builder = InlineKeyboardBuilder() + + builder.row( + types.InlineKeyboardButton( + text="πŸ‘€ Change username", + callback_data="profile_change_username", + ), + types.InlineKeyboardButton( + text="πŸ”’ Change age", + callback_data="profile_change_age", + ), + ) + builder.row( + types.InlineKeyboardButton( + text="ℹ️ Change bio", + callback_data="profile_change_bio", + ), + types.InlineKeyboardButton( + text="πŸ“ Change sex", + callback_data="profile_change_sex", + ), + ) + builder.row( + types.InlineKeyboardButton( + text="πŸ—ΊοΈ Change location", + callback_data="profile_change_location", + ), + ) + + return builder.as_markup() diff --git a/app/messages.py b/app/messages.py new file mode 100644 index 0000000..77e0a3c --- /dev/null +++ b/app/messages.py @@ -0,0 +1,28 @@ +# flake8: noqa + +WELCOME_MESSAGE = "Hello, {name}! Welcome to the ✈️ Travel Agent bot! Let's start our journey by filling out some information about you." +WELCOME_AGAIN_MESSAGE = "Hello, {name}! Welcome back to the ✈️ Travel Agent bot! If you get lost, you can always call the /help command for assistance." + +HELP_MESSAGE = "Help message text." + +REGISTERED_MESSAGE = "You have successfully registered. Welcome to the ✈️ Travel Agent bot! \nYou can view and edit your profile using the /profile command." + +INPUT_USERNAME = "Enter your username (this will be used to interact with other users):\nAllowed characters: a-z, A-Z, 0-9, _\nLength: 5-20 characters" +INPUT_AGE = "Enter your age:\nRange: 13-120" +INPUT_SEX = "Enter your sex:\nOptions: Male or Female" +INPUT_BIO = "Enter your bio (enter /skip if you want to skip this step):\nMaximum length: 100 characters" +INPUT_BIO_SKIPPED = "Sure. You can always fill it later." +INPUT_LOCATION = "Enter your location in this format:\nFormat: country, city\nExample: Russia, Moscow" +INPUT_CALLBACK = "All right, your {key} is set to: {value}" +VALIDATION_ERROR_MESSAGE = "Invalid input. Please try again." + +PROFILE = ( + "Your profile:\n\n" + "\tUsername: {username}\n" + "\tAge: {age}\n" + "\tSex: {sex}\n" + "\tBio: {bio}\n" + "\tCountry: {country}\n" + "\tCity: {city}" +) +NOT_SET = "Not set" diff --git a/app/middlewares/throttling.py b/app/middlewares/throttling.py new file mode 100644 index 0000000..5da372e --- /dev/null +++ b/app/middlewares/throttling.py @@ -0,0 +1,26 @@ +__all__ = ("ThrottlingMiddleware",) + +from typing import Any, Awaitable, Callable, Dict + +from aiogram import BaseMiddleware +from aiogram.types import Message +from cachetools import TTLCache # type: ignore + + +class ThrottlingMiddleware(BaseMiddleware): + + def __init__(self, time_limit: int | float = 2) -> None: + self.limit = TTLCache(maxsize=10_000, ttl=time_limit) + + async def __call__( + self, + handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]], + event: Message, # type: ignore + data: Dict[str, Any], + ) -> Any | None: + if event.chat.id in self.limit: + return None + + self.limit[event.chat.id] = None + + return await handler(event, data) diff --git a/app/migrations/README b/app/migrations/README deleted file mode 100644 index 98e4f9c..0000000 --- a/app/migrations/README +++ /dev/null @@ -1 +0,0 @@ -Generic single-database configuration. \ No newline at end of file diff --git a/app/migrations/env.py b/app/migrations/env.py index 4177dec..b05ea8a 100644 --- a/app/migrations/env.py +++ b/app/migrations/env.py @@ -4,11 +4,12 @@ from logging.config import fileConfig import os from alembic import context -from app.models import Base from dotenv import load_dotenv from sqlalchemy import engine_from_config from sqlalchemy import pool +from app.models.user import Base + load_dotenv() diff --git a/app/migrations/versions/5896f08fbd61_added_user_model.py b/app/migrations/versions/5896f08fbd61_added_user_model.py new file mode 100644 index 0000000..c6769c4 --- /dev/null +++ b/app/migrations/versions/5896f08fbd61_added_user_model.py @@ -0,0 +1,42 @@ +"""Added User model + +Revision ID: 5896f08fbd61 +Revises: +Create Date: 2024-03-19 23:25:50.458639 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '5896f08fbd61' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'users', + sa.Column('telegram_id', sa.Integer(), nullable=False), + sa.Column('username', sa.String(length=20), nullable=False), + sa.Column('age', sa.SmallInteger(), nullable=False), + sa.Column('bio', sa.String(length=100), nullable=True), + sa.Column('sex', sa.String(length=6), nullable=True), + sa.Column('country', sa.Text(), nullable=False), + sa.Column('city', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('telegram_id'), + sa.UniqueConstraint('username'), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('users') + # ### end Alembic commands ### diff --git a/app/models.py b/app/models.py deleted file mode 100644 index 357cd94..0000000 --- a/app/models.py +++ /dev/null @@ -1,19 +0,0 @@ -__all__ = ("User",) - -from typing import Any - -from sqlalchemy import Column, Integer, SmallInteger, String -from sqlalchemy.ext.declarative import declarative_base - - -Base: Any = declarative_base() - - -class User(Base): - __tablename__ = "users" - - telegram_id: Column[int] = Column(Integer, primary_key=True) - age: Column[int] = Column(SmallInteger, nullable=False) - bio: Column[str] = Column(String(100), nullable=True) - country: Column[str] = Column(String(100), nullable=False) - city: Column[str] = Column(String(100), nullable=False) diff --git a/app/models/user.py b/app/models/user.py new file mode 100644 index 0000000..af1e7d9 --- /dev/null +++ b/app/models/user.py @@ -0,0 +1,89 @@ +__all__ = ("User",) + +import re +from typing import Any + +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import validates + +from app import session +from app.utils import geo + + +Base: Any = declarative_base() + + +class User(Base): + __tablename__ = "users" + + telegram_id = sa.Column(sa.Integer, primary_key=True) + username = sa.Column(sa.String(20), nullable=False, unique=True) + age = sa.Column(sa.SmallInteger, nullable=False) + bio = sa.Column(sa.String(100), nullable=True) + sex = sa.Column(sa.String(6), nullable=True) + country = sa.Column(sa.Text, nullable=False) + city = sa.Column(sa.Text, nullable=False) + + @validates("username") + def validate_username(self, key, value): + regex_pattern = re.compile(r"^[a-zA-Z0-9_]{5,20}$") + + assert len(value) <= 20, "Username must be 20 characters or fewer." + assert len(value) >= 5, "Username must be at least 5 characters." + assert ( + re.match(regex_pattern, value) is not None + ), "a-z, A-Z, 0-9, _ only allowed in username." + + return value + + @validates("age") + def validate_age(self, key, value): + assert str(value).isnumeric(), "Invalid input. Please try again." + value = int(value) + assert value >= 13, "You must be at least 13 years old." + assert value <= 120, "You must be less than 120 years old." + + return value + + @validates("bio") + def validate_bio(self, key, value): + if value is not None: + assert len(value) <= 100, "Bio must be 100 characters or fewer." + + return value + + @validates("country") + def validate_country(self, key, value): + verdict, normalized_value = geo.validate_country( + value, + ) + + assert verdict, "There is no such country." + + return normalized_value + + def validate_city(self, city, country): + verdict, normalized_value = geo.validate_city( + city, + country, + ) + + assert verdict, "There is no such city in selected country." + + return normalized_value + + @classmethod + def get_user_by_telegram_id(cls, telegram_id): + return ( + session.query(cls).filter(cls.telegram_id == telegram_id).first() + ) + + @classmethod + def user_by_telegram_id_exist(cls, telegram_id): + return ( + cls.get_user_by_telegram_id( + telegram_id=telegram_id, + ) + is not None + ) diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/geo.py b/app/utils/geo.py new file mode 100644 index 0000000..8574f6f --- /dev/null +++ b/app/utils/geo.py @@ -0,0 +1,66 @@ +# type: ignore +__all__ = ("validate_country", "validate_city") + +from geopy.exc import GeocoderTimedOut +from geopy.geocoders import Nominatim + + +def validate_country(country: str): + geolocator = Nominatim(user_agent="travel_agent_bot") + + for _ in range(3): + try: + geocode = geolocator.geocode( + country, + featuretype="country", + ) + break + except GeocoderTimedOut: + continue + else: + return False, None + + if not geocode: + return False, None + + is_loc_country = geocode.raw.get( + "type", None, + ) == "administrative" + + if is_loc_country: + normalized_country = geocode.raw.get("name", "Invalid country") + return True, normalized_country + + return False, None + + +def validate_city(city: str, country: str): + geolocator = Nominatim(user_agent="travel_agent_bot") + + location_name = f"{country}, {city}" + valid_list = ["city", "town", "administrative"] + + for _ in range(3): + try: + geocode = geolocator.geocode( + location_name, + featuretype="city", + ) + break + except GeocoderTimedOut: + continue + else: + return False, None + + if not geocode: + return False, None + + check_in_valid = geocode.raw.get( + "type", None, + ) in valid_list + + if geocode and check_in_valid: + normalized_country = geocode.raw.get("name", "Invalid city") + return True, normalized_country + + return False, None diff --git a/app/utils/states.py b/app/utils/states.py new file mode 100644 index 0000000..6eb8429 --- /dev/null +++ b/app/utils/states.py @@ -0,0 +1,15 @@ +__all__ = ("RegistrationForm",) + +from aiogram.fsm.state import State, StatesGroup + + +class RegistrationForm(StatesGroup): + username = State() + age = State() + bio = State() + sex = State() + location = State() + + +class UserAltering(StatesGroup): + new_value = State() diff --git a/requirements/lints.txt b/requirements/lints.txt index 16319a5..7a2dc33 100644 --- a/requirements/lints.txt +++ b/requirements/lints.txt @@ -13,7 +13,6 @@ flake8-import-order flake8-print flake8-quotes flake8-return -flake8-type-ignore flake8-use-pathlib flake8_implicit_str_concat pep8-naming diff --git a/requirements/prod.txt b/requirements/prod.txt index f7959a9..575c28b 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -1,5 +1,7 @@ aiogram==3.4.1 alembic==1.13.1 +cachetools==5.3.3 +geopy==2.4.1 psycopg2-binary==2.9.9 python-dotenv==1.0.1 sqlalchemy==2.0.28