Skip to content

Commit

Permalink
Merge pull request #662 from zurdi15/3.0-rc-5
Browse files Browse the repository at this point in the history
Even more fixes for 3.0
  • Loading branch information
zurdi15 committed Feb 16, 2024
2 parents 94c62fa + 6778dc3 commit 3701c97
Show file tree
Hide file tree
Showing 24 changed files with 173 additions and 137 deletions.
1 change: 1 addition & 0 deletions backend/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
ROMM_AUTH_SECRET_KEY: Final = os.environ.get(
"ROMM_AUTH_SECRET_KEY", secrets.token_hex(32)
)
DISABLE_CSRF_PROTECTION = os.environ.get("DISABLE_CSRF_PROTECTION", "false") == "true"

# TASKS
ENABLE_RESCAN_ON_FILESYSTEM_CHANGE: Final = (
Expand Down
23 changes: 8 additions & 15 deletions backend/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import secrets
from datetime import timedelta
from typing import Annotated, Final

Expand All @@ -9,7 +8,6 @@
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security.http import HTTPBasic
from handler import auth_handler, oauth_handler
from handler.redis_handler import cache

ACCESS_TOKEN_EXPIRE_MINUTES: Final = 30
REFRESH_TOKEN_EXPIRE_DAYS: Final = 7
Expand Down Expand Up @@ -45,7 +43,9 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
status_code=status.HTTP_400_BAD_REQUEST, detail="Missing refresh token"
)

