Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New LetterBox(size) CenterCrop(size), ToTensor() transforms (#9213) #9213

Merged
merged 13 commits into from
Aug 30, 2022

Conversation

glenn-jocher
Copy link
Member

@glenn-jocher glenn-jocher commented Aug 29, 2022

YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([T.ToTensor(), LetterBox(size)])

Signed-off-by: Glenn Jocher glenn.jocher@ultralytics.com

@AyushExel

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Enhancements to image preprocessing in YOLOv5 with new transformation classes and modified dataloader behavior.

📊 Key Changes

  • 📝 Introduced LetterBox, CenterCrop, and ToTensor custom preprocessing classes in augmentations.py.
  • 👁️‍🗨️ Modified the behavior of default transformations to use new custom classes instead of torchvision transforms.
  • 🔄 Updated dataloaders.py to incorporate the new transformation classes and streamlined image color conversion when transforms are applied.
  • 🗂️ Cache improvements for storing and retrieving processed images, reducing I/O operations and potentially speeding up training.

🎯 Purpose & Impact

  • 💡 Purpose: To provide YOLOv5 with more flexible image preprocessing options, specifically tailored to the YOLO architecture, and improve performance through better caching mechanisms.
  • 🚀 Impact: Expect increased efficiency in data loading, a potential boost in training speed, and easier customization of image preprocessing steps for YOLOv5 users.

YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([T.ToTensor(), LetterBox(size)])

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
@glenn-jocher glenn-jocher self-assigned this Aug 29, 2022
@glenn-jocher glenn-jocher changed the title New LetterBox transform New LetterBox(size) transform Aug 29, 2022
@glenn-jocher
Copy link
Member Author

@AyushExel this replicates our detection pre-processing. It's maybe 10% faster than current cls preprocessing. Our existing transforms are way faster than either though:

            TF = T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
            with Profile() as dt:
                for _ in range(1000):
                    im = TF(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB))
            print(dt.t)

            TF = T.Compose([T.ToTensor(), LetterBox(size)])
            with Profile() as dt:
                for _ in range(1000):
                    im = TF(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB))
            print(dt.t)

            with Profile() as dt:
                for _ in range(1000):
                    im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
                    im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
                    im = np.ascontiguousarray(im)  # contiguous
                    im = torch.from_numpy(im).to('cpu')
                    im = im.half() if False else im.float()  # uint8 to fp16/32
                    im /= 255  # 0 - 255 to 0.0 - 1.0
            print(dt.t)

# 5.506530046463013
# 4.960566997528076
# 1.5980091094970703

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
@AyushExel
Copy link
Contributor

AyushExel commented Aug 30, 2022

@glenn-jocher this sounds good. Still nowhere close to detection pre-processing speed. Are you planning to include it in the 6.3 release?

glenn-jocher and others added 3 commits August 30, 2022 10:52
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
@glenn-jocher
Copy link
Member Author

Just experimenting right now to see what effect the different preprocessing has on imagewoof.

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Aug 30, 2022

@AyushExel wow this is super awesome. If I replace the torch transforms with the ones I made our ImageNet val accuracies stay the same but our total time goes from 1:36 to 1:07, about 30% faster!

Master

T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

validating: 100% 391/391 [01:36<00:00,  4.07it/s]
                   Class      Images    top1_acc    top5_acc
                     all       50000       0.759       0.928

PR

T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

validating: 100% 391/391 [01:07<00:00,  5.83it/s]
                   Class      Images    top1_acc    top5_acc
                     all       50000       0.759       0.929

@AyushExel
Copy link
Contributor

@glenn-jocher okay this is great! But let's not merge it before testing on imagenet. 92% accuracy is too high so maybe it's not the right dataset to confirm performance variation.

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Aug 30, 2022

@AyushExel oh that comparison is already on ImageNet. It's using the official models on the 50,000 imagenet val set, i.e.

!bash yolov5-glenn/data/scripts/get_imagenet.sh --val  # download ImageNet val split (6.3G - 50000 images)
!python yolov5/classify/val.py --weights yolov5m-cls.pt --data ../datasets/imagenet --img 224  # validate

@AyushExel
Copy link
Contributor

@glenn-jocher ohh okay!! Then this is perfect already. PR is getting better accuracy.

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Aug 30, 2022

Ok PR is all good. Tested on Imagewoof and Imagenette training. Same accuracy but about 10-20% faster training times, i.e. 24 min vs 29 min.

Screenshot 2022-08-30 at 15 14 04

EDIT: The deterministic=True reproducibility is helping a lot in these comparison tests.

@glenn-jocher glenn-jocher changed the title New LetterBox(size) transform New LetterBox(size) CenterCrop(size), ToTensor() transforms (#9213) Aug 30, 2022
@glenn-jocher glenn-jocher merged commit 6e7a7ae into master Aug 30, 2022
@glenn-jocher glenn-jocher deleted the new/LetterBox branch August 30, 2022 13:17
ctjanuhowski pushed a commit to ctjanuhowski/yolov5 that referenced this pull request Sep 8, 2022
…tralytics#9213)

* New LetterBox transform

YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([T.ToTensor(), LetterBox(size)])

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update augmentations.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update augmentations.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update augmentations.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consistent and Efficient Preprocessing for Classification and Detection Model
2 participants