diff --git a/backend/endpoints/responses/platform.py b/backend/endpoints/responses/platform.py index 5164a343f..0e16d3512 100644 --- a/backend/endpoints/responses/platform.py +++ b/backend/endpoints/responses/platform.py @@ -8,12 +8,12 @@ class PlatformSchema(BaseModel): id: int slug: str fs_slug: str + name: str + rom_count: int igdb_id: Optional[int] = None sgdb_id: Optional[int] = None moby_id: Optional[int] = None - name: str logo_path: Optional[str] = "" - rom_count: int firmware: list[FirmwareSchema] = Field(default_factory=list) class Config: diff --git a/backend/endpoints/responses/rom.py b/backend/endpoints/responses/rom.py index 72308842c..85c5e64d0 100644 --- a/backend/endpoints/responses/rom.py +++ b/backend/endpoints/responses/rom.py @@ -7,7 +7,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse from handler.socket_handler import socket_handler -from handler.database import db_user_handler from handler.metadata.igdb_handler import IGDBMetadata from handler.metadata.moby_handler import MobyMetadata from pydantic import BaseModel, computed_field, Field @@ -35,15 +34,11 @@ class RomNoteSchema(BaseModel): last_edited_at: datetime raw_markdown: str is_public: bool + user__username: str class Config: from_attributes = True - @computed_field - @property - def user__username(self) -> str: - return db_user_handler.get_user(self.user_id).username - @classmethod def for_user(cls, db_rom: Rom, user_id: int) -> list["RomNoteSchema"]: return [ diff --git a/backend/handler/database/platforms_handler.py b/backend/handler/database/platforms_handler.py index ff613a6cb..d198b1011 100644 --- a/backend/handler/database/platforms_handler.py +++ b/backend/handler/database/platforms_handler.py @@ -1,5 +1,6 @@ +import functools from sqlalchemy import delete, or_, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, Query, selectinload from decorators.database import begin_session from models.platform import Platform @@ -8,19 +9,39 @@ from .base_handler import DBBaseHandler +def with_roms(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + session = kwargs.get("session") + if session is None: + raise ValueError("session is required") + + kwargs["query"] = select(Platform).options( + selectinload(Platform.roms).load_only(Rom.id) + ) + return func(*args, **kwargs) + + return wrapper + + class DBPlatformsHandler(DBBaseHandler): @begin_session + @with_roms def add_platform( - self, platform: Platform, session: Session = None + self, platform: Platform, query: Query = None, session: Session = None ) -> Platform | None: - return session.merge(platform) + platform = session.merge(platform) + session.flush() + + return session.scalar(query.filter_by(id=platform.id).limit(1)) @begin_session + @with_roms def get_platforms( - self, id: int = None, session: Session = None + self, id: int = None, query: Query = None, session: Session = None ) -> list[Platform] | Platform | None: return ( - session.scalar(select(Platform).filter_by(id=id).limit(1)) + session.scalar(query.filter_by(id=id).limit(1)) if id else ( session.scalars(select(Platform).order_by(Platform.name.asc())) @@ -30,10 +51,11 @@ def get_platforms( ) @begin_session + @with_roms def get_platform_by_fs_slug( - self, fs_slug: str, session: Session = None + self, fs_slug: str, query: Query = None, session: Session = None ) -> Platform | None: - return session.scalar(select(Platform).filter_by(fs_slug=fs_slug).limit(1)) + return session.scalar(query.filter_by(fs_slug=fs_slug).limit(1)) @begin_session def delete_platform(self, id: int, session: Session = None) -> int: diff --git a/backend/handler/database/roms_handler.py b/backend/handler/database/roms_handler.py index 3ad6e38fd..5e2c27400 100644 --- a/backend/handler/database/roms_handler.py +++ b/backend/handler/database/roms_handler.py @@ -1,11 +1,30 @@ +import functools from decorators.database import begin_session from models.rom import Rom, RomNote from sqlalchemy import and_, delete, func, select, update, or_, Select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, Query, selectinload from .base_handler import DBBaseHandler +def with_assets(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + session = kwargs.get("session") + if session is None: + raise ValueError("session is required") + + kwargs["query"] = select(Rom).options( + selectinload(Rom.saves), + selectinload(Rom.states), + selectinload(Rom.screenshots), + selectinload(Rom.notes), + ) + return func(*args, **kwargs) + + return wrapper + + class DBRomsHandler(DBBaseHandler): def _filter(self, data: Select[Rom], platform_id: int | None, search_term: str): if platform_id: @@ -33,10 +52,15 @@ def _order(self, data: Select[Rom], order_by: str, order_dir: str): return data.order_by(_column.asc()) @begin_session - def add_rom(self, rom: Rom, session: Session = None): - return session.merge(rom) + @with_assets + def add_rom(self, rom: Rom, query: Query = None, session: Session = None) -> Rom: + rom = session.merge(rom) + session.flush() + + return session.scalar(query.filter_by(id=rom.id).limit(1)) @begin_session + @with_assets def get_roms( self, id: int = None, @@ -44,10 +68,11 @@ def get_roms( search_term: str = "", order_by: str = "name", order_dir: str = "asc", + query: Query = None, session: Session = None, - ): + ) -> list[Rom] | Rom | None: return ( - session.get(Rom, id) + session.scalar(query.filter_by(id=id).limit(1)) if id else self._order( self._filter(select(Rom), platform_id, search_term), @@ -57,31 +82,38 @@ def get_roms( ) @begin_session + @with_assets def get_rom_by_filename( - self, platform_id: int, file_name: str, session: Session = None - ): - return session.scalars( - select(Rom).filter_by(platform_id=platform_id, file_name=file_name).limit(1) - ).first() + self, + platform_id: int, + file_name: str, + query: Query = None, + session: Session = None, + ) -> Rom | None: + return session.scalar( + query.filter_by(platform_id=platform_id, file_name=file_name).limit(1) + ) @begin_session + @with_assets def get_rom_by_filename_no_tags( - self, file_name_no_tags: str, session: Session = None - ): - return session.scalars( - select(Rom).filter_by(file_name_no_tags=file_name_no_tags).limit(1) - ).first() + self, file_name_no_tags: str, query: Query = None, session: Session = None + ) -> Rom | None: + return session.scalar( + query.filter_by(file_name_no_tags=file_name_no_tags).limit(1) + ) @begin_session + @with_assets def get_rom_by_filename_no_ext( - self, file_name_no_ext: str, session: Session = None - ): - return session.scalars( - select(Rom).filter_by(file_name_no_ext=file_name_no_ext).limit(1) - ).first() + self, file_name_no_ext: str, query: Query = None, session: Session = None + ) -> Rom | None: + return session.scalar( + query.filter_by(file_name_no_ext=file_name_no_ext).limit(1) + ) @begin_session - def update_rom(self, id: int, data: dict, session: Session = None): + def update_rom(self, id: int, data: dict, session: Session = None) -> Rom: return session.execute( update(Rom) .where(Rom.id == id) @@ -90,7 +122,7 @@ def update_rom(self, id: int, data: dict, session: Session = None): ) @begin_session - def delete_rom(self, id: int, session: Session = None): + def delete_rom(self, id: int, session: Session = None) -> Rom: return session.execute( delete(Rom) .where(Rom.id == id) @@ -98,7 +130,9 @@ def delete_rom(self, id: int, session: Session = None): ) @begin_session - def purge_roms(self, platform_id: int, roms: list[str], session: Session = None): + def purge_roms( + self, platform_id: int, roms: list[str], session: Session = None + ) -> int: return session.execute( delete(Rom) .where(and_(Rom.platform_id == platform_id, Rom.file_name.not_in(roms))) @@ -106,17 +140,21 @@ def purge_roms(self, platform_id: int, roms: list[str], session: Session = None) ) @begin_session - def get_rom_note(self, rom_id: int, user_id: int, session: Session = None): - return session.scalars( + def get_rom_note( + self, rom_id: int, user_id: int, session: Session = None + ) -> RomNote | None: + return session.scalar( select(RomNote).filter_by(rom_id=rom_id, user_id=user_id).limit(1) - ).first() + ) @begin_session - def add_rom_note(self, rom_id: int, user_id: int, session: Session = None): + def add_rom_note( + self, rom_id: int, user_id: int, session: Session = None + ) -> RomNote: return session.merge(RomNote(rom_id=rom_id, user_id=user_id)) @begin_session - def update_rom_note(self, id: int, data: dict, session: Session = None): + def update_rom_note(self, id: int, data: dict, session: Session = None) -> RomNote: return session.execute( update(RomNote) .where(RomNote.id == id) diff --git a/backend/handler/database/users_handler.py b/backend/handler/database/users_handler.py index 6a4644f1f..28d6acb8b 100644 --- a/backend/handler/database/users_handler.py +++ b/backend/handler/database/users_handler.py @@ -13,9 +13,7 @@ def add_user(self, user: User, session: Session = None): @begin_session def get_user_by_username(self, username: str, session: Session = None): - return session.scalars( - select(User).filter_by(username=username).limit(1) - ).first() + return session.scalar(select(User).filter_by(username=username).limit(1)) @begin_session def get_user(self, id: int, session: Session = None): @@ -30,6 +28,10 @@ def update_user(self, id: int, data: dict, session: Session = None): .execution_options(synchronize_session="evaluate") ) + @begin_session + def get_users(self, session: Session = None): + return session.scalars(select(User)).all() + @begin_session def delete_user(self, id: int, session: Session = None): return session.execute( @@ -38,10 +40,6 @@ def delete_user(self, id: int, session: Session = None): .execution_options(synchronize_session="evaluate") ) - @begin_session - def get_users(self, session: Session = None): - return session.scalars(select(User)).all() - @begin_session def get_admin_users(self, session: Session = None): return session.scalars(select(User).filter_by(role=Role.ADMIN)).all() diff --git a/backend/logger/logger.py b/backend/logger/logger.py index 72a4ec10a..53adae87d 100644 --- a/backend/logger/logger.py +++ b/backend/logger/logger.py @@ -7,10 +7,15 @@ log = logging.getLogger("romm") log.setLevel(logging.DEBUG) +# Set up sqlachemy logger +# sql_log = logging.getLogger("sqlalchemy.engine") +# sql_log.setLevel(logging.DEBUG) + # Define stdout handler stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(StdoutFormatter()) log.addHandler(stdout_handler) +# sql_log.addHandler(stdout_handler) # Hush passlib warnings logging.getLogger("passlib").setLevel(logging.ERROR) diff --git a/backend/models/assets.py b/backend/models/assets.py index 4a95b6c6b..a1596a4e1 100644 --- a/backend/models/assets.py +++ b/backend/models/assets.py @@ -52,8 +52,8 @@ class Save(RomAsset): emulator = Column(String(length=50), nullable=True) - rom = relationship("Rom", lazy="selectin", back_populates="saves") - user = relationship("User", lazy="selectin", back_populates="saves") + rom = relationship("Rom", lazy="joined", back_populates="saves") + user = relationship("User", lazy="joined", back_populates="saves") @cached_property def screenshot(self) -> Optional["Screenshot"]: @@ -73,8 +73,8 @@ class State(RomAsset): emulator = Column(String(length=50), nullable=True) - rom = relationship("Rom", lazy="selectin", back_populates="states") - user = relationship("User", lazy="selectin", back_populates="states") + rom = relationship("Rom", lazy="joined", back_populates="states") + user = relationship("User", lazy="joined", back_populates="states") @cached_property def screenshot(self) -> Optional["Screenshot"]: @@ -92,5 +92,5 @@ class Screenshot(RomAsset): __tablename__ = "screenshots" __table_args__ = {"extend_existing": True} - rom = relationship("Rom", lazy="selectin", back_populates="screenshots") - user = relationship("User", lazy="selectin", back_populates="screenshots") + rom = relationship("Rom", lazy="joined", back_populates="screenshots") + user = relationship("User", lazy="joined", back_populates="screenshots") diff --git a/backend/models/platform.py b/backend/models/platform.py index 60ca152c5..7f2ee4eb7 100644 --- a/backend/models/platform.py +++ b/backend/models/platform.py @@ -18,6 +18,7 @@ class Platform(BaseModel): name: str = Column(String(length=400)) logo_path: str = Column(String(length=1000), default="") + roms: Mapped[set[Rom]] = relationship("Rom", back_populates="platform") firmware: Mapped[set[Firmware]] = relationship( "Firmware", lazy="selectin", back_populates="platform" ) diff --git a/backend/models/rom.py b/backend/models/rom.py index a74441883..987fe94be 100644 --- a/backend/models/rom.py +++ b/backend/models/rom.py @@ -70,18 +70,13 @@ class Rom(BaseModel): saves: Mapped[list[Save]] = relationship( "Save", - lazy="selectin", back_populates="rom", ) - states: Mapped[list[State]] = relationship( - "State", lazy="selectin", back_populates="rom" - ) + states: Mapped[list[State]] = relationship("State", back_populates="rom") screenshots: Mapped[list[Screenshot]] = relationship( - "Screenshot", lazy="selectin", back_populates="rom" - ) - notes: Mapped[list["RomNote"]] = relationship( - "RomNote", lazy="selectin", back_populates="rom" + "Screenshot", back_populates="rom" ) + notes: Mapped[list["RomNote"]] = relationship("RomNote", back_populates="rom") @property def platform_slug(self) -> str: @@ -189,5 +184,9 @@ class RomNote(BaseModel): nullable=False, ) - rom = relationship("Rom", back_populates="notes") - user = relationship("User", back_populates="notes") + rom = relationship("Rom", lazy="joined", back_populates="notes") + user = relationship("User", lazy="joined", back_populates="notes") + + @property + def user__username(self) -> str: + return self.user.username diff --git a/backend/models/user.py b/backend/models/user.py index 05ff405cf..629a36fc8 100644 --- a/backend/models/user.py +++ b/backend/models/user.py @@ -32,18 +32,13 @@ class User(BaseModel, SimpleUser): saves: Mapped[list[Save]] = relationship( "Save", - lazy="selectin", back_populates="user", ) - states: Mapped[list[State]] = relationship( - "State", lazy="selectin", back_populates="user" - ) + states: Mapped[list[State]] = relationship("State", back_populates="user") screenshots: Mapped[list[Screenshot]] = relationship( - "Screenshot", lazy="selectin", back_populates="user" - ) - notes: Mapped[list[RomNote]] = relationship( - "RomNote", lazy="selectin", back_populates="user" + "Screenshot", back_populates="user" ) + notes: Mapped[list[RomNote]] = relationship("RomNote", back_populates="user") @property def oauth_scopes(self): diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 7f3e6bd9c..7457c7439 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -44,7 +44,7 @@ socket.on( socket.on("scan:scanning_rom", (rom: Rom) => { scanningStore.set(true); - if (romsStore.platform.name === rom.platform_name) { + if (romsStore.platformID === rom.platform_id) { romsStore.add([rom]); romsStore.setFiltered( isFiltered ? romsStore.filteredRoms : romsStore.allRoms, diff --git a/frontend/src/components/Gallery/AppBar/AdminMenu.vue b/frontend/src/components/Gallery/AppBar/AdminMenu.vue index c1bda8842..449f45d3c 100644 --- a/frontend/src/components/Gallery/AppBar/AdminMenu.vue +++ b/frontend/src/components/Gallery/AppBar/AdminMenu.vue @@ -7,8 +7,8 @@ import DeleteBtn from "@/components/Gallery/AppBar/DeleteBtn.vue";