Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utilities to plot datasets to Weights & Biases + Add callback to log validation predictions to Weights & Biases #1167

Merged

Conversation

soumik12345
Copy link
Contributor

Changes made in this PR

Fixes #1136

👇 Here's a sample training code that works with the changes made in this PR

👇 Here's a Weights & Biases run that demonstrates the changes made in this PR:
https://wandb.ai/geekyrakshit/yolo-nas-integration-2/runs/3lcpxzti

Plot datasets to Weights & Biases

This PR adds a function plot_detection_dataset_on_wandb() which can be used to plot and visualize a super_gradients.training.datasets.detection_datasets.DetectionDataset to Weights & Biases as a wandb.Table which can be used for interactive exploratory analysis of respective datasets.

Here's a sample code:

train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size':16,
        'num_workers':2
    }
)
plot_detection_dataset_on_wandb(train_data.dataset, max_examples=20, dataset_name="Train-Dataset")

This would log the dataset to be visualized as a Table in a Weights & Biases dashboard.

Here's a video of how to interact with the table UI on Weights & Biases 👇

Screen.Recording.2023-06-13.at.8.45.49.PM.mov

Log validation prediction as a table on Weights & Biases

This PR also includes a callback WandBDetectionValidationPredictionLoggerCallback that logs object detection predictions to a Weights & Biases Table with interactive bounding-box overlays during training on an epoch-wise basis.

Here's a sample code:

train_params.update({
    "phase_callbacks": [
        WandBDetectionValidationPredictionLoggerCallback(class_names=labels),
    ]
})
trainer.train(model=net, training_params=train_params, train_loader=train_data, valid_loader=val_data)

This would log the validation predictions to be visualized as a Table in a Weights & Biases dashboard.

Here's a video of how to interact with the table UI on Weights & Biases 👇

Screen.Recording.2023-06-13.at.8.48.22.PM.mov

@soumik12345
Copy link
Contributor Author

Hi @ofrimasad
Would love your feedback on this PR.

@kldarek
Copy link

kldarek commented Jun 19, 2023

Hi @ofrimasad & team, we'd really appreciate feedback on the PR, thanks so much! Darek from W&B

@soumik12345
Copy link
Contributor Author

Hi @BloodAxe @ofrimasad @shaydeci @Louis-Dupont
A gentle request for reviewing this PR 🙂

@BloodAxe
Copy link
Collaborator

Thanks for your PR, that is really nice addition. Pardon for delays with reviewing it as we were pretty busy recently preparing a new release.
We are here to help to make the process of accepting this PR as smooth as possible.
So here is a check-list of things that I believe are missing/important to address, but which would allow other users to start using this callback easily:

  • A callback class should either use predict API (All of our detection models support this) or get post_prediction_callback as an argument externally or you probably can do a unwrap_model(context.net).get_post_prediction_callback() call to get it inside the callback method. I tend to lean towards the use of post_prediction_callback` option for a number of reasons:
  1. since predict method may not be up to the task, because the context.inputs is already normalized tensors. @Louis-Dupont correct me if I'm wrong - what would happen if we pass torch tensor to predict()? Would we double-apply normalization here?
  2. predict call may be doing layer fusion to optimize the inference speed, however it may turn model non-trainable, which is not what we want during training. So fuse_model=False must be passed to predict()
  • It looks like the callback is logging images on each batch in validation and accumulating generated images in RAM. I can easily see this is causing OOM errors for datasets like COCO. Is it intentional? If it mean to be used only for validate_from_recipe regime, please clarify this in the docstring description with the intended use case scenario of the callback.

@soumik12345
Copy link
Contributor Author

Hi @BloodAxe
Made the requested changes and apologies for the delay 😅

Also a couple of questions...

  1. Is it possible to visualize the dataset before applying augmentations?

  2. In the validation table, when the data is being logged at the end of the epochs, the bottom half of the images seem to have some kind of padding. Do you recommend some post-processing in order to avoid this? (such as slicing the bottom half of the image before plotting)

image

@BloodAxe
Copy link
Collaborator

BloodAxe commented Aug 7, 2023

Is it possible to visualize the dataset before applying augmentations?

No

In the validation table, when the data is being logged at the end of the epochs, the bottom half of the images seem to have some kind of padding. Do you recommend some post-processing in order to avoid this? (such as slicing the bottom half of the image before plotting)

This is due to padding when preparing samples in dataset. If your input images has aspect ratio 16:9 (example) but the dataset transforms output images in 640x64 (aspect 1:1) then you would get this issue. The right solution is to use the resolution during training that matches (more or less) the aspect ratio of the images in dataset. So I'm not sure whether there is any action is required.

Copy link
Collaborator

@BloodAxe BloodAxe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@BloodAxe BloodAxe merged commit 2d3004a into Deci-AI:master Aug 8, 2023
6 checks passed
@soumik12345 soumik12345 deleted the soumik12345/wandb-validation-logging branch August 11, 2023 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add visualization of predicted bounding boxes on Weights & Biases
3 participants