You've already forked Travel-Agent
257 lines
6.6 KiB
Python
257 lines
6.6 KiB
Python
__all__ = ("Travel", "Location")
|
|
|
|
import datetime
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.orm import relationship, validates
|
|
|
|
from app import messages, session
|
|
from app.models import Base
|
|
from app.models.user import User
|
|
from app.utils.geo import get_location_by_name
|
|
|
|
|
|
association_table = sa.Table(
|
|
"user_travel_association",
|
|
Base.metadata,
|
|
sa.Column(
|
|
"user_id",
|
|
sa.BigInteger,
|
|
sa.ForeignKey("users.telegram_id", ondelete="CASCADE"),
|
|
),
|
|
sa.Column(
|
|
"travel_id",
|
|
sa.Integer,
|
|
sa.ForeignKey("travels.id", ondelete="CASCADE"),
|
|
),
|
|
)
|
|
|
|
|
|
class Travel(Base):
|
|
__tablename__ = "travels"
|
|
|
|
id = sa.Column( # noqa: A003
|
|
sa.Integer,
|
|
unique=True,
|
|
primary_key=True,
|
|
nullable=False,
|
|
autoincrement=True,
|
|
index=True,
|
|
)
|
|
title = sa.Column(
|
|
sa.String(50),
|
|
nullable=False,
|
|
unique=True,
|
|
)
|
|
description = sa.Column(
|
|
sa.String(100),
|
|
nullable=True,
|
|
)
|
|
|
|
author_id = sa.Column(
|
|
sa.BigInteger,
|
|
sa.ForeignKey(User.telegram_id, ondelete="CASCADE"),
|
|
nullable=False,
|
|
)
|
|
|
|
users = relationship(
|
|
User,
|
|
secondary=association_table,
|
|
backref="travels",
|
|
)
|
|
|
|
locations = relationship("Location", backref="travel")
|
|
notes = relationship("Note", backref="travel")
|
|
|
|
@validates("title")
|
|
def validate_title(self, key, value):
|
|
assert len(value) <= 30, "Title must be 30 characters or fewer."
|
|
assert "👑" not in value, "👑 is not allowed symbol."
|
|
|
|
if session.query(Travel).filter(Travel.title == value).first():
|
|
raise AssertionError("This title is already taken.")
|
|
|
|
return value
|
|
|
|
@validates("description")
|
|
def validate_description(self, key, value):
|
|
if value is not None:
|
|
assert (
|
|
len(value) <= 100
|
|
), "Description must be 100 characters or fewer."
|
|
|
|
return value
|
|
|
|
def get_travel_text(self):
|
|
return messages.TRAVEL_DETAIL.format(
|
|
travel_id=self.id,
|
|
title=self.title,
|
|
description=(
|
|
self.description if self.description else messages.NOT_SET
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def get_notes(cls, author_id, travel, public=True):
|
|
return (
|
|
session.query(Note)
|
|
.filter(
|
|
Note.author_id == author_id,
|
|
Note.travel_id == travel.id,
|
|
Note.public == public,
|
|
)
|
|
.all()
|
|
)
|
|
|
|
@classmethod
|
|
def get_travel_by_id(cls, travel_id):
|
|
return session.query(Travel).filter(Travel.id == travel_id).first()
|
|
|
|
@classmethod
|
|
def get_travel_queryset_by_id(cls, travel_id):
|
|
return session.query(Travel).filter(Travel.id == travel_id)
|
|
|
|
@classmethod
|
|
def get_sorted_locations(cls, travel, asc=True):
|
|
return sorted(
|
|
travel.locations,
|
|
key=lambda location: location.date_end,
|
|
reverse=asc,
|
|
)
|
|
|
|
|
|
class Location(Base):
|
|
__tablename__ = "locations"
|
|
|
|
id = sa.Column( # noqa: A003
|
|
sa.Integer,
|
|
unique=True,
|
|
primary_key=True,
|
|
autoincrement=True,
|
|
index=True,
|
|
)
|
|
location = sa.Column(sa.Text, nullable=False)
|
|
date_start = sa.Column(
|
|
sa.DateTime(timezone=True),
|
|
nullable=False,
|
|
)
|
|
date_end = sa.Column(
|
|
sa.DateTime(timezone=True),
|
|
nullable=False,
|
|
)
|
|
|
|
travel_id = sa.Column(
|
|
sa.Integer,
|
|
sa.ForeignKey("travels.id", ondelete="CASCADE"),
|
|
nullable=False,
|
|
)
|
|
|
|
def validate_location(self, key, value):
|
|
geocode = get_location_by_name(value)
|
|
|
|
assert geocode[0], "Invalid location."
|
|
|
|
return geocode[1].raw["display_name"]
|
|
|
|
def validate_date_start(self, key, value):
|
|
try:
|
|
value_datetime = datetime.datetime.strptime(
|
|
value,
|
|
"%Y-%m-%d %H:%M",
|
|
)
|
|
value_datetime = value_datetime.replace(tzinfo=datetime.UTC)
|
|
except ValueError:
|
|
raise AssertionError("Invalid datetime format.")
|
|
|
|
assert value_datetime >= datetime.datetime.now(
|
|
datetime.UTC,
|
|
), "Invalid datetime."
|
|
|
|
return value_datetime
|
|
|
|
def validate_date_end(self, key, value):
|
|
try:
|
|
value_datetime = datetime.datetime.strptime(
|
|
value,
|
|
"%Y-%m-%d %H:%M",
|
|
)
|
|
value_datetime = value_datetime.replace(tzinfo=datetime.UTC)
|
|
except ValueError:
|
|
raise AssertionError("Invalid datetime format.")
|
|
|
|
assert value_datetime >= datetime.datetime.now(
|
|
datetime.UTC,
|
|
), "Invalid datetime."
|
|
|
|
return value_datetime
|
|
|
|
def get_location_text(self):
|
|
return messages.LOCATION_DETAIL.format(
|
|
location=self.location,
|
|
date_start=datetime.datetime.strftime(
|
|
self.date_start,
|
|
"%Y-%m-%d %H:%M",
|
|
),
|
|
date_end=datetime.datetime.strftime(
|
|
self.date_end,
|
|
"%Y-%m-%d %H:%M",
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def get_location_by_id(cls, location_id):
|
|
return (
|
|
session.query(Location).filter(Location.id == location_id).first()
|
|
)
|
|
|
|
@classmethod
|
|
def get_location_queryset_by_id(cls, location_id):
|
|
return session.query(Location).filter(Location.id == location_id)
|
|
|
|
|
|
class Note(Base):
|
|
__tablename__ = "notes"
|
|
|
|
id = sa.Column( # noqa: A003
|
|
sa.Integer,
|
|
unique=True,
|
|
primary_key=True,
|
|
autoincrement=True,
|
|
index=True,
|
|
)
|
|
|
|
file_id = sa.Column(sa.Text, nullable=False)
|
|
file_name = sa.Column(sa.Text, nullable=False)
|
|
file_type = sa.Column(sa.Text, nullable=False)
|
|
public: sa.Column[bool] = sa.Column(
|
|
sa.Boolean(),
|
|
nullable=False,
|
|
default=False,
|
|
)
|
|
|
|
author_id = sa.Column(
|
|
sa.BigInteger,
|
|
sa.ForeignKey(User.telegram_id, ondelete="CASCADE"),
|
|
nullable=False,
|
|
)
|
|
travel_id = sa.Column(
|
|
sa.Integer,
|
|
sa.ForeignKey("travels.id", ondelete="CASCADE"),
|
|
nullable=False,
|
|
)
|
|
|
|
def get_note_text(self):
|
|
return messages.NOTE_DETAIL.format(
|
|
file_name=self.file_name,
|
|
file_type=self.file_type,
|
|
public="Yes" if self.public else "No",
|
|
)
|
|
|
|
@classmethod
|
|
def get_note_by_id(cls, note_id):
|
|
return session.query(Note).filter(Note.id == note_id).first()
|
|
|
|
@classmethod
|
|
def get_note_queryset_by_id(cls, note_id):
|
|
return session.query(Note).filter(Note.id == note_id)
|