You've already forked RekomenciBackend
Merge branch 'ml'
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
from typing import cast, override
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from template_project.application.common.embedding import Embedder
|
||||
|
||||
|
||||
class MiniLMEmbedder(Embedder):
|
||||
def __init__(self, model: SentenceTransformer) -> None:
|
||||
self._model = model
|
||||
|
||||
@override
|
||||
async def encode(self, text: str) -> list[float]:
|
||||
embedding = self._model.encode(text)
|
||||
return cast(list[float], embedding.tolist())
|
||||
@@ -0,0 +1,8 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class Embedder(Protocol):
|
||||
@abstractmethod
|
||||
async def encode(self, text: str) -> list[float]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,37 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tomllib import loads
|
||||
from typing import dataclass_transform
|
||||
|
||||
from adaptix import Retort
|
||||
|
||||
|
||||
@dataclass_transform(frozen_default=True)
|
||||
def to_configuration[ClsT](cls: type[ClsT]) -> type[ClsT]:
|
||||
return dataclass(frozen=True, slots=True, repr=False)(cls)
|
||||
|
||||
|
||||
@to_configuration
|
||||
class ServerConfiguration:
|
||||
host: str
|
||||
port: int
|
||||
access_log: bool
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
|
||||
@to_configuration
|
||||
class Configuration:
|
||||
server: ServerConfiguration
|
||||
|
||||
|
||||
retort = Retort()
|
||||
|
||||
|
||||
def load_configuration(path: Path) -> Configuration:
|
||||
with path.open("r", encoding="utf-8") as config:
|
||||
data = loads(config.read())
|
||||
|
||||
return retort.load(data, Configuration)
|
||||
@@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import uvicorn
|
||||
from dishka import AsyncContainer
|
||||
from dishka.integrations.fastapi import setup_dishka
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from template_project.ml.configuration import load_configuration
|
||||
from template_project.ml.ioc.make import make_ioc
|
||||
from template_project.ml.routes import embedding, healthcheck, predict
|
||||
|
||||
LOG_CONFIG: Final = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "default",
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"level": "DEBUG",
|
||||
"handlers": ["console"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def make_asgi_application(
|
||||
ioc: AsyncContainer,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(
|
||||
docs_url="/docs",
|
||||
title="ML Service",
|
||||
description="ML Service API",
|
||||
version="1.0.0",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(healthcheck.router)
|
||||
app.include_router(embedding.router)
|
||||
app.include_router(predict.router)
|
||||
|
||||
setup_dishka(container=ioc, app=app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def _main(
|
||||
configuration_path: Path,
|
||||
) -> None:
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
configuration = load_configuration(configuration_path)
|
||||
ioc = make_ioc(configuration)
|
||||
asgi_application = make_asgi_application(ioc)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=asgi_application,
|
||||
host=configuration.server.host,
|
||||
port=configuration.server.port,
|
||||
log_config=LOG_CONFIG,
|
||||
access_log=configuration.server.access_log,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
try:
|
||||
await server.serve()
|
||||
finally:
|
||||
await ioc.close()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument("configuration", default=None)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
configuration_path = args.configuration or os.getenv("CONFIGURATION_PATH")
|
||||
|
||||
if configuration_path is None:
|
||||
raise RuntimeError(
|
||||
"pass the path to the config or specify it in the environment variables `CONFIGURATION_PATH`",
|
||||
)
|
||||
|
||||
asyncio.run(_main(Path(configuration_path)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,39 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from template_project.application.common.data_structure import to_data_structure
|
||||
from template_project.application.common.interactor import to_interactor
|
||||
from template_project.application.resume.entity import ResumeId
|
||||
|
||||
|
||||
@to_data_structure
|
||||
class VacancyInput:
|
||||
vacancy_id: str
|
||||
from_salary: Decimal
|
||||
to_salary: Decimal
|
||||
key_skills: list[str]
|
||||
resume_similarity: float
|
||||
|
||||
|
||||
@to_data_structure
|
||||
class PredictSalaryRequest:
|
||||
resume_id: ResumeId
|
||||
key_skills: list[str]
|
||||
vacancies: list[VacancyInput]
|
||||
|
||||
|
||||
@to_data_structure
|
||||
class PredictSalaryResponse:
|
||||
salary_from: Decimal
|
||||
salary_to: Decimal
|
||||
recommended_skills: list[str]
|
||||
|
||||
|
||||
@to_interactor
|
||||
class PredictSalaryInteractor:
|
||||
async def execute(self, request: PredictSalaryRequest) -> PredictSalaryResponse:
|
||||
return PredictSalaryResponse(
|
||||
salary_from=Decimal("50000"),
|
||||
salary_to=Decimal("80000"),
|
||||
recommended_skills=["python", "django", "postgresql"],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from dishka import BaseScope, Provider, Scope, provide
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from template_project.adapters.embedding.minilm_embedder import MiniLMEmbedder
|
||||
from template_project.application.common.embedding import Embedder
|
||||
|
||||
|
||||
class EmbeddingProvider(Provider):
|
||||
scope: BaseScope | None = Scope.APP
|
||||
|
||||
@provide(scope=Scope.APP)
|
||||
def embedder(self, model: SentenceTransformer) -> Embedder:
|
||||
return MiniLMEmbedder(model=model)
|
||||
@@ -0,0 +1,12 @@
|
||||
from dishka import BaseScope, Provider, Scope, provide_all
|
||||
|
||||
from template_project.ml.interactors.predict_salary import PredictSalaryInteractor
|
||||
|
||||
|
||||
class InteractorProvider(Provider):
|
||||
scope: BaseScope | None = Scope.REQUEST
|
||||
|
||||
interactors = provide_all(
|
||||
PredictSalaryInteractor,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from dishka import STRICT_VALIDATION, AsyncContainer, make_async_container
|
||||
from dishka.integrations.fastapi import FastapiProvider
|
||||
|
||||
from template_project.ml.configuration import Configuration, ServerConfiguration
|
||||
from template_project.ml.ioc.embedding import EmbeddingProvider
|
||||
from template_project.ml.ioc.interactor import InteractorProvider
|
||||
from template_project.ml.ioc.model import ModelProvider
|
||||
|
||||
|
||||
def make_ioc(configuration: Configuration) -> AsyncContainer:
|
||||
return make_async_container(
|
||||
ModelProvider(),
|
||||
EmbeddingProvider(),
|
||||
InteractorProvider(),
|
||||
FastapiProvider(),
|
||||
validation_settings=STRICT_VALIDATION,
|
||||
context={
|
||||
ServerConfiguration: configuration.server,
|
||||
Configuration: configuration,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
from dishka import BaseScope, Provider, Scope, provide
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class ModelProvider(Provider):
|
||||
scope: BaseScope | None = Scope.APP
|
||||
|
||||
@provide(scope=Scope.APP)
|
||||
def sentence_transformer_model(self) -> SentenceTransformer:
|
||||
return SentenceTransformer("all-MiniLM-L6-v2")
|
||||
@@ -0,0 +1,44 @@
|
||||
from dishka import FromDishka
|
||||
from dishka.integrations.fastapi import DishkaRoute
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from template_project.application.common.embedding import Embedder
|
||||
|
||||
router = APIRouter(route_class=DishkaRoute, tags=["Embedding"])
|
||||
|
||||
|
||||
class GetEmbeddingRequest(BaseModel):
|
||||
text: str = Field(
|
||||
..., min_length=1, description="Text to encode", examples=["python backend developer with django"]
|
||||
)
|
||||
|
||||
model_config = {"json_schema_extra": {"example": {"text": "python backend developer with django"}}}
|
||||
|
||||
|
||||
class GetEmbeddingResponse(BaseModel):
|
||||
embedding: list[float] = Field(..., description="Embedding vector")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/get_embedding",
|
||||
summary="Get embedding",
|
||||
description="Encode text into embedding vector",
|
||||
responses={
|
||||
200: {"description": "Embedding generated successfully", "model": GetEmbeddingResponse},
|
||||
},
|
||||
)
|
||||
async def get_embedding(
|
||||
request: GetEmbeddingRequest,
|
||||
embedder: FromDishka[Embedder],
|
||||
) -> GetEmbeddingResponse:
|
||||
embedding = await embedder.encode(request.text)
|
||||
return GetEmbeddingResponse(embedding=embedding)
|
||||
@@ -0,0 +1,23 @@
|
||||
from dishka.integrations.fastapi import DishkaRoute
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter(route_class=DishkaRoute, tags=["Health"])
|
||||
|
||||
|
||||
class HealthcheckResponse(BaseModel):
|
||||
ok: bool = Field(description="Service health status")
|
||||
|
||||
model_config = {"json_schema_extra": {"example": {"ok": True}}}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/healthcheck",
|
||||
summary="Health check",
|
||||
description="Check if the service is running and healthy",
|
||||
responses={
|
||||
200: {"description": "Service is healthy", "model": HealthcheckResponse},
|
||||
},
|
||||
)
|
||||
async def healthcheck() -> HealthcheckResponse:
|
||||
return HealthcheckResponse(ok=True)
|
||||
@@ -0,0 +1,123 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from dishka import FromDishka
|
||||
from dishka.integrations.fastapi import DishkaRoute
|
||||
from fastapi import APIRouter, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from template_project.application.resume.entity import ResumeId
|
||||
from template_project.ml.interactors.predict_salary import (
|
||||
PredictSalaryInteractor,
|
||||
PredictSalaryRequest,
|
||||
PredictSalaryResponse,
|
||||
VacancyInput,
|
||||
)
|
||||
|
||||
router = APIRouter(route_class=DishkaRoute, tags=["Prediction"])
|
||||
|
||||
|
||||
class VacancyInputModel(BaseModel):
|
||||
vacancy_id: str = Field(description="Vacancy ID", examples=["vacancy_123"])
|
||||
from_salary: Decimal = Field(description="Minimum salary", examples=[Decimal(100000)])
|
||||
to_salary: Decimal = Field(description="Maximum salary", examples=[Decimal(150000)])
|
||||
key_skills: list[str] = Field(description="List of key skills", examples=[["Python", "FastAPI", "PostgreSQL"]])
|
||||
resume_similarity: float = Field(
|
||||
ge=0.0, le=1.0, description="Resume similarity score (0.0 to 1.0)", examples=[0.85]
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"vacancy_id": "vacancy_123",
|
||||
"from_salary": "100000",
|
||||
"to_salary": "150000",
|
||||
"key_skills": ["Python", "FastAPI", "PostgreSQL"],
|
||||
"resume_similarity": 0.85,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PredictSalaryRequestModel(BaseModel):
|
||||
resume_id: ResumeId = Field(description="Resume ID", examples=["01234567-89ab-cdef-0123-456789abcdef"])
|
||||
key_skills: list[str] = Field(
|
||||
min_length=1, description="List of key skills from resume", examples=[["Python", "FastAPI", "PostgreSQL"]]
|
||||
)
|
||||
vacancies: list[VacancyInputModel] = Field(
|
||||
min_length=1, description="List of relevant vacancies", examples=[[]]
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"resume_id": "01234567-89ab-cdef-0123-456789abcdef",
|
||||
"key_skills": ["Python", "FastAPI", "PostgreSQL"],
|
||||
"vacancies": [
|
||||
{
|
||||
"vacancy_id": "vacancy_123",
|
||||
"from_salary": "100000",
|
||||
"to_salary": "150000",
|
||||
"key_skills": ["Python", "FastAPI", "PostgreSQL", "Docker"],
|
||||
"resume_similarity": 0.85,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PredictSalaryResponseModel(BaseModel):
|
||||
salary_from: Decimal = Field(description="Minimum predicted salary", examples=[Decimal(100000)])
|
||||
salary_to: Decimal = Field(description="Maximum predicted salary", examples=[Decimal(150000)])
|
||||
recommended_skills: list[str] = Field(
|
||||
description="Top 3 recommended skills", examples=[["Kubernetes", "Redis", "Docker"]]
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"salary_from": "100000",
|
||||
"salary_to": "150000",
|
||||
"recommended_skills": ["Kubernetes", "Redis", "Docker"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict_salary",
|
||||
summary="Predict salary",
|
||||
description="Predict salary range and recommend skills based on resume and relevant vacancies",
|
||||
responses={
|
||||
200: {"description": "Salary prediction generated successfully", "model": PredictSalaryResponseModel},
|
||||
},
|
||||
)
|
||||
async def predict_salary(
|
||||
request: PredictSalaryRequestModel,
|
||||
interactor: FromDishka[PredictSalaryInteractor],
|
||||
) -> PredictSalaryResponseModel:
|
||||
vacancy_inputs = [
|
||||
VacancyInput(
|
||||
vacancy_id=vacancy.vacancy_id,
|
||||
from_salary=vacancy.from_salary,
|
||||
to_salary=vacancy.to_salary,
|
||||
key_skills=vacancy.key_skills,
|
||||
resume_similarity=vacancy.resume_similarity,
|
||||
)
|
||||
for vacancy in request.vacancies
|
||||
]
|
||||
|
||||
predict_request = PredictSalaryRequest(
|
||||
resume_id=request.resume_id,
|
||||
key_skills=request.key_skills,
|
||||
vacancies=vacancy_inputs,
|
||||
)
|
||||
|
||||
response = await interactor.execute(predict_request)
|
||||
|
||||
return PredictSalaryResponseModel(
|
||||
salary_from=response.salary_from,
|
||||
salary_to=response.salary_to,
|
||||
recommended_skills=response.recommended_skills,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user