You've already forked RekomenciBackend
feat(): routes and entrypoint for ml
This commit is contained in:
@@ -58,6 +58,7 @@ dev = [
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
web_api_cli = "template_project.web_api.entry_point:main"
|
web_api_cli = "template_project.web_api.entry_point:main"
|
||||||
|
ml_api_cli = "template_project.ml.entry_point:main"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from tomllib import loads
|
||||||
|
from typing import dataclass_transform
|
||||||
|
|
||||||
|
from adaptix import Retort
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass_transform(frozen_default=True)
|
||||||
|
def to_configuration[ClsT](cls: type[ClsT]) -> type[ClsT]:
|
||||||
|
return dataclass(frozen=True, slots=True, repr=False)(cls)
|
||||||
|
|
||||||
|
|
||||||
|
@to_configuration
|
||||||
|
class ServerConfiguration:
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
access_log: bool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> str:
|
||||||
|
return f"http://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
|
||||||
|
@to_configuration
|
||||||
|
class Configuration:
|
||||||
|
server: ServerConfiguration
|
||||||
|
|
||||||
|
|
||||||
|
retort = Retort()
|
||||||
|
|
||||||
|
|
||||||
|
def load_configuration(path: Path) -> Configuration:
|
||||||
|
with path.open("r", encoding="utf-8") as config:
|
||||||
|
data = loads(config.read())
|
||||||
|
|
||||||
|
return retort.load(data, Configuration)
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from dishka import AsyncContainer
|
||||||
|
from dishka.integrations.fastapi import setup_dishka
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from template_project.ml.configuration import load_configuration
|
||||||
|
from template_project.ml.ioc.make import make_ioc
|
||||||
|
from template_project.ml.routes import embedding
|
||||||
|
|
||||||
|
LOG_CONFIG: Final = {
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"formatters": {
|
||||||
|
"default": {
|
||||||
|
"format": "%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"handlers": {
|
||||||
|
"console": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"formatter": "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"level": "DEBUG",
|
||||||
|
"handlers": ["console"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_asgi_application(
|
||||||
|
ioc: AsyncContainer,
|
||||||
|
) -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
docs_url="/docs",
|
||||||
|
title="ML Service",
|
||||||
|
description="ML Service API",
|
||||||
|
version="1.0.0",
|
||||||
|
openapi_url="/openapi.json",
|
||||||
|
)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
app.include_router(embedding.router)
|
||||||
|
|
||||||
|
setup_dishka(container=ioc, app=app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def _main(
|
||||||
|
configuration_path: Path,
|
||||||
|
) -> None:
|
||||||
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
configuration = load_configuration(configuration_path)
|
||||||
|
ioc = make_ioc(configuration)
|
||||||
|
asgi_application = make_asgi_application(ioc)
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
|
app=asgi_application,
|
||||||
|
host=configuration.server.host,
|
||||||
|
port=configuration.server.port,
|
||||||
|
log_config=LOG_CONFIG,
|
||||||
|
access_log=configuration.server.access_log,
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
try:
|
||||||
|
await server.serve()
|
||||||
|
finally:
|
||||||
|
await ioc.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
arg_parser = argparse.ArgumentParser()
|
||||||
|
arg_parser.add_argument("configuration", default=None)
|
||||||
|
|
||||||
|
args = arg_parser.parse_args()
|
||||||
|
configuration_path = args.configuration or os.getenv("CONFIGURATION_PATH")
|
||||||
|
|
||||||
|
if configuration_path is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pass the path to the config or specify it in the environment variables `CONFIGURATION_PATH`",
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(_main(Path(configuration_path)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
from dishka import BaseScope, Provider, Scope, provide
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from template_project.adapters.embedding.minilm_embedder import MiniLMEmbedder
|
||||||
|
from template_project.application.common.embedding import Embedder
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(Provider):
|
||||||
|
scope: BaseScope | None = Scope.APP
|
||||||
|
|
||||||
|
@provide(scope=Scope.APP)
|
||||||
|
def embedder(self, model: SentenceTransformer) -> Embedder:
|
||||||
|
return MiniLMEmbedder(model=model)
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from dishka import STRICT_VALIDATION, AsyncContainer, make_async_container
|
||||||
|
from dishka.integrations.fastapi import FastapiProvider
|
||||||
|
|
||||||
|
from template_project.ml.configuration import Configuration, ServerConfiguration
|
||||||
|
from template_project.ml.ioc.embedding import EmbeddingProvider
|
||||||
|
from template_project.ml.ioc.model import ModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
def make_ioc(configuration: Configuration) -> AsyncContainer:
|
||||||
|
return make_async_container(
|
||||||
|
ModelProvider(),
|
||||||
|
EmbeddingProvider(),
|
||||||
|
FastapiProvider(),
|
||||||
|
validation_settings=STRICT_VALIDATION,
|
||||||
|
context={
|
||||||
|
ServerConfiguration: configuration.server,
|
||||||
|
Configuration: configuration,
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
from dishka import BaseScope, Provider, Scope, provide
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(Provider):
|
||||||
|
scope: BaseScope | None = Scope.APP
|
||||||
|
|
||||||
|
@provide(scope=Scope.APP)
|
||||||
|
def sentence_transformer_model(self) -> SentenceTransformer:
|
||||||
|
return SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
from dishka import FromDishka
|
||||||
|
from dishka.integrations.fastapi import DishkaRoute
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from template_project.application.common.embedding import Embedder
|
||||||
|
|
||||||
|
router = APIRouter(route_class=DishkaRoute, tags=["Embedding"])
|
||||||
|
|
||||||
|
|
||||||
|
class GetEmbeddingRequest(BaseModel):
|
||||||
|
text: str = Field(
|
||||||
|
..., min_length=1, description="Text to encode", examples=["python backend developer with django"]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = {"json_schema_extra": {"example": {"text": "python backend developer with django"}}}
|
||||||
|
|
||||||
|
|
||||||
|
class GetEmbeddingResponse(BaseModel):
|
||||||
|
embedding: list[float] = Field(..., description="Embedding vector")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"example": {
|
||||||
|
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/get_embedding",
|
||||||
|
summary="Get embedding",
|
||||||
|
description="Encode text into embedding vector",
|
||||||
|
responses={
|
||||||
|
200: {"description": "Embedding generated successfully", "model": GetEmbeddingResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_embedding(
|
||||||
|
request: GetEmbeddingRequest,
|
||||||
|
embedder: FromDishka[Embedder],
|
||||||
|
) -> GetEmbeddingResponse:
|
||||||
|
embedding = await embedder.encode(request.text)
|
||||||
|
return GetEmbeddingResponse(embedding=embedding)
|
||||||
Reference in New Issue
Block a user