Skip to content

Commit

Permalink
114 fix batch save (#119)
Browse files Browse the repository at this point in the history
* Updates validation routine when editing Collections and Batches
   * Validates GUID strings by looking them up in the database.
   * Adds those records to the DB session Starlette-admin uses for saving records. 
* Adds initial tests for CRUD on Batches, however skipping for now until we figure out how to auth users for tests.

Also,
* Removes rule enforcing single quotes to double quotes.

---------

Co-authored-by: Harpo Harbert <ryan_harbert@wgbh.org>
  • Loading branch information
afred and mrharpo authored Aug 22, 2023
1 parent 5e357c5 commit 9ff6f30
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 26 deletions.
59 changes: 42 additions & 17 deletions chowda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pydantic import BaseModel
from sqlalchemy.dialects.postgresql import insert

from starlette.requests import Request


def upsert(
model: BaseModel,
Expand Down Expand Up @@ -40,27 +42,50 @@ def chunks_sequential(lst, n):
yield lst[si : si + (d + 1 if i < r else d)]


def validate_media_files(data: Dict[str, Any]):
"""Validate that the media_files are valid GUIDs and exist in the database"""
from sqlmodel import Session, select
# def validate_media_files(view: ModelView, request: Request, data: Dict[str, Any]):
def validate_media_file_guids(request: Request, data: Dict[str, Any]):
"""
1) Validates MediaFile GUIDs by fetching the MediaFile objects from the database,
2) Replaces the GUID strings with the found objects in the `data` dict
3) Adds the found objects to request.state.session which is the db session used by
Starlette-admin when saving.
NOTE: Starlette-admin does not provide a clean way to trigger validation errors when
related objects cannot be found because it does not provide an out-of-box feature
for end users to enter foreign keys as strings in order to related them to other
objects. But that's exactly what we need to do here: enter GUIDs as strings to
relate to Batch and Collection objects.
"""
from sqlmodel import Session, select
from chowda.db import engine
from chowda.models import MediaFile
from starlette_admin.exceptions import FormValidationError

data['media_files'] = data['media_files'].split('\r\n')
data['media_files'] = [guid.strip() for guid in data['media_files'] if guid]
media_files = []
errors = []

with Session(engine) as db:
for guid in data['media_files']:
results = db.exec(select(MediaFile).where(MediaFile.guid == guid)).all()
if not results:
errors.append(guid)
else:
assert len(results) == 1, 'Multiple MediaFiles with same GUID'
media_files.append(results[0])
if errors:
from starlette_admin.exceptions import FormValidationError

raise FormValidationError({'media_files': errors})
# Get all MediaFiles objects for the GUIDs in data['media_files']
media_files = db.exec(
select(MediaFile).where(MediaFile.guid.in_(data['media_files']))
).all()

# Any value in data['media_files'] that does not have a corresponding MediaFile
# object is invalid, so add it to the errors
valid_guids = [media_file.guid for media_file in media_files]
invalid_guids = [
guid for guid in data['media_files'] if guid not in valid_guids
]

if len(invalid_guids):
raise FormValidationError({'media_files': invalid_guids})

# Replace GUID strings with MediaFile objects in `data` dict so they will get added
# the parent object.
data['media_files'] = media_files

# Add MediaFile objects to the DB session Starlette admin uses for persistence.
# This is a bit of a hack to play nice with starlette-admin, but without it, an
# error is thrown if starlette-admin tries to add a validated MediaFile object
# to a parent object when that MediaFile is already there.
for media_file in data['media_files']:
request.state.session.add(media_file)
30 changes: 26 additions & 4 deletions chowda/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.templating import Jinja2Templates
from starlette_admin import CustomView, IntegerField, TextAreaField, action
from starlette.datastructures import FormData
from starlette_admin import CustomView, action
from starlette_admin.fields import IntegerField, TextAreaField
from starlette_admin._types import RequestAction
from starlette_admin.contrib.sqlmodel import ModelView
from starlette_admin.exceptions import ActionFailed

from chowda.auth.utils import get_user
from chowda.db import engine
from chowda.models import Batch
from chowda.utils import validate_media_files
from chowda.utils import validate_media_file_guids


@dataclass
Expand All @@ -28,16 +30,26 @@ class MediaFilesGuidsField(TextAreaField):
form_template: str = 'forms/media_files.html'
display_template: str = 'displays/media_files.html'

async def parse_form_data(
self, request: Request, form_data: FormData, action: RequestAction
) -> Any:
"""Maps a string of GUID to a list"""
return form_data.get(self.id).split()

async def serialize_value(
self, request: Request, value: Any, action: RequestAction
) -> Any:
"""Maps a Collection's MediaFile objects to a list of GUIDs"""
return [media_file.guid for media_file in value]


@dataclass
class MediaFileCount(IntegerField):
"""A field that displays the number of MediaFiles in a collection or batch"""

exclude_from_create: bool = True
exclude_from_edit: bool = True

render_function_key: str = 'media_file_count'
display_template: str = 'displays/media_file_count.html'

Expand Down Expand Up @@ -89,9 +101,17 @@ class CollectionView(BaseModelView):
exclude_from_edit=True,
exclude_from_create=True,
),
'media_files', # default view
# 'media_files', # default view
MediaFilesGuidsField(
'media_files',
id='media_files',
label='GUIDs',
),
]

async def validate(self, request: Request, data: Dict[str, Any]):
validate_media_file_guids(request, data)


class BatchView(BaseModelView):
exclude_fields_from_create: ClassVar[list[Any]] = [Batch.id]
Expand Down Expand Up @@ -121,7 +141,7 @@ class BatchView(BaseModelView):
]

async def validate(self, request: Request, data: Dict[str, Any]):
validate_media_files(data)
validate_media_file_guids(request, data)

async def is_action_allowed(self, request: Request, name: str) -> bool:
if name == 'start_batch':
Expand Down Expand Up @@ -154,6 +174,8 @@ async def start_batch(self, request: Request, pks: List[Any]) -> str:


class MediaFileView(BaseModelView):
pk_attr = 'guid'

fields: ClassVar[list[Any]] = [
'guid',
'collections',
Expand Down
13 changes: 13 additions & 0 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ test = [
"pytest-vcr~=1.0",
"urllib3~=1.26",
"pytest-mock~=3.11",
"pytest-asyncio~=0.21.1",
"trio~=0.22",
]
locust = [
Expand All @@ -67,11 +68,9 @@ omit = ['tests/*']
[tool.pytest.ini_options]
testpaths = ['tests', 'docs']

[tool.ruff.flake8-quotes]
inline-quotes = 'single'

[tool.ruff]
extend-exclude = ['migrations']
ignore = ['Q000']
select = [
'B', # flake8-bugbear
'C4', # flake8-comprehensions
Expand Down
4 changes: 2 additions & 2 deletions templates/forms/media_files.html
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<div class="{% if error %}is-invalid{% endif %}">
<textarea class="form-control" id="{{ field.id }}" name="{{ field.id }}">
<textarea class="form-control" id="{{ field.id }}" name="{{ field.id }}" rows="20">
{%- if data -%}
{% for d in data -%}
{{- d }}&#10;
Expand All @@ -19,4 +19,4 @@ <h3>Invalid GUIDS</h3>
<h6>{{ e }}</h6>
{% endfor %}
</div>
{% endif %}
{% endif %}
57 changes: 57 additions & 0 deletions tests/test_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import AsyncClient
from pytest import fixture, mark
from pytest_asyncio import fixture as async_fixture
from sqlmodel import Session, select

from chowda.db import engine
from chowda.models import Batch


@fixture
def app():
from chowda.app import app

return app


@fixture
def client(app: FastAPI):
return TestClient(app)


@async_fixture
async def async_client(app: FastAPI):
async with AsyncClient(app=app, base_url='http://testserver') as c:
yield c


@fixture
def session():
return Session(engine)


@fixture
def unsaved_batch():
from tests.factories import BatchFactory

return BatchFactory.build()


@mark.anyio
@mark.skip("Redirecting due to user being unauthenticated in tests")
async def test_create_batch(
async_client: AsyncClient,
session: Session,
unsaved_batch: Batch,
):
response = await async_client.post(
'/admin/batch/create',
data=unsaved_batch.dict(),
follow_redirects=False,
)
assert response.status_code == 303
stmt = select(Batch).where(Batch.name == unsaved_batch.name)
batch = session.execute(stmt).scalar_one()
assert batch is not None

0 comments on commit 9ff6f30

Please sign in to comment.