feat(): prediction pipeline

This commit is contained in:
gitgernit
2025-11-23 04:11:52 +03:00
parent 2e6214a5ec
commit d1c7641698
25 changed files with 224 additions and 244 deletions
+1 -1
View File
@@ -24,4 +24,4 @@ access_key = ""
secret_key = ""
[ml_api]
url = "http://localhost:90"
url = "http://ml:8081"
-49
View File
@@ -1,49 +0,0 @@
import json
import logging
import urllib.parse
from collections.abc import Sequence
from pathlib import Path
from typing import Final
from adaptix import DebugTrail, NameStyle, Retort, name_mapping
from dataset.data_structures import DataSetLine, Salary
from dataset.upload_key_skills import upload_key_kills
DATASET_PATH: Final = Path("hh_ru_vacancies.jsonlines")
BASE_URL: Final = "https://team-39-alpha-gm5qjkou.hack.prodcontest.ru"
UPLOAD_KEY_SKILLS: Final = urllib.parse.urljoin(BASE_URL, "key_skills")
logger = logging.getLogger(__name__)
def parse_dataset(file_path: Path) -> Sequence[DataSetLine]:
retort = Retort(
recipe=[
name_mapping(Salary, name_style=NameStyle.CAMEL),
],
debug_trail=DebugTrail.DISABLE,
strict_coercion=False,
)
raw_lines = []
with file_path.open("r", encoding="utf-8") as f:
raw_lines = map(json.loads, f.readlines())
return retort.load(raw_lines, Sequence[DataSetLine])
def main() -> None:
logging.basicConfig(level=logging.INFO)
logger.info("Parsing dataset...")
dataset = parse_dataset(DATASET_PATH)
upload_key_kills(dataset, UPLOAD_KEY_SKILLS)
logger.info("finished script")
if __name__ == "__main__":
main()
+1 -1
View File
@@ -41,7 +41,7 @@ async def main() -> None:
ml_container = make_ml_ioc(ml_configuration)
csv_path = Path("filtered_vacancies.csv")
max_records = 51
max_records = 1000
try:
async with backend_container() as backend_request_container, ml_container() as ml_request_container:
-37
View File
@@ -1,37 +0,0 @@
import logging
from collections.abc import Sequence
from requests import Session
from dataset.data_structures import DataSetLine
logger = logging.getLogger(__name__)
def upload_key_kills(
dataset: Sequence[DataSetLine],
upload_endpoint: str,
max_upload_count: int | None = None
) -> None:
session = Session()
key_skills = []
for count, line in enumerate(dataset):
if max_upload_count is not None and count >= max_upload_count:
break
key_skills.extend(line.key_skills)
logger.info("Upload skills %r", key_skills)
response = session.post(
upload_endpoint,
json={
"key_skills": key_skills
}
)
if response.status_code != 200:
logger.warning("Doesn't upload skills. Status code %r", response.status_code)
else:
logger.info("Upload skills %r. Status code %r", key_skills, response.status_code)
key_skills = []
@@ -26,7 +26,6 @@ from template_project.application.resume.entity import (
ResumePrediction,
ResumeProject,
)
from template_project.application.resume.errors import ResumeNotFoundError
from template_project.application.user.entity import UserId
@@ -35,12 +34,8 @@ class DefaultResumeDataGateway(ResumeDataGateway):
self._session = session
@override
async def load(self, resume_id: ResumeId) -> Resume:
resume = await self._session.get(Resume, resume_id)
if resume is None:
raise ResumeNotFoundError(resume_id=resume_id)
return resume
async def load_by_resume_id(self, resume_id: ResumeId) -> Resume | None:
return await self._session.get(Resume, resume_id)
@override
async def list_by_user_id(self, user_id: UserId, limit: int, offset: int) -> Sequence[Resume]:
@@ -64,11 +59,15 @@ class DefaultResumeDataGateway(ResumeDataGateway):
async def get_history(self, resume_id: ResumeId) -> Sequence[Resume]:
# TODO: N+1
history: list[Resume] = []
current_resume = await self.load(resume_id)
current_resume = await self.load_by_resume_id(resume_id)
if current_resume is None:
return history
history.append(current_resume)
while current_resume.down_resume_id is not None:
current_resume = await self.load(current_resume.down_resume_id)
current_resume = await self.load_by_resume_id(current_resume.down_resume_id)
if current_resume is None:
break
history.append(current_resume)
return history
@@ -70,6 +70,34 @@ class StringArrayType(TypeDecorator[list[str]]):
return []
class ExperienceTypeType(TypeDecorator[ExperienceType]):
impl: Any = String
cache_ok: bool | None = True
@override
def process_bind_param(self, value: Any, dialect: Any) -> Any:
if value is None:
return None
if isinstance(value, ExperienceType):
return value.value
if isinstance(value, str):
return value
return None
@override
def process_result_value(self, value: Any, dialect: Any) -> ExperienceType:
if value is None:
raise ValueError("experience_type cannot be None")
if isinstance(value, ExperienceType):
return value
if isinstance(value, str):
try:
return ExperienceType(value)
except ValueError:
raise ValueError(f"Invalid experience_type value: {value}")
raise ValueError(f"Cannot convert {type(value)} to ExperienceType")
user_table: Final = Table(
"users",
meta_data,
@@ -138,7 +166,7 @@ resume_table: Final = Table(
Column("location", String, nullable=False),
Column("about_me", String, nullable=False),
Column("key_skills", StringArrayType(), nullable=False, server_default=text("'[]'::jsonb")),
Column("experience_type", String, nullable=False),
Column("experience_type", ExperienceTypeType(), nullable=False),
Column("down_resume_id", UUID, ForeignKey("resume.id", ondelete="CASCADE"), nullable=True, default=None),
Column("up_resume_id", UUID, ForeignKey("resume.id", ondelete="CASCADE"), nullable=True, default=None),
)
@@ -241,6 +269,7 @@ mapper_registry.map_imperatively(
resume_table,
properties={
"key_skills": resume_table.c.key_skills,
"experience_type": resume_table.c.experience_type,
},
)
mapper_registry.map_imperatively(ResumeEmbedding, resume_embedding_table)
@@ -20,6 +20,8 @@ class DefaultVacancyDataGateway(VacancyDataGateway):
select(Vacancy, label("resume_similarity", vacancy_embedding_table.c.vector.cosine_distance(vector)))
.join(VacancyEmbedding, vacancy_embedding_table.c.vacancy_id == vacancy_table.c.id)
.where(vacancy_embedding_table.c.vector.cosine_distance(vector) > 0.5)
.order_by(vacancy_embedding_table.c.vector.cosine_distance(vector).asc())
.limit(100)
)
result = await self._session.execute(statement)
return [
@@ -1,16 +1,8 @@
from typing import Final, override
from typing import override
from template_project.adapters.ml_api_gateway import MlApiGateway
from template_project.application.common.enums import ExperienceType
from template_project.application.resume.vector_generator import ResumeEmbeddingVectorGenerator
EMBEDDING_TEXT_TEMPLATE: Final = """
Позиция: {position}
Опыт: {experience_type}
Ключевые навыки: {key_skills}
Описание: {about_me}
"""
class DefaultResumeEmbeddingVectorGenerator(ResumeEmbeddingVectorGenerator):
def __init__(self, ml_api_gateway: MlApiGateway) -> None:
@@ -19,17 +11,6 @@ class DefaultResumeEmbeddingVectorGenerator(ResumeEmbeddingVectorGenerator):
@override
async def generate(
self,
position: str,
about_me: str,
experience_type: ExperienceType,
key_skills: list[str],
text: str,
) -> list[float]:
text = EMBEDDING_TEXT_TEMPLATE.format_map(
{
"position": position,
"experience_type": experience_type,
"key_skills": ", ".join(key_skills),
"about_me": about_me,
}
)
return await self._ml_api_gateway.generate_embedding(text)
@@ -15,20 +15,20 @@ class DefaultResumePredictionGenerator(ResumePredictionGenerator):
async def generate(
self,
resume: Resume,
suituble_vacancies: Sequence[SuitableVacancy],
suitable_vacancies: Sequence[SuitableVacancy],
) -> ResumePrediction:
response = await self._ml_api_gateway.generate_resume_prediction(
resume_id=resume.id,
key_skills=resume.key_skills,
suituble_vacancies=[
suitable_vacancies=[
SuitableVacancyDs(
vacancy_id=suituble_vacancy.vacancy.id,
from_salary=suituble_vacancy.vacancy.from_salary,
to_salary=suituble_vacancy.vacancy.to_salary,
key_skills=suituble_vacancy.vacancy.key_skills,
resume_similarity=suituble_vacancy.resume_similarity,
vacancy_id=suitable_vacancy.vacancy.id,
from_salary=suitable_vacancy.vacancy.from_salary,
to_salary=suitable_vacancy.vacancy.to_salary,
key_skills=suitable_vacancy.vacancy.key_skills,
resume_similarity=suitable_vacancy.resume_similarity,
)
for suituble_vacancy in suituble_vacancies
for suitable_vacancy in suitable_vacancies
],
)
return ResumePrediction.factory(
+12 -11
View File
@@ -36,28 +36,29 @@ class MlApiGateway:
self,
resume_id: ResumeId,
key_skills: list[str],
suituble_vacancies: Sequence[SuitableVacancyDs],
suitable_vacancies: Sequence[SuitableVacancyDs],
) -> GenerateResumePredictionResponse:
response = await self._client.post(
"/predict_salary",
"/predict",
json={
"resume_id": resume_id,
"resume_id": str(resume_id),
"key_skills": key_skills,
"vacancies": [
{
"vacancy_id": suituble_vacancy.vacancy_id,
"from_salary": suituble_vacancy.from_salary,
"to_salary": suituble_vacancy.to_salary,
"key_skills": suituble_vacancy.key_skills,
"resume_similarity": suituble_vacancy.resume_similarity,
} for suituble_vacancy in suituble_vacancies
"vacancy_id": str(suitable_vacancy.vacancy_id),
"from_salary": str(suitable_vacancy.from_salary),
"to_salary": str(suitable_vacancy.to_salary),
"key_skills": suitable_vacancy.key_skills,
"resume_similarity": suitable_vacancy.resume_similarity,
}
for suitable_vacancy in suitable_vacancies
],
},
)
response_json = response.json()
return GenerateResumePredictionResponse(
salary_from=response_json["salary_from"],
salary_to=response_json["salary_to"],
salary_from=Decimal(str(response_json["salary_from"])),
salary_to=Decimal(str(response_json["salary_to"])),
recommended_skills=response_json["recommended_skills"],
)
@@ -15,7 +15,7 @@ from template_project.application.user.entity import UserId
class ResumeDataGateway(Protocol):
@abstractmethod
async def load(self, resume_id: ResumeId) -> Resume:
async def load_by_resume_id(self, resume_id: ResumeId) -> Resume | None:
raise NotImplementedError
@abstractmethod
@@ -10,7 +10,6 @@ from template_project.application.resume.entity import (
ResumeId,
ResumeProject,
)
from template_project.application.resume.interactors.resume_embedding import ResumeEmbeddingInteractor
@to_data_structure
@@ -38,8 +37,6 @@ class ProjectInput:
class AddResumeInteractor:
unit_of_work: UnitOfWork
identity_provider: IdentityProvider
# TODO: переделать в фоновую таску
resume_embedding_interactor: ResumeEmbeddingInteractor
async def execute(
self,
@@ -97,8 +94,6 @@ class AddResumeInteractor:
)
await self.unit_of_work.add(resume_project)
await self.resume_embedding_interactor.run(resume)
await self.unit_of_work.commit()
return resume.id
@@ -16,7 +16,7 @@ from template_project.application.resume.entity import (
ResumeId,
ResumeProject,
)
from template_project.application.resume.errors import ResumeDoesBelongUserError
from template_project.application.resume.errors import ResumeDoesBelongUserError, ResumeNotFoundError
@to_data_structure
@@ -96,7 +96,9 @@ class EditResumeInteractor:
projects: list[ProjectInput] | None = None,
) -> EditResumeResponse:
user = await self.identity_provider.get_current_user()
old_resume = await self.resume_data_gateway.load(resume_id)
old_resume = await self.resume_data_gateway.load_by_resume_id(resume_id)
if old_resume is None:
raise ResumeNotFoundError(resume_id=resume_id)
if old_resume.user_id != user.id:
raise ResumeDoesBelongUserError
@@ -12,7 +12,7 @@ from template_project.application.resume.data_gateway import (
ResumeProjectDataGateway,
)
from template_project.application.resume.entity import ResumeId
from template_project.application.resume.errors import ResumeDoesBelongUserError
from template_project.application.resume.errors import ResumeDoesBelongUserError, ResumeNotFoundError
@to_data_structure
@@ -72,7 +72,9 @@ class GetResumeInteractor:
) -> GetResumeResponse:
user = await self.identity_provider.get_current_user()
resume = await self.resume_data_gateway.load(resume_id)
resume = await self.resume_data_gateway.load_by_resume_id(resume_id)
if resume is None:
raise ResumeNotFoundError(resume_id=resume_id)
if resume.user_id != user.id:
raise ResumeDoesBelongUserError
@@ -167,7 +169,9 @@ class GetResumeHistoryInteractor:
async def execute(self, resume_id: ResumeId) -> list[ResumeListItemResponse]:
user = await self.identity_provider.get_current_user()
resume = await self.resume_data_gateway.load(resume_id)
resume = await self.resume_data_gateway.load_by_resume_id(resume_id)
if resume is None:
raise ResumeNotFoundError(resume_id=resume_id)
if resume.user_id != user.id:
raise ResumeDoesBelongUserError
@@ -33,7 +33,7 @@ class PredictSalaryResponse:
@to_interactor
class PredictSalaryInteractor:
class PredictModelInteractor:
async def execute(self, request: PredictSalaryRequest) -> PredictSalaryResponse:
salary_from, salary_to = self._predict_salary(request.vacancies, request.key_skills)
recommended_skills = self._recommend_skills(request.vacancies, request.key_skills)
@@ -0,0 +1,102 @@
from typing import Final
from Levenshtein import ratio
from template_project.application.common.data_structure import to_data_structure
from template_project.application.common.interactor import to_interactor
from template_project.application.common.unit_of_work import UnitOfWork
from template_project.application.resume.data_gateway import ResumeDataGateway
from template_project.application.resume.entity import Resume, ResumeEmbedding, ResumeId
from template_project.application.resume.resume_prediction_generator import ResumePredictionGenerator
from template_project.application.resume.vector_generator import ResumeEmbeddingVectorGenerator
from template_project.application.vacancy.data_gateway import VacancyDataGateway
from template_project.application.vacancy.data_structure import SuitableVacancy
EMBEDDING_TEXT_TEMPLATE: Final = """
Позиция: {position}
Опыт: {experience_type}
Ключевые навыки: {key_skills}
Описание: {about_me}
"""
def _calculate_skills_matching(resume_skills: list[str], vacancy_skills: list[str]) -> float:
count_skills = 0
ratio_skill_sum = 0.0
for resume_key_skill in resume_skills:
for vacancy_key_skill in vacancy_skills:
ratio_skill = ratio(resume_key_skill, vacancy_key_skill)
if ratio_skill != 0:
count_skills += 1
ratio_skill_sum += ratio_skill
try:
return ratio_skill_sum / count_skills
except ZeroDivisionError:
return 0.0
def _filter_and_sort_vacancies(
resume: Resume,
suitable_vacancies: list[SuitableVacancy],
limit: int = 50,
) -> list[SuitableVacancy]:
def is_suitable(vacancy: SuitableVacancy) -> bool:
experience_match = resume.experience_type == vacancy.vacancy.experience_type
skills_matching = _calculate_skills_matching(resume.key_skills, vacancy.vacancy.key_skills)
skills_match = skills_matching >= 0.5
return experience_match and skills_match
filtered = [v for v in suitable_vacancies if is_suitable(v)]
if len(filtered) >= limit:
filtered.sort(key=lambda v: v.resume_similarity, reverse=True)
return filtered[:limit]
remaining = [v for v in suitable_vacancies if v not in filtered]
remaining.sort(key=lambda v: v.resume_similarity, reverse=True)
total_needed = limit - len(filtered)
return filtered + remaining[:total_needed]
@to_data_structure
class PredictResumeRequest:
resume_id: ResumeId
@to_interactor
class ResumePredictionInteractor:
unit_of_work: UnitOfWork
resume_data_gateway: ResumeDataGateway
vacancy_data_gateway: VacancyDataGateway
vector_generator: ResumeEmbeddingVectorGenerator
resume_prediction_generator: ResumePredictionGenerator
async def execute(self, request: PredictResumeRequest) -> None:
resume = await self.resume_data_gateway.load_by_resume_id(request.resume_id)
if resume is None:
return
embedding_text = EMBEDDING_TEXT_TEMPLATE.format_map({
"position": resume.position,
"experience_type": resume.experience_type.value,
"key_skills": ", ".join(resume.key_skills),
"about_me": resume.about_me,
})
vector = await self.vector_generator.generate(embedding_text)
resume_embedding = ResumeEmbedding.factory(
resume_id=resume.id,
vector=vector,
)
suitable_vacancies_list = list(await self.vacancy_data_gateway.get_suitable(resume_embedding.vector))
suitable_vacancies_filtered = _filter_and_sort_vacancies(resume, suitable_vacancies_list, limit=50)
resume_prediction = await self.resume_prediction_generator.generate(
resume=resume,
suitable_vacancies=suitable_vacancies_filtered,
)
await self.unit_of_work.add(resume_embedding, resume_prediction)
await self.unit_of_work.commit()
@@ -1,76 +0,0 @@
from collections.abc import Callable
from Levenshtein import ratio
from template_project.application.common.unit_of_work import UnitOfWork
from template_project.application.resume.entity import Resume, ResumeEmbedding
from template_project.application.resume.resume_prediction_generator import ResumePredictionGenerator
from template_project.application.resume.vector_generator import ResumeEmbeddingVectorGenerator
from template_project.application.vacancy.data_gateway import VacancyDataGateway
from template_project.application.vacancy.data_structure import SuitableVacancy
def suitable_vacancies_key(
resume: Resume,
) -> Callable[[SuitableVacancy], tuple[bool, bool]]:
def wrapper(suitable_vacancy: SuitableVacancy) -> tuple[bool, bool]:
count_skills = 0
ratio_skill_sum = 0.0
for resum_key_skill in resume.key_skills:
for suitable_resume_key_skill in suitable_vacancy.vacancy.key_skills:
ratio_skill = ratio(resum_key_skill, suitable_resume_key_skill)
if ratio_skill != 0:
count_skills += 1
ratio_skill_sum += ratio_skill
try:
matching_skills = ratio_skill_sum / count_skills
except ZeroDivisionError:
matching_skills = 0
return resume.experience_type == suitable_vacancy.vacancy.experience_type, matching_skills >= 50
return wrapper
class ResumeEmbeddingInteractor:
def __init__(
self,
unit_of_work: UnitOfWork,
vacancy_data_gateway: VacancyDataGateway,
vector_generator: ResumeEmbeddingVectorGenerator,
resume_prediction_generator: ResumePredictionGenerator,
) -> None:
self.unit_of_work = unit_of_work
self.vector_generator = vector_generator
self.vacancy_data_gateway = vacancy_data_gateway
self.resume_prediction_generator = resume_prediction_generator
async def run(
self,
resume: Resume,
) -> None:
vector = await self.vector_generator.generate(
position=resume.position,
about_me=resume.about_me,
key_skills=resume.key_skills,
experience_type=resume.experience_type,
)
resume_embedding = ResumeEmbedding.factory(
resume_id=resume.id,
vector=vector,
)
suitable_vacancies = await self.vacancy_data_gateway.get_suitable(resume_embedding.vector)
suitable_vacancies_filtered = sorted(
suitable_vacancies,
key=suitable_vacancies_key(resume),
)[:50]
resume_prediction = await self.resume_prediction_generator.generate(
resume=resume,
suituble_vacancies=suitable_vacancies_filtered,
)
await self.unit_of_work.add(resume_embedding, resume_prediction)
await self.unit_of_work.commit()
@@ -11,6 +11,6 @@ class ResumePredictionGenerator(Protocol):
async def generate(
self,
resume: Resume,
suituble_vacancies: Sequence[SuitableVacancy],
suitable_vacancies: Sequence[SuitableVacancy],
) -> ResumePrediction:
raise NotImplementedError
@@ -1,15 +1,10 @@
from abc import abstractmethod
from template_project.application.common.enums import ExperienceType
class ResumeEmbeddingVectorGenerator:
@abstractmethod
async def generate(
self,
position: str,
about_me: str,
experience_type: ExperienceType,
key_skills: list[str],
text: str,
) -> list[float]:
raise NotImplementedError
+2 -2
View File
@@ -1,11 +1,11 @@
from dishka import BaseScope, Provider, Scope, provide_all
from template_project.application.resume.interactors.predict_salary import PredictSalaryInteractor
from template_project.application.resume.interactors.predict_model import PredictModelInteractor
class InteractorProvider(Provider):
scope: BaseScope | None = Scope.REQUEST
interactors = provide_all(
PredictSalaryInteractor,
PredictModelInteractor,
)
+3 -3
View File
@@ -6,8 +6,8 @@ from fastapi import APIRouter
from pydantic import BaseModel, Field
from template_project.application.resume.entity import ResumeId
from template_project.application.resume.interactors.predict_salary import (
PredictSalaryInteractor,
from template_project.application.resume.interactors.predict_model import (
PredictModelInteractor,
PredictSalaryRequest,
VacancyInput,
)
@@ -94,7 +94,7 @@ class PredictSalaryResponseModel(BaseModel):
)
async def predict(
request: PredictSalaryRequestModel,
interactor: FromDishka[PredictSalaryInteractor],
interactor: FromDishka[PredictModelInteractor],
) -> PredictSalaryResponseModel:
vacancy_inputs = [
VacancyInput(
@@ -2,7 +2,7 @@ from collections.abc import AsyncIterable
from aioboto3.session import Session
from dishka import Provider, Scope, provide
from httpx import AsyncClient
from httpx import AsyncClient, Timeout
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from template_project.adapters.ml_api_gateway import MlApiGateway
@@ -40,5 +40,6 @@ class ConnectionProvider(Provider):
@provide(scope=Scope.APP)
async def ml_api_gateway(self, config: MlApiConfiguration) -> AsyncIterable[MlApiGateway]:
async with AsyncClient(base_url=config.url) as client:
timeout = Timeout(30.0, read=30.0)
async with AsyncClient(base_url=config.url, timeout=timeout) as client:
yield MlApiGateway(client)
@@ -13,7 +13,7 @@ from template_project.application.resume.interactors.get import (
GetResumeInteractor,
GetResumeListInteractor,
)
from template_project.application.resume.interactors.resume_embedding import ResumeEmbeddingInteractor
from template_project.application.resume.interactors.prediction_pipeline import ResumePredictionInteractor
from template_project.application.user.profile.interactors.get_profile import GetProfileInteractor
from template_project.application.user.profile.interactors.patch_profile import PatchProfileInteractor
@@ -33,5 +33,5 @@ class InteractorProvider(Provider):
GetResumeHistoryInteractor,
AddResumeInteractor,
EditResumeInteractor,
ResumeEmbeddingInteractor,
ResumePredictionInteractor,
)
+33 -2
View File
@@ -2,9 +2,9 @@ from decimal import Decimal
from http import HTTPStatus
from typing import Annotated
from dishka import FromDishka
from dishka import AsyncContainer, FromDishka
from dishka.integrations.fastapi import DishkaRoute
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request
from fastapi.security import HTTPBearer
from pydantic import BaseModel, Field
@@ -28,6 +28,10 @@ from template_project.application.resume.interactors.get import (
GetResumeInteractor,
GetResumeListInteractor,
)
from template_project.application.resume.interactors.prediction_pipeline import (
PredictResumeRequest,
ResumePredictionInteractor,
)
security = HTTPBearer()
router = APIRouter(route_class=DishkaRoute, tags=["Resume"], dependencies=[Depends(security)])
@@ -296,6 +300,8 @@ class GetResumeHistoryResponse(BaseModel):
)
async def create_resume(
request: CreateResumeRequest,
background_tasks: BackgroundTasks,
fastapi_request: Request,
interactor: FromDishka[AddResumeInteractor],
) -> CreateResumeResponse:
experience = (
@@ -332,6 +338,18 @@ async def create_resume(
education=education,
projects=projects,
)
async def run_prediction(resume_id: ResumeId, container: AsyncContainer) -> None:
async with container() as request_container:
prediction_interactor = await request_container.get(ResumePredictionInteractor)
await prediction_interactor.execute(PredictResumeRequest(resume_id=resume_id))
background_tasks.add_task(
run_prediction,
interactor_response,
fastapi_request.app.state.dishka_container,
)
return CreateResumeResponse(
resume_id=interactor_response,
)
@@ -590,6 +608,8 @@ async def get_resume_history(
async def patch_resume(
resume_id: ResumeId,
request: PatchResumeRequest,
background_tasks: BackgroundTasks,
fastapi_request: Request,
interactor: FromDishka[EditResumeInteractor],
) -> PatchResumeResponse:
try:
@@ -628,6 +648,17 @@ async def patch_resume(
education=education,
projects=projects,
)
async def run_prediction(resume_id: ResumeId, container: AsyncContainer) -> None:
async with container() as request_container:
prediction_interactor = await request_container.get(ResumePredictionInteractor)
await prediction_interactor.execute(PredictResumeRequest(resume_id=resume_id))
background_tasks.add_task(
run_prediction,
interactor_response.id,
fastapi_request.app.state.dishka_container,
)
except ResumeDoesBelongUserError as error:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,