52 lines
1.1 KiB
Python
52 lines
1.1 KiB
Python
import datetime
|
|
import typing
|
|
|
|
import jwt
|
|
import sqlmodel
|
|
|
|
from app.core.config import config
|
|
import app.models
|
|
from app.models.user import User
|
|
|
|
|
|
def generate_token(
|
|
payload: dict[typing.Any, typing.Any],
|
|
expires_delta: datetime.timedelta | None = None,
|
|
) -> str:
|
|
to_encode = payload.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.datetime.now() + expires_delta
|
|
|
|
else:
|
|
expire = datetime.datetime.now() + datetime.timedelta(minutes=15)
|
|
|
|
to_encode.update({'exp': expire})
|
|
encoded_jwt = jwt.encode(
|
|
to_encode, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM
|
|
)
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
def user_by_token(token: str) -> User | None:
|
|
"""Expects user_id to be in token payload"""
|
|
|
|
try:
|
|
payload: dict = jwt.decode(
|
|
token, config.JWT_SECRET_KEY, [config.JWT_ALGORITHM]
|
|
)
|
|
|
|
except (jwt.InvalidTokenError, jwt.InvalidSignatureError):
|
|
return None
|
|
|
|
user_id: str = payload.get('user_id')
|
|
|
|
if not user_id.isdigit():
|
|
return None
|
|
|
|
with sqlmodel.Session(app.core.db.engine) as session:
|
|
user = session.get(User, user_id)
|
|
|
|
return user
|