131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
"""Authentication endpoints."""
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from app.core.database import get_db
|
|
from app.core.security import verify_password, get_password_hash, create_access_token, create_refresh_token, decode_token
|
|
from app.models.user import User
|
|
from app.schemas.user import UserCreate, UserRead, Token
|
|
from app.services.audit import AuditService
|
|
|
|
router = APIRouter()
|
|
|
|
@router.post("/register", response_model=Token)
|
|
async def register(
|
|
user_in: UserCreate,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Register a new user."""
|
|
# Check if email exists
|
|
result = await db.execute(select(User).where(User.email == user_in.email))
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Email already registered"
|
|
)
|
|
|
|
# Create user
|
|
user = User(
|
|
email=user_in.email,
|
|
hashed_password=get_password_hash(user_in.password),
|
|
full_name=user_in.full_name
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
|
|
# Audit log
|
|
await AuditService.log(
|
|
db,
|
|
action="user.register",
|
|
user_id=user.id,
|
|
resource_type="user",
|
|
resource_id=user.id,
|
|
ip_address=request.client.host if request.client else None
|
|
)
|
|
|
|
# Return tokens
|
|
access_token = create_access_token({"sub": str(user.id), "email": user.email})
|
|
refresh_token = create_refresh_token({"sub": str(user.id), "email": user.email})
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer"
|
|
)
|
|
|
|
@router.post("/login", response_model=Token)
|
|
async def login(
|
|
email: str,
|
|
password: str,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Login and get access token."""
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(password, user.hashed_password):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid email or password"
|
|
)
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="User is inactive"
|
|
)
|
|
|
|
# Update last login
|
|
user.last_login = datetime.utcnow()
|
|
|
|
# Audit log
|
|
await AuditService.log(
|
|
db,
|
|
action="user.login",
|
|
user_id=user.id,
|
|
resource_type="user",
|
|
resource_id=user.id,
|
|
ip_address=request.client.host if request.client else None,
|
|
user_agent=request.headers.get("user-agent")
|
|
)
|
|
|
|
# Create tokens
|
|
token_data = {"user_id": user.id, "email": user.email}
|
|
access_token = create_access_token(token_data)
|
|
refresh_token = create_refresh_token(token_data)
|
|
|
|
return Token(access_token=access_token, refresh_token=refresh_token)
|
|
|
|
@router.post("/refresh", response_model=Token)
|
|
async def refresh_token(
|
|
refresh_token: str,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Refresh access token."""
|
|
payload = decode_token(refresh_token)
|
|
|
|
if not payload or payload.get("type") != "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token"
|
|
)
|
|
|
|
user_id = payload.get("user_id")
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found or inactive"
|
|
)
|
|
|
|
token_data = {"user_id": user.id, "email": user.email}
|
|
new_access_token = create_access_token(token_data)
|
|
new_refresh_token = create_refresh_token(token_data)
|
|
|
|
return Token(access_token=new_access_token, refresh_token=new_refresh_token)
|