Files
RekomenciBackend/src/template_project/ml/entry_point.py
T
2025-11-22 19:06:02 +03:00

110 lines
2.7 KiB
Python

import argparse
import asyncio
import logging
import os
import sys
from pathlib import Path
from typing import Final
import uvicorn
from dishka import AsyncContainer
from dishka.integrations.fastapi import setup_dishka
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from template_project.ml.configuration import load_configuration
from template_project.ml.ioc.make import make_ioc
from template_project.ml.routes import embed, healthcheck, predict
LOG_CONFIG: Final = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"root": {
"level": "DEBUG",
"handlers": ["console"],
},
}
def make_asgi_application(
ioc: AsyncContainer,
) -> FastAPI:
app = FastAPI(
docs_url="/docs",
title="ML Service",
description="ML Service API",
version="1.0.0",
openapi_url="/openapi.json",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(healthcheck.router)
app.include_router(embed.router)
app.include_router(predict.router)
setup_dishka(container=ioc, app=app)
return app
async def _main(
configuration_path: Path,
) -> None:
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
configuration = load_configuration(configuration_path)
ioc = make_ioc(configuration)
asgi_application = make_asgi_application(ioc)
config = uvicorn.Config(
app=asgi_application,
host=configuration.server.host,
port=configuration.server.port,
log_config=LOG_CONFIG,
access_log=configuration.server.access_log,
)
server = uvicorn.Server(config)
try:
await server.serve()
finally:
await ioc.close()
def main() -> None:
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("configuration", default=None)
args = arg_parser.parse_args()
configuration_path = args.configuration or os.getenv("CONFIGURATION_PATH")
if configuration_path is None:
raise RuntimeError(
"pass the path to the config or specify it in the environment variables `CONFIGURATION_PATH`",
)
asyncio.run(_main(Path(configuration_path)))
if __name__ == "__main__":
main()