import argparse import asyncio import logging import os import sys from pathlib import Path from typing import Final import uvicorn from dishka import AsyncContainer from dishka.integrations.fastapi import setup_dishka from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from template_project.ml.configuration import load_configuration from template_project.ml.ioc.make import make_ioc from template_project.ml.routes import embed, healthcheck, predict LOG_CONFIG: Final = { "version": 1, "disable_existing_loggers": False, "formatters": { "default": { "format": "%(asctime)s [%(levelname)s] [%(name)s] %(message)s", }, }, "handlers": { "console": { "class": "logging.StreamHandler", "formatter": "default", }, }, "root": { "level": "DEBUG", "handlers": ["console"], }, } def make_asgi_application( ioc: AsyncContainer, ) -> FastAPI: app = FastAPI( docs_url="/docs", title="ML Service", description="ML Service API", version="1.0.0", openapi_url="/openapi.json", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.include_router(healthcheck.router) app.include_router(embed.router) app.include_router(predict.router) setup_dishka(container=ioc, app=app) return app async def _main( configuration_path: Path, ) -> None: logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) configuration = load_configuration(configuration_path) ioc = make_ioc(configuration) asgi_application = make_asgi_application(ioc) config = uvicorn.Config( app=asgi_application, host=configuration.server.host, port=configuration.server.port, log_config=LOG_CONFIG, access_log=configuration.server.access_log, ) server = uvicorn.Server(config) try: await server.serve() finally: await ioc.close() def main() -> None: if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) arg_parser = argparse.ArgumentParser() arg_parser.add_argument("configuration", default=None) args = arg_parser.parse_args() configuration_path = args.configuration or os.getenv("CONFIGURATION_PATH") if configuration_path is None: raise RuntimeError( "pass the path to the config or specify it in the environment variables `CONFIGURATION_PATH`", ) asyncio.run(_main(Path(configuration_path))) if __name__ == "__main__": main()