update how oauth is handled

This commit is contained in:
maxDorninger
2025-09-13 18:42:34 +02:00
parent 1c81083e23
commit 8f5cc9329c
10 changed files with 89 additions and 189 deletions

View File

@@ -8,6 +8,7 @@ class OpenIdConfig(BaseSettings):
client_secret: str = ""
configuration_endpoint: str = ""
enabled: bool = False
name: str = "OAuth2"
class AuthConfig(BaseSettings):
@@ -17,7 +18,7 @@ class AuthConfig(BaseSettings):
session_lifetime: int = 60 * 60 * 24
admin_emails: list[str] = []
email_password_resets: bool = False
openid_connect: dict[str, OpenIdConfig] = {}
openid_connect: OpenIdConfig = OpenIdConfig()
@property
def jwt_signing_key(self):

View File

@@ -1,139 +0,0 @@
from media_manager.auth.users import (
SECRET,
openid_cookie_auth_backend as backend,
get_user_manager,
openid_clients as oauth_clients,
)
import jwt
from fastapi import APIRouter, Depends, HTTPException, Request, status
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from pydantic import BaseModel
from fastapi_users import models
from fastapi_users.authentication import Strategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.manager import BaseUserManager
from fastapi_users.router.common import ErrorCode
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
class OAuth2AuthorizeResponse(BaseModel):
authorization_url: str
def generate_state_token(
data: dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds)
router = APIRouter(prefix="/auth/oauth")
def get_authorize_callback(provider_name: str):
if provider_name not in oauth_clients:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return OAuth2AuthorizeCallback(
oauth_clients[provider_name],
route_name="oauth:callback",
redirect_url=str(router.url_path_for("oauth:callback")),
)
@router.get(
"/{openid_provider_name}/authorize",
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request,
openid_provider_name: str,
) -> OAuth2AuthorizeResponse:
if openid_provider_name not in oauth_clients:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
oauth_client = oauth_clients[openid_provider_name]
authorize_redirect_url = str(request.url_for("oauth:callback"))
state = generate_state_token({"provider_name": openid_provider_name}, SECRET)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
["openid", "profile", "email"],
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(
"/callback",
name="oauth:callback",
)
async def callback(
request: Request,
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
access_token_state=None,
) -> None:
state_from_query = request.query_params.get("state")
if not state_from_query:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Missing state token"
)
try:
state_data = decode_jwt(state_from_query, SECRET, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state token"
)
openid_provider_name = state_data.get("provider_name")
if not openid_provider_name or openid_provider_name not in oauth_clients:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid provider name in state token",
)
authorize_callback = get_authorize_callback(openid_provider_name)
token, state = await authorize_callback(request)
oauth_client = oauth_clients[openid_provider_name]
account_id, account_email = await oauth_client.get_id_email(token["access_token"])
if account_email is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
)
try:
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,
account_email,
token.get("expires_at"),
token.get("refresh_token"),
request,
associate_by_email=True,
is_verified_by_default=True,
)
except UserAlreadyExists:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
# Authenticate
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
return response

View File

@@ -1,16 +1,40 @@
from fastapi import APIRouter, Depends
from fastapi import status
from fastapi_users.router import get_oauth_router
from sqlalchemy import select
from media_manager.config import AllEncompassingConfig
from media_manager.auth.db import User
from media_manager.auth.schemas import UserRead
from media_manager.auth.users import current_superuser
from media_manager.auth.schemas import UserRead, AuthMetadata
from media_manager.auth.users import (
current_superuser,
openid_client,
openid_cookie_auth_backend,
SECRET,
get_user_manager,
)
from media_manager.database import DbSessionDependency
users_router = APIRouter()
auth_metadata_router = APIRouter()
oauth_config = AllEncompassingConfig().auth.openid_connect
def get_openid_router():
if openid_client:
return get_oauth_router(
oauth_client=openid_client,
backend=openid_cookie_auth_backend,
get_user_manager=get_user_manager,
state_secret=SECRET,
associate_by_email=True,
is_verified_by_default=True,
redirect_url=None,
)
else:
return None
openid_config = AllEncompassingConfig().auth.openid_connect
@users_router.get(
@@ -25,11 +49,8 @@ def get_all_users(db: DbSessionDependency) -> list[UserRead]:
@auth_metadata_router.get("/auth/metadata", status_code=status.HTTP_200_OK)
def get_auth_metadata() -> dict:
if oauth_config:
provider_names = [
name for name, config in oauth_config.items() if config.enabled
]
return {"oauth_providers": provider_names}
def get_auth_metadata() -> AuthMetadata:
if openid_config.enabled:
return AuthMetadata(oauth_providers=[openid_config.name])
else:
return {"oauth_providers": []}
return AuthMetadata(oauth_providers=[])

View File

@@ -1,6 +1,7 @@
import uuid
from fastapi_users import schemas
from pydantic import BaseModel
class UserRead(schemas.BaseUser[uuid.UUID]):
@@ -13,3 +14,7 @@ class UserCreate(schemas.BaseUserCreate):
class UserUpdate(schemas.BaseUserUpdate):
pass
class AuthMetadata(BaseModel):
oauth_providers: list[str]

View File

@@ -28,21 +28,16 @@ config = AllEncompassingConfig().auth
SECRET = config.token_secret
LIFETIME = config.session_lifetime
openid_clients: dict[str, OpenID] = {}
if config.openid_connect:
log.info(f"got openid-config: {config.openid_connect}")
for name, openid_config in config.openid_connect.items():
if openid_config.enabled:
log.info(f"Discovered OIDC provider: {name}")
client = OpenID(
base_scopes=["openid", "email", "profile"],
client_id=openid_config.client_id,
client_secret=openid_config.client_secret,
name=name,
openid_configuration_endpoint=openid_config.configuration_endpoint,
)
client.base_scopes = ["openid", "email", "profile"]
openid_clients[name] = client
openid_client: OpenID | None = None
if config.openid_connect.enabled:
log.info(f"Configured OIDC provider: {config.openid_connect.name}")
openid_client = OpenID(
base_scopes=["openid", "email", "profile"],
client_id=config.openid_connect.client_id,
client_secret=config.openid_connect.client_secret,
name=config.openid_connect.name,
openid_configuration_endpoint=config.openid_connect.configuration_endpoint,
)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):

