Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepLabv3 + ADE20k benchmark #107

Merged
merged 54 commits into from
Jan 14, 2022
Merged

Conversation

Landanjs
Copy link
Contributor

@Landanjs Landanjs commented Nov 29, 2021

Pull request to add the DeepLabv3 model and the ADE20k dataset as a new benchmark for semantic segmentation tasks

Some results with 4 seeds for each experiment:

Experiment Final Best
Train for 127 epochs 44.888 +/- 0.304 44.959 +/- 0.307
Train for 64 epochs 45.051 +/ - 0.328 45.323 +/- 0.284
Use batch size 32 instead of 16 44.904 +/- 0.181 45.089 +/- 0.055
Decoupled SGD (WD = 5e-6 45.54 +/- 0.364 45.669 +/- 0.284

Here is a convergence run.

mmsegmentation reports 44.08 mIoU and 45.00 mIoU for 64 and 127 epochs, respectively (here). I think these numbers are the final results for a single run.

Before Merging

  • Test for mean IoU
  • Test for each segmentation transformation

@hanlint hanlint added the release label Dec 2, 2021
@Landanjs Landanjs marked this pull request as ready for review December 17, 2021 22:53
@ravi-mosaicml
Copy link
Contributor

Would you be able to use monkeypatch to mock out the actual download during the tests (and instead have it return uninitialized weights)?

@Landanjs
Copy link
Contributor Author

Landanjs commented Jan 4, 2022

The pretrained weight download is done within torchvision, so I thought the easiest way to avoid the long download was to manually change is_backbone_pretrained in test_hparams.py and test_model_registry.py when DeepLabV3 is used.

@hanlint
Copy link
Contributor

hanlint commented Jan 4, 2022

cc: @A-Jacobson or @coryMosaicML for a review here (vision team)

@Landanjs Landanjs linked an issue Jan 5, 2022 that may be closed by this pull request
5 tasks
composer/models/deeplabv3/deeplabv3.py Outdated Show resolved Hide resolved
composer/datasets/ade20k.py Show resolved Hide resolved
composer/models/loss.py Outdated Show resolved Hide resolved
composer/yamls/models/deeplabv3_ade20k.yaml Show resolved Hide resolved
examples/run_mosaic_trainer.py Outdated Show resolved Hide resolved
@Landanjs
Copy link
Contributor Author

Landanjs commented Jan 12, 2022

After some thought and discussion, I think it makes the most sense to make the initial benchmark to be as close as possible to other benchmarks. This means reverting back to the original batch size, no decoupled weight decay, and using a Polynomial LR schedule instead of cosine decay.

@ravi-mosaicml and/or @A-Jacobson could y'all skim through these recent changes today or tomorrow?

  • Added code to use initializers in deeplabv3.py

  • Added PolynomialLR and PolynomialLRHparams to schedulers.py and in the required scheduler tests

  • Added test_loss.py for mIoU tests (subsequently made a tests/models directory and added test_efficientnet.py)

  • Added test_segmentation_transforms.py for tests on segmentation transforms. I could not quickly come up with good ways to test these especially since some of them are random, but I'm not sure how rigorous these tests should be. Be gentle!

It would be amazing if I could get this in tomorrow, thank you!!

Copy link
Contributor

@ravi-mosaicml ravi-mosaicml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Two main things:

  1. I don't think many of the type ignores are required. I tried to comment how to remove them. If they are required, please add the pyright error message as a comment next to the type ignore. I realize they can get annoying; happy to help debug these!

  2. Can you move the imagenet normalization parameters out of the normalization_fn class and to the imagenet dataset file? The normalization function should be dataset-generic.

composer/models/deeplabv3/deeplabv3.py Outdated Show resolved Hide resolved
resnet.model_urls[backbone_arch] = "https://download.pytorch.org/models/resnet101-cd907fc2.pth"
else:
raise ValueError(f"backbone_arch must be one of ['resnet50', 'resnet101'] not {backbone_arch}")
backbone = resnet.__dict__[backbone_arch](pretrained=is_backbone_pretrained,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume this works but am a bit confused, as I don't see resnet50 or resnet101 in https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I took this line from pytorch. The link you posted was for regnet.py, resnet.py is here: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

composer/models/deeplabv3/deeplabv3.py Outdated Show resolved Hide resolved
composer/models/deeplabv3/deeplabv3.py Outdated Show resolved Hide resolved
self.val_ce = CrossEntropyLoss(ignore_index=-1)

def forward(self, batch: Batch):
x = batch[0] # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x = batch[0] # type: ignore
x = composer.utils.types.as_batch_pair(batch)[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed the input type to BatchPair, but it feels weird for forward to take a BatchPair since the label shouldn't be necessary to run the forward pass. This is discussion for another time...


if np.random.randint(2):
hue_factor = np.random.uniform(-self.hue, self.hue)
image = TF.adjust_hue(image, hue_factor) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


if contrast_mode == 0 and np.random.randint(2):
contrast_factor = np.random.uniform(1 - self.contrast, 1 + self.contrast)
image = TF.adjust_contrast(image, contrast_factor) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

image = self.image_transforms(image)

if self.split in ['train', 'val']:
return image, target # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

)
image_transforms = torch.nn.Sequential(
PhotometricDistoration(brightness=32. / 255, contrast=0.5, saturation=0.5, hue=18. / 255),
PadToSize(size=(self.final_size, self.final_size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment to where these magic values came from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

composer/datasets/ade20k.py Outdated Show resolved Hide resolved
@Landanjs
Copy link
Contributor Author

Landanjs commented Jan 14, 2022

Okay, I think addressed all of @ravi-mosaicml comments, but let me know if more is needed! Otherwise, I will merge!

Copy link
Contributor

@A-Jacobson A-Jacobson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with merging this as is. Though I would like to better formalize what constitutes a "baseline" as you've reached/surpassed mmseg's reported IoU with about 3 different recipes.

@Landanjs Landanjs merged commit 1c769ee into mosaicml:dev Jan 14, 2022
@Landanjs Landanjs deleted the landan/ade20k_deeplabv3 branch January 14, 2022 18:38
@Landanjs Landanjs restored the landan/ade20k_deeplabv3 branch January 19, 2022 20:05
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this pull request Feb 23, 2022
* First commit ade20k and deeplabv3

* Allow ImageNet pretrained backbone

* Allow background class to be ignored

* Remove cross entropy metric (temp)

* Add mmseg photometric augmentations

* Add option to sync bn

* Remove dropout and extra 3x3 conv

* Use new resnet50 weights

* Use 3x3 conv before classification

* Add dropout2d

* Remove dropout (again)

* Select pretrained model, ability to randomly initialize, pytorch's syncBN

* Fix LR schedule params

* Update with recent merge and add resnet101

* Refactor ade20k pt. 1

* Missed hflips

* Initial ignore_class refactor

* Remove initial resize for base size in random scale

* Remove cityscapes

* Polynomial LR schedule

* Add ade20k docstrings and some refactoring

* Another iteration on ade20k and partial deeplabv3 refactor

* Change permissions

* total -> train batch size

* Decoupled SGDW

* Cleanup ade20k code and docstrings

* Fix dataset test

* Move preprocessing and collate; add defaults

* Collate docstring, minor name changes, and ade20k synthetic dataset

* Add mosaicml copyright

* Fix formatting

* Format pt. 2

* Remove RANDOM_INT from synthetic datasets

* Format pt. 3

* Update yaml

* Monkeypatch model tests to skip pretrained weights

* mIoU -> MIoU

* PolyLR and no DWD

* Add initializers

* Only initialize head when using pretrained backbone

* Add PolynomialLR docstring and fix scheduler tests

* Add mIoU and seg transformation tests

* Reorder imports

* Reorder imports pt. 2...

* Address type ignores, move imagenet norm params, other comments

* Get tests to pass

* Formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cityscapes + Deeplabv3 benchmark
4 participants