From 453c8912f9a499f9d2f5d85dbd73f60c3a432eea Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Wed, 5 Jun 2024 21:05:06 +0300 Subject: [PATCH] Fix Dataset.get when there are transforms in the dataset (#45) * Fix dataset get when there are transforms * Extend test * Update changelog * Help linter --- CHANGELOG.md | 2 + datumaro/components/dataset.py | 5 ++- tests/test_dataset.py | 71 ++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 641b7d7fb4..10926f522f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -94,6 +94,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Incorrect computation of binary mask bbox (missed 1 pixel of the size) () +- `Dataset.get()` could ignore existing transforms in the dataset + () ### Security - TBD diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index d5309dadf1..d8eda39124 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -419,8 +419,9 @@ def _update_status(item_id, new_status: ItemStatus): source = self._source or DatasetItemStorageDatasetView( self._storage, categories=self._categories, media_type=media_type ) - transform = None + transform = None + old_ids = None if self._transforms: transform = _StackedTransform(source, self._transforms) if transform.is_local: @@ -588,7 +589,7 @@ def get(self, id, subset=None) -> Optional[DatasetItem]: item = self._storage.get(id, subset) if item is None and not self.is_cache_initialized(): - if self._source.get.__func__ == Extractor.get: + if self._source.get.__func__ == Extractor.get or self._transforms: # can be improved if IDataset is ABC self.init_cache() item = self._storage.get(id, subset) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 61137d5b7e..3f4f64ad2b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1267,6 +1267,77 @@ def __iter__(self): self.assertEqual(iter_called, 2) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_item_after_local_transforms(self): + iter_called = 0 + + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(ItemTransform): + def transform_item(self, item): + return self.wrap_item(item, id=int(item.id) + 1) + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertIsNone(dataset.get(1)) + self.assertIsNone(dataset.get(2)) + self.assertIsNotNone(dataset.get(3)) + self.assertIsNotNone(dataset.get(4)) + self.assertIsNotNone(dataset.get(5)) + self.assertIsNotNone(dataset.get(6)) + + self.assertEqual(iter_called, 1) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_item_after_nonlocal_transforms(self): + iter_called = 0 + + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(Transform): + def __iter__(self): + for item in self._extractor: + yield self.wrap_item(item, id=int(item.id) + 1) + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertIsNone(dataset.get(1)) + self.assertIsNone(dataset.get(2)) + self.assertIsNotNone(dataset.get(3)) + self.assertIsNotNone(dataset.get(4)) + self.assertIsNotNone(dataset.get(5)) + self.assertIsNotNone(dataset.get(6)) + + self.assertEqual(iter_called, 2) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_get_subsets_after_local_transforms(self): iter_called = 0