Merge pull request #15 from uja-dev-practices/feature/backend-v5

Feature/backend v5
This commit is contained in:
Mireya Cueto Garrido
2026-05-08 12:14:39 +02:00
committed by GitHub
37 changed files with 1376 additions and 280 deletions
+7
View File
@@ -22,6 +22,13 @@ ENV/
# FastAPI / Uvicorn # FastAPI / Uvicorn
*.pid *.pid
# Test / type checkers
.pytest_cache/
.mypy_cache/
.ruff_cache/
.coverage
htmlcov/
# --- NODE FRONTEND --- # --- NODE FRONTEND ---
node_modules/ node_modules/
dist/ dist/
+20
View File
@@ -0,0 +1,20 @@
.env
.env.*
!.env.example
__pycache__/
*.pyc
*.pyo
*.pyd
*.log
*.sqlite3
*.db
.pytest_cache/
.mypy_cache/
.ruff_cache/
.venv/
venv/
.git/
.gitignore
README.md
docs/
tests/
+69 -7
View File
@@ -1,19 +1,81 @@
ORCID_CLIENT_ID=123412341234 # ============================================================
ORCID_CLIENT_SECRET=123412341234 # ENVIRONMENT
# ============================================================
API_KEY_NAME=X-API-Key ENVIRONMENT=development
API_KEY_VALUE=123412341234 DEBUG=false
# ============================================================
# DATABASE / CACHE
# ============================================================
DATABASE_URL=postgresql://postgres:postgres@db:5432/orcid_db DATABASE_URL=postgresql://postgres:postgres@db:5432/orcid_db
REDIS_URL=redis://redis:6379/0 REDIS_URL=redis://redis:6379/0
# ============================================================
# BASE URL (uso interno del scheduler)
# ============================================================
BASE_URL=http://localhost:8000/api BASE_URL=http://localhost:8000/api
# ============================================================
# CORS — lista blanca estricta separada por comas
# Nunca uses "*" si allow_credentials=true.
# ============================================================
CORS_ALLOWED_ORIGINS=http://localhost:5173
# ============================================================
# Trusted Hosts — anti Host-header injection (en prod, sé explícito)
# ============================================================
TRUSTED_HOSTS=*
# ============================================================
# JWT (login ORCID) # JWT (login ORCID)
JWT_SECRET=change_me # Genera un secreto fuerte: `openssl rand -base64 64`
# ============================================================
JWT_SECRET=change_me_to_a_long_random_value_at_least_32_chars
JWT_ALGORITHM=HS256 JWT_ALGORITHM=HS256
JWT_EXPIRES_MINUTES=720 JWT_EXPIRES_MINUTES=720
JWT_ISSUER=orcid-sword-backend
JWT_AUDIENCE=orcid-sword-frontend
# ============================================================
# API key máquina-a-máquina (scheduler interno)
# Genera con: `python -c "import secrets;print(secrets.token_urlsafe(48))"`
# ============================================================
API_KEY_NAME=X-API-Key
API_KEY_VALUE=replace_with_a_strong_random_value_min_24_chars
# ============================================================
# ORCID OAuth 3-legged (authorization code) # ORCID OAuth 3-legged (authorization code)
# Debe coincidir exactamente con el redirect URI configurado en tu app ORCID. # ============================================================
ORCID_CLIENT_ID=APP-XXXXXXXXXXXXXXXX
ORCID_CLIENT_SECRET=replace_me
ORCID_REDIRECT_URI=http://localhost:8000/api/auth/orcid/callback ORCID_REDIRECT_URI=http://localhost:8000/api/auth/orcid/callback
ORCID_OAUTH_STATE_ENABLED=true
# ============================================================
# Rate limits (formato slowapi: "<n>/<window>")
# ============================================================
RATE_LIMIT_DEFAULT=60/minute
RATE_LIMIT_AUTH=10/minute
RATE_LIMIT_SEARCH_ANON=5/minute
RATE_LIMIT_SEARCH_AUTH=30/minute
RATE_LIMIT_EXPORT=20/minute
RATE_LIMIT_SYNC=5/minute
# ============================================================
# Tope de tamaños (anti DoS)
# ============================================================
MAX_ORCID_BATCH=25
MAX_PUB_IDS_BATCH=500
MAX_REQUEST_BODY_BYTES=1048576
# ============================================================
# Documentación interactiva (deshabilita en producción si no es necesaria)
# ============================================================
DOCS_ENABLED=true
# ============================================================
# HSTS
# ============================================================
SECURITY_HSTS_SECONDS=31536000
SECURITY_HSTS_INCLUDE_SUBDOMAINS=true
SECURITY_HSTS_PRELOAD=false
+29 -3
View File
@@ -1,10 +1,36 @@
FROM python:3.12-slim FROM python:3.12-slim AS base
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
RUN apt-get update \
&& apt-get install -y --no-install-recommends curl \
&& rm -rf /var/lib/apt/lists/*
RUN groupadd --system --gid 1001 app \
&& useradd --system --uid 1001 --gid app --home /app --shell /usr/sbin/nologin app
WORKDIR /app WORKDIR /app
COPY requirements.txt . COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
COPY app ./app COPY app ./app
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] RUN chown -R app:app /app
USER app
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
CMD curl -fsS http://127.0.0.1:8000/health || exit 1
CMD ["uvicorn", "app.main:app", \
"--host", "0.0.0.0", \
"--port", "8000", \
"--proxy-headers", \
"--forwarded-allow-ips", "*", \
"--no-server-header"]
+74 -38
View File
@@ -1,64 +1,68 @@
import logging
import httpx import httpx
import os from fastapi import APIRouter, Depends, HTTPException, Request, status
from pathlib import Path from fastapi.responses import JSONResponse, RedirectResponse
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.rate_limit import limiter
from app.db.models import Researcher from app.db.models import Researcher
from app.db.session import get_db from app.db.session import get_db
from app.schema.auth import OrcidLoginResponseSchema from app.schema.auth import OrcidLoginResponseSchema
from app.security.jwt import create_access_token from app.security.jwt import create_access_token
from app.security.oauth_state import (
attach_state_cookie,
clear_state_cookie,
generate_state,
validate_state,
)
from app.services.orcid_client import ORCIDClient from app.services.orcid_client import ORCIDClient
from app.utils.orcid_validator import is_valid_orcid from app.utils.orcid_validator import is_valid_orcid
# Asegura que al ejecutar `uvicorn` local también se carga `backend/.env`.
_ENV_PATH = Path(__file__).resolve().parents[2] / ".env"
load_dotenv(dotenv_path=_ENV_PATH, override=False)
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
logger = logging.getLogger("app.auth")
def _extract_display_name(record: dict) -> str | None: def _extract_display_name(record: dict) -> str | None:
person = (record or {}).get("person") or {} person = (record or {}).get("person") or {}
name = person.get("name") or {} name = person.get("name") or {}
given = ((name.get("given-names") or {}).get("value")) if isinstance(name.get("given-names"), dict) else None given_obj = name.get("given-names")
family = ((name.get("family-name") or {}).get("value")) if isinstance(name.get("family-name"), dict) else None family_obj = name.get("family-name")
full = " ".join([p for p in [given, family] if p]) given = given_obj.get("value") if isinstance(given_obj, dict) else None
family = family_obj.get("value") if isinstance(family_obj, dict) else None
full = " ".join(p for p in [given, family] if p)
return full or None return full or None
def _orcid_redirect_uri() -> str: def _orcid_redirect_uri() -> str:
# Debe coincidir con el `redirect_uri` registrado en tu integración ORCID. return settings.ORCID_REDIRECT_URI
return os.getenv("ORCID_REDIRECT_URI") or "http://localhost:8000/api/auth/orcid/callback"
def _complete_oauth_login(*, code: str, db: Session) -> OrcidLoginResponseSchema: def _complete_oauth_login(*, code: str, db: Session) -> OrcidLoginResponseSchema:
""" """
Completa el login OAuth: 1) Intercambia el `code` con ORCID (server-side).
1) intercambio del `code` en ORCID (server-side) 2) Crea/actualiza el investigador.
2) crea/actualiza el investigador 3) Emite el JWT propio.
3) emite nuestro JWT
""" """
if not code: if not code or len(code) > 256:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing ORCID authorization code") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid ORCID authorization code")
client = ORCIDClient() client = ORCIDClient()
redirect_uri = _orcid_redirect_uri()
try: try:
token_data = client.exchange_authorization_code(code=code, redirect_uri=redirect_uri) token_data = client.exchange_authorization_code(code=code, redirect_uri=_orcid_redirect_uri())
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
logger.warning("ORCID token exchange failed: %s", exc.response.status_code)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"ORCID token error ({exc.response.status_code})", detail="ORCID token exchange failed",
) ) from exc
except httpx.TimeoutException: except httpx.TimeoutException as exc:
raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="ORCID timeout") raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="ORCID timeout") from exc
except Exception: except Exception as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="ORCID unavailable") logger.exception("Unexpected error during ORCID token exchange")
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="ORCID unavailable") from exc
orcid_id = (token_data.get("orcid") or "").strip() orcid_id = (token_data.get("orcid") or "").strip()
if not is_valid_orcid(orcid_id): if not is_valid_orcid(orcid_id):
@@ -66,7 +70,6 @@ def _complete_oauth_login(*, code: str, db: Session) -> OrcidLoginResponseSchema
display_name = token_data.get("name") display_name = token_data.get("name")
if not display_name: if not display_name:
# Fallback si ORCID no devuelve `name` en el token response.
try: try:
record = client.fetch_record(orcid_id) record = client.fetch_record(orcid_id)
display_name = _extract_display_name(record) display_name = _extract_display_name(record)
@@ -89,21 +92,54 @@ def _complete_oauth_login(*, code: str, db: Session) -> OrcidLoginResponseSchema
return OrcidLoginResponseSchema(access_token=token) return OrcidLoginResponseSchema(access_token=token)
@router.get("/orcid/authorize") def complete_oauth_login_response(
def authorize_orcid(): *, request: Request, code: str, state: str | None, db: Session
) -> JSONResponse:
""" """
Inicia el flujo OAuth 3-legged (authorization code) hacia ORCID. Valida `state`, completa el login y limpia la cookie del state.
Devuelve directamente la JSONResponse (para poder borrar cookie).
"""
validate_state(request, state)
payload = _complete_oauth_login(code=code, db=db)
json_resp = JSONResponse(content=payload.model_dump())
clear_state_cookie(json_resp)
return json_resp
# ---------------------------------------------------------
# ENDPOINT 1: Iniciar flujo OAuth 3-legged hacia ORCID
# ---------------------------------------------------------
@router.get("/orcid/authorize")
@limiter.limit(settings.RATE_LIMIT_AUTH)
def authorize_orcid(request: Request):
"""
Genera la URL de autorización ORCID y persiste el `state` en cookie
HttpOnly para validarlo en el callback (anti-CSRF).
""" """
client = ORCIDClient() client = ORCIDClient()
state = generate_state() if settings.ORCID_OAUTH_STATE_ENABLED else None
authorize_url = client.build_authorize_url( authorize_url = client.build_authorize_url(
redirect_uri=_orcid_redirect_uri(), redirect_uri=_orcid_redirect_uri(),
# Solo necesitamos el Authenticated iD del usuario.
scope="/authenticate", scope="/authenticate",
state=state,
) )
return RedirectResponse(authorize_url) response = RedirectResponse(authorize_url)
if state:
attach_state_cookie(response, state)
return response
# ---------------------------------------------------------
# ENDPOINT 2: Callback OAuth 3-legged desde ORCID
# ---------------------------------------------------------
@router.get("/orcid/callback", response_model=OrcidLoginResponseSchema) @router.get("/orcid/callback", response_model=OrcidLoginResponseSchema)
def orcid_callback(code: str, db: Session = Depends(get_db)): @limiter.limit(settings.RATE_LIMIT_AUTH)
return _complete_oauth_login(code=code, db=db) def orcid_callback(
request: Request,
code: str,
state: str | None = None,
db: Session = Depends(get_db),
):
return complete_oauth_login_response(request=request, code=code, state=state, db=db)
+120 -79
View File
@@ -1,167 +1,208 @@
from fastapi import APIRouter, Depends, HTTPException from typing import Iterable, List
from fastapi.responses import Response
from sqlalchemy.orm import Session
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request
from fastapi.responses import Response
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.rate_limit import limiter
from app.db.models import Publication, PublicationDownload, Researcher
from app.db.session import get_db from app.db.session import get_db
from app.db.models import Publication, Researcher, PublicationDownload
from app.security.api_key import get_api_key_optional from app.security.api_key import get_api_key_optional
from app.security.jwt import get_optional_current_researcher from app.security.jwt import get_optional_current_researcher
from app.services.sword_generator import SWORDGenerator from app.services.sword_generator import SWORDGenerator
from app.services.zip_generator import ZIPGenerator from app.services.zip_generator import ZIPGenerator
from app.utils.orcid_validator import ORCID_PATTERN, is_valid_orcid
router = APIRouter(prefix="/export") router = APIRouter(prefix="/export")
def validate_uuid_list(pub_ids: list[str]) -> list[UUID]: def _ensure_credentials(api_key: str | None, current: Researcher | None) -> None:
valid_ids = [] if not api_key and not current:
for pid in pub_ids: raise HTTPException(status_code=401, detail="Authentication required")
try:
valid_ids.append(UUID(pid))
except Exception: def _record_downloads(db: Session, current: Researcher, pubs: Iterable[Publication]) -> None:
"""
Inserta marcadores de descarga (researcher_id, publication_id).
- Resuelve descargas existentes con UNA sola query.
- Solo añade las que faltan.
"""
pub_ids = [p.id for p in pubs]
if not pub_ids:
return
existing_ids = {
row[0]
for row in (
db.query(PublicationDownload.publication_id)
.filter(
PublicationDownload.researcher_id == current.id,
PublicationDownload.publication_id.in_(pub_ids),
)
.all()
)
}
new_rows = [
PublicationDownload(researcher_id=current.id, publication_id=pid)
for pid in pub_ids
if pid not in existing_ids
]
if new_rows:
db.add_all(new_rows)
db.commit()
def _validate_pub_ids(pub_ids: List[UUID]) -> List[UUID]:
if len(pub_ids) > settings.MAX_PUB_IDS_BATCH:
raise HTTPException(status_code=413, detail="Too many publication IDs")
return pub_ids
def _raise_clear_error_if_researcher_id_was_used(db: Session, pub_ids: List[UUID]) -> None:
"""
Si el cliente envía por error el UUID de un investigador al endpoint
de publicaciones, devolvemos un mensaje explícito para guiar el uso.
"""
if len(pub_ids) != 1:
return
researcher = db.query(Researcher).filter(Researcher.id == pub_ids[0]).first()
if researcher:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid publication ID (not UUID): {pid}" detail=(
"The provided UUID belongs to a researcher, not a publication. "
"Use publication IDs for this endpoint, or call "
f"/api/export/sword/researcher/{researcher.orcid_id} "
f"(or /api/export/zip/researcher/{researcher.orcid_id})."
),
) )
return valid_ids
# ---------------------------------------------------------
# ENDPOINT 1: SWORD múltiples publicaciones
# ---------------------------------------------------------
@router.post("/sword/publications") @router.post("/sword/publications")
@limiter.limit(settings.RATE_LIMIT_EXPORT)
async def export_multiple_sword( async def export_multiple_sword(
pub_ids: list[str], request: Request,
pub_ids: List[UUID] = Body(..., min_length=1, max_length=settings.MAX_PUB_IDS_BATCH),
db: Session = Depends(get_db), db: Session = Depends(get_db),
api_key: str | None = Depends(get_api_key_optional), api_key: str | None = Depends(get_api_key_optional),
current: Researcher | None = Depends(get_optional_current_researcher), current: Researcher | None = Depends(get_optional_current_researcher),
): ):
if not api_key and not current: _ensure_credentials(api_key, current)
raise HTTPException(status_code=401, detail="Missing credentials") _validate_pub_ids(pub_ids)
validate_uuid_list(pub_ids)
pubs = db.query(Publication).filter(Publication.id.in_(pub_ids)).all() pubs = db.query(Publication).filter(Publication.id.in_(pub_ids)).all()
if not pubs: if not pubs:
_raise_clear_error_if_researcher_id_was_used(db, pub_ids)
raise HTTPException(status_code=404, detail="No publications found") raise HTTPException(status_code=404, detail="No publications found")
researcher = db.query(Researcher).filter_by(id=pubs[0].researcher_id).first() researcher = db.query(Researcher).filter_by(id=pubs[0].researcher_id).first()
xml_bytes = SWORDGenerator.generate_feed_xml(researcher, pubs) xml_bytes = SWORDGenerator.generate_feed_xml(researcher, pubs)
# Registrar descarga solo si hay usuario logueado
if current: if current:
for p in pubs: _record_downloads(db, current, pubs)
exists = (
db.query(PublicationDownload)
.filter(
PublicationDownload.researcher_id == current.id,
PublicationDownload.publication_id == p.id,
)
.first()
)
if not exists:
db.add(PublicationDownload(researcher_id=current.id, publication_id=p.id))
db.commit()
return Response(content=xml_bytes, media_type="application/xml") return Response(content=xml_bytes, media_type="application/xml")
# ---------------------------------------------------------
# ENDPOINT 2: SWORD por investigador
# ---------------------------------------------------------
@router.get("/sword/researcher/{orcid_id}") @router.get("/sword/researcher/{orcid_id}")
@limiter.limit(settings.RATE_LIMIT_EXPORT)
async def export_researcher_sword( async def export_researcher_sword(
orcid_id: str, request: Request,
orcid_id: str = Path(min_length=19, max_length=19, pattern=ORCID_PATTERN),
db: Session = Depends(get_db), db: Session = Depends(get_db),
api_key: str | None = Depends(get_api_key_optional), api_key: str | None = Depends(get_api_key_optional),
current: Researcher | None = Depends(get_optional_current_researcher), current: Researcher | None = Depends(get_optional_current_researcher),
): ):
if not api_key and not current: _ensure_credentials(api_key, current)
raise HTTPException(status_code=401, detail="Missing credentials") if not is_valid_orcid(orcid_id):
raise HTTPException(status_code=400, detail="Invalid ORCID iD")
researcher = db.query(Researcher).filter_by(orcid_id=orcid_id).first() researcher = db.query(Researcher).filter_by(orcid_id=orcid_id).first()
if not researcher: if not researcher:
raise HTTPException(status_code=404, detail="Researcher not found") raise HTTPException(status_code=404, detail="Researcher not found")
pubs = db.query(Publication).filter_by(researcher_id=researcher.id).all() pubs = db.query(Publication).filter_by(researcher_id=researcher.id).all()
if not pubs: if not pubs:
raise HTTPException(status_code=404, detail="No publications found for this researcher") raise HTTPException(status_code=404, detail="No publications found for this researcher")
xml_bytes = SWORDGenerator.generate_feed_xml(researcher, pubs) xml_bytes = SWORDGenerator.generate_feed_xml(researcher, pubs)
if current: if current:
for p in pubs: _record_downloads(db, current, pubs)
exists = (
db.query(PublicationDownload)
.filter(
PublicationDownload.researcher_id == current.id,
PublicationDownload.publication_id == p.id,
)
.first()
)
if not exists:
db.add(PublicationDownload(researcher_id=current.id, publication_id=p.id))
db.commit()
return Response(content=xml_bytes, media_type="application/xml") return Response(content=xml_bytes, media_type="application/xml")
# ---------------------------------------------------------
# ENDPOINT 3: ZIP múltiples publicaciones
# ---------------------------------------------------------
@router.post("/zip/publications") @router.post("/zip/publications")
@limiter.limit(settings.RATE_LIMIT_EXPORT)
async def export_multiple_zip( async def export_multiple_zip(
pub_ids: list[str], request: Request,
pub_ids: List[UUID] = Body(..., min_length=1, max_length=settings.MAX_PUB_IDS_BATCH),
db: Session = Depends(get_db), db: Session = Depends(get_db),
api_key: str | None = Depends(get_api_key_optional), api_key: str | None = Depends(get_api_key_optional),
current: Researcher | None = Depends(get_optional_current_researcher), current: Researcher | None = Depends(get_optional_current_researcher),
): ):
if not api_key and not current: _ensure_credentials(api_key, current)
raise HTTPException(status_code=401, detail="Missing credentials") _validate_pub_ids(pub_ids)
validate_uuid_list(pub_ids)
pubs = db.query(Publication).filter(Publication.id.in_(pub_ids)).all() pubs = db.query(Publication).filter(Publication.id.in_(pub_ids)).all()
if not pubs: if not pubs:
_raise_clear_error_if_researcher_id_was_used(db, pub_ids)
raise HTTPException(status_code=404, detail="No publications found") raise HTTPException(status_code=404, detail="No publications found")
researcher = db.query(Researcher).filter_by(id=pubs[0].researcher_id).first() researcher = db.query(Researcher).filter_by(id=pubs[0].researcher_id).first()
zip_bytes = ZIPGenerator.generate_zip(researcher, pubs) zip_bytes = ZIPGenerator.generate_zip(researcher, pubs)
if current: if current:
for p in pubs: _record_downloads(db, current, pubs)
exists = (
db.query(PublicationDownload)
.filter(
PublicationDownload.researcher_id == current.id,
PublicationDownload.publication_id == p.id,
)
.first()
)
if not exists:
db.add(PublicationDownload(researcher_id=current.id, publication_id=p.id))
db.commit()
return Response(content=zip_bytes, media_type="application/zip") return Response(content=zip_bytes, media_type="application/zip")
# ---------------------------------------------------------
# ENDPOINT 4: ZIP por investigador
# ---------------------------------------------------------
@router.get("/zip/researcher/{orcid_id}") @router.get("/zip/researcher/{orcid_id}")
@limiter.limit(settings.RATE_LIMIT_EXPORT)
async def export_researcher_zip( async def export_researcher_zip(
orcid_id: str, request: Request,
orcid_id: str = Path(min_length=19, max_length=19, pattern=ORCID_PATTERN),
db: Session = Depends(get_db), db: Session = Depends(get_db),
api_key: str | None = Depends(get_api_key_optional), api_key: str | None = Depends(get_api_key_optional),
current: Researcher | None = Depends(get_optional_current_researcher), current: Researcher | None = Depends(get_optional_current_researcher),
): ):
if not api_key and not current: _ensure_credentials(api_key, current)
raise HTTPException(status_code=401, detail="Missing credentials") if not is_valid_orcid(orcid_id):
raise HTTPException(status_code=400, detail="Invalid ORCID iD")
researcher = db.query(Researcher).filter_by(orcid_id=orcid_id).first() researcher = db.query(Researcher).filter_by(orcid_id=orcid_id).first()
if not researcher: if not researcher:
raise HTTPException(status_code=404, detail="Researcher not found") raise HTTPException(status_code=404, detail="Researcher not found")
pubs = db.query(Publication).filter_by(researcher_id=researcher.id).all() pubs = db.query(Publication).filter_by(researcher_id=researcher.id).all()
if not pubs: if not pubs:
raise HTTPException(status_code=404, detail="No publications found for this researcher") raise HTTPException(status_code=404, detail="No publications found for this researcher")
zip_bytes = ZIPGenerator.generate_zip(researcher, pubs) zip_bytes = ZIPGenerator.generate_zip(researcher, pubs)
if current: if current:
for p in pubs: _record_downloads(db, current, pubs)
exists = (
db.query(PublicationDownload)
.filter(
PublicationDownload.researcher_id == current.id,
PublicationDownload.publication_id == p.id,
)
.first()
)
if not exists:
db.add(PublicationDownload(researcher_id=current.id, publication_id=p.id))
db.commit()
return Response(content=zip_bytes, media_type="application/zip") return Response(content=zip_bytes, media_type="application/zip")
+52 -36
View File
@@ -2,11 +2,14 @@ from datetime import datetime
from typing import List from typing import List
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Path, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Publication, Researcher from app.core.config import settings
from app.core.rate_limit import limiter
from app.db.models import Publication, PublicationDownload, Researcher
from app.db.session import get_db from app.db.session import get_db
from app.schema.publication import PublicationSchema
from app.schema.researcher import ( from app.schema.researcher import (
ResearcherBatchSearchRequestSchema, ResearcherBatchSearchRequestSchema,
ResearcherBatchSearchResponseSchema, ResearcherBatchSearchResponseSchema,
@@ -14,18 +17,15 @@ from app.schema.researcher import (
ResearcherStatsSchema, ResearcherStatsSchema,
ResearcherWithPublicationsSchema, ResearcherWithPublicationsSchema,
) )
from app.services.normalizer import PublicationNormalizer
from app.services.orcid_client import get_display_name, get_works_summary, get_work_detail
from app.schema.publication import PublicationSchema
from app.db.models import PublicationDownload
from app.security.jwt import get_optional_current_researcher from app.security.jwt import get_optional_current_researcher
from app.services.normalizer import PublicationNormalizer
from app.services.orcid_client import get_display_name, get_work_detail, get_works_summary
from app.utils.orcid_validator import ORCID_PATTERN, is_valid_orcid
router = APIRouter(prefix="/researchers", tags=["researchers"]) router = APIRouter(prefix="/researchers", tags=["researchers"])
# ---------------------------------------------------------
# Función auxiliar: detectar si una publicación ha cambiado
# ---------------------------------------------------------
def publication_changed(existing: Publication, data: dict) -> bool: def publication_changed(existing: Publication, data: dict) -> bool:
fields = [ fields = [
"title", "subtitle", "type", "journal", "title", "subtitle", "type", "journal",
@@ -33,18 +33,13 @@ def publication_changed(existing: Publication, data: dict) -> bool:
"doi", "url", "short_description", "doi", "url", "short_description",
"citation_type", "citation_value", "citation_type", "citation_value",
"language_code", "country", "language_code", "country",
"external_ids", "contributors" "external_ids", "contributors",
] ]
return any(getattr(existing, f) != data[f] for f in fields)
for f in fields:
if getattr(existing, f) != data[f]:
return True
return False
def build_researcher_stats(publications: list) -> ResearcherStatsSchema: def build_researcher_stats(publications: list) -> ResearcherStatsSchema:
publication_types: dict[str, int] = {} publication_types: dict[str, int] = {}
for publication in publications: for publication in publications:
pub_type = getattr(publication, "type", None) or "unknown" pub_type = getattr(publication, "type", None) or "unknown"
publication_types[pub_type] = publication_types.get(pub_type, 0) + 1 publication_types[pub_type] = publication_types.get(pub_type, 0) + 1
@@ -98,7 +93,7 @@ def _upsert_researcher_publications(
"doi", "url", "short_description", "doi", "url", "short_description",
"citation_type", "citation_value", "citation_type", "citation_value",
"language_code", "country", "language_code", "country",
"external_ids", "contributors" "external_ids", "contributors",
]: ]:
setattr(existing, field, data[field]) setattr(existing, field, data[field])
existing.last_modified = datetime.utcnow() existing.last_modified = datetime.utcnow()
@@ -142,12 +137,17 @@ def _decorate_downloaded_by_me(
out: List[PublicationSchema] = [] out: List[PublicationSchema] = []
for p in publications: for p in publications:
out.append( out.append(
PublicationSchema.model_validate(p).model_copy(update={"downloaded_by_me": p.id in downloaded_ids}) PublicationSchema.model_validate(p).model_copy(
update={"downloaded_by_me": p.id in downloaded_ids}
)
) )
return out return out
def build_search_response(orcid_id: str, db: Session, current: Researcher | None) -> ResearcherWithPublicationsSchema: def build_search_response(orcid_id: str, db: Session, current: Researcher | None) -> ResearcherWithPublicationsSchema:
if not is_valid_orcid(orcid_id):
raise HTTPException(status_code=400, detail="Invalid ORCID iD")
researcher = db.query(Researcher).filter(Researcher.orcid_id == orcid_id).first() researcher = db.query(Researcher).filter(Researcher.orcid_id == orcid_id).first()
if not researcher: if not researcher:
researcher = Researcher( researcher = Researcher(
@@ -159,10 +159,6 @@ def build_search_response(orcid_id: str, db: Session, current: Researcher | None
db.add(researcher) db.add(researcher)
db.flush() db.flush()
# Si todavía no conocemos el nombre del investigador (por ejemplo, recién
# creado al sincronizarse desde el buscador), lo resolvemos contra el
# endpoint `/record` público de ORCID. No tocamos un nombre ya existente
# para no pisar valores establecidos por el flujo de autenticación.
if not researcher.name: if not researcher.name:
display_name = get_display_name(orcid_id) display_name = get_display_name(orcid_id)
if display_name: if display_name:
@@ -185,10 +181,18 @@ def build_search_response(orcid_id: str, db: Session, current: Researcher | None
# --------------------------------------------------------- # ---------------------------------------------------------
# ENDPOINT 1: SEARCH + SYNC (sin contadores) # ENDPOINT 1: SEARCH + SYNC
# --------------------------------------------------------- # ---------------------------------------------------------
@router.post("/search", response_model=ResearcherBatchSearchResponseSchema, response_model_exclude_none=True)
@router.post(
"/search",
response_model=ResearcherBatchSearchResponseSchema,
response_model_exclude_none=True,
)
@limiter.limit(settings.RATE_LIMIT_SEARCH_ANON)
def search_and_sync_researchers( def search_and_sync_researchers(
request: Request,
payload: ResearcherBatchSearchRequestSchema, payload: ResearcherBatchSearchRequestSchema,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current: Researcher | None = Depends(get_optional_current_researcher), current: Researcher | None = Depends(get_optional_current_researcher),
@@ -196,26 +200,33 @@ def search_and_sync_researchers(
results: List[ResearcherWithPublicationsSchema] = [] results: List[ResearcherWithPublicationsSchema] = []
errors: List[ResearcherSearchErrorSchema] = [] errors: List[ResearcherSearchErrorSchema] = []
# Evita llamadas duplicadas a ORCID conservando el orden de entrada.
unique_orcid_ids = list(dict.fromkeys(payload.orcid_ids)) unique_orcid_ids = list(dict.fromkeys(payload.orcid_ids))
for orcid_id in unique_orcid_ids: for orcid_id in unique_orcid_ids:
try: try:
results.append(build_search_response(orcid_id, db, current)) results.append(build_search_response(orcid_id, db, current))
except HTTPException as exc:
db.rollback()
errors.append(
ResearcherSearchErrorSchema(
orcid_id=orcid_id,
detail=str(exc.detail),
)
)
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
db.rollback() db.rollback()
errors.append( errors.append(
ResearcherSearchErrorSchema( ResearcherSearchErrorSchema(
orcid_id=orcid_id, orcid_id=orcid_id,
detail=f"ORCID devolvió {exc.response.status_code} para {orcid_id}.", detail=f"ORCID returned {exc.response.status_code}",
) )
) )
except Exception as exc: except Exception:
db.rollback() db.rollback()
errors.append( errors.append(
ResearcherSearchErrorSchema( ResearcherSearchErrorSchema(
orcid_id=orcid_id, orcid_id=orcid_id,
detail=str(exc), detail="Unexpected error while processing ORCID iD",
) )
) )
@@ -228,14 +239,24 @@ def search_and_sync_researchers(
# --------------------------------------------------------- # ---------------------------------------------------------
# ENDPOINT 2: SYNC COMPLETO (con contadores + status) # ENDPOINT 2: SYNC COMPLETO (requiere autenticación)
# --------------------------------------------------------- # ---------------------------------------------------------
@router.post("/{orcid_id}/sync", response_model=ResearcherWithPublicationsSchema, response_model_exclude_none=True)
@router.post(
"/{orcid_id}/sync",
response_model=ResearcherWithPublicationsSchema,
response_model_exclude_none=True,
)
@limiter.limit(settings.RATE_LIMIT_SYNC)
def sync_researcher( def sync_researcher(
orcid_id: str, request: Request,
orcid_id: str = Path(min_length=19, max_length=19, pattern=ORCID_PATTERN),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current: Researcher | None = Depends(get_optional_current_researcher), current: Researcher | None = Depends(get_optional_current_researcher),
): ):
if not is_valid_orcid(orcid_id):
raise HTTPException(status_code=400, detail="Invalid ORCID iD")
researcher = db.query(Researcher).filter_by(orcid_id=orcid_id).first() researcher = db.query(Researcher).filter_by(orcid_id=orcid_id).first()
if not researcher: if not researcher:
raise HTTPException(status_code=404, detail="Researcher not found") raise HTTPException(status_code=404, detail="Researcher not found")
@@ -244,7 +265,6 @@ def sync_researcher(
groups = works.get("group", []) groups = works.get("group", [])
publications_output = [] publications_output = []
new_count = 0 new_count = 0
updated_count = 0 updated_count = 0
unchanged_count = 0 unchanged_count = 0
@@ -277,21 +297,17 @@ def sync_researcher(
if existing: if existing:
if publication_changed(existing, data): if publication_changed(existing, data):
# updated
for field in data: for field in data:
setattr(existing, field, data[field]) setattr(existing, field, data[field])
existing.last_modified = datetime.utcnow() existing.last_modified = datetime.utcnow()
existing.status = "updated" existing.status = "updated"
updated_count += 1 updated_count += 1
else: else:
# unchanged
existing.status = "unchanged" existing.status = "unchanged"
unchanged_count += 1 unchanged_count += 1
pub = existing pub = existing
else: else:
# new
pub = Publication( pub = Publication(
researcher_id=researcher.id, researcher_id=researcher.id,
**data, **data,
View File
+35
View File
@@ -0,0 +1,35 @@
"""
Middleware que limita el tamaño máximo del cuerpo de la petición.
Evita ataques de agotamiento de memoria/CPU enviando bodies enormes a
endpoints POST. Se aplica antes de que FastAPI deserialice el JSON.
"""
from __future__ import annotations
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
class BodySizeLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, *, max_bytes: int):
super().__init__(app)
self._max_bytes = max_bytes
async def dispatch(self, request: Request, call_next) -> Response:
content_length = request.headers.get("content-length")
if content_length is not None:
try:
if int(content_length) > self._max_bytes:
return JSONResponse(
status_code=413,
content={"detail": "Request body too large"},
)
except ValueError:
return JSONResponse(
status_code=400,
content={"detail": "Invalid Content-Length header"},
)
return await call_next(request)
+183
View File
@@ -0,0 +1,183 @@
"""
Configuración tipada y validada del backend.
Centraliza la lectura de variables de entorno, valida secretos críticos al
arranque y evita fallbacks inseguros (p. ej. JWT_SECRET="change_me") en
entornos productivos.
"""
from __future__ import annotations
import os
from functools import lru_cache
from pathlib import Path
from typing import List, Literal
from urllib.parse import urlparse
from dotenv import load_dotenv
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
_ENV_PATH = Path(__file__).resolve().parents[2] / ".env"
load_dotenv(dotenv_path=_ENV_PATH, override=False)
def _split_csv(value: str | List[str] | None) -> List[str]:
if value is None:
return []
if isinstance(value, list):
return [str(v).strip().rstrip("/") for v in value if str(v).strip()]
return [v.strip().rstrip("/") for v in value.split(",") if v.strip()]
class Settings(BaseSettings):
"""
Settings inmutables para toda la aplicación.
En `production` se aplican validaciones más estrictas:
- JWT_SECRET no puede ser un valor débil ni por defecto.
- CORS_ALLOWED_ORIGINS no puede contener "*".
- Se exige ORCID_CLIENT_ID/SECRET y API_KEY_VALUE.
"""
model_config = SettingsConfigDict(
env_file=str(_ENV_PATH),
env_file_encoding="utf-8",
extra="ignore",
case_sensitive=False,
)
ENVIRONMENT: Literal["development", "staging", "production"] = "development"
DEBUG: bool = False
DATABASE_URL: str = Field(...)
REDIS_URL: str | None = None
BASE_URL: str = "http://localhost:8000/api"
JWT_SECRET: str = Field(...)
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRES_MINUTES: int = 720
JWT_ISSUER: str = "orcid-sword-backend"
JWT_AUDIENCE: str = "orcid-sword-frontend"
API_KEY_NAME: str = "X-API-Key"
API_KEY_VALUE: str = Field(...)
ORCID_CLIENT_ID: str = Field(...)
ORCID_CLIENT_SECRET: str = Field(...)
ORCID_REDIRECT_URI: str = "http://localhost:8000/api/auth/orcid/callback"
ORCID_OAUTH_STATE_ENABLED: bool = True
ORCID_OAUTH_STATE_COOKIE: str = "orcid_oauth_state"
ORCID_OAUTH_STATE_TTL_SECONDS: int = 600
CORS_ALLOWED_ORIGINS: str = ""
TRUSTED_HOSTS: str = "*"
RATE_LIMIT_DEFAULT: str = "60/minute"
RATE_LIMIT_AUTH: str = "10/minute"
RATE_LIMIT_SEARCH_ANON: str = "5/minute"
RATE_LIMIT_SEARCH_AUTH: str = "30/minute"
RATE_LIMIT_EXPORT: str = "20/minute"
RATE_LIMIT_SYNC: str = "5/minute"
MAX_ORCID_BATCH: int = 25
MAX_PUB_IDS_BATCH: int = 500
MAX_REQUEST_BODY_BYTES: int = 1_048_576 # 1 MiB
DOCS_ENABLED: bool = True
SECURITY_HSTS_SECONDS: int = 31_536_000
SECURITY_HSTS_INCLUDE_SUBDOMAINS: bool = True
SECURITY_HSTS_PRELOAD: bool = False
@model_validator(mode="after")
def _validate_security(self) -> "Settings":
cors_origins = self.cors_allowed_origins
trusted_hosts = self.trusted_hosts
if self.ENVIRONMENT == "production":
weak = {"change_me", "changeme", "secret", "password", ""}
if self.JWT_SECRET.strip().lower() in weak:
raise ValueError(
"JWT_SECRET es débil o está sin configurar. "
"Define un secreto aleatorio fuerte (>= 32 bytes)."
)
if len(self.JWT_SECRET) < 32:
raise ValueError(
"JWT_SECRET debe tener al menos 32 caracteres en producción."
)
if "*" in cors_origins:
raise ValueError(
"CORS_ALLOWED_ORIGINS no puede contener '*' en producción."
)
if not cors_origins:
raise ValueError(
"CORS_ALLOWED_ORIGINS debe definirse explícitamente en producción."
)
if not self.API_KEY_VALUE or len(self.API_KEY_VALUE) < 24:
raise ValueError(
"API_KEY_VALUE debe tener al menos 24 caracteres en producción."
)
if trusted_hosts == ["*"]:
raise ValueError(
"TRUSTED_HOSTS debe definirse explícitamente en producción."
)
for origin in cors_origins:
parsed = urlparse(origin)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise ValueError(f"Origen CORS inválido: {origin!r}")
return self
@property
def is_production(self) -> bool:
return self.ENVIRONMENT == "production"
@property
def cors_allowed_origins(self) -> List[str]:
return _split_csv(self.CORS_ALLOWED_ORIGINS)
@property
def trusted_hosts(self) -> List[str]:
parsed = _split_csv(self.TRUSTED_HOSTS)
return parsed or ["*"]
@property
def docs_url(self) -> str | None:
return "/docs" if self.DOCS_ENABLED else None
@property
def redoc_url(self) -> str | None:
return "/redoc" if self.DOCS_ENABLED else None
@property
def openapi_url(self) -> str | None:
return "/openapi.json" if self.DOCS_ENABLED else None
@lru_cache(maxsize=1)
def get_settings() -> Settings:
"""
Devuelve la instancia única de configuración.
Se cachea para no releer entorno/archivos en cada request.
"""
return Settings() # type: ignore[call-arg]
settings = get_settings()
def reload_settings_for_tests() -> Settings:
"""
Helper para tests: invalida la caché y recarga settings.
"""
get_settings.cache_clear()
globals()["settings"] = get_settings()
return globals()["settings"]
__all__ = ["Settings", "get_settings", "reload_settings_for_tests", "settings"]
+67
View File
@@ -0,0 +1,67 @@
"""
Manejadores de errores que NO filtran información sensible.
- En producción, las excepciones no controladas devuelven un mensaje genérico.
- En desarrollo, se incluye `type` para depurar (sin trazas).
- Errores de validación se devuelven con 422 estándar de FastAPI.
"""
from __future__ import annotations
import logging
import uuid
from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy.exc import SQLAlchemyError
from app.core.config import settings
logger = logging.getLogger("app.error")
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail},
headers=getattr(exc, "headers", None),
)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
safe_errors = []
for err in exc.errors():
safe_errors.append(
{
"loc": err.get("loc"),
"msg": err.get("msg"),
"type": err.get("type"),
}
)
return JSONResponse(status_code=422, content={"detail": safe_errors})
async def sqlalchemy_exception_handler(
request: Request, exc: SQLAlchemyError
) -> JSONResponse:
error_id = str(uuid.uuid4())
logger.exception("DB error [%s] on %s %s", error_id, request.method, request.url.path)
return JSONResponse(
status_code=500,
content={"detail": "Database error", "error_id": error_id},
)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
error_id = str(uuid.uuid4())
logger.exception(
"Unhandled error [%s] on %s %s", error_id, request.method, request.url.path
)
payload: dict = {"detail": "Internal server error", "error_id": error_id}
if not settings.is_production and settings.DEBUG:
payload["type"] = exc.__class__.__name__
return JSONResponse(status_code=500, content=payload)
+28
View File
@@ -0,0 +1,28 @@
"""
Configuración de logging estructurada y minimalista.
- Formatea con timestamp, nivel y logger.
- En producción usa nivel INFO; en desarrollo DEBUG.
- Silencia logs ruidosos de librerías externas para no filtrar headers.
"""
from __future__ import annotations
import logging
from app.core.config import settings
_LOG_FORMAT = "%(asctime)s %(levelname)s %(name)s :: %(message)s"
def configure_logging() -> None:
level = logging.DEBUG if settings.DEBUG else logging.INFO
logging.basicConfig(level=level, format=_LOG_FORMAT)
for noisy in ("httpx", "httpcore", "sqlalchemy.engine.Engine"):
logging.getLogger(noisy).setLevel(logging.WARNING)
logging.getLogger("uvicorn.error").setLevel(level)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
+60
View File
@@ -0,0 +1,60 @@
"""
Rate limiting basado en SlowAPI.
- Usa Redis como backend si `REDIS_URL` está definido (compartido entre workers).
- Cae a memoria local en desarrollo si Redis no está disponible.
- Identifica al cliente por IP y, cuando hay JWT, también por `sub` (orcid_id),
para que un atacante autenticado no comparta cupo con su IP.
"""
from __future__ import annotations
from typing import Optional
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from starlette.requests import Request
from starlette.responses import JSONResponse
from app.core.config import settings
def _key_func(request: Request) -> str:
"""
Devuelve la clave de rate limit para el request.
- Si hay un investigador autenticado en el state, usa su orcid_id.
- En caso contrario, usa la IP remota.
"""
researcher = getattr(request.state, "researcher", None)
if researcher is not None:
return f"user:{getattr(researcher, 'orcid_id', None) or researcher.id}"
return f"ip:{get_remote_address(request)}"
def _build_limiter() -> Limiter:
storage_uri: Optional[str] = settings.REDIS_URL
return Limiter(
key_func=_key_func,
default_limits=[settings.RATE_LIMIT_DEFAULT],
storage_uri=storage_uri,
headers_enabled=False,
strategy="fixed-window",
)
limiter = _build_limiter()
def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
"""
Respuesta uniforme cuando se supera el límite.
No revela límites internos exactos para reducir oráculo a atacantes.
"""
return JSONResponse(
status_code=429,
content={"detail": "Too many requests, slow down."},
headers={"Retry-After": "60"},
)
+75
View File
@@ -0,0 +1,75 @@
from __future__ import annotations
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from app.core.config import Settings
_DOCS_PATHS = ("/docs", "/redoc", "/openapi.json")
_BASE_CSP = (
"default-src 'none'; "
"frame-ancestors 'none'; "
"base-uri 'none'; "
"form-action 'none'"
)
_SWAGGER_CSP = (
"default-src 'self'; "
"img-src 'self' data: https://fastapi.tiangolo.com; "
"script-src 'self' https://cdn.jsdelivr.net 'unsafe-inline'; "
"style-src 'self' https://cdn.jsdelivr.net 'unsafe-inline'; "
"font-src 'self' data: https://cdn.jsdelivr.net; "
"connect-src 'self'; "
"frame-ancestors 'none'; "
"base-uri 'self'; "
"form-action 'self'"
)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Inserta cabeceras de seguridad en cada respuesta.
"""
def __init__(self, app, settings: Settings):
super().__init__(app)
self._settings = settings
async def dispatch(self, request: Request, call_next) -> Response:
response: Response = await call_next(request)
response.headers.setdefault("X-Content-Type-Options", "nosniff")
response.headers.setdefault("X-Frame-Options", "DENY")
response.headers.setdefault("Referrer-Policy", "strict-origin-when-cross-origin")
response.headers.setdefault(
"Permissions-Policy",
"geolocation=(), microphone=(), camera=(), payment=(), usb=(), "
"accelerometer=(), gyroscope=(), magnetometer=(), interest-cohort=()",
)
response.headers.setdefault("Cross-Origin-Opener-Policy", "same-origin")
response.headers.setdefault("Cross-Origin-Resource-Policy", "same-site")
response.headers.setdefault("X-Permitted-Cross-Domain-Policies", "none")
if request.url.path in _DOCS_PATHS:
response.headers.setdefault("Content-Security-Policy", _SWAGGER_CSP)
else:
response.headers.setdefault("Content-Security-Policy", _BASE_CSP)
if request.url.scheme == "https" or self._settings.is_production:
hsts = f"max-age={self._settings.SECURITY_HSTS_SECONDS}"
if self._settings.SECURITY_HSTS_INCLUDE_SUBDOMAINS:
hsts += "; includeSubDomains"
if self._settings.SECURITY_HSTS_PRELOAD:
hsts += "; preload"
response.headers.setdefault("Strict-Transport-Security", hsts)
# `MutableHeaders` no implementa `.pop()`. Eliminamos de forma segura.
if "server" in response.headers:
del response.headers["server"]
if "x-powered-by" in response.headers:
del response.headers["x-powered-by"]
return response
+4
View File
@@ -1,3 +1,7 @@
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
# ---------------------------------------------------------
# Base de datos
# ---------------------------------------------------------
Base = declarative_base() Base = declarative_base()
+9
View File
@@ -6,6 +6,9 @@ from datetime import datetime
from app.db.session import Base from app.db.session import Base
# ---------------------------------------------------------
# Modelo de investigador
# ---------------------------------------------------------
class Researcher(Base): class Researcher(Base):
__tablename__ = "researchers" __tablename__ = "researchers"
@@ -18,6 +21,9 @@ class Researcher(Base):
publications = relationship("Publication", back_populates="researcher", cascade="all, delete-orphan") publications = relationship("Publication", back_populates="researcher", cascade="all, delete-orphan")
# ---------------------------------------------------------
# Modelo de publicación
# ---------------------------------------------------------
class Publication(Base): class Publication(Base):
__tablename__ = "publications" __tablename__ = "publications"
@@ -65,6 +71,9 @@ class Publication(Base):
# Legacy: descargado global (deprecado). Mantener por compatibilidad de DB. # Legacy: descargado global (deprecado). Mantener por compatibilidad de DB.
downloaded = Column(Boolean, nullable=False, default=False) downloaded = Column(Boolean, nullable=False, default=False)
# ---------------------------------------------------------
# Modelo de descarga de publicación
# ---------------------------------------------------------
class PublicationDownload(Base): class PublicationDownload(Base):
""" """
@@ -1,8 +1,16 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Publication from app.db.models import Publication
# ---------------------------------------------------------
# Repositorio de publicaciones
# ---------------------------------------------------------
class PublicationRepository: class PublicationRepository:
# ---------------------------------------------------------
# Función auxiliar: obtener publicación por put_code
# ---------------------------------------------------------
@staticmethod @staticmethod
def get_by_put_code(db: Session, researcher_id: str, put_code: int): def get_by_put_code(db: Session, researcher_id: str, put_code: int):
""" """
@@ -17,6 +25,10 @@ class PublicationRepository:
.first() .first()
) )
# ---------------------------------------------------------
# Función auxiliar: crear una nueva publicación
# ---------------------------------------------------------
@staticmethod @staticmethod
def create(db: Session, researcher_id: str, data: dict): def create(db: Session, researcher_id: str, data: dict):
""" """
@@ -37,6 +49,10 @@ class PublicationRepository:
db.refresh(pub) db.refresh(pub)
return pub return pub
# ---------------------------------------------------------
# Función auxiliar: actualizar una publicación existente
# ---------------------------------------------------------
@staticmethod @staticmethod
def update(db: Session, publication: Publication, data: dict): def update(db: Session, publication: Publication, data: dict):
""" """
@@ -53,6 +69,10 @@ class PublicationRepository:
db.refresh(publication) db.refresh(publication)
return publication return publication
# ---------------------------------------------------------
# Función auxiliar: listar publicaciones de un investigador
# ---------------------------------------------------------
@staticmethod @staticmethod
def list_by_researcher(db: Session, researcher_id: str): def list_by_researcher(db: Session, researcher_id: str):
""" """
@@ -2,13 +2,24 @@ from sqlalchemy.orm import Session
from app.db.models import Researcher from app.db.models import Researcher
from sqlalchemy.sql import func from sqlalchemy.sql import func
# ---------------------------------------------------------
# Repositorio de investigadores
# ---------------------------------------------------------
class ResearcherRepository: class ResearcherRepository:
# ---------------------------------------------------------
# Función auxiliar: obtener investigador por ORCID ID
# ---------------------------------------------------------
@staticmethod @staticmethod
def get_by_orcid(db: Session, orcid_id: str): def get_by_orcid(db: Session, orcid_id: str):
return db.query(Researcher).filter(Researcher.orcid_id == orcid_id).first() return db.query(Researcher).filter(Researcher.orcid_id == orcid_id).first()
# ---------------------------------------------------------
# Función auxiliar: crear un nuevo investigador
# ---------------------------------------------------------
@staticmethod @staticmethod
def create(db: Session, orcid_id: str, name: str = None): def create(db: Session, orcid_id: str, name: str = None):
researcher = Researcher(orcid_id=orcid_id, name=name) researcher = Researcher(orcid_id=orcid_id, name=name)
@@ -17,6 +28,10 @@ class ResearcherRepository:
db.refresh(researcher) db.refresh(researcher)
return researcher return researcher
# ---------------------------------------------------------
# Función auxiliar: actualizar la última sincronización
# ---------------------------------------------------------
@staticmethod @staticmethod
def update_last_sync(db: Session, researcher: Researcher): def update_last_sync(db: Session, researcher: Researcher):
researcher.last_sync_at = func.now() researcher.last_sync_at = func.now()
@@ -2,9 +2,16 @@ from sqlalchemy.orm import Session
from app.db.models import SyncJob from app.db.models import SyncJob
from sqlalchemy.sql import func from sqlalchemy.sql import func
# ---------------------------------------------------------
# Repositorio de trabajos de sincronización
# ---------------------------------------------------------
class SyncJobRepository: class SyncJobRepository:
# ---------------------------------------------------------
# Función auxiliar: iniciar un nuevo trabajo de sincronización
# ---------------------------------------------------------
@staticmethod @staticmethod
def start_job(db: Session, researcher_id: str): def start_job(db: Session, researcher_id: str):
job = SyncJob( job = SyncJob(
@@ -17,6 +24,10 @@ class SyncJobRepository:
db.refresh(job) db.refresh(job)
return job return job
# ---------------------------------------------------------
# Función auxiliar: finalizar un trabajo de sincronización
# ---------------------------------------------------------
@staticmethod @staticmethod
def finish_job(db: Session, job: SyncJob, new_records: int, updated_records: int): def finish_job(db: Session, job: SyncJob, new_records: int, updated_records: int):
job.status = "finished" job.status = "finished"
+10
View File
@@ -9,6 +9,7 @@ load_dotenv()
# ----------------------------- # -----------------------------
# DATABASE URL # DATABASE URL
# ----------------------------- # -----------------------------
DATABASE_URL = os.getenv("DATABASE_URL") DATABASE_URL = os.getenv("DATABASE_URL")
engine = create_engine( engine = create_engine(
@@ -29,6 +30,7 @@ Base = declarative_base()
# ----------------------------- # -----------------------------
# DB SESSION DEPENDENCY # DB SESSION DEPENDENCY
# ----------------------------- # -----------------------------
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
try: try:
@@ -40,17 +42,25 @@ def get_db():
# ----------------------------- # -----------------------------
# INIT DB (CREA TABLAS) # INIT DB (CREA TABLAS)
# ----------------------------- # -----------------------------
def init_db(): def init_db():
# Importa modelos para que SQLAlchemy los registre # Importa modelos para que SQLAlchemy los registre
import app.db.models # noqa import app.db.models # noqa
# Crea todas las tablas si no existen # Crea todas las tablas si no existen
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
# Pequeñas migraciones "best-effort" para entornos sin Alembic. # Pequeñas migraciones "best-effort" para entornos sin Alembic.
# (create_all no altera tablas existentes) # (create_all no altera tablas existentes)
_ensure_columns() _ensure_columns()
# ---------------------------------------------------------
# Función auxiliar: asegurar columnas existentes
# ---------------------------------------------------------
def _ensure_columns(): def _ensure_columns():
insp = inspect(engine) insp = inspect(engine)
+119 -33
View File
@@ -1,68 +1,154 @@
from fastapi import Depends, FastAPI """
Entry point del backend FastAPI.
Aplica un perfil de seguridad por defecto:
- Configuración tipada (Pydantic Settings) que falla rápido en producción.
- TrustedHostMiddleware (anti Host-header injection).
- CORS con lista blanca estricta (sin `*`).
- Body size limit (anti DoS por payload).
- Cabeceras de seguridad HTTP.
- Rate limiting (slowapi) con backend Redis si está configurado.
- Error handlers que NO filtran trazas ni internals.
"""
from __future__ import annotations
import logging
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from slowapi.errors import RateLimitExceeded
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.session import init_db from app.api.auth import complete_oauth_login_response, router as auth_router
from app.db.session import get_db
from app.api.researchers import router as researchers_router
from app.api.export import router as export_router from app.api.export import router as export_router
from app.api.auth import router as auth_router from app.api.researchers import router as researchers_router
from app.api.auth import _complete_oauth_login from app.core.body_size import BodySizeLimitMiddleware
from app.schema.auth import OrcidLoginResponseSchema from app.core.config import settings
from app.core.error_handlers import (
http_exception_handler,
sqlalchemy_exception_handler,
unhandled_exception_handler,
validation_exception_handler,
)
from app.core.logging_config import configure_logging
from app.core.rate_limit import limiter, rate_limit_exceeded_handler
from app.core.security_headers import SecurityHeadersMiddleware
from app.db.session import get_db, init_db
from app.scheduler.sync_scheduler import start_scheduler from app.scheduler.sync_scheduler import start_scheduler
from app.schema.auth import OrcidLoginResponseSchema
configure_logging()
logger = logging.getLogger("app.main")
# ---------------------------------------------------------
# Crear instancia principal de FastAPI
# ---------------------------------------------------------
app = FastAPI( app = FastAPI(
title="ORCID SWORD Backend", title="ORCID SWORD Backend",
description="Backend para sincronización ORCID y exportación SWORD", description="Backend para sincronización ORCID y exportación SWORD",
version="1.0.0" version="1.0.0",
docs_url=settings.docs_url,
redoc_url=settings.redoc_url,
openapi_url=settings.openapi_url,
) )
# --------------------------------------------------------- # ---------------------------------------------------------
# Crear tablas al iniciar la aplicación # Middlewares (orden importa: el último añadido es el más externo)
# --------------------------------------------------------- # ---------------------------------------------------------
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)
app.add_middleware(SecurityHeadersMiddleware, settings=settings)
app.add_middleware(
BodySizeLimitMiddleware,
max_bytes=settings.MAX_REQUEST_BODY_BYTES,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=[
"Authorization",
"Content-Type",
"Accept",
"Origin",
"X-Requested-With",
settings.API_KEY_NAME,
],
expose_headers=["Content-Disposition", "X-RateLimit-Remaining", "X-RateLimit-Reset"],
max_age=600,
)
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=settings.trusted_hosts,
)
# ---------------------------------------------------------
# Exception handlers
# ---------------------------------------------------------
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)
# ---------------------------------------------------------
# Lifecycle
# ---------------------------------------------------------
@app.on_event("startup") @app.on_event("startup")
def startup_event(): def on_startup() -> None:
init_db() # 🔥 CREA TABLAS init_db()
start_scheduler() # 🔥 INICIA SCHEDULER start_scheduler()
logger.info(
"Backend ready (env=%s, docs=%s)",
settings.ENVIRONMENT,
bool(settings.DOCS_ENABLED),
)
# --------------------------------------------------------- # ---------------------------------------------------------
# Healthcheck # Healthcheck
# --------------------------------------------------------- # ---------------------------------------------------------
@app.get("/health") @app.get("/health")
def health(): def health() -> dict[str, str]:
return {"status": "ok"} return {"status": "ok"}
# ---------------------------------------------------------
# Alias del callback OAuth (mismo flujo, mismo endurecimiento)
# ---------------------------------------------------------
@app.get("/callback", response_model=OrcidLoginResponseSchema) @app.get("/callback", response_model=OrcidLoginResponseSchema)
def oauth_callback_root(code: str, db: Session = Depends(get_db)): def oauth_callback_root(
request: Request,
code: str,
state: str | None = None,
db: Session = Depends(get_db),
):
""" """
Alias para probar redirect URIs como `https://127.0.0.1/callback` en local. Alias para integraciones que registran un redirect_uri tipo
Intercambia el code con ORCID y emite el JWT. `https://<host>/callback` en ORCID.
""" """
return _complete_oauth_login(code=code, db=db) return complete_oauth_login_response(request=request, code=code, state=state, db=db)
# --------------------------------------------------------- # ---------------------------------------------------------
# Registrar routers # Routers
# --------------------------------------------------------- # ---------------------------------------------------------
app.include_router(researchers_router, prefix="/api") app.include_router(researchers_router, prefix="/api")
app.include_router(export_router, prefix="/api") app.include_router(export_router, prefix="/api")
app.include_router(auth_router, prefix="/api") app.include_router(auth_router, prefix="/api")
# ---------------------------------------------------------
# CORS
# ---------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # en producción limitar
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
+10
View File
@@ -9,9 +9,16 @@ import os
# Cargar variables del .env # Cargar variables del .env
load_dotenv() load_dotenv()
# ---------------------------------------------------------
# Variables de entorno
# ---------------------------------------------------------
API_KEY = os.getenv("API_KEY_VALUE") API_KEY = os.getenv("API_KEY_VALUE")
BASE_URL = os.getenv("BASE_URL") BASE_URL = os.getenv("BASE_URL")
# ---------------------------------------------------------
# Función auxiliar: ejecutar sincronización mensual
# ---------------------------------------------------------
def run_monthly_sync(): def run_monthly_sync():
db = SessionLocal() db = SessionLocal()
@@ -36,6 +43,9 @@ def run_monthly_sync():
db.close() db.close()
# ---------------------------------------------------------
# Función auxiliar: iniciar el scheduler
# ---------------------------------------------------------
def start_scheduler(): def start_scheduler():
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()
+6
View File
@@ -1,11 +1,17 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# ---------------------------------------------------------
# Modelo de solicitud de login OAuth
# ---------------------------------------------------------
class OrcidLoginRequestSchema(BaseModel): class OrcidLoginRequestSchema(BaseModel):
# `code` is the authorization code returned by ORCID OAuth after the user signs in. # `code` is the authorization code returned by ORCID OAuth after the user signs in.
# Exchanging it for tokens must happen server-side. # Exchanging it for tokens must happen server-side.
code: str = Field(..., examples=["Q70Y3A"]) code: str = Field(..., examples=["Q70Y3A"])
# ---------------------------------------------------------
# Modelo de respuesta de login OAuth
# ---------------------------------------------------------
class OrcidLoginResponseSchema(BaseModel): class OrcidLoginResponseSchema(BaseModel):
access_token: str access_token: str
+23
View File
@@ -0,0 +1,23 @@
"""
Schemas de los endpoints de export.
El backend recibe `pub_ids` como UUIDs en formato string. Pydantic ya los
valida y convierte; aquí además aplicamos un tope de tamaño para impedir
peticiones gigantes.
"""
from __future__ import annotations
from typing import List
from uuid import UUID
from pydantic import BaseModel, Field
from app.core.config import settings
class PublicationIdsRequestSchema(BaseModel):
pub_ids: List[UUID] = Field(
min_length=1,
max_length=settings.MAX_PUB_IDS_BATCH,
)
+4
View File
@@ -3,6 +3,10 @@ from uuid import UUID
from typing import Optional, List, Any from typing import Optional, List, Any
from datetime import datetime from datetime import datetime
# ---------------------------------------------------------
# Modelo de publicación
# ---------------------------------------------------------
class PublicationSchema(BaseModel): class PublicationSchema(BaseModel):
id: UUID id: UUID
put_code: int | None = None put_code: int | None = None
+29 -6
View File
@@ -1,13 +1,18 @@
from pydantic import BaseModel, Field
from uuid import UUID
from typing import Optional, List, Dict
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from app.core.config import settings
from app.schema.publication import PublicationSchema from app.schema.publication import PublicationSchema
from app.utils.orcid_validator import ORCID_PATTERN, is_valid_orcid
class ResearcherSchema(BaseModel): class ResearcherSchema(BaseModel):
id: UUID id: UUID
orcid_id: str orcid_id: str = Field(min_length=19, max_length=19, pattern=ORCID_PATTERN)
name: Optional[str] name: Optional[str] = Field(default=None, max_length=255)
authenticated: bool authenticated: bool
last_sync_at: Optional[datetime] last_sync_at: Optional[datetime]
@@ -33,7 +38,25 @@ class ResearcherWithPublicationsSchema(BaseModel):
class ResearcherBatchSearchRequestSchema(BaseModel): class ResearcherBatchSearchRequestSchema(BaseModel):
orcid_ids: List[str] = Field(min_length=1) orcid_ids: List[str] = Field(
min_length=1,
max_length=settings.MAX_ORCID_BATCH,
)
@field_validator("orcid_ids")
@classmethod
def _validate_each(cls, value: List[str]) -> List[str]:
deduped: List[str] = []
seen = set()
for v in value:
if not isinstance(v, str):
raise ValueError("ORCID iD debe ser string")
if not is_valid_orcid(v):
raise ValueError(f"ORCID iD inválido: {v}")
if v not in seen:
seen.add(v)
deduped.append(v)
return deduped
class ResearcherSearchErrorSchema(BaseModel): class ResearcherSearchErrorSchema(BaseModel):
+34 -25
View File
@@ -1,43 +1,52 @@
import os """
from dotenv import load_dotenv Autenticación por API key (uso máquina-a-máquina, p. ej. el scheduler interno).
Endurecimiento:
- Comparación constante en tiempo (`hmac.compare_digest`) para evitar timing attacks.
- No se loggea el valor de la cabecera bajo ninguna circunstancia.
- Se separa este mecanismo del JWT de usuario; la API key NO debe usarse como
prueba de identidad de un investigador.
"""
from __future__ import annotations
import hmac
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
# Cargar variables del .env from app.core.config import settings
load_dotenv()
API_KEY_NAME = os.getenv("API_KEY_NAME")
API_KEY_VALUE = os.getenv("API_KEY_VALUE")
if not API_KEY_NAME:
raise RuntimeError("ERROR: La variable API_KEY_NAME no está definida en el .env")
if not API_KEY_VALUE:
raise RuntimeError("ERROR: La variable API_KEY_VALUE no está definida en el .env")
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
def get_api_key(api_key: str = Depends(api_key_header)): api_key_header = APIKeyHeader(name=settings.API_KEY_NAME, auto_error=False)
if api_key != API_KEY_VALUE:
def _is_valid_key(provided: str | None) -> bool:
if not provided or not settings.API_KEY_VALUE:
return False
return hmac.compare_digest(provided.encode("utf-8"), settings.API_KEY_VALUE.encode("utf-8"))
def get_api_key(api_key: str | None = Depends(api_key_header)) -> str:
if not _is_valid_key(api_key):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key inválida o ausente." detail="Invalid or missing API key",
) )
return api_key return api_key # type: ignore[return-value]
def get_api_key_optional(api_key: str = Depends(api_key_header)) -> str | None: def get_api_key_optional(api_key: str | None = Depends(api_key_header)) -> str | None:
""" """
Devuelve la API key si está presente y es correcta. - Si no llega cabecera: None.
- Si no está presente: None - Si llega y es válida: la devuelve.
- Si está presente pero incorrecta: 401 - Si llega pero es inválida: 401.
""" """
if api_key is None: if api_key is None:
return None return None
if api_key != API_KEY_VALUE: if not _is_valid_key(api_key):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key inválida." detail="Invalid API key",
) )
return api_key return api_key
+94 -31
View File
@@ -1,75 +1,138 @@
import os """
Emisión y verificación de JWT.
Endurecimiento aplicado:
- Sin fallback de secreto débil: si la configuración no es válida, falla al arranque.
- `iss` y `aud` obligatorios.
- `nbf` (not-before) y `iat` validados.
- `typ=access` para evitar mezclar tipos de token.
- Algoritmo fijo (no se acepta "none" ni cambios por payload).
- Errores opacos: nunca se expone el motivo del fallo de verificación al cliente.
"""
from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
from uuid import uuid4
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt from jose import JWTError, jwt
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dotenv import load_dotenv
from app.core.config import settings
from app.db.models import Researcher from app.db.models import Researcher
from app.db.session import get_db from app.db.session import get_db
from app.utils.orcid_validator import is_valid_orcid
load_dotenv()
_bearer = HTTPBearer(auto_error=False) _bearer = HTTPBearer(auto_error=False)
def _settings() -> tuple[str, str, int]:
# Fallback de desarrollo para evitar 500 por configuración ausente.
secret = os.getenv("JWT_SECRET") or "change_me"
algorithm = os.getenv("JWT_ALGORITHM") or "HS256"
expires_minutes = int(os.getenv("JWT_EXPIRES_MINUTES") or "720")
return secret, algorithm, expires_minutes
def create_access_token(*, subject: str, extra: dict[str, Any] | None = None) -> str: def create_access_token(*, subject: str, extra: dict[str, Any] | None = None) -> str:
secret, algorithm, expires_minutes = _settings() """
Emite un access token firmado con HS256 (configurable).
`subject` debe ser el ORCID iD verificado del investigador.
"""
if not is_valid_orcid(subject):
raise ValueError("subject must be a valid ORCID iD")
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
payload: dict[str, Any] = { payload: dict[str, Any] = {
"iss": settings.JWT_ISSUER,
"aud": settings.JWT_AUDIENCE,
"sub": subject, "sub": subject,
"iat": int(now.timestamp()), "iat": int(now.timestamp()),
"exp": int((now + timedelta(minutes=expires_minutes)).timestamp()), "nbf": int(now.timestamp()),
"exp": int((now + timedelta(minutes=settings.JWT_EXPIRES_MINUTES)).timestamp()),
"jti": uuid4().hex,
"typ": "access",
} }
if extra: if extra:
for reserved in ("iss", "aud", "sub", "iat", "nbf", "exp", "jti", "typ"):
extra.pop(reserved, None)
payload.update(extra) payload.update(extra)
return jwt.encode(payload, secret, algorithm=algorithm)
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
def _decode_token(token: str) -> dict[str, Any]:
try:
return jwt.decode(
token,
settings.JWT_SECRET,
algorithms=[settings.JWT_ALGORITHM],
audience=settings.JWT_AUDIENCE,
issuer=settings.JWT_ISSUER,
options={
"require_iat": True,
"require_nbf": True,
"require_exp": True,
"require_aud": True,
"require_iss": True,
},
)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
def get_current_researcher( def get_current_researcher(
creds: HTTPAuthorizationCredentials = Depends(_bearer), request: Request,
creds: HTTPAuthorizationCredentials | None = Depends(_bearer),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> Researcher: ) -> Researcher:
if not creds or not creds.credentials: if not creds or not creds.credentials:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing bearer token") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
secret, algorithm, _ = _settings() payload = _decode_token(creds.credentials)
try:
payload = jwt.decode(creds.credentials, secret, algorithms=[algorithm]) if payload.get("typ") != "access":
except JWTError: raise HTTPException(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
headers={"WWW-Authenticate": "Bearer"},
)
orcid_id = payload.get("sub") orcid_id = payload.get("sub")
if not isinstance(orcid_id, str) or not orcid_id: if not isinstance(orcid_id, str) or not is_valid_orcid(orcid_id):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token subject") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token subject",
headers={"WWW-Authenticate": "Bearer"},
)
researcher = db.query(Researcher).filter(Researcher.orcid_id == orcid_id).first() researcher = db.query(Researcher).filter(Researcher.orcid_id == orcid_id).first()
if not researcher or not researcher.authenticated: if not researcher or not researcher.authenticated:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Researcher not authenticated") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Researcher not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
request.state.researcher = researcher
return researcher return researcher
def get_optional_current_researcher( def get_optional_current_researcher(
creds: HTTPAuthorizationCredentials = Depends(_bearer), request: Request,
creds: HTTPAuthorizationCredentials | None = Depends(_bearer),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> Researcher | None: ) -> Researcher | None:
""" """
Devuelve el investigador autenticado si hay Bearer token. Devuelve el investigador autenticado si hay Bearer válido.
Si no hay token, devuelve None. Si no hay Bearer, devuelve None.
Si hay token inválido, lanza 401. Si hay Bearer inválido, lanza 401 (no se acepta como anónimo).
""" """
if not creds or not creds.credentials: if not creds or not creds.credentials:
return None return None
return get_current_researcher(creds=creds, db=db) return get_current_researcher(request=request, creds=creds, db=db)
+76
View File
@@ -0,0 +1,76 @@
"""
OAuth state anti-CSRF para el flujo de login con ORCID.
El parámetro `state` se genera en `/auth/orcid/authorize`, se guarda en una
cookie HttpOnly + SameSite=Lax con TTL corto, y se valida en el callback.
Si el `state` falta, no coincide o ha expirado, el login se rechaza.
"""
from __future__ import annotations
import hmac
import secrets
from datetime import datetime, timezone
from fastapi import HTTPException, status
from starlette.requests import Request
from starlette.responses import Response
from app.core.config import settings
_STATE_BYTES = 32
def generate_state() -> str:
return secrets.token_urlsafe(_STATE_BYTES)
def attach_state_cookie(response: Response, state: str) -> None:
"""
Persiste el `state` en una cookie segura y devuelve el valor crudo.
"""
response.set_cookie(
key=settings.ORCID_OAUTH_STATE_COOKIE,
value=state,
max_age=settings.ORCID_OAUTH_STATE_TTL_SECONDS,
secure=settings.is_production,
httponly=True,
samesite="lax",
path="/",
)
def clear_state_cookie(response: Response) -> None:
response.delete_cookie(
key=settings.ORCID_OAUTH_STATE_COOKIE,
path="/",
)
def validate_state(request: Request, received_state: str | None) -> None:
"""
Compara el state recibido en el callback con el almacenado en cookie.
Lanza 400 si no coincide o falta. Comparación en tiempo constante.
"""
if not settings.ORCID_OAUTH_STATE_ENABLED:
return
cookie_value = request.cookies.get(settings.ORCID_OAUTH_STATE_COOKIE)
if not cookie_value or not received_state:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth state missing",
)
if not hmac.compare_digest(cookie_value.encode("utf-8"), received_state.encode("utf-8")):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth state mismatch",
)
def now_ts() -> int:
return int(datetime.now(timezone.utc).timestamp())
+6
View File
@@ -1,5 +1,8 @@
from typing import List from typing import List
# ---------------------------------------------------------
# Función auxiliar: obtener valor de un diccionario
# ---------------------------------------------------------
def _get(d: dict | None, *keys, default=None): def _get(d: dict | None, *keys, default=None):
cur = d or {} cur = d or {}
@@ -11,6 +14,9 @@ def _get(d: dict | None, *keys, default=None):
return default return default
return cur return cur
# ---------------------------------------------------------
# Clase de normalización de publicaciones
# ---------------------------------------------------------
class PublicationNormalizer: class PublicationNormalizer:
@staticmethod @staticmethod
+10
View File
@@ -14,8 +14,14 @@ BASE_URL_SANDBOX = "https://pub.sandbox.orcid.org/v3.0"
# TOKEN_URL_PROD = "https://orcid.org/oauth/token" # TOKEN_URL_PROD = "https://orcid.org/oauth/token"
# BASE_URL_PROD = "https://pub.orcid.org/v3.0" # BASE_URL_PROD = "https://pub.orcid.org/v3.0"
# ---------------------------------------------------------
# Clase de cliente de ORCID
# ---------------------------------------------------------
class ORCIDClient: class ORCIDClient:
# ---------------------------------------------------------
# Función auxiliar: inicializar el cliente de ORCID
# ---------------------------------------------------------
def __init__(self): def __init__(self):
# Asegura que al ejecutar `uvicorn` local también se carga `backend/.env`. # Asegura que al ejecutar `uvicorn` local también se carga `backend/.env`.
# (En docker `ORCID_REDIRECT_URI` y secretos llegan por env_file, así que esto no molesta.) # (En docker `ORCID_REDIRECT_URI` y secretos llegan por env_file, así que esto no molesta.)
@@ -115,6 +121,10 @@ class ORCIDClient:
params["state"] = state params["state"] = state
return f"{self.authorization_url}?{urllib.parse.urlencode(params)}" return f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
# ---------------------------------------------------------
# Función auxiliar: intercambiar código de autorización
# ---------------------------------------------------------
def exchange_authorization_code( def exchange_authorization_code(
self, self,
*, *,
+3
View File
@@ -6,6 +6,9 @@ ATOM_NS = "http://www.w3.org/2005/Atom"
DC_NS = "http://purl.org/dc/elements/1.1/" DC_NS = "http://purl.org/dc/elements/1.1/"
EXTRA_NS = "http://example.org/orcid-extra" # namespace para campos extendidos EXTRA_NS = "http://example.org/orcid-extra" # namespace para campos extendidos
# ---------------------------------------------------------
# Clase de generador de feed SWORD
# ---------------------------------------------------------
class SWORDGenerator: class SWORDGenerator:
+15
View File
@@ -8,12 +8,23 @@ from app.db.repositories.researcher_repository import ResearcherRepository
from app.db.repositories.publication_repository import PublicationRepository from app.db.repositories.publication_repository import PublicationRepository
from app.db.repositories.syncjob_repository import SyncJobRepository from app.db.repositories.syncjob_repository import SyncJobRepository
# ---------------------------------------------------------
# Clase de servicio de sincronización
# ---------------------------------------------------------
class SyncService: class SyncService:
# ---------------------------------------------------------
# Función auxiliar: inicializar el servicio de sincronización
# ---------------------------------------------------------
def __init__(self): def __init__(self):
self.orcid_client = ORCIDClient() self.orcid_client = ORCIDClient()
# ---------------------------------------------------------
# Función auxiliar: sincronizar las publicaciones de un investigador
# ---------------------------------------------------------
def sync_researcher(self, db: Session, orcid_id: str): def sync_researcher(self, db: Session, orcid_id: str):
""" """
Sincroniza las publicaciones de un investigador con manejo robusto de errores. Sincroniza las publicaciones de un investigador con manejo robusto de errores.
@@ -109,6 +120,10 @@ class SyncService:
"total": new_records + updated_records "total": new_records + updated_records
} }
# ---------------------------------------------------------
# Función auxiliar: sincronizar y obtener investigador + publicaciones
# ---------------------------------------------------------
def sync_and_get_full(self, db: Session, orcid_id: str): def sync_and_get_full(self, db: Session, orcid_id: str):
""" """
Sincroniza (si es necesario) y devuelve investigador + publicaciones. Sincroniza (si es necesario) y devuelve investigador + publicaciones.
+5 -1
View File
@@ -7,12 +7,16 @@ from xml.etree.ElementTree import Element, SubElement, tostring
from app.db.models import Publication, Researcher from app.db.models import Publication, Researcher
from app.services.sword_generator import SWORDGenerator from app.services.sword_generator import SWORDGenerator
# ---------------------------------------------------------
# Clase de generador de ZIP
# ---------------------------------------------------------
class ZIPGenerator: class ZIPGenerator:
# --------------------------------------------------------- # ---------------------------------------------------------
# MANIFEST.TXT — más completo # Función auxiliar: generar manifest.txt
# --------------------------------------------------------- # ---------------------------------------------------------
@staticmethod @staticmethod
def generate_manifest(researcher, publications): def generate_manifest(researcher, publications):
lines = [ lines = [
+15 -4
View File
@@ -2,27 +2,38 @@ import re
ORCID_REGEX = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[0-9X]$") ORCID_REGEX = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[0-9X]$")
ORCID_PATTERN = r"^\d{4}-\d{4}-\d{4}-\d{3}[0-9X]$"
def is_valid_orcid(orcid_id: str) -> bool:
def is_valid_orcid(orcid_id: str | None) -> bool:
""" """
Valida un ORCID ID: Valida un ORCID ID:
- Formato: 0000-0000-0000-0000 - Formato: 0000-0000-0000-0000
- Dígito de control según ISO 7064 Mod 11-2 - Dígito de control según ISO 7064 Mod 11-2
""" """
if not isinstance(orcid_id, str):
return False
if not ORCID_REGEX.match(orcid_id): if not ORCID_REGEX.match(orcid_id):
return False return False
# Quitar guiones
digits = orcid_id.replace("-", "") digits = orcid_id.replace("-", "")
total = 0 total = 0
# Los primeros 15 dígitos
for char in digits[:-1]: for char in digits[:-1]:
total = (total + int(char)) * 2 total = (total + int(char)) * 2
# Resto
remainder = total % 11 remainder = total % 11
result = (12 - remainder) % 11 result = (12 - remainder) % 11
check_digit = "X" if result == 10 else str(result) check_digit = "X" if result == 10 else str(result)
return digits[-1] == check_digit return digits[-1] == check_digit
def assert_valid_orcid(orcid_id: str) -> str:
"""
Devuelve el ORCID si es válido. Lanza ValueError si no.
Útil para usar como Pydantic validator.
"""
if not is_valid_orcid(orcid_id):
raise ValueError("ORCID iD inválido")
return orcid_id
+5 -3
View File
@@ -1,14 +1,16 @@
fastapi fastapi
uvicorn uvicorn[standard]
sqlalchemy sqlalchemy
psycopg2-binary psycopg2-binary
httpx httpx
pydantic pydantic
pydantic-settings
python-dotenv python-dotenv
lxml lxml
apscheduler defusedxml
APScheduler==3.10.4
authlib authlib
redis redis
APScheduler==3.10.4
requests requests
python-jose[cryptography] python-jose[cryptography]
slowapi
+30 -10
View File
@@ -3,9 +3,9 @@ services:
backend: backend:
build: ./backend build: ./backend
container_name: orcid-backend container_name: orcid-backend
restart: always restart: unless-stopped
ports: ports:
- "8000:8000" - "127.0.0.1:8000:8000"
env_file: env_file:
- ./backend/.env - ./backend/.env
environment: environment:
@@ -17,28 +17,43 @@ services:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_started condition: service_started
read_only: true
tmpfs:
- /tmp
cap_drop:
- ALL
security_opt:
- no-new-privileges:true
healthcheck:
test: ["CMD", "curl", "-fsS", "http://127.0.0.1:8000/health"]
interval: 30s
timeout: 5s
retries: 3
start_period: 15s
frontend: frontend:
build: ./frontend build: ./frontend
container_name: orcid-frontend container_name: orcid-frontend
restart: always restart: unless-stopped
ports: ports:
- "5173:5173" - "127.0.0.1:5173:5173"
depends_on: depends_on:
- backend - backend
env_file: env_file:
- ./frontend/.env - ./frontend/.env
security_opt:
- no-new-privileges:true
db: db:
image: postgres:16 image: postgres:16
container_name: orcid-postgres container_name: orcid-postgres
restart: always restart: unless-stopped
environment: environment:
POSTGRES_USER: postgres POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: orcid_db POSTGRES_DB: orcid_db
ports: expose:
- "5432:5432" - "5432"
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
healthcheck: healthcheck:
@@ -46,13 +61,18 @@ services:
interval: 2s interval: 2s
timeout: 3s timeout: 3s
retries: 20 retries: 20
security_opt:
- no-new-privileges:true
redis: redis:
image: redis:7 image: redis:7
container_name: orcid-redis container_name: orcid-redis
restart: always restart: unless-stopped
ports: command: ["redis-server", "--save", "60", "1", "--loglevel", "warning"]
- "6379:6379" expose:
- "6379"
security_opt:
- no-new-privileges:true
volumes: volumes:
postgres_data: postgres_data: