fix(): fix e2e ml tests, handle no entries for vacancies

This commit is contained in:
gitgernit
2025-11-23 11:44:08 +03:00
parent a879da4ed5
commit b15282baef
8 changed files with 40 additions and 27 deletions
@@ -22,7 +22,7 @@ from sqlalchemy.orm import registry
from template_project.application.access_token.entity import AccessToken
from template_project.application.auth_identity.entity import AuthIdentity, AuthMethod
from template_project.application.common.enums import EducationGrade, ExperienceType, ExperienceType
from template_project.application.common.enums import EducationGrade, ExperienceType
from template_project.application.notification_device.entity import NotificationDevice
from template_project.application.resume.entity import (
Resume,
@@ -57,6 +57,7 @@ class MlApiGateway:
timeout=100,
)
response.raise_for_status()
response_json = response.json()
return GenerateResumePredictionResponse(
salary_from=Decimal(str(response_json["salary_from"])),
@@ -46,7 +46,7 @@ class PredictModelInteractor:
def _predict_salary(self, vacancies: list[VacancyInput], resume_skills: list[str]) -> tuple[Decimal, Decimal]:
if not vacancies:
return Decimal(50000), Decimal(80000)
return Decimal(0), Decimal(0)
vacancy_weights: list[float] = []
for vacancy in vacancies:
@@ -56,7 +56,7 @@ class PredictModelInteractor:
total_weight = sum(vacancy_weights)
if total_weight == 0:
return Decimal(50000), Decimal(80000)
return Decimal(0), Decimal(0)
weighted_from_sum = Decimal(0)
weighted_to_sum = Decimal(0)
@@ -143,6 +143,9 @@ class PredictModelInteractor:
if skill in candidate_skills
}
if not candidate_skills:
return []
frequencies = [skill_frequencies[skill] for skill in candidate_skills]
avg_salaries = [float(skill_avg_salaries[skill]) for skill in candidate_skills]
+12 -4
View File
@@ -3,6 +3,8 @@ import asyncio
import logging
import os
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Final
@@ -11,6 +13,7 @@ from dishka import AsyncContainer
from dishka.integrations.fastapi import setup_dishka
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from template_project.ml.configuration import load_configuration
from template_project.ml.ioc.make import make_ioc
@@ -37,10 +40,18 @@ LOG_CONFIG: Final = {
}
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
await app.state.dishka_container.get(SentenceTransformer)
yield
await app.state.dishka_container.close()
def make_asgi_application(
ioc: AsyncContainer,
) -> FastAPI:
app = FastAPI(
lifespan=lifespan,
docs_url="/docs",
title="ML Service",
description="ML Service API",
@@ -81,10 +92,7 @@ async def _main(
access_log=configuration.server.access_log,
)
server = uvicorn.Server(config)
try:
await server.serve()
finally:
await ioc.close()
await server.serve()
def main() -> None:
+1 -1
View File
@@ -42,7 +42,7 @@ class PredictSalaryRequestModel(BaseModel):
key_skills: list[str] = Field(
description="List of key skills from resume", examples=[["Python", "FastAPI", "PostgreSQL"]]
)
vacancies: list[VacancyInputModel] = Field(description="List of relevant vacancies", examples=[[]])
vacancies: list[VacancyInputModel] = Field(description="List of relevant vacancies", examples=[[]], min_length=0)
model_config = {
"json_schema_extra": {