Nuevos cambios en el backend
This commit is contained in:
@@ -0,0 +1,132 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import Settings, get_settings
|
||||
from google.auth.transport import requests as google_requests
|
||||
from google.oauth2 import id_token as google_id_token
|
||||
|
||||
from app.core.errors import AppError, ConflictError, NotFoundError, UnauthorizedError
|
||||
from app.core.security import clean_text
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserLogin, UserRead, UserRegister
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: Session, settings: Settings) -> None:
|
||||
self.db = db
|
||||
self.settings = settings
|
||||
|
||||
def register(self, payload: UserRegister) -> UserRead:
|
||||
email = payload.email.lower().strip()
|
||||
existing = self.db.scalar(select(User).where(User.email == email))
|
||||
if existing is not None:
|
||||
raise ConflictError("Email is already registered")
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=pwd_context.hash(payload.password),
|
||||
full_name=clean_text(payload.full_name, max_length=200) if payload.full_name else None,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return UserRead.model_validate(user)
|
||||
|
||||
def authenticate(self, payload: UserLogin) -> User:
|
||||
email = payload.email.lower().strip()
|
||||
user = self.db.scalar(select(User).where(User.email == email))
|
||||
if user is None or user.password_hash is None:
|
||||
raise UnauthorizedError("Invalid email or password")
|
||||
if not pwd_context.verify(payload.password, user.password_hash):
|
||||
raise UnauthorizedError("Invalid email or password")
|
||||
return user
|
||||
|
||||
def login_with_google(self, id_token_value: str) -> User:
|
||||
if not self.settings.google_client_id:
|
||||
raise AppError(
|
||||
message="Google login is not configured",
|
||||
status_code=503,
|
||||
code="google_not_configured",
|
||||
)
|
||||
|
||||
try:
|
||||
idinfo = google_id_token.verify_oauth2_token(
|
||||
id_token_value,
|
||||
google_requests.Request(),
|
||||
self.settings.google_client_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise UnauthorizedError("Invalid Google ID token") from exc
|
||||
|
||||
google_sub = idinfo.get("sub")
|
||||
email = (idinfo.get("email") or "").lower().strip()
|
||||
if not google_sub or not email:
|
||||
raise UnauthorizedError("Google token does not include required user information")
|
||||
if not idinfo.get("email_verified", False):
|
||||
raise UnauthorizedError("Google email is not verified")
|
||||
|
||||
user = self.db.scalar(select(User).where(User.google_sub == google_sub))
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = self.db.scalar(select(User).where(User.email == email))
|
||||
if user is not None:
|
||||
if user.google_sub is not None and user.google_sub != google_sub:
|
||||
raise ConflictError("Email is linked to another Google account")
|
||||
user.google_sub = google_sub
|
||||
if not user.full_name and idinfo.get("name"):
|
||||
user.full_name = clean_text(idinfo["name"], max_length=200)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=None,
|
||||
google_sub=google_sub,
|
||||
full_name=clean_text(idinfo["name"], max_length=200) if idinfo.get("name") else None,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def get_user_by_id(self, user_id: uuid.UUID) -> User:
|
||||
user = self.db.get(User, user_id)
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
return user
|
||||
|
||||
def create_access_token(self, user_id: uuid.UUID) -> str:
|
||||
expire = datetime.now(UTC) + timedelta(minutes=self.settings.jwt_expire_minutes)
|
||||
payload = {"sub": str(user_id), "exp": expire}
|
||||
return jwt.encode(payload, self.settings.jwt_secret_key, algorithm=self.settings.jwt_algorithm)
|
||||
|
||||
def decode_user_id(self, token: str) -> uuid.UUID:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.settings.jwt_secret_key,
|
||||
algorithms=[self.settings.jwt_algorithm],
|
||||
)
|
||||
user_id = uuid.UUID(payload["sub"])
|
||||
except (JWTError, KeyError, ValueError) as exc:
|
||||
raise UnauthorizedError("Invalid or expired token") from exc
|
||||
return user_id
|
||||
|
||||
|
||||
def get_auth_service(
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
settings: Annotated[Settings, Depends(get_settings)],
|
||||
) -> AuthService:
|
||||
return AuthService(db, settings)
|
||||
@@ -3,10 +3,11 @@ import uuid
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.errors import NotFoundError
|
||||
from app.core.errors import ForbiddenError, NotFoundError
|
||||
from app.core.security import clean_text
|
||||
from app.models.exam import ExamTemplate, ExportFormat, ExportJob, ExportStatus, Question
|
||||
from app.schemas.exam import (
|
||||
ExamHistoryItem,
|
||||
ExamTemplateCreate,
|
||||
ExamTemplateRead,
|
||||
ExportResponse,
|
||||
@@ -35,8 +36,9 @@ class ExamService:
|
||||
self.parser = parser or AIQuestionParser()
|
||||
self.exporter = exporter or MoodleXMLExporter()
|
||||
|
||||
def create_template(self, payload: ExamTemplateCreate) -> ExamTemplateRead:
|
||||
def create_template(self, user_id: uuid.UUID, payload: ExamTemplateCreate) -> ExamTemplateRead:
|
||||
template = ExamTemplate(
|
||||
user_id=user_id,
|
||||
title=clean_text(payload.title, max_length=200),
|
||||
subject=clean_text(payload.subject, max_length=200),
|
||||
educational_level=clean_text(payload.educational_level, max_length=120),
|
||||
@@ -49,37 +51,67 @@ class ExamService:
|
||||
self.db.refresh(template)
|
||||
return self._template_read(template)
|
||||
|
||||
def list_templates(self) -> list[ExamTemplateRead]:
|
||||
templates = self.db.scalars(select(ExamTemplate).order_by(ExamTemplate.created_at.desc())).all()
|
||||
def list_templates(self, user_id: uuid.UUID) -> list[ExamTemplateRead]:
|
||||
templates = self.db.scalars(
|
||||
select(ExamTemplate)
|
||||
.where(ExamTemplate.user_id == user_id)
|
||||
.order_by(ExamTemplate.created_at.desc())
|
||||
).all()
|
||||
return [self._template_read(template) for template in templates]
|
||||
|
||||
def get_template(self, template_id: uuid.UUID) -> ExamTemplateRead:
|
||||
return self._template_read(self._get_template_or_404(template_id))
|
||||
def list_history(self, user_id: uuid.UUID) -> list[ExamHistoryItem]:
|
||||
templates = self.db.scalars(
|
||||
select(ExamTemplate)
|
||||
.where(ExamTemplate.user_id == user_id)
|
||||
.order_by(ExamTemplate.updated_at.desc())
|
||||
).all()
|
||||
history: list[ExamHistoryItem] = []
|
||||
for template in templates:
|
||||
export_jobs = sorted(template.export_jobs, key=lambda job: job.created_at, reverse=True)
|
||||
history.append(
|
||||
ExamHistoryItem(
|
||||
id=template.id,
|
||||
title=template.title,
|
||||
subject=template.subject,
|
||||
educational_level=template.educational_level,
|
||||
language=template.language,
|
||||
question_count=len(template.questions),
|
||||
export_count=len(export_jobs),
|
||||
last_export_at=export_jobs[0].created_at if export_jobs else None,
|
||||
created_at=template.created_at,
|
||||
updated_at=template.updated_at,
|
||||
)
|
||||
)
|
||||
return history
|
||||
|
||||
def build_prompt(self, template_id: uuid.UUID, topic_prompt: str) -> PromptResponse:
|
||||
template = self._get_template_or_404(template_id)
|
||||
def get_template(self, user_id: uuid.UUID, template_id: uuid.UUID) -> ExamTemplateRead:
|
||||
return self._template_read(self._get_user_template_or_404(user_id, template_id))
|
||||
|
||||
def build_prompt(self, user_id: uuid.UUID, template_id: uuid.UUID, topic_prompt: str) -> PromptResponse:
|
||||
template = self._get_user_template_or_404(user_id, template_id)
|
||||
prompt = self.prompt_builder.build_prompt(template, topic_prompt)
|
||||
return PromptResponse(template_id=template.id, prompt=prompt)
|
||||
|
||||
async def generate_with_llm(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
template_id: uuid.UUID,
|
||||
topic_prompt: str,
|
||||
llm_client: LLMClient,
|
||||
) -> ParsedQuestionsResponse:
|
||||
template = self._get_template_or_404(template_id)
|
||||
template = self._get_user_template_or_404(user_id, template_id)
|
||||
prompt = self.prompt_builder.build_prompt(template, topic_prompt)
|
||||
raw_output = await llm_client.generate(prompt)
|
||||
questions = self.parser.parse_json(raw_output)
|
||||
return self._persist_questions(template.id, questions)
|
||||
|
||||
def parse_and_persist(self, payload: ParseRequest) -> ParsedQuestionsResponse:
|
||||
self._get_template_or_404(payload.template_id)
|
||||
def parse_and_persist(self, user_id: uuid.UUID, payload: ParseRequest) -> ParsedQuestionsResponse:
|
||||
self._get_user_template_or_404(user_id, payload.template_id)
|
||||
questions = self.parser.parse(payload.raw_output, payload.input_format)
|
||||
return self._persist_questions(payload.template_id, questions)
|
||||
|
||||
def export(self, template_id: uuid.UUID, export_format: ExportFormat) -> ExportResponse:
|
||||
template = self._get_template_or_404(template_id)
|
||||
def export(self, user_id: uuid.UUID, template_id: uuid.UUID, export_format: ExportFormat) -> ExportResponse:
|
||||
template = self._get_user_template_or_404(user_id, template_id)
|
||||
questions = list(template.questions)
|
||||
if not questions:
|
||||
raise NotFoundError("Template does not contain questions to export")
|
||||
@@ -126,10 +158,12 @@ class ExamService:
|
||||
|
||||
return ParsedQuestionsResponse(questions=[QuestionRead.model_validate(question) for question in persisted])
|
||||
|
||||
def _get_template_or_404(self, template_id: uuid.UUID) -> ExamTemplate:
|
||||
def _get_user_template_or_404(self, user_id: uuid.UUID, template_id: uuid.UUID) -> ExamTemplate:
|
||||
template = self.db.get(ExamTemplate, template_id)
|
||||
if template is None:
|
||||
raise NotFoundError("Exam template not found")
|
||||
if template.user_id != user_id:
|
||||
raise ForbiddenError("You do not have access to this exam template")
|
||||
return template
|
||||
|
||||
def _template_read(self, template: ExamTemplate) -> ExamTemplateRead:
|
||||
|
||||
Reference in New Issue
Block a user