You've already forked RekomenciBackend
110 lines
2.7 KiB
Python
110 lines
2.7 KiB
Python
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Final
|
|
|
|
import uvicorn
|
|
from dishka import AsyncContainer
|
|
from dishka.integrations.fastapi import setup_dishka
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from template_project.ml.configuration import load_configuration
|
|
from template_project.ml.ioc.make import make_ioc
|
|
from template_project.ml.routes import embed, healthcheck, predict
|
|
|
|
LOG_CONFIG: Final = {
|
|
"version": 1,
|
|
"disable_existing_loggers": False,
|
|
"formatters": {
|
|
"default": {
|
|
"format": "%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
|
},
|
|
},
|
|
"handlers": {
|
|
"console": {
|
|
"class": "logging.StreamHandler",
|
|
"formatter": "default",
|
|
},
|
|
},
|
|
"root": {
|
|
"level": "DEBUG",
|
|
"handlers": ["console"],
|
|
},
|
|
}
|
|
|
|
|
|
def make_asgi_application(
|
|
ioc: AsyncContainer,
|
|
) -> FastAPI:
|
|
app = FastAPI(
|
|
docs_url="/docs",
|
|
title="ML Service",
|
|
description="ML Service API",
|
|
version="1.0.0",
|
|
openapi_url="/openapi.json",
|
|
)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
app.include_router(healthcheck.router)
|
|
app.include_router(embed.router)
|
|
app.include_router(predict.router)
|
|
|
|
setup_dishka(container=ioc, app=app)
|
|
|
|
return app
|
|
|
|
|
|
async def _main(
|
|
configuration_path: Path,
|
|
) -> None:
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
|
|
|
configuration = load_configuration(configuration_path)
|
|
ioc = make_ioc(configuration)
|
|
asgi_application = make_asgi_application(ioc)
|
|
|
|
config = uvicorn.Config(
|
|
app=asgi_application,
|
|
host=configuration.server.host,
|
|
port=configuration.server.port,
|
|
log_config=LOG_CONFIG,
|
|
access_log=configuration.server.access_log,
|
|
)
|
|
server = uvicorn.Server(config)
|
|
try:
|
|
await server.serve()
|
|
finally:
|
|
await ioc.close()
|
|
|
|
|
|
def main() -> None:
|
|
if sys.platform == "win32":
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
arg_parser = argparse.ArgumentParser()
|
|
arg_parser.add_argument("configuration", default=None)
|
|
|
|
args = arg_parser.parse_args()
|
|
configuration_path = args.configuration or os.getenv("CONFIGURATION_PATH")
|
|
|
|
if configuration_path is None:
|
|
raise RuntimeError(
|
|
"pass the path to the config or specify it in the environment variables `CONFIGURATION_PATH`",
|
|
)
|
|
|
|
asyncio.run(_main(Path(configuration_path)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|