You've already forked RekomenciBackend
fix(): fix e2e ml tests, handle no entries for vacancies
This commit is contained in:
+6
-8
@@ -83,14 +83,12 @@ services:
|
||||
profiles:
|
||||
- ml
|
||||
ports:
|
||||
- 8081:8080
|
||||
# ports:
|
||||
# - name: web
|
||||
# target: 8081
|
||||
# published: 13562
|
||||
# host_ip: 127.0.0.1
|
||||
# protocol: tcp
|
||||
# app_protocol: http
|
||||
- name: web
|
||||
target: 8081
|
||||
published: 13562
|
||||
host_ip: 127.0.0.1
|
||||
protocol: tcp
|
||||
app_protocol: http
|
||||
restart: unless-stopped
|
||||
shm_size: 4mb
|
||||
volumes:
|
||||
|
||||
@@ -232,4 +232,7 @@ omit = [
|
||||
'*/__about__.py',
|
||||
'*/__main__.py',
|
||||
'*/__init__.py',
|
||||
'src/dataset/*',
|
||||
'src/template_project/ml/*',
|
||||
'src/template_project/application/resume/interactors/predict_model.py',
|
||||
]
|
||||
|
||||
@@ -22,7 +22,7 @@ from sqlalchemy.orm import registry
|
||||
|
||||
from template_project.application.access_token.entity import AccessToken
|
||||
from template_project.application.auth_identity.entity import AuthIdentity, AuthMethod
|
||||
from template_project.application.common.enums import EducationGrade, ExperienceType, ExperienceType
|
||||
from template_project.application.common.enums import EducationGrade, ExperienceType
|
||||
from template_project.application.notification_device.entity import NotificationDevice
|
||||
from template_project.application.resume.entity import (
|
||||
Resume,
|
||||
|
||||
@@ -57,6 +57,7 @@ class MlApiGateway:
|
||||
timeout=100,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
return GenerateResumePredictionResponse(
|
||||
salary_from=Decimal(str(response_json["salary_from"])),
|
||||
|
||||
@@ -46,7 +46,7 @@ class PredictModelInteractor:
|
||||
|
||||
def _predict_salary(self, vacancies: list[VacancyInput], resume_skills: list[str]) -> tuple[Decimal, Decimal]:
|
||||
if not vacancies:
|
||||
return Decimal(50000), Decimal(80000)
|
||||
return Decimal(0), Decimal(0)
|
||||
|
||||
vacancy_weights: list[float] = []
|
||||
for vacancy in vacancies:
|
||||
@@ -56,7 +56,7 @@ class PredictModelInteractor:
|
||||
|
||||
total_weight = sum(vacancy_weights)
|
||||
if total_weight == 0:
|
||||
return Decimal(50000), Decimal(80000)
|
||||
return Decimal(0), Decimal(0)
|
||||
|
||||
weighted_from_sum = Decimal(0)
|
||||
weighted_to_sum = Decimal(0)
|
||||
@@ -143,6 +143,9 @@ class PredictModelInteractor:
|
||||
if skill in candidate_skills
|
||||
}
|
||||
|
||||
if not candidate_skills:
|
||||
return []
|
||||
|
||||
frequencies = [skill_frequencies[skill] for skill in candidate_skills]
|
||||
avg_salaries = [float(skill_avg_salaries[skill]) for skill in candidate_skills]
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
@@ -11,6 +13,7 @@ from dishka import AsyncContainer
|
||||
from dishka.integrations.fastapi import setup_dishka
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from template_project.ml.configuration import load_configuration
|
||||
from template_project.ml.ioc.make import make_ioc
|
||||
@@ -37,10 +40,18 @@ LOG_CONFIG: Final = {
|
||||
}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
await app.state.dishka_container.get(SentenceTransformer)
|
||||
yield
|
||||
await app.state.dishka_container.close()
|
||||
|
||||
|
||||
def make_asgi_application(
|
||||
ioc: AsyncContainer,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
title="ML Service",
|
||||
description="ML Service API",
|
||||
@@ -81,10 +92,7 @@ async def _main(
|
||||
access_log=configuration.server.access_log,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
try:
|
||||
await server.serve()
|
||||
finally:
|
||||
await ioc.close()
|
||||
await server.serve()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -42,7 +42,7 @@ class PredictSalaryRequestModel(BaseModel):
|
||||
key_skills: list[str] = Field(
|
||||
description="List of key skills from resume", examples=[["Python", "FastAPI", "PostgreSQL"]]
|
||||
)
|
||||
vacancies: list[VacancyInputModel] = Field(description="List of relevant vacancies", examples=[[]])
|
||||
vacancies: list[VacancyInputModel] = Field(description="List of relevant vacancies", examples=[[]], min_length=0)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Final
|
||||
|
||||
import pytest
|
||||
from dirty_equals import IsDict, IsPartialDict, IsUUID
|
||||
from dirty_equals import IsDict, IsOneOf, IsPartialDict, IsUUID
|
||||
from dishka import FromDishka
|
||||
from uuid_utils.compat import uuid7
|
||||
|
||||
@@ -17,7 +16,6 @@ from tests.web_api.test_api_gateway import TestApiGateway
|
||||
DEFAULT_PASSWORD: Final = "Sup3rSecret" # noqa: S105
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires ML service")
|
||||
@inject
|
||||
async def test_success_add_resume(
|
||||
unique_email: str,
|
||||
@@ -43,7 +41,6 @@ async def test_success_add_resume(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires ML service")
|
||||
@inject
|
||||
async def test_unauthorized_add_resume(
|
||||
unique_email: str,
|
||||
@@ -63,7 +60,6 @@ async def test_unauthorized_add_resume(
|
||||
assert is_unauthorized_response(response)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires ML service")
|
||||
@inject
|
||||
async def test_success_get_resume(
|
||||
unique_email: str,
|
||||
@@ -89,7 +85,8 @@ async def test_success_get_resume(
|
||||
resume_id=response.json()["resume_id"],
|
||||
)
|
||||
assert is_success_response(response)
|
||||
assert response.json() == IsPartialDict(
|
||||
json_response = response.json()
|
||||
assert json_response == IsPartialDict(
|
||||
position="Position",
|
||||
location="Moscow",
|
||||
about_me="About me",
|
||||
@@ -98,11 +95,17 @@ async def test_success_get_resume(
|
||||
experience=[],
|
||||
education=[],
|
||||
projects=[],
|
||||
prediction=None,
|
||||
)
|
||||
assert json_response["prediction"] == IsOneOf(
|
||||
None,
|
||||
IsPartialDict(
|
||||
from_salary="0",
|
||||
to_salary="0",
|
||||
recommended_skills=[],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires ML service")
|
||||
@inject
|
||||
async def test_unauthorized_get_resume(
|
||||
unique_email: str,
|
||||
@@ -148,7 +151,6 @@ async def test_not_found_get_resume(
|
||||
assert is_not_found_response(response)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires ML service")
|
||||
@inject
|
||||
async def test_success_edit_resume(
|
||||
unique_email: str,
|
||||
@@ -191,7 +193,6 @@ async def test_success_edit_resume(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires ML service")
|
||||
@inject
|
||||
async def test_unauthorized_edit_resume(
|
||||
unique_email: str,
|
||||
@@ -245,7 +246,6 @@ async def test_not_found_edit_resume(
|
||||
assert is_not_found_response(response)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires ML service")
|
||||
@inject
|
||||
async def test_forbidden_edit_resume(
|
||||
unique_email: str,
|
||||
|
||||
Reference in New Issue
Block a user