From 9ff6f3092cd997519213f9eabf5e30a47a085ebf Mon Sep 17 00:00:00 2001 From: Andrew Myers Date: Tue, 22 Aug 2023 18:23:03 -0400 Subject: [PATCH] 114 fix batch save (#119) * Updates validation routine when editing Collections and Batches * Validates GUID strings by looking them up in the database. * Adds those records to the DB session Starlette-admin uses for saving records. * Adds initial tests for CRUD on Batches, however skipping for now until we figure out how to auth users for tests. Also, * Removes rule enforcing single quotes to double quotes. --------- Co-authored-by: Harpo Harbert --- chowda/utils.py | 59 +++++++++++++++++++++++--------- chowda/views.py | 30 +++++++++++++--- pdm.lock | 13 +++++++ pyproject.toml | 5 ++- templates/forms/media_files.html | 4 +-- tests/test_batches.py | 57 ++++++++++++++++++++++++++++++ 6 files changed, 142 insertions(+), 26 deletions(-) create mode 100644 tests/test_batches.py diff --git a/chowda/utils.py b/chowda/utils.py index be84a13c..e5f6f34c 100644 --- a/chowda/utils.py +++ b/chowda/utils.py @@ -3,6 +3,8 @@ from pydantic import BaseModel from sqlalchemy.dialects.postgresql import insert +from starlette.requests import Request + def upsert( model: BaseModel, @@ -40,27 +42,50 @@ def chunks_sequential(lst, n): yield lst[si : si + (d + 1 if i < r else d)] -def validate_media_files(data: Dict[str, Any]): - """Validate that the media_files are valid GUIDs and exist in the database""" - from sqlmodel import Session, select +# def validate_media_files(view: ModelView, request: Request, data: Dict[str, Any]): +def validate_media_file_guids(request: Request, data: Dict[str, Any]): + """ + 1) Validates MediaFile GUIDs by fetching the MediaFile objects from the database, + 2) Replaces the GUID strings with the found objects in the `data` dict + 3) Adds the found objects to request.state.session which is the db session used by + Starlette-admin when saving. + NOTE: Starlette-admin does not provide a clean way to trigger validation errors when + related objects cannot be found because it does not provide an out-of-box feature + for end users to enter foreign keys as strings in order to related them to other + objects. But that's exactly what we need to do here: enter GUIDs as strings to + relate to Batch and Collection objects. + """ + from sqlmodel import Session, select from chowda.db import engine from chowda.models import MediaFile + from starlette_admin.exceptions import FormValidationError - data['media_files'] = data['media_files'].split('\r\n') - data['media_files'] = [guid.strip() for guid in data['media_files'] if guid] media_files = [] - errors = [] + with Session(engine) as db: - for guid in data['media_files']: - results = db.exec(select(MediaFile).where(MediaFile.guid == guid)).all() - if not results: - errors.append(guid) - else: - assert len(results) == 1, 'Multiple MediaFiles with same GUID' - media_files.append(results[0]) - if errors: - from starlette_admin.exceptions import FormValidationError - - raise FormValidationError({'media_files': errors}) + # Get all MediaFiles objects for the GUIDs in data['media_files'] + media_files = db.exec( + select(MediaFile).where(MediaFile.guid.in_(data['media_files'])) + ).all() + + # Any value in data['media_files'] that does not have a corresponding MediaFile + # object is invalid, so add it to the errors + valid_guids = [media_file.guid for media_file in media_files] + invalid_guids = [ + guid for guid in data['media_files'] if guid not in valid_guids + ] + + if len(invalid_guids): + raise FormValidationError({'media_files': invalid_guids}) + + # Replace GUID strings with MediaFile objects in `data` dict so they will get added + # the parent object. data['media_files'] = media_files + + # Add MediaFile objects to the DB session Starlette admin uses for persistence. + # This is a bit of a hack to play nice with starlette-admin, but without it, an + # error is thrown if starlette-admin tries to add a validated MediaFile object + # to a parent object when that MediaFile is already there. + for media_file in data['media_files']: + request.state.session.add(media_file) diff --git a/chowda/views.py b/chowda/views.py index 02972ba6..5a7ccecc 100644 --- a/chowda/views.py +++ b/chowda/views.py @@ -9,7 +9,9 @@ from starlette.requests import Request from starlette.responses import Response from starlette.templating import Jinja2Templates -from starlette_admin import CustomView, IntegerField, TextAreaField, action +from starlette.datastructures import FormData +from starlette_admin import CustomView, action +from starlette_admin.fields import IntegerField, TextAreaField from starlette_admin._types import RequestAction from starlette_admin.contrib.sqlmodel import ModelView from starlette_admin.exceptions import ActionFailed @@ -17,7 +19,7 @@ from chowda.auth.utils import get_user from chowda.db import engine from chowda.models import Batch -from chowda.utils import validate_media_files +from chowda.utils import validate_media_file_guids @dataclass @@ -28,9 +30,16 @@ class MediaFilesGuidsField(TextAreaField): form_template: str = 'forms/media_files.html' display_template: str = 'displays/media_files.html' + async def parse_form_data( + self, request: Request, form_data: FormData, action: RequestAction + ) -> Any: + """Maps a string of GUID to a list""" + return form_data.get(self.id).split() + async def serialize_value( self, request: Request, value: Any, action: RequestAction ) -> Any: + """Maps a Collection's MediaFile objects to a list of GUIDs""" return [media_file.guid for media_file in value] @@ -38,6 +47,9 @@ async def serialize_value( class MediaFileCount(IntegerField): """A field that displays the number of MediaFiles in a collection or batch""" + exclude_from_create: bool = True + exclude_from_edit: bool = True + render_function_key: str = 'media_file_count' display_template: str = 'displays/media_file_count.html' @@ -89,9 +101,17 @@ class CollectionView(BaseModelView): exclude_from_edit=True, exclude_from_create=True, ), - 'media_files', # default view + # 'media_files', # default view + MediaFilesGuidsField( + 'media_files', + id='media_files', + label='GUIDs', + ), ] + async def validate(self, request: Request, data: Dict[str, Any]): + validate_media_file_guids(request, data) + class BatchView(BaseModelView): exclude_fields_from_create: ClassVar[list[Any]] = [Batch.id] @@ -121,7 +141,7 @@ class BatchView(BaseModelView): ] async def validate(self, request: Request, data: Dict[str, Any]): - validate_media_files(data) + validate_media_file_guids(request, data) async def is_action_allowed(self, request: Request, name: str) -> bool: if name == 'start_batch': @@ -154,6 +174,8 @@ async def start_batch(self, request: Request, pks: List[Any]) -> str: class MediaFileView(BaseModelView): + pk_attr = 'guid' + fields: ClassVar[list[Any]] = [ 'guid', 'collections', diff --git a/pdm.lock b/pdm.lock index bb2e11de..1b25cdf2 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2478,6 +2478,19 @@ files = [ {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, ] +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +requires_python = ">=3.7" +summary = "Pytest support for asyncio" +dependencies = [ + "pytest>=7.0.0", +] +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + [[package]] name = "pytest-cov" version = "4.1.0" diff --git a/pyproject.toml b/pyproject.toml index d2ddfbe4..3a3a97c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ test = [ "pytest-vcr~=1.0", "urllib3~=1.26", "pytest-mock~=3.11", + "pytest-asyncio~=0.21.1", "trio~=0.22", ] locust = [ @@ -67,11 +68,9 @@ omit = ['tests/*'] [tool.pytest.ini_options] testpaths = ['tests', 'docs'] -[tool.ruff.flake8-quotes] -inline-quotes = 'single' - [tool.ruff] extend-exclude = ['migrations'] +ignore = ['Q000'] select = [ 'B', # flake8-bugbear 'C4', # flake8-comprehensions diff --git a/templates/forms/media_files.html b/templates/forms/media_files.html index 8273d998..93723464 100644 --- a/templates/forms/media_files.html +++ b/templates/forms/media_files.html @@ -1,5 +1,5 @@
-