From 099fecc218a4b67b20e4abd0b328bd0e235545e3 Mon Sep 17 00:00:00 2001 From: gitgernit Date: Sat, 22 Nov 2025 11:33:15 +0300 Subject: [PATCH] feat(): routes and entrypoint for ml --- pyproject.toml | 1 + src/template_project/ml/__init__.py | 0 src/template_project/ml/configuration.py | 37 +++++++ src/template_project/ml/entry_point.py | 107 ++++++++++++++++++++ src/template_project/ml/ioc/__init__.py | 0 src/template_project/ml/ioc/embedding.py | 13 +++ src/template_project/ml/ioc/make.py | 19 ++++ src/template_project/ml/ioc/model.py | 10 ++ src/template_project/ml/routes/__init__.py | 0 src/template_project/ml/routes/embedding.py | 44 ++++++++ 10 files changed, 231 insertions(+) create mode 100644 src/template_project/ml/__init__.py create mode 100644 src/template_project/ml/configuration.py create mode 100644 src/template_project/ml/entry_point.py create mode 100644 src/template_project/ml/ioc/__init__.py create mode 100644 src/template_project/ml/ioc/embedding.py create mode 100644 src/template_project/ml/ioc/make.py create mode 100644 src/template_project/ml/ioc/model.py create mode 100644 src/template_project/ml/routes/__init__.py create mode 100644 src/template_project/ml/routes/embedding.py diff --git a/pyproject.toml b/pyproject.toml index b8fa3c1..2b4d410 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dev = [ [project.scripts] web_api_cli = "template_project.web_api.entry_point:main" +ml_api_cli = "template_project.ml.entry_point:main" [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/template_project/ml/__init__.py b/src/template_project/ml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/template_project/ml/configuration.py b/src/template_project/ml/configuration.py new file mode 100644 index 0000000..ad88ca3 --- /dev/null +++ b/src/template_project/ml/configuration.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from pathlib import Path +from tomllib import loads +from typing import dataclass_transform + +from adaptix import Retort + + +@dataclass_transform(frozen_default=True) +def to_configuration[ClsT](cls: type[ClsT]) -> type[ClsT]: + return dataclass(frozen=True, slots=True, repr=False)(cls) + + +@to_configuration +class ServerConfiguration: + host: str + port: int + access_log: bool + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + +@to_configuration +class Configuration: + server: ServerConfiguration + + +retort = Retort() + + +def load_configuration(path: Path) -> Configuration: + with path.open("r", encoding="utf-8") as config: + data = loads(config.read()) + + return retort.load(data, Configuration) diff --git a/src/template_project/ml/entry_point.py b/src/template_project/ml/entry_point.py new file mode 100644 index 0000000..39888dd --- /dev/null +++ b/src/template_project/ml/entry_point.py @@ -0,0 +1,107 @@ +import argparse +import asyncio +import logging +import os +import sys +from pathlib import Path +from typing import Final + +import uvicorn +from dishka import AsyncContainer +from dishka.integrations.fastapi import setup_dishka +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from template_project.ml.configuration import load_configuration +from template_project.ml.ioc.make import make_ioc +from template_project.ml.routes import embedding + +LOG_CONFIG: Final = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(asctime)s [%(levelname)s] [%(name)s] %(message)s", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "root": { + "level": "DEBUG", + "handlers": ["console"], + }, +} + + +def make_asgi_application( + ioc: AsyncContainer, +) -> FastAPI: + app = FastAPI( + docs_url="/docs", + title="ML Service", + description="ML Service API", + version="1.0.0", + openapi_url="/openapi.json", + ) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + app.include_router(embedding.router) + + setup_dishka(container=ioc, app=app) + + return app + + +async def _main( + configuration_path: Path, +) -> None: + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + + configuration = load_configuration(configuration_path) + ioc = make_ioc(configuration) + asgi_application = make_asgi_application(ioc) + + config = uvicorn.Config( + app=asgi_application, + host=configuration.server.host, + port=configuration.server.port, + log_config=LOG_CONFIG, + access_log=configuration.server.access_log, + ) + server = uvicorn.Server(config) + try: + await server.serve() + finally: + await ioc.close() + + +def main() -> None: + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("configuration", default=None) + + args = arg_parser.parse_args() + configuration_path = args.configuration or os.getenv("CONFIGURATION_PATH") + + if configuration_path is None: + raise RuntimeError( + "pass the path to the config or specify it in the environment variables `CONFIGURATION_PATH`", + ) + + asyncio.run(_main(Path(configuration_path))) + + +if __name__ == "__main__": + main() diff --git a/src/template_project/ml/ioc/__init__.py b/src/template_project/ml/ioc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/template_project/ml/ioc/embedding.py b/src/template_project/ml/ioc/embedding.py new file mode 100644 index 0000000..92a7c5f --- /dev/null +++ b/src/template_project/ml/ioc/embedding.py @@ -0,0 +1,13 @@ +from dishka import BaseScope, Provider, Scope, provide +from sentence_transformers import SentenceTransformer + +from template_project.adapters.embedding.minilm_embedder import MiniLMEmbedder +from template_project.application.common.embedding import Embedder + + +class EmbeddingProvider(Provider): + scope: BaseScope | None = Scope.APP + + @provide(scope=Scope.APP) + def embedder(self, model: SentenceTransformer) -> Embedder: + return MiniLMEmbedder(model=model) diff --git a/src/template_project/ml/ioc/make.py b/src/template_project/ml/ioc/make.py new file mode 100644 index 0000000..06215f3 --- /dev/null +++ b/src/template_project/ml/ioc/make.py @@ -0,0 +1,19 @@ +from dishka import STRICT_VALIDATION, AsyncContainer, make_async_container +from dishka.integrations.fastapi import FastapiProvider + +from template_project.ml.configuration import Configuration, ServerConfiguration +from template_project.ml.ioc.embedding import EmbeddingProvider +from template_project.ml.ioc.model import ModelProvider + + +def make_ioc(configuration: Configuration) -> AsyncContainer: + return make_async_container( + ModelProvider(), + EmbeddingProvider(), + FastapiProvider(), + validation_settings=STRICT_VALIDATION, + context={ + ServerConfiguration: configuration.server, + Configuration: configuration, + }, + ) diff --git a/src/template_project/ml/ioc/model.py b/src/template_project/ml/ioc/model.py new file mode 100644 index 0000000..4eda055 --- /dev/null +++ b/src/template_project/ml/ioc/model.py @@ -0,0 +1,10 @@ +from dishka import BaseScope, Provider, Scope, provide +from sentence_transformers import SentenceTransformer + + +class ModelProvider(Provider): + scope: BaseScope | None = Scope.APP + + @provide(scope=Scope.APP) + def sentence_transformer_model(self) -> SentenceTransformer: + return SentenceTransformer("all-MiniLM-L6-v2") diff --git a/src/template_project/ml/routes/__init__.py b/src/template_project/ml/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/template_project/ml/routes/embedding.py b/src/template_project/ml/routes/embedding.py new file mode 100644 index 0000000..523bf22 --- /dev/null +++ b/src/template_project/ml/routes/embedding.py @@ -0,0 +1,44 @@ +from dishka import FromDishka +from dishka.integrations.fastapi import DishkaRoute +from fastapi import APIRouter +from pydantic import BaseModel, Field + +from template_project.application.common.embedding import Embedder + +router = APIRouter(route_class=DishkaRoute, tags=["Embedding"]) + + +class GetEmbeddingRequest(BaseModel): + text: str = Field( + ..., min_length=1, description="Text to encode", examples=["python backend developer with django"] + ) + + model_config = {"json_schema_extra": {"example": {"text": "python backend developer with django"}}} + + +class GetEmbeddingResponse(BaseModel): + embedding: list[float] = Field(..., description="Embedding vector") + + model_config = { + "json_schema_extra": { + "example": { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + } + } + } + + +@router.post( + "/get_embedding", + summary="Get embedding", + description="Encode text into embedding vector", + responses={ + 200: {"description": "Embedding generated successfully", "model": GetEmbeddingResponse}, + }, +) +async def get_embedding( + request: GetEmbeddingRequest, + embedder: FromDishka[Embedder], +) -> GetEmbeddingResponse: + embedding = await embedder.encode(request.text) + return GetEmbeddingResponse(embedding=embedding)