diff --git a/backend/config/__init__.py b/backend/config/__init__.py index f2320924d..9ad9d09f4 100644 --- a/backend/config/__init__.py +++ b/backend/config/__init__.py @@ -50,6 +50,7 @@ ROMM_AUTH_SECRET_KEY: Final = os.environ.get( "ROMM_AUTH_SECRET_KEY", secrets.token_hex(32) ) +DISABLE_CSRF_PROTECTION = os.environ.get("DISABLE_CSRF_PROTECTION", "false") == "true" # TASKS ENABLE_RESCAN_ON_FILESYSTEM_CHANGE: Final = ( diff --git a/backend/endpoints/auth.py b/backend/endpoints/auth.py index c9a058ced..ed6448569 100644 --- a/backend/endpoints/auth.py +++ b/backend/endpoints/auth.py @@ -1,4 +1,3 @@ -import secrets from datetime import timedelta from typing import Annotated, Final @@ -9,7 +8,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.security.http import HTTPBasic from handler import auth_handler, oauth_handler -from handler.redis_handler import cache ACCESS_TOKEN_EXPIRE_MINUTES: Final = 30 REFRESH_TOKEN_EXPIRE_DAYS: Final = 7 @@ -45,7 +43,9 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp status_code=status.HTTP_400_BAD_REQUEST, detail="Missing refresh token" ) - user, payload = await oauth_handler.get_current_active_user_from_bearer_token(token) + user, payload = await oauth_handler.get_current_active_user_from_bearer_token( + token + ) if payload.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" @@ -54,6 +54,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp access_token = oauth_handler.create_oauth_token( data={ "sub": user.username, + "iss": "romm:oauth", "scopes": payload.get("scopes"), "type": "access", }, @@ -105,6 +106,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp access_token = oauth_handler.create_oauth_token( data={ "sub": user.username, + "iss": "romm:oauth", "scopes": " ".join(form_data.scopes), "type": "access", }, @@ -114,6 +116,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp refresh_token = oauth_handler.create_oauth_token( data={ "sub": user.username, + "iss": "romm:oauth", "scopes": " ".join(form_data.scopes), "type": "refresh", }, @@ -151,9 +154,7 @@ def login(request: Request, credentials=Depends(HTTPBasic())) -> MessageResponse if not user.enabled: raise DisabledException - # Generate unique session key and store in cache - request.session["session_id"] = secrets.token_hex(16) - cache.set(f'romm:{request.session["session_id"]}', user.username) # type: ignore[attr-defined] + request.session.update({"iss": "romm:auth", "sub": user.username}) return {"msg": "Successfully logged in"} @@ -169,14 +170,6 @@ def logout(request: Request) -> MessageResponse: MessageResponse: Standard message response """ - # Check if session key already stored in cache - session_id = request.session.get("session_id") - if not session_id: - return {"msg": "Already logged out"} - - if not request.user.is_authenticated: - return {"msg": "Already logged out"} - - auth_handler.clear_session(request) + request.session.clear() return {"msg": "Successfully logged out"} diff --git a/backend/endpoints/tests/conftest.py b/backend/endpoints/tests/conftest.py index 441b5e892..038377189 100644 --- a/backend/endpoints/tests/conftest.py +++ b/backend/endpoints/tests/conftest.py @@ -10,6 +10,7 @@ def access_token(admin_user): # noqa data = { "sub": admin_user.username, + "iss": "romm:oauth", "scopes": " ".join(admin_user.oauth_scopes), "type": "access", } @@ -23,6 +24,7 @@ def access_token(admin_user): # noqa def refresh_token(admin_user): # noqa data = { "sub": admin_user.username, + "iss": "romm:oauth", "scopes": " ".join(admin_user.oauth_scopes), "type": "refresh", } diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 3dca29102..4f383e786 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Annotated diff --git a/backend/handler/auth_handler/__init__.py b/backend/handler/auth_handler/__init__.py index e49ef2f0f..4dd9af1a5 100644 --- a/backend/handler/auth_handler/__init__.py +++ b/backend/handler/auth_handler/__init__.py @@ -7,8 +7,7 @@ ROMM_AUTH_USERNAME, ) from exceptions.auth_exceptions import OAuthCredentialsException -from fastapi import HTTPException, Request, status -from handler.redis_handler import cache +from fastapi import HTTPException, status from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.exc import IntegrityError @@ -52,12 +51,6 @@ def verify_password(self, plain_password, hashed_password): def get_password_hash(self, password): return self.pwd_context.hash(password) - def clear_session(self, req: HTTPConnection | Request): - session_id = req.session.get("session_id") - if session_id: - cache.delete(f"romm:{session_id}") # type: ignore[attr-defined] - req.session["session_id"] = None - def authenticate_user(self, username: str, password: str): from handler import db_user_handler @@ -72,20 +65,19 @@ def authenticate_user(self, username: str, password: str): async def get_current_active_user_from_session(self, conn: HTTPConnection): from handler import db_user_handler - - # Check if session key already stored in cache - session_id = conn.session.get("session_id") - if not session_id: + + issuer = conn.session.get('iss') + if not issuer or issuer != 'romm:auth': return None - username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined] + username = conn.session.get('sub') if not username: return None # Key exists therefore user is probably authenticated user = db_user_handler.get_user_by_username(username) if user is None: - self.clear_session(conn) + conn.session.clear() raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -93,7 +85,7 @@ async def get_current_active_user_from_session(self, conn: HTTPConnection): ) if not user.enabled: - self.clear_session(conn) + conn.session.clear() raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" @@ -140,6 +132,10 @@ async def get_current_active_user_from_bearer_token(self, token: str): payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM]) except JWTError: raise OAuthCredentialsException + + issuer = payload.get('iss') + if not issuer or issuer != 'romm:oauth': + return None username = payload.get("sub") if username is None: diff --git a/backend/handler/auth_handler/tests/test_auth.py b/backend/handler/auth_handler/tests/test_auth.py index 04821ed0e..8d38192f0 100644 --- a/backend/handler/auth_handler/tests/test_auth.py +++ b/backend/handler/auth_handler/tests/test_auth.py @@ -6,7 +6,6 @@ from handler import auth_handler, oauth_handler, db_user_handler from handler.auth_handler import WRITE_SCOPES from handler.auth_handler.hybrid_auth import HybridAuthBackend -from handler.redis_handler import cache def test_verify_password(): @@ -14,20 +13,17 @@ def test_verify_password(): assert not auth_handler.verify_password("password", auth_handler.get_password_hash("notpassword")) -def test_authenticate_user(admin_user): +def test_authenticate_user(admin_user: User): current_user = auth_handler.authenticate_user("test_admin", "test_admin_password") assert current_user assert current_user.id == admin_user.id -async def test_get_current_active_user_from_session(editor_user): - session_id = "test_session_id" - cache.set(f"romm:{session_id}", editor_user.username) - +async def test_get_current_active_user_from_session(editor_user: User): class MockConnection: def __init__(self): - self.session = {"session_id": session_id} + self.session = {"iss": "romm:auth", "sub": editor_user.username} conn = MockConnection() current_user = await auth_handler.get_current_active_user_from_session(conn) @@ -37,28 +33,10 @@ def __init__(self): assert current_user.id == editor_user.id -async def test_get_current_active_user_from_session_bad_session_key(editor_user): - cache.set("romm:test_session_id", editor_user.username) - +async def test_get_current_active_user_from_session_bad_username(editor_user: User): class MockConnection: def __init__(self): - self.session = {"session_id": "not_real_test_session_id"} - self.headers = {} - - conn = MockConnection() - current_user = await auth_handler.get_current_active_user_from_session(conn) - - assert not current_user - - -async def test_get_current_active_user_from_session_bad_username(editor_user): - session_id = "test_session_id" - cache.set(f"romm:{session_id}", "not_real_username") - - class MockConnection: - def __init__(self): - self.session = {"session_id": session_id} - self.headers = {} + self.session = {"iss": "romm:auth", "sub": "not_real_username"} conn = MockConnection() @@ -69,13 +47,10 @@ def __init__(self): assert e.detail == "User not found" -async def test_get_current_active_user_from_session_disabled_user(editor_user): - session_id = "test_session_id" - cache.set(f"romm:{session_id}", editor_user.username) - +async def test_get_current_active_user_from_session_disabled_user(editor_user: User): class MockConnection: def __init__(self): - self.session = {"session_id": session_id} + self.session = {"iss": "romm:auth", "sub": editor_user.username} self.headers = {} conn = MockConnection() @@ -105,13 +80,10 @@ def test_create_default_admin_user(): assert len(users) == 1 -async def test_hybrid_auth_backend_session(editor_user): - session_id = "test_session_id" - cache.set(f"romm:{session_id}", editor_user.username) - +async def test_hybrid_auth_backend_session(editor_user: User): class MockConnection: def __init__(self): - self.session = {"session_id": session_id} + self.session = {"iss": "romm:auth", "sub": editor_user.username} backend = HybridAuthBackend() conn = MockConnection() @@ -123,7 +95,7 @@ def __init__(self): assert creds.scopes == WRITE_SCOPES -async def test_hybrid_auth_backend_empty_session_and_headers(editor_user): +async def test_hybrid_auth_backend_empty_session_and_headers(editor_user: User): class MockConnection: def __init__(self): self.session = {} @@ -138,10 +110,11 @@ def __init__(self): assert creds.scopes == [] -async def test_hybrid_auth_backend_bearer_auth_header(editor_user): +async def test_hybrid_auth_backend_bearer_auth_header(editor_user: User): access_token = oauth_handler.create_oauth_token( data={ "sub": editor_user.username, + "iss": "romm:oauth", "scopes": " ".join(editor_user.oauth_scopes), "type": "access", }, @@ -161,7 +134,7 @@ def __init__(self): assert set(creds.scopes).issubset(editor_user.oauth_scopes) -async def test_hybrid_auth_backend_bearer_invalid_token(editor_user): +async def test_hybrid_auth_backend_bearer_invalid_token(editor_user: User): class MockConnection: def __init__(self): self.session = {} @@ -174,7 +147,7 @@ def __init__(self): await backend.authenticate(conn) -async def test_hybrid_auth_backend_basic_auth_header(editor_user): +async def test_hybrid_auth_backend_basic_auth_header(editor_user: User): token = b64encode("test_editor:test_editor_password".encode()).decode() class MockConnection: @@ -192,7 +165,7 @@ def __init__(self): assert set(creds.scopes).issubset(editor_user.oauth_scopes) -async def test_hybrid_auth_backend_basic_auth_header_unencoded(editor_user): +async def test_hybrid_auth_backend_basic_auth_header_unencoded(editor_user: User): class MockConnection: def __init__(self): self.session = {} @@ -220,10 +193,11 @@ def __init__(self): assert creds.scopes == [] -async def test_hybrid_auth_backend_with_refresh_token(editor_user): +async def test_hybrid_auth_backend_with_refresh_token(editor_user: User): refresh_token = oauth_handler.create_oauth_token( data={ "sub": editor_user.username, + "iss": "romm:oauth", "scopes": " ".join(editor_user.oauth_scopes), "type": "refresh", }, @@ -243,11 +217,12 @@ def __init__(self): assert creds.scopes == [] -async def test_hybrid_auth_backend_scope_subset(editor_user): +async def test_hybrid_auth_backend_scope_subset(editor_user: User): scopes = editor_user.oauth_scopes[:3] access_token = oauth_handler.create_oauth_token( data={ "sub": editor_user.username, + "iss": "romm:oauth", "scopes": " ".join(scopes), "type": "access", }, diff --git a/backend/handler/auth_handler/tests/test_oauth.py b/backend/handler/auth_handler/tests/test_oauth.py index b9e3b7566..9210a21d1 100644 --- a/backend/handler/auth_handler/tests/test_oauth.py +++ b/backend/handler/auth_handler/tests/test_oauth.py @@ -16,6 +16,7 @@ async def test_get_current_active_user_from_bearer_token(admin_user): token = oauth_handler.create_oauth_token( data={ "sub": admin_user.username, + "iss": "romm:oauth", "scopes": " ".join(admin_user.oauth_scopes), "type": "access", }, @@ -24,6 +25,7 @@ async def test_get_current_active_user_from_bearer_token(admin_user): assert user.id == admin_user.id assert payload["sub"] == admin_user.username + assert payload["iss"] == "romm:oauth" assert set(payload["scopes"].split()).issubset(admin_user.oauth_scopes) assert payload["type"] == "access" @@ -34,7 +36,7 @@ async def test_get_current_active_user_from_bearer_token_invalid_token(): async def test_get_current_active_user_from_bearer_token_invalid_user(): - token = oauth_handler.create_oauth_token(data={"sub": "invalid_user"}) + token = oauth_handler.create_oauth_token(data={"sub": "invalid_user", "iss": "romm:oauth"}) with pytest.raises(HTTPException): await oauth_handler.get_current_active_user_from_bearer_token(token) @@ -44,6 +46,7 @@ async def test_get_current_active_user_from_bearer_token_disabled_user(admin_use token = oauth_handler.create_oauth_token( data={ "sub": admin_user.username, + "iss": "romm:oauth", "scopes": " ".join(admin_user.oauth_scopes), "type": "access", }, diff --git a/backend/handler/gh_handler.py b/backend/handler/gh_handler.py index d8cf0a14f..3ddd7273e 100644 --- a/backend/handler/gh_handler.py +++ b/backend/handler/gh_handler.py @@ -35,7 +35,7 @@ def check_new_version(self) -> str: try: response = requests.get( - "https://github.com/gitapi/repos/zurdi15/romm/releases/latest", timeout=0.5 + "https://github.com/gitapi/repos/zurdi15/romm/releases/latest", timeout=5 ) except ReadTimeout: log.warning("Couldn't check last RomM version.") diff --git a/backend/main.py b/backend/main.py index 19903abb0..51a048952 100644 --- a/backend/main.py +++ b/backend/main.py @@ -3,7 +3,7 @@ import alembic.config import uvicorn -from config import DEV_HOST, DEV_PORT, ROMM_AUTH_SECRET_KEY +from config import DEV_HOST, DEV_PORT, ROMM_AUTH_SECRET_KEY, DISABLE_CSRF_PROTECTION from endpoints import ( auth, config, @@ -20,15 +20,16 @@ webrcade, screenshots, ) -import endpoints.sockets.scan # noqa +import endpoints.sockets.scan # noqa from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi_pagination import add_pagination from handler import auth_handler, db_user_handler, github_handler, socket_handler +from handler.auth_handler import ALGORITHM from handler.auth_handler.hybrid_auth import HybridAuthBackend from handler.auth_handler.middleware import CustomCSRFMiddleware from starlette.middleware.authentication import AuthenticationMiddleware -from starlette.middleware.sessions import SessionMiddleware +from starlette_authlib.middleware import AuthlibMiddleware as SessionMiddleware app = FastAPI(title="RomM API", version=github_handler.get_version()) @@ -40,7 +41,7 @@ allow_headers=["*"], ) -if "pytest" not in sys.modules: +if "pytest" not in sys.modules and not DISABLE_CSRF_PROTECTION: # CSRF protection (except endpoints listed in exempt_urls) app.add_middleware( CustomCSRFMiddleware, @@ -60,6 +61,7 @@ secret_key=ROMM_AUTH_SECRET_KEY, same_site="strict", https_only=False, + jwt_alg=ALGORITHM, ) app.include_router(heartbeat.router) diff --git a/frontend/src/App.vue b/frontend/src/App.vue index a9dfd6710..a0e32f149 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -1,14 +1,13 @@