Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Use dataclasses for model in- and outputs #2098

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

djdameln
Copy link
Contributor

📝 Description

  • Initial implementation of dataclasses for model inputs and outputs, meant to collect feedback (DO NOT MERGE).
  • Adds ImageBatch and VideoBatch dataclasses and base classes.
  • collate_fn had to be updated to allow collating the dataclasses as batches.
  • Backward compatibility is achieved by defining the __getitem__ and __setitem__ methods in the BatchItem base class and by adding prorties for renamed attributes (e.g. mask -> gt_mask).

✨ Changes

Select what type of change your PR is:

  • 🐞 Bug fix (non-breaking change which fixes an issue)
  • 🔨 Refactor (non-breaking change which refactors the code base)
  • 🚀 New feature (non-breaking change which adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📚 Documentation update
  • 🔒 Security update

✅ Checklist

Before you submit your pull request, please make sure you have completed the following steps:

  • 📋 I have summarized my changes in the CHANGELOG and followed the guidelines for my type of change (skip for minor changes, documentation updates, and test enhancements).
  • 📚 I have made the necessary updates to the documentation (if applicable).
  • 🧪 I have written tests that support my changes and prove that my fix is effective or my feature works (if applicable).

For more information about code review checklists, see the Code Review Checklist.

@samet-akcay samet-akcay marked this pull request as draft May 31, 2024 12:09
Copy link
Collaborator

@ashwinvaidya17 ashwinvaidya17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave visualizer refactor to a different PR. It might require a lot more changes. We can just add a TODO to the comments in the inferencers.
The rest looks good to me. Once we remove the commented code we can push this as is.

@@ -83,7 +84,18 @@ def infer(args: Namespace) -> None:
for filename in filenames:
image = read_image(filename, as_tensor=True)
predictions = inferencer.predict(image=image)
output = visualizer.visualize_image(predictions)

# this is temporary until we update the visualizer to take the dataclass directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a TODO tag would be better

@ashwinvaidya17 ashwinvaidya17 self-requested a review July 31, 2024 11:48
Copy link
Collaborator

@ashwinvaidya17 ashwinvaidya17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had quick pass over this but I need to revisit it as it is a big PR. I still need to look into Batch vs DatasetItems. Meanwhile I have some comments. Mainly, with this new design I am wondering if we even need InferenceModel.

@@ -82,7 +84,8 @@ def to_torch(
... )
"""
transform = transform or self.transform or self.configure_transforms()
inference_model = InferenceModel(model=self.model, transform=transform)
post_processor = post_processor or self.post_processor
inference_model = InferenceModel(model=self.model, transform=transform, post_processor=post_processor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If post_processor is going to be part of the lightning module, then do we need InferenceModel? Currently, AnomalyModule has a forward method that just calls self.model(batch). I don't think we use this anywhere. And, InferenceModel has this as the forward method

def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """Transform the input batch and pass it through the model."""
        batch = self.transform(batch)
        predictions = self.model(batch)
        return self.post_processor(predictions)

I feel this can just move here

@@ -178,16 +181,18 @@ def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
if isinstance(output, dict):
for key, value in output.items():
output[key] = self._outputs_to_cpu(value)
elif isinstance(output, Batch):
output = output.__class__(**self._outputs_to_cpu(asdict(output)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor design comment but can we move to_cpu to Batch class which returns a copy of itself? This way we will be able to just call output = output.to_cpu().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is also deprecated with the new design, so you can ignore the changes in this file. So far I hadn't considered device handling in the new design, but adding a to_cpu() method could be a good idea!

@@ -27,98 +138,15 @@ def on_validation_batch_end(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this callback as LightningModule inherits from ModelHooks which have on_validation_epoch_end and on_validation_batch_end?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is deprecated with the new design. I left it here for legacy purposes until we decide how to handle backward compatibility, but you can ignore it for now.



class OneClassPostProcessor(PostProcessor):
"""Default post-processor for one-class anomaly detection."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this actually does normalization and thresholding. And, the actual post-processing is done inside the Batch dataclass. It makes sense to name this post-processing, but it is now doing a different task compared to our original post-processing callback. I do prefer this class though. I am in favour of merging Normalization, post-processing, thresholding, and metrics computation in a single class as they all depend on each other and in specific order. This way, all the steps are visible in a single place. We will need to figure out how to make this configurable, and handle different thresholding/normalization approaches. If I remember correctly, the current thresholding design was adopted to accomodate CDF normalization which created another inference model at the end of training to collect stats on validation set. We might have to handle such approaches in the future.

Copy link
Contributor Author

@djdameln djdameln Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that normalization and thresholding are both post-procesing operations, since they are applied to the raw model predictions in order to enhance their interpretability.

In the previous design we had separate callbacks for normalization, thresholding and post-processing, which was confusing (normalization and thresholding can also be seen as post-processing steps) and prone to bugs (the different callbacks were dependent on each other's outputs).

The new design and terminology simplifies this. All post-processing operations are now collected in a single class, so we don't need to pass some variables around. For example, the image and pixel threshold are now directly accessible during both thresholding and normalization.

Some of the simple operations that were previously part of the post processing callback are now handled by the __post_init__ of the dataclass. I feel it is warranted to do this outside of the post processor, because it's nothing more than just computing some additional fields based on the available data (e.g. image score from anomaly map by taking the max pixel value).

I am in favour of merging Normalization, post-processing, thresholding, and metrics computation in a single class as they all depend on each other and in specific order.

I personally feel that metric computation should still have its own separate callback. The metric computation is independent of any post-processing operations, as the computation steps are the same with or without post-processing. Including it in the post-processor would lead to code duplication and violate the single responsibility principle.

We will need to figure out how to make this configurable, and handle different thresholding/normalization approaches.

The most straightforward way to configure this is to have the user define the post-processor class path and init args as part of the engine arguments. The engine can then set the post-processor attribute of the model or use the model's default one if nothing is specified by the user. This would be achievable with minor changes to AnomalyModule and Engine.

In general, I think we should keep our default post-processor class simple. In practice we will use min-max normalization and adaptive thresholding in 99% of the use-cases, so maybe we should not add unnecessary complexity to this class by adding different thresholding/normalization approaches. When a user really needs some alternative post-processing steps, they can easily implement a custom post-processor and enable it through the engine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants