diff --git a/media_manager/auth/oauth.py b/media_manager/auth/oauth.py new file mode 100644 index 0000000..a161039 --- /dev/null +++ b/media_manager/auth/oauth.py @@ -0,0 +1,154 @@ +from typing import Optional + +import jwt +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback +from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token +from pydantic import BaseModel + +from fastapi_users import models, schemas +from fastapi_users.authentication import AuthenticationBackend, Authenticator, Strategy +from fastapi_users.exceptions import UserAlreadyExists +from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt +from fastapi_users.manager import BaseUserManager, UserManagerDependency +from fastapi_users.router.common import ErrorCode, ErrorModel + +STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state" + + +class OAuth2AuthorizeResponse(BaseModel): + authorization_url: str + + +def generate_state_token( + data: dict[str, str], secret: SecretType, lifetime_seconds: int = 3600 +) -> str: + data["aud"] = STATE_TOKEN_AUDIENCE + return generate_jwt(data, secret, lifetime_seconds) + + +def get_oauth_router( + oauth_client: BaseOAuth2, + backend: AuthenticationBackend[models.UP, models.ID], + get_user_manager: UserManagerDependency[models.UP, models.ID], + state_secret: SecretType, + redirect_url: Optional[str] = None, + associate_by_email: bool = False, + is_verified_by_default: bool = False, +) -> APIRouter: + """Generate a router with the OAuth routes.""" + router = APIRouter() + callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback" + + if redirect_url is not None: + oauth2_authorize_callback = OAuth2AuthorizeCallback( + oauth_client, + redirect_url=redirect_url, + ) + else: + oauth2_authorize_callback = OAuth2AuthorizeCallback( + oauth_client, + route_name=callback_route_name, + ) + + @router.get( + "/authorize", + name=f"oauth:{oauth_client.name}.{backend.name}.authorize", + response_model=OAuth2AuthorizeResponse, + ) + async def authorize( + request: Request, scopes: list[str] = ["openid", "profile", "email"] + ) -> OAuth2AuthorizeResponse: + if redirect_url is not None: + authorize_redirect_url = redirect_url + else: + authorize_redirect_url = str(request.url_for(callback_route_name)) + + state_data: dict[str, str] = {} + state = generate_state_token(state_data, state_secret) + authorization_url = await oauth_client.get_authorization_url( + authorize_redirect_url, + state, + scopes, + ) + + return OAuth2AuthorizeResponse(authorization_url=authorization_url) + + @router.get( + "/callback", + name=callback_route_name, + description="The response varies based on the authentication backend used.", + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": ErrorModel, + "content": { + "application/json": { + "examples": { + "INVALID_STATE_TOKEN": { + "summary": "Invalid state token.", + "value": None, + }, + ErrorCode.LOGIN_BAD_CREDENTIALS: { + "summary": "User is inactive.", + "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS}, + }, + } + } + }, + }, + }, + ) + async def callback( + request: Request, + access_token_state: tuple[OAuth2Token, str] = Depends( + oauth2_authorize_callback + ), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), + ): + token, state = access_token_state + account_id, account_email = await oauth_client.get_id_email( + token["access_token"] + ) + + if account_email is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL, + ) + + try: + decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) + except jwt.DecodeError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + try: + user = await user_manager.oauth_callback( + oauth_client.name, + token["access_token"], + account_id, + account_email, + token.get("expires_at"), + token.get("refresh_token"), + request, + associate_by_email=associate_by_email, + is_verified_by_default=is_verified_by_default, + ) + except UserAlreadyExists: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS, + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.LOGIN_BAD_CREDENTIALS, + ) + + # Authenticate + response = await backend.login(strategy, user) + await user_manager.on_after_login(user, request, response) + return response + + return router diff --git a/media_manager/main.py b/media_manager/main.py index e3b0a73..7c790a4 100644 --- a/media_manager/main.py +++ b/media_manager/main.py @@ -47,7 +47,6 @@ log = logging.getLogger(__name__) from media_manager.database import init_db import media_manager.tv.router as tv_router import media_manager.torrent.router as torrent_router - init_db() log.info("Database initialized") @@ -92,6 +91,8 @@ from media_manager.auth.users import SECRET as AUTH_USERS_SECRET from media_manager.auth.router import users_router as custom_users_router from media_manager.auth.router import auth_metadata_router from media_manager.auth.schemas import UserCreate, UserRead, UserUpdate +from media_manager.auth.oauth import get_oauth_router + from media_manager.auth.users import ( bearer_auth_backend, fastapi_users, @@ -139,10 +140,11 @@ app.include_router( # OAuth2 Routers if openid_client is not None: app.include_router( - fastapi_users.get_oauth_router( - openid_client, - openid_cookie_auth_backend, - AUTH_USERS_SECRET, + get_oauth_router( + oauth_client=openid_client, + backend=openid_cookie_auth_backend, + get_user_manager=fastapi_users.get_user_manager, + state_secret=AUTH_USERS_SECRET, associate_by_email=True, is_verified_by_default=True, ),