View File

@@ -69,7 +69,7 @@ from fastapi.staticfiles import StaticFiles # noqa: E402
from media_manager.auth.router import users_router as custom_users_router # noqa: E402
from media_manager.auth.router import auth_metadata_router # noqa: E402
from media_manager.auth.schemas import UserCreate, UserRead, UserUpdate # noqa: E402
from media_manager.auth.oauth import router as openid_router # noqa: E402
from media_manager.auth.router import get_openid_router # noqa: E402
from media_manager.auth.users import ( # noqa: E402
bearer_auth_backend,
@@ -268,8 +268,7 @@ api_app.include_router(
# ----------------------------
api_app.include_router(auth_metadata_router, tags=["openid"])
api_app.include_router(openid_router, tags=["openid"])
api_app.include_router(get_openid_router(), tags=["openid"], prefix="/auth/oauth")
api_app.include_router(tv_router.router, prefix="/tv", tags=["tv"])
api_app.include_router(torrent_router.router, prefix="/torrent", tags=["torrent"])

View File

@@ -248,15 +248,15 @@ export interface paths {
patch?: never;
trace?: never;
};
'/api/v1/auth/oauth/{openid_provider_name}/authorize': {
'/api/v1/auth/oauth/authorize': {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/** Authorize */
get: operations['authorize_api_v1_auth_oauth__openid_provider_name__authorize_get'];
/** Oauth:Openid Connect.Cookie.Authorize */
get: operations['oauth_OpenID_Connect_cookie_authorize_api_v1_auth_oauth_authorize_get'];
put?: never;
post?: never;
delete?: never;
@@ -272,8 +272,11 @@ export interface paths {
path?: never;
cookie?: never;
};
/** Oauth:Callback */
get: operations['oauth_callback_api_v1_auth_oauth_callback_get'];
/**
* Oauth:Openid Connect.Cookie.Callback
* @description The response varies based on the authentication backend used.
*/
get: operations['oauth_OpenID_Connect_cookie_callback_api_v1_auth_oauth_callback_get'];
put?: never;
post?: never;
delete?: never;
@@ -969,6 +972,11 @@ export interface paths {
export type webhooks = Record<string, never>;
export interface components {
schemas: {
/** AuthMetadata */
AuthMetadata: {
/** Oauth Providers */
oauth_providers: string[];
};
/** BearerResponse */
BearerResponse: {
/** Access Token */
@@ -2299,20 +2307,18 @@ export interface operations {
[name: string]: unknown;
};
content: {
'application/json': {
[key: string]: unknown;
};
'application/json': components['schemas']['AuthMetadata'];
};
};
};
};
authorize_api_v1_auth_oauth__openid_provider_name__authorize_get: {
oauth_OpenID_Connect_cookie_authorize_api_v1_auth_oauth_authorize_get: {
parameters: {
query?: never;
header?: never;
path: {
openid_provider_name: string;
query?: {
scopes?: string[];
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
@@ -2337,10 +2343,13 @@ export interface operations {
};
};
};
oauth_callback_api_v1_auth_oauth_callback_get: {
oauth_OpenID_Connect_cookie_callback_api_v1_auth_oauth_callback_get: {
parameters: {
query?: {
access_token_state?: unknown;
code?: string | null;
code_verifier?: string | null;
state?: string | null;
error?: string | null;
};
header?: never;
path?: never;
@@ -2357,6 +2366,15 @@ export interface operations {
'application/json': unknown;
};
};
/** @description Bad Request */
400: {
headers: {
[name: string]: unknown;
};
content: {
'application/json': components['schemas']['ErrorModel'];
};
};
/** @description Validation Error */
422: {
headers: {

View File

@@ -116,7 +116,7 @@
Or continue with
</span>
</div>
<Button class="mt-2 w-full" onclick={() => handleOauth(name)} variant="outline"
<Button class="mt-2 w-full" onclick={() => handleOauth()} variant="outline"
>Login with {name}</Button
>
{/each}

View File

@@ -120,7 +120,7 @@
Or continue with
</span>
</div>
<Button class="mt-2 w-full" onclick={() => handleOauth(name)} variant="outline"
<Button class="mt-2 w-full" onclick={() => handleOauth()} variant="outline"
>Login with {name}</Button
>
{/each}

View File

@@ -63,11 +63,11 @@ export async function handleLogout() {
await goto(base + '/login');
}
export async function handleOauth(oauth_name: string) {
const { error, data } = await client.GET(`/api/v1/auth/oauth/{openid_provider_name}/authorize`, {
export async function handleOauth() {
const { error, data } = await client.GET(`/api/v1/auth/oauth/authorize`, {
params: {
path: {
openid_provider_name: oauth_name
query: {
scopes: ['openid', 'email', 'profile']
}
}
});