fix(): fix e2e ml tests, handle no entries for vacancies

This commit is contained in:
gitgernit
2025-11-23 11:44:08 +03:00
parent a879da4ed5
commit b15282baef
8 changed files with 40 additions and 27 deletions
+6 -8
View File
@@ -83,14 +83,12 @@ services:
profiles: profiles:
- ml - ml
ports: ports:
- 8081:8080 - name: web
# ports: target: 8081
# - name: web published: 13562
# target: 8081 host_ip: 127.0.0.1
# published: 13562 protocol: tcp
# host_ip: 127.0.0.1 app_protocol: http
# protocol: tcp
# app_protocol: http
restart: unless-stopped restart: unless-stopped
shm_size: 4mb shm_size: 4mb
volumes: volumes:
+3
View File
@@ -232,4 +232,7 @@ omit = [
'*/__about__.py', '*/__about__.py',
'*/__main__.py', '*/__main__.py',
'*/__init__.py', '*/__init__.py',
'src/dataset/*',
'src/template_project/ml/*',
'src/template_project/application/resume/interactors/predict_model.py',
] ]
@@ -22,7 +22,7 @@ from sqlalchemy.orm import registry
from template_project.application.access_token.entity import AccessToken from template_project.application.access_token.entity import AccessToken
from template_project.application.auth_identity.entity import AuthIdentity, AuthMethod from template_project.application.auth_identity.entity import AuthIdentity, AuthMethod
from template_project.application.common.enums import EducationGrade, ExperienceType, ExperienceType from template_project.application.common.enums import EducationGrade, ExperienceType
from template_project.application.notification_device.entity import NotificationDevice from template_project.application.notification_device.entity import NotificationDevice
from template_project.application.resume.entity import ( from template_project.application.resume.entity import (
Resume, Resume,
@@ -57,6 +57,7 @@ class MlApiGateway:
timeout=100, timeout=100,
) )
response.raise_for_status()
response_json = response.json() response_json = response.json()
return GenerateResumePredictionResponse( return GenerateResumePredictionResponse(
salary_from=Decimal(str(response_json["salary_from"])), salary_from=Decimal(str(response_json["salary_from"])),
@@ -46,7 +46,7 @@ class PredictModelInteractor:
def _predict_salary(self, vacancies: list[VacancyInput], resume_skills: list[str]) -> tuple[Decimal, Decimal]: def _predict_salary(self, vacancies: list[VacancyInput], resume_skills: list[str]) -> tuple[Decimal, Decimal]:
if not vacancies: if not vacancies:
return Decimal(50000), Decimal(80000) return Decimal(0), Decimal(0)
vacancy_weights: list[float] = [] vacancy_weights: list[float] = []
for vacancy in vacancies: for vacancy in vacancies:
@@ -56,7 +56,7 @@ class PredictModelInteractor:
total_weight = sum(vacancy_weights) total_weight = sum(vacancy_weights)
if total_weight == 0: if total_weight == 0:
return Decimal(50000), Decimal(80000) return Decimal(0), Decimal(0)
weighted_from_sum = Decimal(0) weighted_from_sum = Decimal(0)
weighted_to_sum = Decimal(0) weighted_to_sum = Decimal(0)
@@ -143,6 +143,9 @@ class PredictModelInteractor:
if skill in candidate_skills if skill in candidate_skills
} }
if not candidate_skills:
return []
frequencies = [skill_frequencies[skill] for skill in candidate_skills] frequencies = [skill_frequencies[skill] for skill in candidate_skills]
avg_salaries = [float(skill_avg_salaries[skill]) for skill in candidate_skills] avg_salaries = [float(skill_avg_salaries[skill]) for skill in candidate_skills]
+12 -4
View File
@@ -3,6 +3,8 @@ import asyncio
import logging import logging
import os import os
import sys import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Final from typing import Final
@@ -11,6 +13,7 @@ from dishka import AsyncContainer
from dishka.integrations.fastapi import setup_dishka from dishka.integrations.fastapi import setup_dishka
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from template_project.ml.configuration import load_configuration from template_project.ml.configuration import load_configuration
from template_project.ml.ioc.make import make_ioc from template_project.ml.ioc.make import make_ioc
@@ -37,10 +40,18 @@ LOG_CONFIG: Final = {
} }
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
await app.state.dishka_container.get(SentenceTransformer)
yield
await app.state.dishka_container.close()
def make_asgi_application( def make_asgi_application(
ioc: AsyncContainer, ioc: AsyncContainer,
) -> FastAPI: ) -> FastAPI:
app = FastAPI( app = FastAPI(
lifespan=lifespan,
docs_url="/docs", docs_url="/docs",
title="ML Service", title="ML Service",
description="ML Service API", description="ML Service API",
@@ -81,10 +92,7 @@ async def _main(
access_log=configuration.server.access_log, access_log=configuration.server.access_log,
) )
server = uvicorn.Server(config) server = uvicorn.Server(config)
try: await server.serve()
await server.serve()
finally:
await ioc.close()
def main() -> None: def main() -> None:
+1 -1
View File
@@ -42,7 +42,7 @@ class PredictSalaryRequestModel(BaseModel):
key_skills: list[str] = Field( key_skills: list[str] = Field(
description="List of key skills from resume", examples=[["Python", "FastAPI", "PostgreSQL"]] description="List of key skills from resume", examples=[["Python", "FastAPI", "PostgreSQL"]]
) )
vacancies: list[VacancyInputModel] = Field(description="List of relevant vacancies", examples=[[]]) vacancies: list[VacancyInputModel] = Field(description="List of relevant vacancies", examples=[[]], min_length=0)
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
+11 -11
View File
@@ -1,7 +1,6 @@
from typing import Final from typing import Final
import pytest from dirty_equals import IsDict, IsOneOf, IsPartialDict, IsUUID
from dirty_equals import IsDict, IsPartialDict, IsUUID
from dishka import FromDishka from dishka import FromDishka
from uuid_utils.compat import uuid7 from uuid_utils.compat import uuid7
@@ -17,7 +16,6 @@ from tests.web_api.test_api_gateway import TestApiGateway
DEFAULT_PASSWORD: Final = "Sup3rSecret" # noqa: S105 DEFAULT_PASSWORD: Final = "Sup3rSecret" # noqa: S105
@pytest.mark.skip(reason="Requires ML service")
@inject @inject
async def test_success_add_resume( async def test_success_add_resume(
unique_email: str, unique_email: str,
@@ -43,7 +41,6 @@ async def test_success_add_resume(
) )
@pytest.mark.skip(reason="Requires ML service")
@inject @inject
async def test_unauthorized_add_resume( async def test_unauthorized_add_resume(
unique_email: str, unique_email: str,
@@ -63,7 +60,6 @@ async def test_unauthorized_add_resume(
assert is_unauthorized_response(response) assert is_unauthorized_response(response)
@pytest.mark.skip(reason="Requires ML service")
@inject @inject
async def test_success_get_resume( async def test_success_get_resume(
unique_email: str, unique_email: str,
@@ -89,7 +85,8 @@ async def test_success_get_resume(
resume_id=response.json()["resume_id"], resume_id=response.json()["resume_id"],
) )
assert is_success_response(response) assert is_success_response(response)
assert response.json() == IsPartialDict( json_response = response.json()
assert json_response == IsPartialDict(
position="Position", position="Position",
location="Moscow", location="Moscow",
about_me="About me", about_me="About me",
@@ -98,11 +95,17 @@ async def test_success_get_resume(
experience=[], experience=[],
education=[], education=[],
projects=[], projects=[],
prediction=None, )
assert json_response["prediction"] == IsOneOf(
None,
IsPartialDict(
from_salary="0",
to_salary="0",
recommended_skills=[],
),
) )
@pytest.mark.skip(reason="Requires ML service")
@inject @inject
async def test_unauthorized_get_resume( async def test_unauthorized_get_resume(
unique_email: str, unique_email: str,
@@ -148,7 +151,6 @@ async def test_not_found_get_resume(
assert is_not_found_response(response) assert is_not_found_response(response)
@pytest.mark.skip(reason="Requires ML service")
@inject @inject
async def test_success_edit_resume( async def test_success_edit_resume(
unique_email: str, unique_email: str,
@@ -191,7 +193,6 @@ async def test_success_edit_resume(
) )
@pytest.mark.skip(reason="Requires ML service")
@inject @inject
async def test_unauthorized_edit_resume( async def test_unauthorized_edit_resume(
unique_email: str, unique_email: str,
@@ -245,7 +246,6 @@ async def test_not_found_edit_resume(
assert is_not_found_response(response) assert is_not_found_response(response)
@pytest.mark.skip(reason="Requires ML service")
@inject @inject
async def test_forbidden_edit_resume( async def test_forbidden_edit_resume(
unique_email: str, unique_email: str,