user, payload = await oauth_handler.get_current_active_user_from_bearer_token(token)
user, payload = await oauth_handler.get_current_active_user_from_bearer_token(
token
)
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
Expand All @@ -54,6 +54,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
access_token = oauth_handler.create_oauth_token(
data={
"sub": user.username,
"iss": "romm:oauth",
"scopes": payload.get("scopes"),
"type": "access",
},
Expand Down Expand Up @@ -105,6 +106,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
access_token = oauth_handler.create_oauth_token(
data={
"sub": user.username,
"iss": "romm:oauth",
"scopes": " ".join(form_data.scopes),
"type": "access",
},
Expand All @@ -114,6 +116,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
refresh_token = oauth_handler.create_oauth_token(
data={
"sub": user.username,
"iss": "romm:oauth",
"scopes": " ".join(form_data.scopes),
"type": "refresh",
},
Expand Down Expand Up @@ -151,9 +154,7 @@ def login(request: Request, credentials=Depends(HTTPBasic())) -> MessageResponse
if not user.enabled:
raise DisabledException

# Generate unique session key and store in cache
request.session["session_id"] = secrets.token_hex(16)
cache.set(f'romm:{request.session["session_id"]}', user.username) # type: ignore[attr-defined]
request.session.update({"iss": "romm:auth", "sub": user.username})

return {"msg": "Successfully logged in"}

Expand All @@ -169,14 +170,6 @@ def logout(request: Request) -> MessageResponse:
MessageResponse: Standard message response
"""

# Check if session key already stored in cache
session_id = request.session.get("session_id")
if not session_id:
return {"msg": "Already logged out"}

if not request.user.is_authenticated:
return {"msg": "Already logged out"}

auth_handler.clear_session(request)
request.session.clear()

return {"msg": "Successfully logged out"}
2 changes: 2 additions & 0 deletions backend/endpoints/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def access_token(admin_user): # noqa
data = {
"sub": admin_user.username,
"iss": "romm:oauth",
"scopes": " ".join(admin_user.oauth_scopes),
"type": "access",
}
Expand All @@ -23,6 +24,7 @@ def access_token(admin_user): # noqa
def refresh_token(admin_user): # noqa
data = {
"sub": admin_user.username,
"iss": "romm:oauth",
"scopes": " ".join(admin_user.oauth_scopes),
"type": "refresh",
}
Expand Down
1 change: 0 additions & 1 deletion backend/endpoints/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from pathlib import Path
from typing import Annotated

Expand Down
26 changes: 11 additions & 15 deletions backend/handler/auth_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
ROMM_AUTH_USERNAME,
)
from exceptions.auth_exceptions import OAuthCredentialsException
from fastapi import HTTPException, Request, status
from handler.redis_handler import cache
from fastapi import HTTPException, status
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -52,12 +51,6 @@ def verify_password(self, plain_password, hashed_password):
def get_password_hash(self, password):
return self.pwd_context.hash(password)

def clear_session(self, req: HTTPConnection | Request):
session_id = req.session.get("session_id")
if session_id:
cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
req.session["session_id"] = None

def authenticate_user(self, username: str, password: str):
from handler import db_user_handler

Expand All @@ -72,28 +65,27 @@ def authenticate_user(self, username: str, password: str):

async def get_current_active_user_from_session(self, conn: HTTPConnection):
from handler import db_user_handler

# Check if session key already stored in cache
session_id = conn.session.get("session_id")
if not session_id:

issuer = conn.session.get('iss')
if not issuer or issuer != 'romm:auth':
return None

username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined]
username = conn.session.get('sub')
if not username:
return None

# Key exists therefore user is probably authenticated
user = db_user_handler.get_user_by_username(username)
if user is None:
self.clear_session(conn)
conn.session.clear()

raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User not found",
)

if not user.enabled:
self.clear_session(conn)
conn.session.clear()

raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
Expand Down Expand Up @@ -140,6 +132,10 @@ async def get_current_active_user_from_bearer_token(self, token: str):
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise OAuthCredentialsException

issuer = payload.get('iss')
if not issuer or issuer != 'romm:oauth':
return None

username = payload.get("sub")
if username is None:
Expand Down
63 changes: 19 additions & 44 deletions backend/handler/auth_handler/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,24 @@
from handler import auth_handler, oauth_handler, db_user_handler
from handler.auth_handler import WRITE_SCOPES
from handler.auth_handler.hybrid_auth import HybridAuthBackend
from handler.redis_handler import cache


def test_verify_password():
assert auth_handler.verify_password("password", auth_handler.get_password_hash("password"))
assert not auth_handler.verify_password("password", auth_handler.get_password_hash("notpassword"))


def test_authenticate_user(admin_user):
def test_authenticate_user(admin_user: User):
current_user = auth_handler.authenticate_user("test_admin", "test_admin_password")

assert current_user
assert current_user.id == admin_user.id


async def test_get_current_active_user_from_session(editor_user):
session_id = "test_session_id"
cache.set(f"romm:{session_id}", editor_user.username)

async def test_get_current_active_user_from_session(editor_user: User):
class MockConnection:
def __init__(self):
self.session = {"session_id": session_id}
self.session = {"iss": "romm:auth", "sub": editor_user.username}

conn = MockConnection()
current_user = await auth_handler.get_current_active_user_from_session(conn)
Expand All @@ -37,28 +33,10 @@ def __init__(self):
assert current_user.id == editor_user.id


async def test_get_current_active_user_from_session_bad_session_key(editor_user):
cache.set("romm:test_session_id", editor_user.username)

async def test_get_current_active_user_from_session_bad_username(editor_user: User):
class MockConnection:
def __init__(self):
self.session = {"session_id": "not_real_test_session_id"}
self.headers = {}

conn = MockConnection()
current_user = await auth_handler.get_current_active_user_from_session(conn)

assert not current_user


async def test_get_current_active_user_from_session_bad_username(editor_user):
session_id = "test_session_id"
cache.set(f"romm:{session_id}", "not_real_username")

class MockConnection:
def __init__(self):
self.session = {"session_id": session_id}
self.headers = {}
self.session = {"iss": "romm:auth", "sub": "not_real_username"}

conn = MockConnection()

Expand All @@ -69,13 +47,10 @@ def __init__(self):
assert e.detail == "User not found"


async def test_get_current_active_user_from_session_disabled_user(editor_user):
session_id = "test_session_id"
cache.set(f"romm:{session_id}", editor_user.username)

async def test_get_current_active_user_from_session_disabled_user(editor_user: User):
class MockConnection:
def __init__(self):
self.session = {"session_id": session_id}
self.session = {"iss": "romm:auth", "sub": editor_user.username}
self.headers = {}

conn = MockConnection()
Expand Down Expand Up @@ -105,13 +80,10 @@ def test_create_default_admin_user():
assert len(users) == 1


async def test_hybrid_auth_backend_session(editor_user):
session_id = "test_session_id"
cache.set(f"romm:{session_id}", editor_user.username)

async def test_hybrid_auth_backend_session(editor_user: User):
class MockConnection:
def __init__(self):
self.session = {"session_id": session_id}
self.session = {"iss": "romm:auth", "sub": editor_user.username}

backend = HybridAuthBackend()
conn = MockConnection()
Expand All @@ -123,7 +95,7 @@ def __init__(self):
assert creds.scopes == WRITE_SCOPES


async def test_hybrid_auth_backend_empty_session_and_headers(editor_user):
async def test_hybrid_auth_backend_empty_session_and_headers(editor_user: User):
class MockConnection:
def __init__(self):
self.session = {}
Expand All @@ -138,10 +110,11 @@ def __init__(self):
assert creds.scopes == []


async def test_hybrid_auth_backend_bearer_auth_header(editor_user):
async def test_hybrid_auth_backend_bearer_auth_header(editor_user: User):
access_token = oauth_handler.create_oauth_token(
data={
"sub": editor_user.username,
"iss": "romm:oauth",
"scopes": " ".join(editor_user.oauth_scopes),
"type": "access",
},
Expand All @@ -161,7 +134,7 @@ def __init__(self):
assert set(creds.scopes).issubset(editor_user.oauth_scopes)


async def test_hybrid_auth_backend_bearer_invalid_token(editor_user):
async def test_hybrid_auth_backend_bearer_invalid_token(editor_user: User):
class MockConnection:
def __init__(self):
self.session = {}
Expand All @@ -174,7 +147,7 @@ def __init__(self):
await backend.authenticate(conn)


async def test_hybrid_auth_backend_basic_auth_header(editor_user):
async def test_hybrid_auth_backend_basic_auth_header(editor_user: User):
token = b64encode("test_editor:test_editor_password".encode()).decode()

class MockConnection:
Expand All @@ -192,7 +165,7 @@ def __init__(self):
assert set(creds.scopes).issubset(editor_user.oauth_scopes)


async def test_hybrid_auth_backend_basic_auth_header_unencoded(editor_user):
async def test_hybrid_auth_backend_basic_auth_header_unencoded(editor_user: User):
class MockConnection:
def __init__(self):
self.session = {}
Expand Down Expand Up @@ -220,10 +193,11 @@ def __init__(self):
assert creds.scopes == []


async def test_hybrid_auth_backend_with_refresh_token(editor_user):
async def test_hybrid_auth_backend_with_refresh_token(editor_user: User):
refresh_token = oauth_handler.create_oauth_token(
data={
"sub": editor_user.username,
"iss": "romm:oauth",
"scopes": " ".join(editor_user.oauth_scopes),
"type": "refresh",
},
Expand All @@ -243,11 +217,12 @@ def __init__(self):
assert creds.scopes == []


async def test_hybrid_auth_backend_scope_subset(editor_user):
async def test_hybrid_auth_backend_scope_subset(editor_user: User):
scopes = editor_user.oauth_scopes[:3]
access_token = oauth_handler.create_oauth_token(
data={
"sub": editor_user.username,
"iss": "romm:oauth",
"scopes": " ".join(scopes),
"type": "access",
},
Expand Down
5 changes: 4 additions & 1 deletion backend/handler/auth_handler/tests/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ async def test_get_current_active_user_from_bearer_token(admin_user):
token = oauth_handler.create_oauth_token(
data={
"sub": admin_user.username,
"iss": "romm:oauth",
"scopes": " ".join(admin_user.oauth_scopes),
"type": "access",
},
Expand All @@ -24,6 +25,7 @@ async def test_get_current_active_user_from_bearer_token(admin_user):

assert user.id == admin_user.id
assert payload["sub"] == admin_user.username
assert payload["iss"] == "romm:oauth"
assert set(payload["scopes"].split()).issubset(admin_user.oauth_scopes)
assert payload["type"] == "access"

Expand All @@ -34,7 +36,7 @@ async def test_get_current_active_user_from_bearer_token_invalid_token():


async def test_get_current_active_user_from_bearer_token_invalid_user():
token = oauth_handler.create_oauth_token(data={"sub": "invalid_user"})
token = oauth_handler.create_oauth_token(data={"sub": "invalid_user", "iss": "romm:oauth"})

with pytest.raises(HTTPException):
await oauth_handler.get_current_active_user_from_bearer_token(token)
Expand All @@ -44,6 +46,7 @@ async def test_get_current_active_user_from_bearer_token_disabled_user(admin_use
token = oauth_handler.create_oauth_token(
data={
"sub": admin_user.username,
"iss": "romm:oauth",
"scopes": " ".join(admin_user.oauth_scopes),
"type": "access",
},
Expand Down
2 changes: 1 addition & 1 deletion backend/handler/gh_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def check_new_version(self) -> str:

try:
response = requests.get(
"https://github.com/gitapi/repos/zurdi15/romm/releases/latest", timeout=0.5
"https://github.com/gitapi/repos/zurdi15/romm/releases/latest", timeout=5
)
except ReadTimeout:
log.warning("Couldn't check last RomM version.")
Expand Down
Loading

0 comments on commit 3701c97

Please sign in to comment.