mirror of
https://github.com/maxdorninger/MediaManager.git
synced 2026-04-17 15:13:24 +02:00
update how oauth is handled
This commit is contained in:
@@ -8,6 +8,7 @@ class OpenIdConfig(BaseSettings):
|
||||
client_secret: str = ""
|
||||
configuration_endpoint: str = ""
|
||||
enabled: bool = False
|
||||
name: str = "OAuth2"
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
@@ -17,7 +18,7 @@ class AuthConfig(BaseSettings):
|
||||
session_lifetime: int = 60 * 60 * 24
|
||||
admin_emails: list[str] = []
|
||||
email_password_resets: bool = False
|
||||
openid_connect: dict[str, OpenIdConfig] = {}
|
||||
openid_connect: OpenIdConfig = OpenIdConfig()
|
||||
|
||||
@property
|
||||
def jwt_signing_key(self):
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
from media_manager.auth.users import (
|
||||
SECRET,
|
||||
openid_cookie_auth_backend as backend,
|
||||
get_user_manager,
|
||||
openid_clients as oauth_clients,
|
||||
)
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from pydantic import BaseModel
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import BaseUserManager
|
||||
from fastapi_users.router.common import ErrorCode
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
authorization_url: str
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth/oauth")
|
||||
|
||||
|
||||
def get_authorize_callback(provider_name: str):
|
||||
if provider_name not in oauth_clients:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
return OAuth2AuthorizeCallback(
|
||||
oauth_clients[provider_name],
|
||||
route_name="oauth:callback",
|
||||
redirect_url=str(router.url_path_for("oauth:callback")),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{openid_provider_name}/authorize",
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
openid_provider_name: str,
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
if openid_provider_name not in oauth_clients:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
oauth_client = oauth_clients[openid_provider_name]
|
||||
|
||||
authorize_redirect_url = str(request.url_for("oauth:callback"))
|
||||
state = generate_state_token({"provider_name": openid_provider_name}, SECRET)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
["openid", "profile", "email"],
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/callback",
|
||||
name="oauth:callback",
|
||||
)
|
||||
async def callback(
|
||||
request: Request,
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
access_token_state=None,
|
||||
) -> None:
|
||||
state_from_query = request.query_params.get("state")
|
||||
if not state_from_query:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Missing state token"
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_jwt(state_from_query, SECRET, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state token"
|
||||
)
|
||||
|
||||
openid_provider_name = state_data.get("provider_name")
|
||||
if not openid_provider_name or openid_provider_name not in oauth_clients:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid provider name in state token",
|
||||
)
|
||||
|
||||
authorize_callback = get_authorize_callback(openid_provider_name)
|
||||
token, state = await authorize_callback(request)
|
||||
|
||||
oauth_client = oauth_clients[openid_provider_name]
|
||||
|
||||
account_id, account_email = await oauth_client.get_id_email(token["access_token"])
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
account_email,
|
||||
token.get("expires_at"),
|
||||
token.get("refresh_token"),
|
||||
request,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
)
|
||||
except UserAlreadyExists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Authenticate
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
return response
|
||||
@@ -1,16 +1,40 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import status
|
||||
from fastapi_users.router import get_oauth_router
|
||||
from sqlalchemy import select
|
||||
|
||||
from media_manager.config import AllEncompassingConfig
|
||||
from media_manager.auth.db import User
|
||||
from media_manager.auth.schemas import UserRead
|
||||
from media_manager.auth.users import current_superuser
|
||||
from media_manager.auth.schemas import UserRead, AuthMetadata
|
||||
from media_manager.auth.users import (
|
||||
current_superuser,
|
||||
openid_client,
|
||||
openid_cookie_auth_backend,
|
||||
SECRET,
|
||||
get_user_manager,
|
||||
)
|
||||
from media_manager.database import DbSessionDependency
|
||||
|
||||
users_router = APIRouter()
|
||||
auth_metadata_router = APIRouter()
|
||||
oauth_config = AllEncompassingConfig().auth.openid_connect
|
||||
|
||||
|
||||
def get_openid_router():
|
||||
if openid_client:
|
||||
return get_oauth_router(
|
||||
oauth_client=openid_client,
|
||||
backend=openid_cookie_auth_backend,
|
||||
get_user_manager=get_user_manager,
|
||||
state_secret=SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
redirect_url=None,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
openid_config = AllEncompassingConfig().auth.openid_connect
|
||||
|
||||
|
||||
@users_router.get(
|
||||
@@ -25,11 +49,8 @@ def get_all_users(db: DbSessionDependency) -> list[UserRead]:
|
||||
|
||||
|
||||
@auth_metadata_router.get("/auth/metadata", status_code=status.HTTP_200_OK)
|
||||
def get_auth_metadata() -> dict:
|
||||
if oauth_config:
|
||||
provider_names = [
|
||||
name for name, config in oauth_config.items() if config.enabled
|
||||
]
|
||||
return {"oauth_providers": provider_names}
|
||||
def get_auth_metadata() -> AuthMetadata:
|
||||
if openid_config.enabled:
|
||||
return AuthMetadata(oauth_providers=[openid_config.name])
|
||||
else:
|
||||
return {"oauth_providers": []}
|
||||
return AuthMetadata(oauth_providers=[])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
|
||||
from fastapi_users import schemas
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
@@ -13,3 +14,7 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class AuthMetadata(BaseModel):
|
||||
oauth_providers: list[str]
|
||||
|
||||
@@ -28,21 +28,16 @@ config = AllEncompassingConfig().auth
|
||||
SECRET = config.token_secret
|
||||
LIFETIME = config.session_lifetime
|
||||
|
||||
openid_clients: dict[str, OpenID] = {}
|
||||
if config.openid_connect:
|
||||
log.info(f"got openid-config: {config.openid_connect}")
|
||||
for name, openid_config in config.openid_connect.items():
|
||||
if openid_config.enabled:
|
||||
log.info(f"Discovered OIDC provider: {name}")
|
||||
client = OpenID(
|
||||
base_scopes=["openid", "email", "profile"],
|
||||
client_id=openid_config.client_id,
|
||||
client_secret=openid_config.client_secret,
|
||||
name=name,
|
||||
openid_configuration_endpoint=openid_config.configuration_endpoint,
|
||||
)
|
||||
client.base_scopes = ["openid", "email", "profile"]
|
||||
openid_clients[name] = client
|
||||
openid_client: OpenID | None = None
|
||||
if config.openid_connect.enabled:
|
||||
log.info(f"Configured OIDC provider: {config.openid_connect.name}")
|
||||
openid_client = OpenID(
|
||||
base_scopes=["openid", "email", "profile"],
|
||||
client_id=config.openid_connect.client_id,
|
||||
client_secret=config.openid_connect.client_secret,
|
||||
name=config.openid_connect.name,
|
||||
openid_configuration_endpoint=config.openid_connect.configuration_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
@@ -69,7 +69,7 @@ from fastapi.staticfiles import StaticFiles # noqa: E402
|
||||
from media_manager.auth.router import users_router as custom_users_router # noqa: E402
|
||||
from media_manager.auth.router import auth_metadata_router # noqa: E402
|
||||
from media_manager.auth.schemas import UserCreate, UserRead, UserUpdate # noqa: E402
|
||||
from media_manager.auth.oauth import router as openid_router # noqa: E402
|
||||
from media_manager.auth.router import get_openid_router # noqa: E402
|
||||
|
||||
from media_manager.auth.users import ( # noqa: E402
|
||||
bearer_auth_backend,
|
||||
@@ -268,8 +268,7 @@ api_app.include_router(
|
||||
# ----------------------------
|
||||
|
||||
api_app.include_router(auth_metadata_router, tags=["openid"])
|
||||
|
||||
api_app.include_router(openid_router, tags=["openid"])
|
||||
api_app.include_router(get_openid_router(), tags=["openid"], prefix="/auth/oauth")
|
||||
|
||||
api_app.include_router(tv_router.router, prefix="/tv", tags=["tv"])
|
||||
api_app.include_router(torrent_router.router, prefix="/torrent", tags=["torrent"])
|
||||
|
||||
48
web/src/lib/api/api.d.ts
vendored
48
web/src/lib/api/api.d.ts
vendored
@@ -248,15 +248,15 @@ export interface paths {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
'/api/v1/auth/oauth/{openid_provider_name}/authorize': {
|
||||
'/api/v1/auth/oauth/authorize': {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
/** Authorize */
|
||||
get: operations['authorize_api_v1_auth_oauth__openid_provider_name__authorize_get'];
|
||||
/** Oauth:Openid Connect.Cookie.Authorize */
|
||||
get: operations['oauth_OpenID_Connect_cookie_authorize_api_v1_auth_oauth_authorize_get'];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
@@ -272,8 +272,11 @@ export interface paths {
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
/** Oauth:Callback */
|
||||
get: operations['oauth_callback_api_v1_auth_oauth_callback_get'];
|
||||
/**
|
||||
* Oauth:Openid Connect.Cookie.Callback
|
||||
* @description The response varies based on the authentication backend used.
|
||||
*/
|
||||
get: operations['oauth_OpenID_Connect_cookie_callback_api_v1_auth_oauth_callback_get'];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
@@ -969,6 +972,11 @@ export interface paths {
|
||||
export type webhooks = Record<string, never>;
|
||||
export interface components {
|
||||
schemas: {
|
||||
/** AuthMetadata */
|
||||
AuthMetadata: {
|
||||
/** Oauth Providers */
|
||||
oauth_providers: string[];
|
||||
};
|
||||
/** BearerResponse */
|
||||
BearerResponse: {
|
||||
/** Access Token */
|
||||
@@ -2299,20 +2307,18 @@ export interface operations {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
'application/json': {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
'application/json': components['schemas']['AuthMetadata'];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
authorize_api_v1_auth_oauth__openid_provider_name__authorize_get: {
|
||||
oauth_OpenID_Connect_cookie_authorize_api_v1_auth_oauth_authorize_get: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path: {
|
||||
openid_provider_name: string;
|
||||
query?: {
|
||||
scopes?: string[];
|
||||
};
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
@@ -2337,10 +2343,13 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
oauth_callback_api_v1_auth_oauth_callback_get: {
|
||||
oauth_OpenID_Connect_cookie_callback_api_v1_auth_oauth_callback_get: {
|
||||
parameters: {
|
||||
query?: {
|
||||
access_token_state?: unknown;
|
||||
code?: string | null;
|
||||
code_verifier?: string | null;
|
||||
state?: string | null;
|
||||
error?: string | null;
|
||||
};
|
||||
header?: never;
|
||||
path?: never;
|
||||
@@ -2357,6 +2366,15 @@ export interface operations {
|
||||
'application/json': unknown;
|
||||
};
|
||||
};
|
||||
/** @description Bad Request */
|
||||
400: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
'application/json': components['schemas']['ErrorModel'];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
headers: {
|
||||
|
||||
@@ -116,7 +116,7 @@
|
||||
Or continue with
|
||||
</span>
|
||||
</div>
|
||||
<Button class="mt-2 w-full" onclick={() => handleOauth(name)} variant="outline"
|
||||
<Button class="mt-2 w-full" onclick={() => handleOauth()} variant="outline"
|
||||
>Login with {name}</Button
|
||||
>
|
||||
{/each}
|
||||
|
||||
@@ -120,7 +120,7 @@
|
||||
Or continue with
|
||||
</span>
|
||||
</div>
|
||||
<Button class="mt-2 w-full" onclick={() => handleOauth(name)} variant="outline"
|
||||
<Button class="mt-2 w-full" onclick={() => handleOauth()} variant="outline"
|
||||
>Login with {name}</Button
|
||||
>
|
||||
{/each}
|
||||
|
||||
@@ -63,11 +63,11 @@ export async function handleLogout() {
|
||||
await goto(base + '/login');
|
||||
}
|
||||
|
||||
export async function handleOauth(oauth_name: string) {
|
||||
const { error, data } = await client.GET(`/api/v1/auth/oauth/{openid_provider_name}/authorize`, {
|
||||
export async function handleOauth() {
|
||||
const { error, data } = await client.GET(`/api/v1/auth/oauth/authorize`, {
|
||||
params: {
|
||||
path: {
|
||||
openid_provider_name: oauth_name
|
||||
query: {
|
||||
scopes: ['openid', 'email', 'profile']
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user