Skip to content

Commit

Permalink
Make LabeledDataSerializer.validated_data usable (#8144)
Browse files Browse the repository at this point in the history
Currently, the views that use `LabeledDataSerializer` for input do
something unconventional: they create the serializer with the input
data, call `is_valid()`, but then use the original data instead of the
`validated_data` member. I believe this is because `validated_data` in
this case is unusable because of the `source` attributes on some of the
fields in the nested serializers.

For example, in `LabeledImageSerializer`, the `attributes` field has a
source of `labeledimageattributeval_set`. This works well when
serializing `LabeledImage` objects, but when you're deserializing, this
creates a dict with a `labeledimageattributeval_set` key. Such objects
are incompatible with functions like `patch_job_data`, which expect the
`attributes` key instead.

In the current code, using `data` instead of `validated_data` seems to
work okay-ish. It _is_ a bit confusing, though, because it's
unconventional. For example, the `default` values set in serializer
fields are effectively useless, because they're only filled in
`validated_data`.

However, I'm currently working on a feature where a
`LabeledDataSerializer` is incorporated into another serializer, and
this problem means that I can't use `validated_data` for the parent
serializer either, and that means I can't implement custom
`to_internal_value` or `create` methods. So I'd much rather fix this.

While we could do it by making `patch_job_data` and others accept
`labeledimageattributeval_set`, this seems counterproductive. The name
`attributes` is much easier to read & understand. So instead, change the
models so that the attributes of an annotation object can be accessed
via `.attributes` and the shapes of a track via `.tracks`. That way, the
`source` attributes become unnecessary. This fixes the problem _and_
makes the code clearer at the same time.
  • Loading branch information
SpecLad committed Jul 11, 2024
1 parent ad1bcd5 commit 0317871
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,23 @@ def get_track_count():
self._db_obj.labeledtrack_set.exclude(source=SourceType.FILE)
.values(
"id",
"trackedshape__id",
"trackedshape__frame",
"trackedshape__type",
"trackedshape__outside",
"shape__id",
"shape__frame",
"shape__type",
"shape__outside",
)
.order_by("id", "trackedshape__frame")
.order_by("id", "shape__frame")
.iterator(chunk_size=2000)
)

db_tracks = merge_table_rows(
rows=db_tracks,
keys_for_merge={
"shapes": [
"trackedshape__id",
"trackedshape__frame",
"trackedshape__type",
"trackedshape__outside",
"shape__id",
"shape__frame",
"shape__type",
"shape__outside",
],
},
field_id="id",
Expand Down
120 changes: 59 additions & 61 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _sync_frames(self, tracks, parent_track):

if min_frame < parent_track.frame:
# parent track cannot have a frame greater than the frame of the child track
parent_tracked_shape = parent_track.trackedshape_set.first()
parent_tracked_shape = parent_track.shapes.first()
parent_track.frame = min_frame
parent_tracked_shape.frame = min_frame

Expand Down Expand Up @@ -501,25 +501,25 @@ def _init_tags_from_db(self):
'label_id',
'group',
'source',
'labeledimageattributeval__spec_id',
'labeledimageattributeval__value',
'labeledimageattributeval__id',
'attribute__spec_id',
'attribute__value',
'attribute__id',
).order_by('frame').iterator(chunk_size=2000)

