feat(): routes and entrypoint for ml

This commit is contained in:
gitgernit
2025-11-22 11:33:15 +03:00
parent 8c76413c3b
commit 099fecc218
10 changed files with 231 additions and 0 deletions
View File
+37
View File
@@ -0,0 +1,37 @@
from dataclasses import dataclass
from pathlib import Path
from tomllib import loads
from typing import dataclass_transform
from adaptix import Retort
@dataclass_transform(frozen_default=True)
def to_configuration[ClsT](cls: type[ClsT]) -> type[ClsT]:
return dataclass(frozen=True, slots=True, repr=False)(cls)
@to_configuration
class ServerConfiguration:
host: str
port: int
access_log: bool
@property
def url(self) -> str:
return f"http://{self.host}:{self.port}"
@to_configuration
class Configuration:
server: ServerConfiguration
retort = Retort()
def load_configuration(path: Path) -> Configuration:
with path.open("r", encoding="utf-8") as config:
data = loads(config.read())
return retort.load(data, Configuration)
+107
View File
@@ -0,0 +1,107 @@
import argparse
import asyncio
import logging
import os
import sys
from pathlib import Path
from typing import Final
import uvicorn
from dishka import AsyncContainer
from dishka.integrations.fastapi import setup_dishka
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from template_project.ml.configuration import load_configuration
from template_project.ml.ioc.make import make_ioc
from template_project.ml.routes import embedding
LOG_CONFIG: Final = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"root": {
"level": "DEBUG",
"handlers": ["console"],
},
}
def make_asgi_application(
ioc: AsyncContainer,
) -> FastAPI:
app = FastAPI(
docs_url="/docs",
title="ML Service",
description="ML Service API",
version="1.0.0",
openapi_url="/openapi.json",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(embedding.router)
setup_dishka(container=ioc, app=app)
return app
async def _main(
configuration_path: Path,
) -> None:
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
configuration = load_configuration(configuration_path)
ioc = make_ioc(configuration)
asgi_application = make_asgi_application(ioc)
config = uvicorn.Config(
app=asgi_application,
host=configuration.server.host,
port=configuration.server.port,
log_config=LOG_CONFIG,
access_log=configuration.server.access_log,
)
server = uvicorn.Server(config)
try:
await server.serve()
finally:
await ioc.close()
def main() -> None:
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("configuration", default=None)
args = arg_parser.parse_args()
configuration_path = args.configuration or os.getenv("CONFIGURATION_PATH")
if configuration_path is None:
raise RuntimeError(
"pass the path to the config or specify it in the environment variables `CONFIGURATION_PATH`",
)
asyncio.run(_main(Path(configuration_path)))
if __name__ == "__main__":
main()
+13
View File
@@ -0,0 +1,13 @@
from dishka import BaseScope, Provider, Scope, provide
from sentence_transformers import SentenceTransformer
from template_project.adapters.embedding.minilm_embedder import MiniLMEmbedder
from template_project.application.common.embedding import Embedder
class EmbeddingProvider(Provider):
scope: BaseScope | None = Scope.APP
@provide(scope=Scope.APP)
def embedder(self, model: SentenceTransformer) -> Embedder:
return MiniLMEmbedder(model=model)
+19
View File
@@ -0,0 +1,19 @@
from dishka import STRICT_VALIDATION, AsyncContainer, make_async_container
from dishka.integrations.fastapi import FastapiProvider
from template_project.ml.configuration import Configuration, ServerConfiguration
from template_project.ml.ioc.embedding import EmbeddingProvider
from template_project.ml.ioc.model import ModelProvider
def make_ioc(configuration: Configuration) -> AsyncContainer:
return make_async_container(
ModelProvider(),
EmbeddingProvider(),
FastapiProvider(),
validation_settings=STRICT_VALIDATION,
context={
ServerConfiguration: configuration.server,
Configuration: configuration,
},
)
+10
View File
@@ -0,0 +1,10 @@
from dishka import BaseScope, Provider, Scope, provide
from sentence_transformers import SentenceTransformer
class ModelProvider(Provider):
scope: BaseScope | None = Scope.APP
@provide(scope=Scope.APP)
def sentence_transformer_model(self) -> SentenceTransformer:
return SentenceTransformer("all-MiniLM-L6-v2")
@@ -0,0 +1,44 @@
from dishka import FromDishka
from dishka.integrations.fastapi import DishkaRoute
from fastapi import APIRouter
from pydantic import BaseModel, Field
from template_project.application.common.embedding import Embedder
router = APIRouter(route_class=DishkaRoute, tags=["Embedding"])
class GetEmbeddingRequest(BaseModel):
text: str = Field(
..., min_length=1, description="Text to encode", examples=["python backend developer with django"]
)
model_config = {"json_schema_extra": {"example": {"text": "python backend developer with django"}}}
class GetEmbeddingResponse(BaseModel):
embedding: list[float] = Field(..., description="Embedding vector")
model_config = {
"json_schema_extra": {
"example": {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
}
}
}
@router.post(
"/get_embedding",
summary="Get embedding",
description="Encode text into embedding vector",
responses={
200: {"description": "Embedding generated successfully", "model": GetEmbeddingResponse},
},
)
async def get_embedding(
request: GetEmbeddingRequest,
embedder: FromDishka[Embedder],
) -> GetEmbeddingResponse:
embedding = await embedder.encode(request.text)
return GetEmbeddingResponse(embedding=embedding)