db_tags = merge_table_rows(
rows=db_tags,
keys_for_merge={
"labeledimageattributeval_set": [
'labeledimageattributeval__spec_id',
'labeledimageattributeval__value',
'labeledimageattributeval__id',
"attributes": [
'attribute__spec_id',
'attribute__value',
'attribute__id',
],
},
field_id='id',
)

for db_tag in db_tags:
self._extend_attributes(db_tag.labeledimageattributeval_set,
self._extend_attributes(db_tag.attributes,
self.db_attributes[db_tag.label_id]["all"].values())

serializer = serializers.LabeledImageSerializerFromDB(db_tags, many=True)
Expand All @@ -541,18 +541,18 @@ def _init_shapes_from_db(self):
'rotation',
'points',
'parent',
'labeledshapeattributeval__spec_id',
'labeledshapeattributeval__value',
'labeledshapeattributeval__id',
'attribute__spec_id',
'attribute__value',
'attribute__id',
).order_by('frame').iterator(chunk_size=2000)

db_shapes = merge_table_rows(
rows=db_shapes,
keys_for_merge={
'labeledshapeattributeval_set': [
'labeledshapeattributeval__spec_id',
'labeledshapeattributeval__value',
'labeledshapeattributeval__id',
'attributes': [
'attribute__spec_id',
'attribute__value',
'attribute__id',
],
},
field_id='id',
Expand All @@ -561,7 +561,7 @@ def _init_shapes_from_db(self):
shapes = {}
elements = {}
for db_shape in db_shapes:
self._extend_attributes(db_shape.labeledshapeattributeval_set,
self._extend_attributes(db_shape.attributes,
self.db_attributes[db_shape.label_id]["all"].values())

if db_shape.parent is None:
Expand All @@ -588,42 +588,42 @@ def _init_tracks_from_db(self):
"group",
"source",
"parent",
"labeledtrackattributeval__spec_id",
"labeledtrackattributeval__value",
"labeledtrackattributeval__id",
"trackedshape__type",
"trackedshape__occluded",
"trackedshape__z_order",
"trackedshape__rotation",
"trackedshape__points",
"trackedshape__id",
"trackedshape__frame",
"trackedshape__outside",
"trackedshape__trackedshapeattributeval__spec_id",
"trackedshape__trackedshapeattributeval__value",
"trackedshape__trackedshapeattributeval__id",
).order_by('id', 'trackedshape__frame').iterator(chunk_size=2000)
"attribute__spec_id",
"attribute__value",
"attribute__id",
"shape__type",
"shape__occluded",
"shape__z_order",
"shape__rotation",
"shape__points",
"shape__id",
"shape__frame",
"shape__outside",
"shape__attribute__spec_id",
"shape__attribute__value",
"shape__attribute__id",
).order_by('id', 'shape__frame').iterator(chunk_size=2000)

db_tracks = merge_table_rows(
rows=db_tracks,
keys_for_merge={
"labeledtrackattributeval_set": [
"labeledtrackattributeval__spec_id",
"labeledtrackattributeval__value",
"labeledtrackattributeval__id",
"attributes": [
"attribute__spec_id",
"attribute__value",
"attribute__id",
],
"trackedshape_set":[
"trackedshape__type",
"trackedshape__occluded",
"trackedshape__z_order",
"trackedshape__points",
"trackedshape__rotation",
"trackedshape__id",
"trackedshape__frame",
"trackedshape__outside",
"trackedshape__trackedshapeattributeval__spec_id",
"trackedshape__trackedshapeattributeval__value",
"trackedshape__trackedshapeattributeval__id",
"shapes":[
"shape__type",
"shape__occluded",
"shape__z_order",
"shape__points",
"shape__rotation",
"shape__id",
"shape__frame",
"shape__outside",
"shape__attribute__spec_id",
"shape__attribute__value",
"shape__attribute__id",
],
},
field_id="id",
Expand All @@ -632,29 +632,27 @@ def _init_tracks_from_db(self):
tracks = {}
elements = {}
for db_track in db_tracks:
db_track["trackedshape_set"] = merge_table_rows(db_track["trackedshape_set"], {
'trackedshapeattributeval_set': [
'trackedshapeattributeval__value',
'trackedshapeattributeval__spec_id',
'trackedshapeattributeval__id',
db_track["shapes"] = merge_table_rows(db_track["shapes"], {
'attributes': [
'attribute__value',
'attribute__spec_id',
'attribute__id',
]
}, 'id')

# A result table can consist many equal rows for track/shape attributes
# We need filter unique attributes manually
db_track["labeledtrackattributeval_set"] = list(set(db_track["labeledtrackattributeval_set"]))
self._extend_attributes(db_track.labeledtrackattributeval_set,
db_track["attributes"] = list(set(db_track["attributes"]))
self._extend_attributes(db_track.attributes,
self.db_attributes[db_track.label_id]["immutable"].values())

default_attribute_values = self.db_attributes[db_track.label_id]["mutable"].values()
for db_shape in db_track["trackedshape_set"]:
db_shape["trackedshapeattributeval_set"] = list(
set(db_shape["trackedshapeattributeval_set"])
)
for db_shape in db_track["shapes"]:
db_shape["attributes"] = list(set(db_shape["attributes"]))
# in case of trackedshapes need to interpolate attriute values and extend it
# by previous shape attribute values (not default values)
self._extend_attributes(db_shape["trackedshapeattributeval_set"], default_attribute_values)
default_attribute_values = db_shape["trackedshapeattributeval_set"]
self._extend_attributes(db_shape["attributes"], default_attribute_values)
default_attribute_values = db_shape["attributes"]

if db_track.parent is None:
db_track.elements = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Generated by Django 4.2.13 on 2024-07-09 11:08

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
("engine", "0078_alter_cloudstorage_credentials"),
]

operations = [
migrations.AlterField(
model_name="labeledimageattributeval",
name="image",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="attributes",
related_query_name="attribute",
to="engine.labeledimage",
),
),
migrations.AlterField(
model_name="labeledshapeattributeval",
name="shape",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="attributes",
related_query_name="attribute",
to="engine.labeledshape",
),
),
migrations.AlterField(
model_name="labeledtrackattributeval",
name="track",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="attributes",
related_query_name="attribute",
to="engine.labeledtrack",
),
),
migrations.AlterField(
model_name="trackedshapeattributeval",
name="shape",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="attributes",
related_query_name="attribute",
to="engine.trackedshape",
),
),
]
15 changes: 10 additions & 5 deletions cvat/apps/engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,27 +935,32 @@ class LabeledImage(Annotation):
pass

class LabeledImageAttributeVal(AttributeVal):
image = models.ForeignKey(LabeledImage, on_delete=models.CASCADE)
image = models.ForeignKey(LabeledImage, on_delete=models.CASCADE,
related_name='attributes', related_query_name='attribute')

class LabeledShape(Annotation, Shape):
parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, related_name='elements')

class LabeledShapeAttributeVal(AttributeVal):
shape = models.ForeignKey(LabeledShape, on_delete=models.CASCADE)
shape = models.ForeignKey(LabeledShape, on_delete=models.CASCADE,
related_name='attributes', related_query_name='attribute')

class LabeledTrack(Annotation):
parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, related_name='elements')

class LabeledTrackAttributeVal(AttributeVal):
track = models.ForeignKey(LabeledTrack, on_delete=models.CASCADE)
track = models.ForeignKey(LabeledTrack, on_delete=models.CASCADE,
related_name='attributes', related_query_name='attribute')

class TrackedShape(Shape):
id = models.BigAutoField(primary_key=True)
track = models.ForeignKey(LabeledTrack, on_delete=models.CASCADE)
track = models.ForeignKey(LabeledTrack, on_delete=models.CASCADE,
related_name='shapes', related_query_name='shape')
frame = models.PositiveIntegerField()

class TrackedShapeAttributeVal(AttributeVal):
shape = models.ForeignKey(TrackedShape, on_delete=models.CASCADE)
shape = models.ForeignKey(TrackedShape, on_delete=models.CASCADE,
related_name='attributes', related_query_name='attribute')

class Profile(models.Model):
user = models.OneToOneField(User, on_delete=models.CASCADE)
Expand Down
28 changes: 11 additions & 17 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,8 +1504,7 @@ class AnnotationSerializer(serializers.Serializer):
source = serializers.CharField(default='manual')

class LabeledImageSerializer(AnnotationSerializer):
attributes = AttributeValSerializer(many=True,
source="labeledimageattributeval_set", default=[])
attributes = AttributeValSerializer(many=True, default=[])

class OptimizedFloatListField(serializers.ListField):
'''Default ListField is extremely slow when try to process long lists of points'''
Expand Down Expand Up @@ -1541,8 +1540,7 @@ class ShapeSerializer(serializers.Serializer):
)

class SubLabeledShapeSerializer(ShapeSerializer, AnnotationSerializer):
attributes = AttributeValSerializer(many=True,
source="labeledshapeattributeval_set", default=[])
attributes = AttributeValSerializer(many=True, default=[])

class LabeledShapeSerializer(SubLabeledShapeSerializer):
elements = SubLabeledShapeSerializer(many=True, required=False)
Expand All @@ -1562,7 +1560,7 @@ class LabeledImageSerializerFromDB(serializers.BaseSerializer):
def to_representation(self, instance):
def convert_tag(tag):
result = _convert_annotation(tag, ['id', 'label_id', 'frame', 'group', 'source'])
result['attributes'] = _convert_attributes(tag['labeledimageattributeval_set'])
result['attributes'] = _convert_attributes(tag['attributes'])
return result

return convert_tag(instance)
Expand All @@ -1576,7 +1574,7 @@ def convert_shape(shape):
'id', 'label_id', 'type', 'frame', 'group', 'source',
'occluded', 'outside', 'z_order', 'rotation', 'points',
])
result['attributes'] = _convert_attributes(shape['labeledshapeattributeval_set'])
result['attributes'] = _convert_attributes(shape['attributes'])
if shape.get('elements', None) is not None and shape['parent'] is None:
result['elements'] = [convert_shape(element) for element in shape['elements']]
return result
Expand All @@ -1590,14 +1588,13 @@ def to_representation(self, instance):
def convert_track(track):
shape_keys = [
'id', 'type', 'frame', 'occluded', 'outside', 'z_order',
'rotation', 'points', 'trackedshapeattributeval_set',
'rotation', 'points', 'attributes',
]
result = _convert_annotation(track, ['id', 'label_id', 'frame', 'group', 'source'])
result['shapes'] = [_convert_annotation(shape, shape_keys) for shape in track['trackedshape_set']]
result['attributes'] = _convert_attributes(track['labeledtrackattributeval_set'])
result['shapes'] = [_convert_annotation(shape, shape_keys) for shape in track['shapes']]
result['attributes'] = _convert_attributes(track['attributes'])
for shape in result['shapes']:
shape['attributes'] = _convert_attributes(shape['trackedshapeattributeval_set'])
shape.pop('trackedshapeattributeval_set', None)
shape['attributes'] = _convert_attributes(shape['attributes'])
if track.get('elements', None) is not None and track['parent'] is None:
result['elements'] = [convert_track(element) for element in track['elements']]
return result
Expand All @@ -1607,14 +1604,11 @@ def convert_track(track):
class TrackedShapeSerializer(ShapeSerializer):
id = serializers.IntegerField(default=None, allow_null=True)
frame = serializers.IntegerField(min_value=0)
attributes = AttributeValSerializer(many=True,
source="trackedshapeattributeval_set", default=[])
attributes = AttributeValSerializer(many=True, default=[])

class SubLabeledTrackSerializer(AnnotationSerializer):
shapes = TrackedShapeSerializer(many=True, allow_empty=True,
source="trackedshape_set")
attributes = AttributeValSerializer(many=True,
source="labeledtrackattributeval_set", default=[])
shapes = TrackedShapeSerializer(many=True, allow_empty=True)
attributes = AttributeValSerializer(many=True, default=[])

class LabeledTrackSerializer(SubLabeledTrackSerializer):
elements = SubLabeledTrackSerializer(many=True, required=False)
Expand Down
Loading

0 comments on commit 0317871

Please sign in to comment.