diff --git a/docs/en/datasets/detect/index.md b/docs/en/datasets/detect/index.md index 430853db62de..2342d553b203 100644 --- a/docs/en/datasets/detect/index.md +++ b/docs/en/datasets/detect/index.md @@ -74,6 +74,7 @@ Here is a list of the supported datasets and a brief description for each: - [**Argoverse**](argoverse.md): A collection of sensor data collected from autonomous vehicles. It contains 3D tracking annotations for car objects. - [**COCO**](coco.md): Common Objects in Context (COCO) is a large-scale object detection, segmentation, and captioning dataset with 80 object categories. +- [**LVIS**](lvis.md): LVIS is a large-scale object detection, segmentation, and captioning dataset with 1203 object categories. - [**COCO8**](coco8.md): A smaller subset of the COCO dataset, COCO8 is more lightweight and faster to train. - [**GlobalWheat2020**](globalwheat2020.md): A dataset containing images of wheat heads for the Global Wheat Challenge 2020. - [**Objects365**](objects365.md): A large-scale object detection dataset with 365 object categories and 600k images, aimed at advancing object detection research. diff --git a/docs/en/datasets/detect/lvis.md b/docs/en/datasets/detect/lvis.md new file mode 100644 index 000000000000..e196c61d533e --- /dev/null +++ b/docs/en/datasets/detect/lvis.md @@ -0,0 +1,96 @@ +--- +comments: true +description: Learn how LVIS, a leading dataset for object detection and segmentation, integrates with Ultralytics. Discover ways to use it for training YOLO models. +keywords: Ultralytics, LVIS dataset, object detection, YOLO, YOLO model training, image segmentation, computer vision, deep learning models +--- + +# LVIS Dataset + +The [LVIS](https://www.lvisdataset.org/dataset) dataset is a large-scale, fine-grained vocabulary-level annotation dataset developed and released by Facebook AI Research (FAIR). It is primarily used as a research benchmark for object detection and instance segmentation with a large vocabulary of categories, aiming to drive further advancements in computer vision field. + +## Key Features + +- LVIS contains 160k images and 2M instance annotations for object detection, segmentation, and captioning tasks. +- The dataset comprises 1203 object categories, including common objects like cars, bicycles, and animals, as well as more specific categories such as umbrellas, handbags, and sports equipment. +- Annotations include object bounding boxes, segmentation masks, and captions for each image. +- LVIS provides standardized evaluation metrics like mean Average Precision (mAP) for object detection, and mean Average Recall (mAR) for segmentation tasks, making it suitable for comparing model performance. +- LVIS uses the exactly the same images as [COCO](./coco.md) dataset, but with different splits and different annotations. + +## Dataset Structure + +The LVIS dataset is split into three subsets: + +1. **Train**: This subset contains 100k images for training object detection, segmentation, and captioning models. +2. **Val**: This subset has 20k images used for validation purposes during model training. +3. **Minival**: This subset is exactly the same as COCO val2017 set which has 5k images used for validation purposes during model training. +4. **Test**: This subset consists of 20k images used for testing and benchmarking the trained models. Ground truth annotations for this subset are not publicly available, and the results are submitted to the [LVIS evaluation server](https://eval.ai/web/challenges/challenge-page/675/overview) for performance evaluation. + + +## Applications + +The LVIS dataset is widely used for training and evaluating deep learning models in object detection (such as YOLO, Faster R-CNN, and SSD), instance segmentation (such as Mask R-CNN). The dataset's diverse set of object categories, large number of annotated images, and standardized evaluation metrics make it an essential resource for computer vision researchers and practitioners. + +## Dataset YAML + +A YAML (Yet Another Markup Language) file is used to define the dataset configuration. It contains information about the dataset's paths, classes, and other relevant information. In the case of the LVIS dataset, the `lvis.yaml` file is maintained at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml). + +!!! Example "ultralytics/cfg/datasets/lvis.yaml" + + ```yaml + --8<-- "ultralytics/cfg/datasets/lvis.yaml" + ``` + +## Usage + +To train a YOLOv8n model on the LVIS dataset for 100 epochs with an image size of 640, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page. + +!!! Example "Train Example" + + === "Python" + + ```python + from ultralytics import YOLO + + # Load a model + model = YOLO('yolov8n.pt') # load a pretrained model (recommended for training) + + # Train the model + results = model.train(data='lvis.yaml', epochs=100, imgsz=640) + ``` + + === "CLI" + + ```bash + # Start training from a pretrained *.pt model + yolo detect train data=lvis.yaml model=yolov8n.pt epochs=100 imgsz=640 + ``` + +## Sample Images and Annotations + +The LVIS dataset contains a diverse set of images with various object categories and complex scenes. Here are some examples of images from the dataset, along with their corresponding annotations: + +![Dataset sample image](https://private-user-images.githubusercontent.com/61612323/316485965-a88c2e62-58d0-4f67-bc69-1418e42175e9.jpg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTEzNjcyNjYsIm5iZiI6MTcxMTM2Njk2NiwicGF0aCI6Ii82MTYxMjMyMy8zMTY0ODU5NjUtYTg4YzJlNjItNThkMC00ZjY3LWJjNjktMTQxOGU0MjE3NWU5LmpwZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDAzMjUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwMzI1VDExNDI0NlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWZmMTVlNzE5MTBkOTZmNDQwNzJjNWQzYzM2NmEyMGMxODQ4ZDEyMjYwYmMyY2JjZDU5YzBmMDIyZGEwMGEwZDAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.7thukPdnJKYuBmTk1ROUyqxxV3Ix5GeNLqyi4wSDYvA) + + +- **Mosaiced Image**: This image demonstrates a training batch composed of mosaiced dataset images. Mosaicing is a technique used during training that combines multiple images into a single image to increase the variety of objects and scenes within each training batch. This helps improve the model's ability to generalize to different object sizes, aspect ratios, and contexts. + +The example showcases the variety and complexity of the images in the LVIS dataset and the benefits of using mosaicing during the training process. + +## Citations and Acknowledgments + +If you use the LVIS dataset in your research or development work, please cite the following paper: + +!!! Quote "" + + === "BibTeX" + + ```bibtex + @inproceedings{gupta2019lvis, + title={{LVIS}: A Dataset for Large Vocabulary Instance Segmentation}, + author={Gupta, Agrim and Dollar, Piotr and Girshick, Ross}, + booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition}, + year={2019} + } + ``` + +We would like to acknowledge the LVIS Consortium for creating and maintaining this valuable resource for the computer vision community. For more information about the LVIS dataset and its creators, visit the [LVIS dataset website](https://www.lvisdataset.org/dataset). diff --git a/docs/en/datasets/index.md b/docs/en/datasets/index.md index 7470113ea4f6..1ac05fc8b107 100644 --- a/docs/en/datasets/index.md +++ b/docs/en/datasets/index.md @@ -36,6 +36,7 @@ Bounding box object detection is a computer vision technique that involves detec - [Argoverse](detect/argoverse.md): A dataset containing 3D tracking and motion forecasting data from urban environments with rich annotations. - [COCO](detect/coco.md): A large-scale dataset designed for object detection, segmentation, and captioning with over 200K labeled images. +- [LVIS](lvis.md): A large-scale object detection, segmentation, and captioning dataset with 1203 object categories. - [COCO8](detect/coco8.md): Contains the first 4 images from COCO train and COCO val, suitable for quick tests. - [Global Wheat 2020](detect/globalwheat2020.md): A dataset of wheat head images collected from around the world for object detection and localization tasks. - [Objects365](detect/objects365.md): A high-quality, large-scale dataset for object detection with 365 object categories and over 600K annotated images. diff --git a/docs/en/models/fast-sam.md b/docs/en/models/fast-sam.md index e35899268e9d..784f70345640 100644 --- a/docs/en/models/fast-sam.md +++ b/docs/en/models/fast-sam.md @@ -147,7 +147,7 @@ FastSAM is also available directly from the [https://github.com/CASIA-IVA-Lab/Fa 4. Install the CLIP model: ```shell - pip install git+https://github.com/openai/CLIP.git + pip install git+https://github.com/ultralytics/CLIP.git ``` ### Example Usage diff --git a/docs/en/models/yolo-world.md b/docs/en/models/yolo-world.md index 116d62dfbccb..f8d400e238be 100644 --- a/docs/en/models/yolo-world.md +++ b/docs/en/models/yolo-world.md @@ -64,6 +64,39 @@ This section details the models available with their specific pre-trained weight The YOLO-World models are easy to integrate into your Python applications. Ultralytics provides user-friendly Python API and CLI commands to streamline development. +### Train Usage + +!!! Tip "Tip" + + We strongly recommend to use `yolov8-worldv2` model for custom training, because it supports deterministic training and also easy to export other formats i.e onnx/tensorrt. + +Object detection is straightforward with the `train` method, as illustrated below: + +!!! Example + + === "Python" + PyTorch pretrained `*.pt` models as well as configuration `*.yaml` files can be passed to the `YOLOWorld()` class to create a model instance in python: + + ```python + from ultralytics import YOLOWorld + + # Load a pretrained YOLOv8s-worldv2 model + model = YOLOWorld('yolov8s-worldv2.pt') + + # Train the model on the COCO8 example dataset for 100 epochs + results = model.train(data='coco8.yaml', epochs=100, imgsz=640) + + # Run inference with the YOLOv8n model on the 'bus.jpg' image + results = model('path/to/bus.jpg') + ``` + + === "CLI" + + ```bash + # Load a pretrained YOLOv8s-worldv2 model and train it on the COCO8 example dataset for 100 epochs + yolo train model=yolov8s-worldv2.yaml data=coco8.yaml epochs=100 imgsz=640 + ``` + ### Predict Usage Object detection is straightforward with the `predict` method, as illustrated below: @@ -196,6 +229,59 @@ You can also save a model after setting custom classes. By doing this you create This approach provides a powerful means of customizing state-of-the-art object detection models for specific tasks, making advanced AI more accessible and applicable to a broader range of practical applications. +## Reproduce official results from scratch(Experimental) + +### Prepare datasets + +- Train data + +| Dataset | Type | Samples | Boxes | Annotation Files | +|-------------------------------------------------------------------|-----------|---------|-------|--------------------------------------------------------------------------------------------------------------------------------------------| +| [Objects365v1](https://opendatalab.com/OpenDataLab/Objects365_v1) | Detection | 609k | 9621k | [objects365_train.json](https://opendatalab.com/OpenDataLab/Objects365_v1) | +| [GQA](https://nlp.stanford.edu/data/gqa/images.zip) | Grounding | 621k | 3681k | [final_mixed_train_no_coco.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_mixed_train_no_coco.json) | +| [Flickr30k](https://shannon.cs.illinois.edu/DenotationGraph/) | Grounding | 149k | 641k | [final_flickr_separateGT_train.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_flickr_separateGT_train.json) | + +- Val data + +| Dataset | Type | Annotation Files | +|---------------------------------------------------------------------------------------------------------|-----------|--------------------------------------------------------------------------------------------------------| +| [LVIS minival](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) | Detection | [minival.txt](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) | + +### Launch training from scratch + +!!! Note + + `WorldTrainerFromScratch` is highly customized to allow training yolo-world models on both detection datasets and grounding datasets simultaneously. More details please checkout [ultralytics.model.yolo.world.train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py). + +!!! Example + + === "Python" + + ```python + from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch + from ultralytics import YOLOWorld + + data = dict( + train=dict( + yolo_data=["Objects365.yaml"], + grounding_data=[ + dict( + img_path="../datasets/flickr30k/images", + json_file="../datasets/flickr30k/final_flickr_separateGT_train.json", + ), + dict( + img_path="../datasets/GQA/images", + json_file="../datasets/GQA/final_mixed_train_no_coco.json", + ), + ], + ), + val=dict(yolo_data=["lvis.yaml"]), + ) + model = YOLOWorld("yolov8s-worldv2.yaml") + model.train(data=data, batch=128, epochs=100, trainer=WorldTrainerFromScratch) + + ``` + ## Citations and Acknowledgements We extend our gratitude to the [Tencent AILab Computer Vision Center](https://ai.tencent.com/) for their pioneering work in real-time open-vocabulary object detection with YOLO-World: diff --git a/docs/en/reference/data/augment.md b/docs/en/reference/data/augment.md index 1d4099fca1a4..b28fd08ae4d6 100644 --- a/docs/en/reference/data/augment.md +++ b/docs/en/reference/data/augment.md @@ -59,6 +59,10 @@ keywords: Ultralytics, Data Augmentation, BaseTransform, MixUp, RandomHSV, Lette

+## ::: ultralytics.data.augment.RandomLoadText + +

+ ## ::: ultralytics.data.augment.ClassifyLetterBox

diff --git a/docs/en/reference/data/build.md b/docs/en/reference/data/build.md index 811c11d4e8b1..3c80e1edfba1 100644 --- a/docs/en/reference/data/build.md +++ b/docs/en/reference/data/build.md @@ -27,6 +27,10 @@ keywords: Ultralytics, YOLO v3, Data build, DataLoader, InfiniteDataLoader, seed

+## ::: ultralytics.data.build.build_grounding + +

+ ## ::: ultralytics.data.build.build_dataloader

diff --git a/docs/en/reference/data/dataset.md b/docs/en/reference/data/dataset.md index 242a054a98f2..7c8a9263174b 100644 --- a/docs/en/reference/data/dataset.md +++ b/docs/en/reference/data/dataset.md @@ -19,14 +19,18 @@ keywords: Ultralytics, YOLO, YOLODataset, SemanticDataset, data handling, data m

-## ::: ultralytics.data.dataset.SemanticDataset +## ::: ultralytics.data.dataset.YOLOMultiModalDataset

-## ::: ultralytics.data.dataset.load_dataset_cache_file +## ::: ultralytics.data.dataset.GroundingDataset

-## ::: ultralytics.data.dataset.save_dataset_cache_file +## ::: ultralytics.data.dataset.YOLOConcatDataset + +

+ +## ::: ultralytics.data.dataset.SemanticDataset

diff --git a/docs/en/reference/data/utils.md b/docs/en/reference/data/utils.md index af06ce857bd5..7ec7782617aa 100644 --- a/docs/en/reference/data/utils.md +++ b/docs/en/reference/data/utils.md @@ -66,3 +66,11 @@ keywords: Ultralytics, data utils, YOLO, img2label_paths, exif_size, polygon2mas ## ::: ultralytics.data.utils.autosplit

+ +## ::: ultralytics.data.utils.load_dataset_cache_file + +

+ +## ::: ultralytics.data.utils.save_dataset_cache_file + +

diff --git a/docs/en/reference/models/yolo/world/train.md b/docs/en/reference/models/yolo/world/train.md new file mode 100644 index 000000000000..4de8474bad64 --- /dev/null +++ b/docs/en/reference/models/yolo/world/train.md @@ -0,0 +1,15 @@ +# Reference for `ultralytics/models/yolo/world/train.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train.py) 🛠️. Thank you 🙏! + +

+ +## ::: ultralytics.models.yolo.world.train.WorldTrainer + +

+ +## ::: ultralytics.models.yolo.world.train.on_pretrain_routine_end + +

diff --git a/docs/en/reference/models/yolo/world/train_world.md b/docs/en/reference/models/yolo/world/train_world.md new file mode 100644 index 000000000000..c5028fd53733 --- /dev/null +++ b/docs/en/reference/models/yolo/world/train_world.md @@ -0,0 +1,11 @@ +# Reference for `ultralytics/models/yolo/world/train_world.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train_world.py) 🛠️. Thank you 🙏! + +

+ +## ::: ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch + +

diff --git a/docs/mkdocs_github_authors.yaml b/docs/mkdocs_github_authors.yaml index 5734987c2372..8c1c476ade21 100644 --- a/docs/mkdocs_github_authors.yaml +++ b/docs/mkdocs_github_authors.yaml @@ -18,6 +18,7 @@ chr043416@gmail.com: RizwanMunawar glenn.jocher@ultralytics.com: glenn-jocher muhammadrizwanmunawar123@gmail.com: RizwanMunawar not.committed.yet: null +plashchynski@gmail.com: plashchynski priytosh.revolution@live.com: priytosh-tripathi shuizhuyuanluo@126.com: null xinwang614@gmail.com: GreatV diff --git a/mkdocs.yml b/mkdocs.yml index da0a9310ae98..fcbe490b0ed7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -240,6 +240,7 @@ nav: - datasets/detect/index.md - Argoverse: datasets/detect/argoverse.md - COCO: datasets/detect/coco.md + - LVIS: datasets/detect/lvis.md - COCO8: datasets/detect/coco8.md - GlobalWheat2020: datasets/detect/globalwheat2020.md - Objects365: datasets/detect/objects365.md @@ -492,6 +493,9 @@ nav: - predict: reference/models/yolo/segment/predict.md - train: reference/models/yolo/segment/train.md - val: reference/models/yolo/segment/val.md + - world: + - train: reference/models/yolo/world/train.md + - train_world: reference/models/yolo/world/train_world.md - nn: - autobackend: reference/nn/autobackend.md - modules: diff --git a/tests/test_python.py b/tests/test_python.py index 4301f72e8ace..fcda95389910 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -643,3 +643,29 @@ def test_yolo_world(): model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet model.set_classes(["tree", "window"]) model(ASSETS / "bus.jpg", conf=0.01) + + # Training from yaml + model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet + model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world") + + model = YOLO("yolov8s-worldv2.pt") # no YOLOv8n-world model yet + # val + model.val(data="coco8.yaml", imgsz=32, save_txt=True, save_json=True) + # Training from pretrain + model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world") + + # test WorWorldTrainerFromScratch + from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch + + model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet + data = dict(train=dict(yolo_data=["coco8.yaml"]), val=dict(yolo_data=["coco8.yaml"])) + model.train( + data=data, + epochs=2, + imgsz=32, + cache="disk", + batch=-1, + close_mosaic=1, + name="yolo-world", + trainer=WorldTrainerFromScratch, + ) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index d7c24ca9c433..07232b441455 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.38" +__version__ = "8.1.39" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/cfg/datasets/lvis.yaml b/ultralytics/cfg/datasets/lvis.yaml new file mode 100644 index 000000000000..98149ed60431 --- /dev/null +++ b/ultralytics/cfg/datasets/lvis.yaml @@ -0,0 +1,1239 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# LVIS dataset http://www.lvisdataset.org by Facebook AI Research. +# Documentation: https://docs.ultralytics.com/datasets/detect/lvis/ +# Example usage: yolo train data=lvis.yaml +# parent +# ├── ultralytics +# └── datasets +# └── lvis ← downloads here (20.1 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/lvis # dataset root dir +train: train.txt # train images (relative to 'path') 100170 images +val: val.txt # val images (relative to 'path') 19809 images +minival: minival.txt # minval images (relative to 'path') 5000 images + +names: + 0: aerosol can/spray can + 1: air conditioner + 2: airplane/aeroplane + 3: alarm clock + 4: alcohol/alcoholic beverage + 5: alligator/gator + 6: almond + 7: ambulance + 8: amplifier + 9: anklet/ankle bracelet + 10: antenna/aerial/transmitting aerial + 11: apple + 12: applesauce + 13: apricot + 14: apron + 15: aquarium/fish tank + 16: arctic/arctic type of shoe/galosh/golosh/rubber/rubber type of shoe/gumshoe + 17: armband + 18: armchair + 19: armoire + 20: armor/armour + 21: artichoke + 22: trash can/garbage can/wastebin/dustbin/trash barrel/trash bin + 23: ashtray + 24: asparagus + 25: atomizer/atomiser/spray/sprayer/nebulizer/nebuliser + 26: avocado + 27: award/accolade + 28: awning + 29: ax/axe + 30: baboon + 31: baby buggy/baby carriage/perambulator/pram/stroller + 32: basketball backboard + 33: backpack/knapsack/packsack/rucksack/haversack + 34: handbag/purse/pocketbook + 35: suitcase/baggage/luggage + 36: bagel/beigel + 37: bagpipe + 38: baguet/baguette + 39: bait/lure + 40: ball + 41: ballet skirt/tutu + 42: balloon + 43: bamboo + 44: banana + 45: Band Aid + 46: bandage + 47: bandanna/bandana + 48: banjo + 49: banner/streamer + 50: barbell + 51: barge + 52: barrel/cask + 53: barrette + 54: barrow/garden cart/lawn cart/wheelbarrow + 55: baseball base + 56: baseball + 57: baseball bat + 58: baseball cap/jockey cap/golf cap + 59: baseball glove/baseball mitt + 60: basket/handbasket + 61: basketball + 62: bass horn/sousaphone/tuba + 63: bat/bat animal + 64: bath mat + 65: bath towel + 66: bathrobe + 67: bathtub/bathing tub + 68: batter/batter food + 69: battery + 70: beachball + 71: bead + 72: bean curd/tofu + 73: beanbag + 74: beanie/beany + 75: bear + 76: bed + 77: bedpan + 78: bedspread/bedcover/bed covering/counterpane/spread + 79: cow + 80: beef/beef food/boeuf/boeuf food + 81: beeper/pager + 82: beer bottle + 83: beer can + 84: beetle + 85: bell + 86: bell pepper/capsicum + 87: belt + 88: belt buckle + 89: bench + 90: beret + 91: bib + 92: Bible + 93: bicycle/bike/bike bicycle + 94: visor/vizor + 95: billboard + 96: binder/ring-binder + 97: binoculars/field glasses/opera glasses + 98: bird + 99: birdfeeder + 100: birdbath + 101: birdcage + 102: birdhouse + 103: birthday cake + 104: birthday card + 105: pirate flag + 106: black sheep + 107: blackberry + 108: blackboard/chalkboard + 109: blanket + 110: blazer/sport jacket/sport coat/sports jacket/sports coat + 111: blender/liquidizer/liquidiser + 112: blimp + 113: blinker/flasher + 114: blouse + 115: blueberry + 116: gameboard + 117: boat/ship/ship boat + 118: bob/bobber/bobfloat + 119: bobbin/spool/reel + 120: bobby pin/hairgrip + 121: boiled egg/coddled egg + 122: bolo tie/bolo/bola tie/bola + 123: deadbolt + 124: bolt + 125: bonnet + 126: book + 127: bookcase + 128: booklet/brochure/leaflet/pamphlet + 129: bookmark/bookmarker + 130: boom microphone/microphone boom + 131: boot + 132: bottle + 133: bottle opener + 134: bouquet + 135: bow/bow weapon + 136: bow/bow decorative ribbons + 137: bow-tie/bowtie + 138: bowl + 139: pipe bowl + 140: bowler hat/bowler/derby hat/derby/plug hat + 141: bowling ball + 142: box + 143: boxing glove + 144: suspenders + 145: bracelet/bangle + 146: brass plaque + 147: brassiere/bra/bandeau + 148: bread-bin/breadbox + 149: bread + 150: breechcloth/breechclout/loincloth + 151: bridal gown/wedding gown/wedding dress + 152: briefcase + 153: broccoli + 154: broach + 155: broom + 156: brownie + 157: brussels sprouts + 158: bubble gum + 159: bucket/pail + 160: horse buggy + 161: horned cow + 162: bulldog + 163: bulldozer/dozer + 164: bullet train + 165: bulletin board/notice board + 166: bulletproof vest + 167: bullhorn/megaphone + 168: bun/roll + 169: bunk bed + 170: buoy + 171: burrito + 172: bus/bus vehicle/autobus/charabanc/double-decker/motorbus/motorcoach + 173: business card + 174: butter + 175: butterfly + 176: button + 177: cab/cab taxi/taxi/taxicab + 178: cabana + 179: cabin car/caboose + 180: cabinet + 181: locker/storage locker + 182: cake + 183: calculator + 184: calendar + 185: calf + 186: camcorder + 187: camel + 188: camera + 189: camera lens + 190: camper/camper vehicle/camping bus/motor home + 191: can/tin can + 192: can opener/tin opener + 193: candle/candlestick + 194: candle holder + 195: candy bar + 196: candy cane + 197: walking cane + 198: canister/canister + 199: canoe + 200: cantaloup/cantaloupe + 201: canteen + 202: cap/cap headwear + 203: bottle cap/cap/cap container lid + 204: cape + 205: cappuccino/coffee cappuccino + 206: car/car automobile/auto/auto automobile/automobile + 207: railcar/railcar part of a train/railway car/railway car part of a train/railroad + car/railroad car part of a train + 208: elevator car + 209: car battery/automobile battery + 210: identity card + 211: card + 212: cardigan + 213: cargo ship/cargo vessel + 214: carnation + 215: horse carriage + 216: carrot + 217: tote bag + 218: cart + 219: carton + 220: cash register/register/register for cash transactions + 221: casserole + 222: cassette + 223: cast/plaster cast/plaster bandage + 224: cat + 225: cauliflower + 226: cayenne/cayenne spice/cayenne pepper/cayenne pepper spice/red pepper/red pepper + spice + 227: CD player + 228: celery + 229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone + 230: chain mail/ring mail/chain armor/chain armour/ring armor/ring armour + 231: chair + 232: chaise longue/chaise/daybed + 233: chalice + 234: chandelier + 235: chap + 236: checkbook/chequebook + 237: checkerboard + 238: cherry + 239: chessboard + 240: chicken/chicken animal + 241: chickpea/garbanzo + 242: chili/chili vegetable/chili pepper/chili pepper vegetable/chilli/chilli vegetable/chilly/chilly + vegetable/chile/chile vegetable + 243: chime/gong + 244: chinaware + 245: crisp/crisp potato chip/potato chip + 246: poker chip + 247: chocolate bar + 248: chocolate cake + 249: chocolate milk + 250: chocolate mousse + 251: choker/collar/neckband + 252: chopping board/cutting board/chopping block + 253: chopstick + 254: Christmas tree + 255: slide + 256: cider/cyder + 257: cigar box + 258: cigarette + 259: cigarette case/cigarette pack + 260: cistern/water tank + 261: clarinet + 262: clasp + 263: cleansing agent/cleanser/cleaner + 264: cleat/cleat for securing rope + 265: clementine + 266: clip + 267: clipboard + 268: clippers/clippers for plants + 269: cloak + 270: clock/timepiece/timekeeper + 271: clock tower + 272: clothes hamper/laundry basket/clothes basket + 273: clothespin/clothes peg + 274: clutch bag + 275: coaster + 276: coat + 277: coat hanger/clothes hanger/dress hanger + 278: coatrack/hatrack + 279: cock/rooster + 280: cockroach + 281: cocoa/cocoa beverage/hot chocolate/hot chocolate beverage/drinking chocolate + 282: coconut/cocoanut + 283: coffee maker/coffee machine + 284: coffee table/cocktail table + 285: coffeepot + 286: coil + 287: coin + 288: colander/cullender + 289: coleslaw/slaw + 290: coloring material/colouring material + 291: combination lock + 292: pacifier/teething ring + 293: comic book + 294: compass + 295: computer keyboard/keyboard/keyboard computer + 296: condiment + 297: cone/traffic cone + 298: control/controller + 299: convertible/convertible automobile + 300: sofa bed + 301: cooker + 302: cookie/cooky/biscuit/biscuit cookie + 303: cooking utensil + 304: cooler/cooler for food/ice chest + 305: cork/cork bottle plug/bottle cork + 306: corkboard + 307: corkscrew/bottle screw + 308: edible corn/corn/maize + 309: cornbread + 310: cornet/horn/trumpet + 311: cornice/valance/valance board/pelmet + 312: cornmeal + 313: corset/girdle + 314: costume + 315: cougar/puma/catamount/mountain lion/panther + 316: coverall + 317: cowbell + 318: cowboy hat/ten-gallon hat + 319: crab/crab animal + 320: crabmeat + 321: cracker + 322: crape/crepe/French pancake + 323: crate + 324: crayon/wax crayon + 325: cream pitcher + 326: crescent roll/croissant + 327: crib/cot + 328: crock pot/earthenware jar + 329: crossbar + 330: crouton + 331: crow + 332: crowbar/wrecking bar/pry bar + 333: crown + 334: crucifix + 335: cruise ship/cruise liner + 336: police cruiser/patrol car/police car/squad car + 337: crumb + 338: crutch + 339: cub/cub animal + 340: cube/square block + 341: cucumber/cuke + 342: cufflink + 343: cup + 344: trophy cup + 345: cupboard/closet + 346: cupcake + 347: hair curler/hair roller/hair crimper + 348: curling iron + 349: curtain/drapery + 350: cushion + 351: cylinder + 352: cymbal + 353: dagger + 354: dalmatian + 355: dartboard + 356: date/date fruit + 357: deck chair/beach chair + 358: deer/cervid + 359: dental floss/floss + 360: desk + 361: detergent + 362: diaper + 363: diary/journal + 364: die/dice + 365: dinghy/dory/rowboat + 366: dining table + 367: tux/tuxedo + 368: dish + 369: dish antenna + 370: dishrag/dishcloth + 371: dishtowel/tea towel + 372: dishwasher/dishwashing machine + 373: dishwasher detergent/dishwashing detergent/dishwashing liquid/dishsoap + 374: dispenser + 375: diving board + 376: Dixie cup/paper cup + 377: dog + 378: dog collar + 379: doll + 380: dollar/dollar bill/one dollar bill + 381: dollhouse/doll's house + 382: dolphin + 383: domestic ass/donkey + 384: doorknob/doorhandle + 385: doormat/welcome mat + 386: doughnut/donut + 387: dove + 388: dragonfly + 389: drawer + 390: underdrawers/boxers/boxershorts + 391: dress/frock + 392: dress hat/high hat/opera hat/silk hat/top hat + 393: dress suit + 394: dresser + 395: drill + 396: drone + 397: dropper/eye dropper + 398: drum/drum musical instrument + 399: drumstick + 400: duck + 401: duckling + 402: duct tape + 403: duffel bag/duffle bag/duffel/duffle + 404: dumbbell + 405: dumpster + 406: dustpan + 407: eagle + 408: earphone/earpiece/headphone + 409: earplug + 410: earring + 411: easel + 412: eclair + 413: eel + 414: egg/eggs + 415: egg roll/spring roll + 416: egg yolk/yolk/yolk egg + 417: eggbeater/eggwhisk + 418: eggplant/aubergine + 419: electric chair + 420: refrigerator + 421: elephant + 422: elk/moose + 423: envelope + 424: eraser + 425: escargot + 426: eyepatch + 427: falcon + 428: fan + 429: faucet/spigot/tap + 430: fedora + 431: ferret + 432: Ferris wheel + 433: ferry/ferryboat + 434: fig/fig fruit + 435: fighter jet/fighter aircraft/attack aircraft + 436: figurine + 437: file cabinet/filing cabinet + 438: file/file tool + 439: fire alarm/smoke alarm + 440: fire engine/fire truck + 441: fire extinguisher/extinguisher + 442: fire hose + 443: fireplace + 444: fireplug/fire hydrant/hydrant + 445: first-aid kit + 446: fish + 447: fish/fish food + 448: fishbowl/goldfish bowl + 449: fishing rod/fishing pole + 450: flag + 451: flagpole/flagstaff + 452: flamingo + 453: flannel + 454: flap + 455: flash/flashbulb + 456: flashlight/torch + 457: fleece + 458: flip-flop/flip-flop sandal + 459: flipper/flipper footwear/fin/fin footwear + 460: flower arrangement/floral arrangement + 461: flute glass/champagne flute + 462: foal + 463: folding chair + 464: food processor + 465: football/football American + 466: football helmet + 467: footstool/footrest + 468: fork + 469: forklift + 470: freight car + 471: French toast + 472: freshener/air freshener + 473: frisbee + 474: frog/toad/toad frog + 475: fruit juice + 476: frying pan/frypan/skillet + 477: fudge + 478: funnel + 479: futon + 480: gag/muzzle + 481: garbage + 482: garbage truck + 483: garden hose + 484: gargle/mouthwash + 485: gargoyle + 486: garlic/ail + 487: gasmask/respirator/gas helmet + 488: gazelle + 489: gelatin/jelly + 490: gemstone + 491: generator + 492: giant panda/panda/panda bear + 493: gift wrap + 494: ginger/gingerroot + 495: giraffe + 496: cincture/sash/waistband/waistcloth + 497: glass/glass drink container/drinking glass + 498: globe + 499: glove + 500: goat + 501: goggles + 502: goldfish + 503: golf club/golf-club + 504: golfcart + 505: gondola/gondola boat + 506: goose + 507: gorilla + 508: gourd + 509: grape + 510: grater + 511: gravestone/headstone/tombstone + 512: gravy boat/gravy holder + 513: green bean + 514: green onion/spring onion/scallion + 515: griddle + 516: grill/grille/grillwork/radiator grille + 517: grits/hominy grits + 518: grizzly/grizzly bear + 519: grocery bag + 520: guitar + 521: gull/seagull + 522: gun + 523: hairbrush + 524: hairnet + 525: hairpin + 526: halter top + 527: ham/jambon/gammon + 528: hamburger/beefburger/burger + 529: hammer + 530: hammock + 531: hamper + 532: hamster + 533: hair dryer + 534: hand glass/hand mirror + 535: hand towel/face towel + 536: handcart/pushcart/hand truck + 537: handcuff + 538: handkerchief + 539: handle/grip/handgrip + 540: handsaw/carpenter's saw + 541: hardback book/hardcover book + 542: harmonium/organ/organ musical instrument/reed organ/reed organ musical instrument + 543: hat + 544: hatbox + 545: veil + 546: headband + 547: headboard + 548: headlight/headlamp + 549: headscarf + 550: headset + 551: headstall/headstall for horses/headpiece/headpiece for horses + 552: heart + 553: heater/warmer + 554: helicopter + 555: helmet + 556: heron + 557: highchair/feeding chair + 558: hinge + 559: hippopotamus + 560: hockey stick + 561: hog/pig + 562: home plate/home plate baseball/home base/home base baseball + 563: honey + 564: fume hood/exhaust hood + 565: hook + 566: hookah/narghile/nargileh/sheesha/shisha/water pipe + 567: hornet + 568: horse + 569: hose/hosepipe + 570: hot-air balloon + 571: hotplate + 572: hot sauce + 573: hourglass + 574: houseboat + 575: hummingbird + 576: hummus/humus/hommos/hoummos/humous + 577: polar bear + 578: icecream + 579: popsicle + 580: ice maker + 581: ice pack/ice bag + 582: ice skate + 583: igniter/ignitor/lighter + 584: inhaler/inhalator + 585: iPod + 586: iron/iron for clothing/smoothing iron/smoothing iron for clothing + 587: ironing board + 588: jacket + 589: jam + 590: jar + 591: jean/blue jean/denim + 592: jeep/landrover + 593: jelly bean/jelly egg + 594: jersey/T-shirt/tee shirt + 595: jet plane/jet-propelled plane + 596: jewel/gem/precious stone + 597: jewelry/jewellery + 598: joystick + 599: jumpsuit + 600: kayak + 601: keg + 602: kennel/doghouse + 603: kettle/boiler + 604: key + 605: keycard + 606: kilt + 607: kimono + 608: kitchen sink + 609: kitchen table + 610: kite + 611: kitten/kitty + 612: kiwi fruit + 613: knee pad + 614: knife + 615: knitting needle + 616: knob + 617: knocker/knocker on a door/doorknocker + 618: koala/koala bear + 619: lab coat/laboratory coat + 620: ladder + 621: ladle + 622: ladybug/ladybeetle/ladybird beetle + 623: lamb/lamb animal + 624: lamb-chop/lambchop + 625: lamp + 626: lamppost + 627: lampshade + 628: lantern + 629: lanyard/laniard + 630: laptop computer/notebook computer + 631: lasagna/lasagne + 632: latch + 633: lawn mower + 634: leather + 635: legging/legging clothing/leging/leging clothing/leg covering + 636: Lego/Lego set + 637: legume + 638: lemon + 639: lemonade + 640: lettuce + 641: license plate/numberplate + 642: life buoy/lifesaver/life belt/life ring + 643: life jacket/life vest + 644: lightbulb + 645: lightning rod/lightning conductor + 646: lime + 647: limousine + 648: lion + 649: lip balm + 650: liquor/spirits/hard liquor/liqueur/cordial + 651: lizard + 652: log + 653: lollipop + 654: speaker/speaker stereo equipment + 655: loveseat + 656: machine gun + 657: magazine + 658: magnet + 659: mail slot + 660: mailbox/mailbox at home/letter box/letter box at home + 661: mallard + 662: mallet + 663: mammoth + 664: manatee + 665: mandarin orange + 666: manager/through + 667: manhole + 668: map + 669: marker + 670: martini + 671: mascot + 672: mashed potato + 673: masher + 674: mask/facemask + 675: mast + 676: mat/mat gym equipment/gym mat + 677: matchbox + 678: mattress + 679: measuring cup + 680: measuring stick/ruler/ruler measuring stick/measuring rod + 681: meatball + 682: medicine + 683: melon + 684: microphone + 685: microscope + 686: microwave oven + 687: milestone/milepost + 688: milk + 689: milk can + 690: milkshake + 691: minivan + 692: mint candy + 693: mirror + 694: mitten + 695: mixer/mixer kitchen tool/stand mixer + 696: money + 697: monitor/monitor computer equipment + 698: monkey + 699: motor + 700: motor scooter/scooter + 701: motor vehicle/automotive vehicle + 702: motorcycle + 703: mound/mound baseball/pitcher's mound + 704: mouse/mouse computer equipment/computer mouse + 705: mousepad + 706: muffin + 707: mug + 708: mushroom + 709: music stool/piano stool + 710: musical instrument/instrument/instrument musical + 711: nailfile + 712: napkin/table napkin/serviette + 713: neckerchief + 714: necklace + 715: necktie/tie/tie necktie + 716: needle + 717: nest + 718: newspaper/paper/paper newspaper + 719: newsstand + 720: nightshirt/nightwear/sleepwear/nightclothes + 721: nosebag/nosebag for animals/feedbag + 722: noseband/noseband for animals/nosepiece/nosepiece for animals + 723: notebook + 724: notepad + 725: nut + 726: nutcracker + 727: oar + 728: octopus/octopus food + 729: octopus/octopus animal + 730: oil lamp/kerosene lamp/kerosine lamp + 731: olive oil + 732: omelet/omelette + 733: onion + 734: orange/orange fruit + 735: orange juice + 736: ostrich + 737: ottoman/pouf/pouffe/hassock + 738: oven + 739: overalls/overalls clothing + 740: owl + 741: packet + 742: inkpad/inking pad/stamp pad + 743: pad + 744: paddle/boat paddle + 745: padlock + 746: paintbrush + 747: painting + 748: pajamas/pyjamas + 749: palette/pallet + 750: pan/pan for cooking/cooking pan + 751: pan/pan metal container + 752: pancake + 753: pantyhose + 754: papaya + 755: paper plate + 756: paper towel + 757: paperback book/paper-back book/softback book/soft-cover book + 758: paperweight + 759: parachute + 760: parakeet/parrakeet/parroket/paraquet/paroquet/parroquet + 761: parasail/parasail sports + 762: parasol/sunshade + 763: parchment + 764: parka/anorak + 765: parking meter + 766: parrot + 767: passenger car/passenger car part of a train/coach/coach part of a train + 768: passenger ship + 769: passport + 770: pastry + 771: patty/patty food + 772: pea/pea food + 773: peach + 774: peanut butter + 775: pear + 776: peeler/peeler tool for fruit and vegetables + 777: wooden leg/pegleg + 778: pegboard + 779: pelican + 780: pen + 781: pencil + 782: pencil box/pencil case + 783: pencil sharpener + 784: pendulum + 785: penguin + 786: pennant + 787: penny/penny coin + 788: pepper/peppercorn + 789: pepper mill/pepper grinder + 790: perfume + 791: persimmon + 792: person/baby/child/boy/girl/man/woman/human + 793: pet + 794: pew/pew church bench/church bench + 795: phonebook/telephone book/telephone directory + 796: phonograph record/phonograph recording/record/record phonograph recording + 797: piano + 798: pickle + 799: pickup truck + 800: pie + 801: pigeon + 802: piggy bank/penny bank + 803: pillow + 804: pin/pin non jewelry + 805: pineapple + 806: pinecone + 807: ping-pong ball + 808: pinwheel + 809: tobacco pipe + 810: pipe/piping + 811: pistol/handgun + 812: pita/pita bread/pocket bread + 813: pitcher/pitcher vessel for liquid/ewer + 814: pitchfork + 815: pizza + 816: place mat + 817: plate + 818: platter + 819: playpen + 820: pliers/plyers + 821: plow/plow farm equipment/plough/plough farm equipment + 822: plume + 823: pocket watch + 824: pocketknife + 825: poker/poker fire stirring tool/stove poker/fire hook + 826: pole/post + 827: polo shirt/sport shirt + 828: poncho + 829: pony + 830: pool table/billiard table/snooker table + 831: pop/pop soda/soda/soda pop/tonic/soft drink + 832: postbox/postbox public/mailbox/mailbox public + 833: postcard/postal card/mailing-card + 834: poster/placard + 835: pot + 836: flowerpot + 837: potato + 838: potholder + 839: pottery/clayware + 840: pouch + 841: power shovel/excavator/digger + 842: prawn/shrimp + 843: pretzel + 844: printer/printing machine + 845: projectile/projectile weapon/missile + 846: projector + 847: propeller/propellor + 848: prune + 849: pudding + 850: puffer/puffer fish/pufferfish/blowfish/globefish + 851: puffin + 852: pug-dog + 853: pumpkin + 854: puncher + 855: puppet/marionette + 856: puppy + 857: quesadilla + 858: quiche + 859: quilt/comforter + 860: rabbit + 861: race car/racing car + 862: racket/racquet + 863: radar + 864: radiator + 865: radio receiver/radio set/radio/tuner/tuner radio + 866: radish/daikon + 867: raft + 868: rag doll + 869: raincoat/waterproof jacket + 870: ram/ram animal + 871: raspberry + 872: rat + 873: razorblade + 874: reamer/reamer juicer/juicer/juice reamer + 875: rearview mirror + 876: receipt + 877: recliner/reclining chair/lounger/lounger chair + 878: record player/phonograph/phonograph record player/turntable + 879: reflector + 880: remote control + 881: rhinoceros + 882: rib/rib food + 883: rifle + 884: ring + 885: river boat + 886: road map + 887: robe + 888: rocking chair + 889: rodent + 890: roller skate + 891: Rollerblade + 892: rolling pin + 893: root beer + 894: router/router computer equipment + 895: rubber band/elastic band + 896: runner/runner carpet + 897: plastic bag/paper bag + 898: saddle/saddle on an animal + 899: saddle blanket/saddlecloth/horse blanket + 900: saddlebag + 901: safety pin + 902: sail + 903: salad + 904: salad plate/salad bowl + 905: salami + 906: salmon/salmon fish + 907: salmon/salmon food + 908: salsa + 909: saltshaker + 910: sandal/sandal type of shoe + 911: sandwich + 912: satchel + 913: saucepan + 914: saucer + 915: sausage + 916: sawhorse/sawbuck + 917: saxophone + 918: scale/scale measuring instrument + 919: scarecrow/strawman + 920: scarf + 921: school bus + 922: scissors + 923: scoreboard + 924: scraper + 925: screwdriver + 926: scrubbing brush + 927: sculpture + 928: seabird/seafowl + 929: seahorse + 930: seaplane/hydroplane + 931: seashell + 932: sewing machine + 933: shaker + 934: shampoo + 935: shark + 936: sharpener + 937: Sharpie + 938: shaver/shaver electric/electric shaver/electric razor + 939: shaving cream/shaving soap + 940: shawl + 941: shears + 942: sheep + 943: shepherd dog/sheepdog + 944: sherbert/sherbet + 945: shield + 946: shirt + 947: shoe/sneaker/sneaker type of shoe/tennis shoe + 948: shopping bag + 949: shopping cart + 950: short pants/shorts/shorts clothing/trunks/trunks clothing + 951: shot glass + 952: shoulder bag + 953: shovel + 954: shower head + 955: shower cap + 956: shower curtain + 957: shredder/shredder for paper + 958: signboard + 959: silo + 960: sink + 961: skateboard + 962: skewer + 963: ski + 964: ski boot + 965: ski parka/ski jacket + 966: ski pole + 967: skirt + 968: skullcap + 969: sled/sledge/sleigh + 970: sleeping bag + 971: sling/sling bandage/triangular bandage + 972: slipper/slipper footwear/carpet slipper/carpet slipper footwear + 973: smoothie + 974: snake/serpent + 975: snowboard + 976: snowman + 977: snowmobile + 978: soap + 979: soccer ball + 980: sock + 981: sofa/couch/lounge + 982: softball + 983: solar array/solar battery/solar panel + 984: sombrero + 985: soup + 986: soup bowl + 987: soupspoon + 988: sour cream/soured cream + 989: soya milk/soybean milk/soymilk + 990: space shuttle + 991: sparkler/sparkler fireworks + 992: spatula + 993: spear/lance + 994: spectacles/specs/eyeglasses/glasses + 995: spice rack + 996: spider + 997: crawfish/crayfish + 998: sponge + 999: spoon + 1000: sportswear/athletic wear/activewear + 1001: spotlight + 1002: squid/squid food/calamari/calamary + 1003: squirrel + 1004: stagecoach + 1005: stapler/stapler stapling machine + 1006: starfish/sea star + 1007: statue/statue sculpture + 1008: steak/steak food + 1009: steak knife + 1010: steering wheel + 1011: stepladder + 1012: step stool + 1013: stereo/stereo sound system + 1014: stew + 1015: stirrer + 1016: stirrup + 1017: stool + 1018: stop sign + 1019: brake light + 1020: stove/kitchen stove/range/range kitchen appliance/kitchen range/cooking stove + 1021: strainer + 1022: strap + 1023: straw/straw for drinking/drinking straw + 1024: strawberry + 1025: street sign + 1026: streetlight/street lamp + 1027: string cheese + 1028: stylus + 1029: subwoofer + 1030: sugar bowl + 1031: sugarcane/sugarcane plant + 1032: suit/suit clothing + 1033: sunflower + 1034: sunglasses + 1035: sunhat + 1036: surfboard + 1037: sushi + 1038: mop + 1039: sweat pants + 1040: sweatband + 1041: sweater + 1042: sweatshirt + 1043: sweet potato + 1044: swimsuit/swimwear/bathing suit/swimming costume/bathing costume/swimming trunks/bathing + trunks + 1045: sword + 1046: syringe + 1047: Tabasco sauce + 1048: table-tennis table/ping-pong table + 1049: table + 1050: table lamp + 1051: tablecloth + 1052: tachometer + 1053: taco + 1054: tag + 1055: taillight/rear light + 1056: tambourine + 1057: army tank/armored combat vehicle/armoured combat vehicle + 1058: tank/tank storage vessel/storage tank + 1059: tank top/tank top clothing + 1060: tape/tape sticky cloth or paper + 1061: tape measure/measuring tape + 1062: tapestry + 1063: tarp + 1064: tartan/plaid + 1065: tassel + 1066: tea bag + 1067: teacup + 1068: teakettle + 1069: teapot + 1070: teddy bear + 1071: telephone/phone/telephone set + 1072: telephone booth/phone booth/call box/telephone box/telephone kiosk + 1073: telephone pole/telegraph pole/telegraph post + 1074: telephoto lens/zoom lens + 1075: television camera/tv camera + 1076: television set/tv/tv set + 1077: tennis ball + 1078: tennis racket + 1079: tequila + 1080: thermometer + 1081: thermos bottle + 1082: thermostat + 1083: thimble + 1084: thread/yarn + 1085: thumbtack/drawing pin/pushpin + 1086: tiara + 1087: tiger + 1088: tights/tights clothing/leotards + 1089: timer/stopwatch + 1090: tinfoil + 1091: tinsel + 1092: tissue paper + 1093: toast/toast food + 1094: toaster + 1095: toaster oven + 1096: toilet + 1097: toilet tissue/toilet paper/bathroom tissue + 1098: tomato + 1099: tongs + 1100: toolbox + 1101: toothbrush + 1102: toothpaste + 1103: toothpick + 1104: cover + 1105: tortilla + 1106: tow truck + 1107: towel + 1108: towel rack/towel rail/towel bar + 1109: toy + 1110: tractor/tractor farm equipment + 1111: traffic light + 1112: dirt bike + 1113: trailer truck/tractor trailer/trucking rig/articulated lorry/semi truck + 1114: train/train railroad vehicle/railroad train + 1115: trampoline + 1116: tray + 1117: trench coat + 1118: triangle/triangle musical instrument + 1119: tricycle + 1120: tripod + 1121: trousers/pants/pants clothing + 1122: truck + 1123: truffle/truffle chocolate/chocolate truffle + 1124: trunk + 1125: vat + 1126: turban + 1127: turkey/turkey food + 1128: turnip + 1129: turtle + 1130: turtleneck/turtleneck clothing/polo-neck + 1131: typewriter + 1132: umbrella + 1133: underwear/underclothes/underclothing/underpants + 1134: unicycle + 1135: urinal + 1136: urn + 1137: vacuum cleaner + 1138: vase + 1139: vending machine + 1140: vent/blowhole/air vent + 1141: vest/waistcoat + 1142: videotape + 1143: vinegar + 1144: violin/fiddle + 1145: vodka + 1146: volleyball + 1147: vulture + 1148: waffle + 1149: waffle iron + 1150: wagon + 1151: wagon wheel + 1152: walking stick + 1153: wall clock + 1154: wall socket/wall plug/electric outlet/electrical outlet/outlet/electric receptacle + 1155: wallet/billfold + 1156: walrus + 1157: wardrobe + 1158: washbasin/basin/basin for washing/washbowl/washstand/handbasin + 1159: automatic washer/washing machine + 1160: watch/wristwatch + 1161: water bottle + 1162: water cooler + 1163: water faucet/water tap/tap/tap water faucet + 1164: water heater/hot-water heater + 1165: water jug + 1166: water gun/squirt gun + 1167: water scooter/sea scooter/jet ski + 1168: water ski + 1169: water tower + 1170: watering can + 1171: watermelon + 1172: weathervane/vane/vane weathervane/wind vane + 1173: webcam + 1174: wedding cake/bridecake + 1175: wedding ring/wedding band + 1176: wet suit + 1177: wheel + 1178: wheelchair + 1179: whipped cream + 1180: whistle + 1181: wig + 1182: wind chime + 1183: windmill + 1184: window box/window box for plants + 1185: windshield wiper/windscreen wiper/wiper/wiper for windshield or screen + 1186: windsock/air sock/air-sleeve/wind sleeve/wind cone + 1187: wine bottle + 1188: wine bucket/wine cooler + 1189: wineglass + 1190: blinder/blinder for horses + 1191: wok + 1192: wolf + 1193: wooden spoon + 1194: wreath + 1195: wrench/spanner + 1196: wristband + 1197: wristlet/wrist band + 1198: yacht + 1199: yogurt/yoghurt/yoghourt + 1200: yoke/yoke animal equipment + 1201: zebra + 1202: zucchini/courgette + +# Download script/URL (optional) +download: | + from ultralytics.utils.downloads import download + from pathlib import Path + + # Download labels + dir = Path(yaml['path']) # dataset root dir + url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/' + urls = [url + 'lvis-labels-segments.zip'] # labels + download(urls, dir=dir.parent) + # Download data + urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images + 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images + 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) + download(urls, dir=dir / 'images', threads=3) diff --git a/ultralytics/data/__init__.py b/ultralytics/data/__init__.py index 9f91ce97f1ed..fba2aeb0a38c 100644 --- a/ultralytics/data/__init__.py +++ b/ultralytics/data/__init__.py @@ -1,15 +1,31 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from .base import BaseDataset -from .build import build_dataloader, build_yolo_dataset, load_inference_source -from .dataset import ClassificationDataset, SemanticDataset, YOLODataset +from .build import ( + build_dataloader, + build_yolo_dataset, + build_grounding, + load_inference_source, +) +from .dataset import ( + ClassificationDataset, + SemanticDataset, + YOLODataset, + YOLOMultiModalDataset, + GroundingDataset, + YOLOConcatDataset, +) __all__ = ( "BaseDataset", "ClassificationDataset", "SemanticDataset", "YOLODataset", + "YOLOMultiModalDataset", + "YOLOConcatDataset", + "GroundingDataset", "build_yolo_dataset", + "build_grounding", "build_dataloader", "load_inference_source", ) diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index 9d7141654ffa..c72fa077e110 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -3,6 +3,7 @@ import math import random from copy import deepcopy +from typing import Tuple, Union import cv2 import numpy as np @@ -66,7 +67,7 @@ class Compose: def __init__(self, transforms): """Initializes the Compose object with a list of transforms.""" - self.transforms = transforms + self.transforms = transforms if isinstance(transforms, list) else [transforms] def __call__(self, data): """Applies a series of transformations to input data.""" @@ -78,6 +79,29 @@ def append(self, transform): """Appends a new transform to the existing list of transforms.""" self.transforms.append(transform) + def insert(self, index, transform): + """Inserts a new transform to the existing list of transforms.""" + self.transforms.insert(index, transform) + + def __getitem__(self, index: Union[list, int]) -> "Compose": + """Retrieve a specific transform or a set of transforms using indexing.""" + assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" + index = [index] if isinstance(index, int) else index + return Compose([self.transforms[i] for i in index]) + + def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None: + """Retrieve a specific transform or a set of transforms using indexing.""" + assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" + if isinstance(index, list): + assert isinstance( + value, list + ), f"The indices should be the same type as values, but got {type(index)} and {type(value)}" + if isinstance(index, int): + index, value = [index], [value] + for i, v in zip(index, value): + assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}." + self.transforms[i] = v + def tolist(self): """Converts the list of transforms to a standard Python list.""" return self.transforms @@ -118,6 +142,8 @@ def __call__(self, labels): mix_labels[i] = self.pre_transform(data) labels["mix_labels"] = mix_labels + # Update cls and texts + labels = self._update_label_text(labels) # Mosaic or MixUp labels = self._mix_transform(labels) labels.pop("mix_labels", None) @@ -131,6 +157,22 @@ def get_indexes(self): """Gets a list of shuffled indexes for mosaic augmentation.""" raise NotImplementedError + def _update_label_text(self, labels): + """Update label text.""" + if "texts" not in labels: + return labels + + mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], []) + mix_texts = list({tuple(x) for x in mix_texts}) + text2id = {text: i for i, text in enumerate(mix_texts)} + + for label in [labels] + labels["mix_labels"]: + for i, l in enumerate(label["cls"].squeeze(-1).tolist()): + text = label["texts"][int(l)] + label["cls"][i] = text2id[tuple(text)] + label["texts"] = mix_texts + return labels + class Mosaic(BaseMixTransform): """ @@ -320,6 +362,8 @@ def _cat_labels(self, mosaic_labels): final_labels["instances"].clip(imgsz, imgsz) good = final_labels["instances"].remove_zero_area_boxes() final_labels["cls"] = final_labels["cls"][good] + if "texts" in mosaic_labels[0]: + final_labels["texts"] = mosaic_labels[0]["texts"] return final_labels @@ -970,6 +1014,83 @@ def _format_segments(self, instances, cls, w, h): return masks, instances, cls +class RandomLoadText: + """ + Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples. + + Attributes: + prompt_format (str): Format for prompt. Default is '{}'. + neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80). + max_samples (int): The max number of different text samples in one image, Default is 80. + padding (bool): Whether to pad texts to max_samples. Default is False. + padding_value (str): The padding text. Default is "". + """ + + def __init__( + self, + prompt_format: str = "{}", + neg_samples: Tuple[int, int] = (80, 80), + max_samples: int = 80, + padding: bool = False, + padding_value: str = "", + ) -> None: + """Initializes the RandomLoadText class with given parameters.""" + self.prompt_format = prompt_format + self.neg_samples = neg_samples + self.max_samples = max_samples + self.padding = padding + self.padding_value = padding_value + + def __call__(self, labels: dict) -> dict: + """Return updated classes and texts.""" + assert "texts" in labels, "No texts found in labels." + class_texts = labels["texts"] + num_classes = len(class_texts) + cls = np.asarray(labels.pop("cls"), dtype=int) + pos_labels = np.unique(cls).tolist() + + if len(pos_labels) > self.max_samples: + pos_labels = set(random.sample(pos_labels, k=self.max_samples)) + + neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples)) + neg_labels = [] + for i in range(num_classes): + if i not in pos_labels: + neg_labels.append(i) + neg_labels = random.sample(neg_labels, k=neg_samples) + + sampled_labels = pos_labels + neg_labels + random.shuffle(sampled_labels) + + label2ids = {label: i for i, label in enumerate(sampled_labels)} + valid_idx = np.zeros(len(labels["instances"]), dtype=bool) + new_cls = [] + for i, label in enumerate(cls.squeeze(-1).tolist()): + if label not in label2ids: + continue + valid_idx[i] = True + new_cls.append([label2ids[label]]) + labels["instances"] = labels["instances"][valid_idx] + labels["cls"] = np.array(new_cls) + + # Randomly select one prompt when there's more than one prompts + texts = [] + for label in sampled_labels: + prompts = class_texts[label] + assert len(prompts) > 0 + prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))]) + texts.append(prompt) + + if self.padding: + valid_labels = len(pos_labels) + len(neg_labels) + num_padding = self.max_samples - valid_labels + if num_padding > 0: + texts += [self.padding_value] * num_padding + + labels["texts"] = texts + return labels + + def v8_transforms(dataset, imgsz, hyp, stretch=False): """Convert images to a size suitable for YOLOv8 training.""" pre_transform = Compose( diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py index 6bfb48f33908..768f46a8a7cb 100644 --- a/ultralytics/data/build.py +++ b/ultralytics/data/build.py @@ -22,7 +22,7 @@ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.utils import RANK, colorstr from ultralytics.utils.checks import check_file -from .dataset import YOLODataset +from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset from .utils import PIN_MEMORY @@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa random.seed(worker_seed) -def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32): +def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False): """Build YOLO Dataset.""" - return YOLODataset( + dataset = YOLOMultiModalDataset if multi_modal else YOLODataset + return dataset( img_path=img_path, imgsz=cfg.imgsz, batch_size=batch, @@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str ) +def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32): + """Build YOLO Dataset.""" + return GroundingDataset( + img_path=img_path, + json_file=json_file, + imgsz=cfg.imgsz, + batch_size=batch, + augment=mode == "train", # augmentation + hyp=cfg, # TODO: probably add a get_hyps_from_cfg function + rect=cfg.rect or rect, # rectangular batches + cache=cfg.cache or None, + single_cls=cfg.single_cls or False, + stride=int(stride), + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), + task=cfg.task, + classes=cfg.classes, + fraction=cfg.fraction if mode == "train" else 1.0, + ) + + def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): """Return an InfiniteDataLoader or DataLoader for training or validation set.""" batch = min(batch, len(dataset)) diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py index eff4dac162cc..62370f814398 100644 --- a/ultralytics/data/converter.py +++ b/ultralytics/data/converter.py @@ -219,6 +219,7 @@ def convert_coco( use_segments=False, use_keypoints=False, cls91to80=True, + lvis=False, ): """ Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models. @@ -229,12 +230,14 @@ def convert_coco( use_segments (bool, optional): Whether to include segmentation masks in the output. use_keypoints (bool, optional): Whether to include keypoint annotations in the output. cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs. + lvis (bool, optional): Whether to convert data in lvis dataset way. Example: ```python from ultralytics.data.converter import convert_coco convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True) + convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True) ``` Output: @@ -251,8 +254,14 @@ def convert_coco( # Import json for json_file in sorted(Path(labels_dir).resolve().glob("*.json")): - fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name + lname = "" if lvis else json_file.stem.replace("instances_", "") + fn = Path(save_dir) / "labels" / lname # folder name fn.mkdir(parents=True, exist_ok=True) + if lvis: + # NOTE: create folders for both train and val in advance, + # since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split. + (fn / "train2017").mkdir(parents=True, exist_ok=True) + (fn / "val2017").mkdir(parents=True, exist_ok=True) with open(json_file) as f: data = json.load(f) @@ -263,16 +272,20 @@ def convert_coco( for ann in data["annotations"]: imgToAnns[ann["image_id"]].append(ann) + image_txt = [] # Write labels file for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"): img = images[f"{img_id:d}"] - h, w, f = img["height"], img["width"], img["file_name"] + h, w = img["height"], img["width"] + f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"] + if lvis: + image_txt.append(str(Path("./images") / f)) bboxes = [] segments = [] keypoints = [] for ann in anns: - if ann["iscrowd"]: + if ann.get("iscrowd", False): continue # The COCO box format is [top left x, top left y, width, height] box = np.array(ann["bbox"], dtype=np.float64) @@ -314,7 +327,12 @@ def convert_coco( ) # cls, box or segments file.write(("%g " * len(line)).rstrip() % line + "\n") - LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}") + if lvis: + with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f: + for l in image_txt: + f.write(f"{l}\n") + + LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}") def convert_dota_to_yolo_obb(dota_root_path: str): diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index 76379104f7a8..7acf7689a2e3 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -1,20 +1,41 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import contextlib from itertools import repeat +from collections import defaultdict from multiprocessing.pool import ThreadPool from pathlib import Path import cv2 +import json import numpy as np import torch import torchvision from PIL import Image -from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable +from torch.utils.data import ConcatDataset +from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr from ultralytics.utils.ops import resample_segments -from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms +from .augment import ( + Compose, + Format, + Instances, + LetterBox, + RandomLoadText, + classify_augmentations, + classify_transforms, + v8_transforms, +) from .base import BaseDataset -from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label +from .utils import ( + HELP_URL, + LOGGER, + get_hash, + img2label_paths, + verify_image, + verify_image_label, + load_dataset_cache_file, + save_dataset_cache_file, +) # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 DATASET_CACHE_VERSION = "1.0.3" @@ -105,7 +126,7 @@ def cache_labels(self, path=Path("./labels.cache")): x["hash"] = get_hash(self.label_files + self.im_files) x["results"] = nf, nm, ne, nc, len(self.im_files) x["msgs"] = msgs # warnings - save_dataset_cache_file(self.prefix, path, x) + save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) return x def get_labels(self): @@ -339,31 +360,125 @@ def verify_images(self): x["hash"] = get_hash([x[0] for x in self.samples]) x["results"] = nf, nc, len(samples), samples x["msgs"] = msgs # warnings - save_dataset_cache_file(self.prefix, path, x) + save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) return samples -def load_dataset_cache_file(path): - """Load an Ultralytics *.cache dictionary from path.""" - import gc +class YOLOMultiModalDataset(YOLODataset): + """ + Dataset class for loading object detection and/or segmentation labels in YOLO format. + + Args: + data (dict, optional): A dataset YAML dictionary. Defaults to None. + task (str): An explicit arg to point current task, Defaults to 'detect'. + + Returns: + (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. + """ + + def __init__(self, *args, data=None, task="detect", **kwargs): + """Initializes a dataset object for object detection tasks with optional specifications.""" + super().__init__(*args, data=data, task=task, **kwargs) + + def update_labels_info(self, label): + """Add texts information for multi modal model training.""" + labels = super().update_labels_info(label) + # NOTE: some categories are concatenated with its synonyms by `/`. + labels["texts"] = [v.split("/") for _, v in self.data["names"].items()] + return labels + + def build_transforms(self, hyp=None): + """Enhances data transformations with optional text augmentation for multi-modal training.""" + transforms = super().build_transforms(hyp) + if self.augment: + # NOTE: hard-coded the args for now. + transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True)) + return transforms + + +class GroundingDataset(YOLODataset): + def __init__(self, *args, task="detect", json_file, **kwargs): + """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file.""" + assert task == "detect", "`GroundingDataset` only support `detect` task for now!" + self.json_file = json_file + super().__init__(*args, task=task, data={}, **kwargs) + + def get_img_files(self, img_path): + """The image files would be read in `get_labels` function, return empty list here.""" + return [] + + def get_labels(self): + """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.""" + labels = [] + LOGGER.info("Loading annotation file...") + with open(self.json_file, "r") as f: + annotations = json.load(f) + images = {f'{x["id"]:d}': x for x in annotations["images"]} + imgToAnns = defaultdict(list) + for ann in annotations["annotations"]: + imgToAnns[ann["image_id"]].append(ann) + for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"): + img = images[f"{img_id:d}"] + h, w, f = img["height"], img["width"], img["file_name"] + im_file = Path(self.img_path) / f + if not im_file.exists(): + continue + self.im_files.append(str(im_file)) + bboxes = [] + cat2id = {} + texts = [] + for ann in anns: + if ann["iscrowd"]: + continue + box = np.array(ann["bbox"], dtype=np.float32) + box[:2] += box[2:] / 2 + box[[0, 2]] /= float(w) + box[[1, 3]] /= float(h) + if box[2] <= 0 or box[3] <= 0: + continue + + cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]]) + if cat_name not in cat2id: + cat2id[cat_name] = len(cat2id) + texts.append([cat_name]) + cls = cat2id[cat_name] # class + box = [cls] + box.tolist() + if box not in bboxes: + bboxes.append(box) + lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32) + labels.append( + dict( + im_file=im_file, + shape=(h, w), + cls=lb[:, 0:1], # n, 1 + bboxes=lb[:, 1:], # n, 4 + normalized=True, + bbox_format="xywh", + texts=texts, + ) + ) + return labels + + def build_transforms(self, hyp=None): + """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity.""" + transforms = super().build_transforms(hyp) + if self.augment: + # NOTE: hard-coded the args for now. + transforms.insert(-1, RandomLoadText(max_samples=80, padding=True)) + return transforms - gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 - cache = np.load(str(path), allow_pickle=True).item() # load dict - gc.enable() - return cache +class YOLOConcatDataset(ConcatDataset): + """ + Dataset as a concatenation of multiple datasets. -def save_dataset_cache_file(prefix, path, x): - """Save an Ultralytics dataset *.cache dictionary x to path.""" - x["version"] = DATASET_CACHE_VERSION # add cache version - if is_dir_writeable(path.parent): - if path.exists(): - path.unlink() # remove *.cache file if exists - np.save(str(path), x) # save cache for next time - path.with_suffix(".cache.npy").rename(path) # remove .npy suffix - LOGGER.info(f"{prefix}New cache created: {path}") - else: - LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") + This class is useful to assemble different existing datasets. + """ + + @staticmethod + def collate_fn(batch): + """Collates data samples into batches.""" + return YOLODataset.collate_fn(batch) # TODO: support semantic segmentation diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index c0a07736830b..1ad926c9d50b 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -29,6 +29,7 @@ emojis, yaml_load, yaml_save, + is_dir_writeable, ) from ultralytics.utils.checks import check_file, check_font, is_ascii from ultralytics.utils.downloads import download, safe_download, unzip_file @@ -303,7 +304,7 @@ def check_det_dataset(dataset, autodownload=True): # Set paths data["path"] = path # download scripts - for k in "train", "val", "test": + for k in "train", "val", "test", "minival": if data.get(k): # prepend path if isinstance(data[k], str): x = (path / data[k]).resolve() @@ -649,3 +650,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label with open(path.parent / txt[i], "a") as f: f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file + + +def load_dataset_cache_file(path): + """Load an Ultralytics *.cache dictionary from path.""" + import gc + + gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 + cache = np.load(str(path), allow_pickle=True).item() # load dict + gc.enable() + return cache + + +def save_dataset_cache_file(prefix, path, x, version): + """Save an Ultralytics dataset *.cache dictionary x to path.""" + x["version"] = version # add cache version + if is_dir_writeable(path.parent): + if path.exists(): + path.unlink() # remove *.cache file if exists + np.save(str(path), x) # save cache for next time + path.with_suffix(".cache.npy").rename(path) # remove .npy suffix + LOGGER.info(f"{prefix}New cache created: {path}") + else: + LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 8b5e47ccfcdb..b8d917ce8ced 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -126,22 +126,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): # Model and Dataset self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt - try: - if self.args.task == "classify": - self.data = check_cls_dataset(self.args.data) - elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( - "detect", - "segment", - "pose", - "obb", - ): - self.data = check_det_dataset(self.args.data) - if "yaml_file" in self.data: - self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage - except Exception as e: - raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e - - self.trainset, self.testset = self.get_dataset(self.data) + self.trainset, self.testset = self.get_dataset() self.ema = None # Optimization utils init @@ -509,13 +494,27 @@ def save_model(self): if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt' - @staticmethod - def get_dataset(data): + def get_dataset(self): """ Get train, val path from data dict if it exists. Returns None if data format is not recognized. """ + try: + if self.args.task == "classify": + data = check_cls_dataset(self.args.data) + elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( + "detect", + "segment", + "pose", + "obb", + ): + data = check_det_dataset(self.args.data) + if "yaml_file" in data: + self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage + except Exception as e: + raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e + self.data = data return data["train"], data.get("val") or data.get("test") def setup_model(self): @@ -666,8 +665,8 @@ def resume_training(self, ckpt): if ckpt is None: return best_fitness = 0.0 - start_epoch = ckpt["epoch"] + 1 - if ckpt["optimizer"] is not None: + start_epoch = ckpt.get("epoch", -1) + 1 + if ckpt.get("optimizer", None) is not None: self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer best_fitness = ckpt["best_fitness"] if self.ema and ckpt.get("ema"): diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index f7bf5add8908..544938a5b42a 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -35,7 +35,7 @@ def __init__(self, source, results, device="cuda") -> None: except ImportError: from ultralytics.utils.checks import check_requirements - check_requirements("git+https://github.com/openai/CLIP.git") + check_requirements("git+https://github.com/ultralytics/CLIP.git") import clip self.clip = clip diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py index 7b1a59770652..8d9aedfecb83 100644 --- a/ultralytics/models/yolo/__init__.py +++ b/ultralytics/models/yolo/__init__.py @@ -1,7 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.models.yolo import classify, detect, obb, pose, segment +from ultralytics.models.yolo import classify, detect, obb, pose, segment, world from .model import YOLO, YOLOWorld -__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld" +__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld" diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py index 8226cd694cc9..58fcac4d4cde 100644 --- a/ultralytics/models/yolo/detect/val.py +++ b/ultralytics/models/yolo/detect/val.py @@ -33,6 +33,7 @@ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callba super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.nt_per_class = None self.is_coco = False + self.is_lvis = False self.class_map = None self.args.task = "detect" self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) @@ -66,8 +67,9 @@ def init_metrics(self, model): """Initialize evaluation metrics for YOLO.""" val = self.data.get(self.args.split, "") # validation path self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO - self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) - self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO + self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS + self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names))) + self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO self.names = model.names self.nc = len(model.names) self.metrics.names = self.names @@ -266,7 +268,8 @@ def pred_to_json(self, predn, filename): self.jdict.append( { "image_id": image_id, - "category_id": self.class_map[int(p[5])], + "category_id": self.class_map[int(p[5])] + + (1 if self.is_lvis else 0), # index starts from 1 if it's lvis "bbox": [round(x, 3) for x in b], "score": round(p[4], 5), } @@ -274,26 +277,42 @@ def pred_to_json(self, predn, filename): def eval_json(self, stats): """Evaluates YOLO output in JSON format and returns performance statistics.""" - if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict): pred_json = self.save_dir / "predictions.json" # predictions - LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") + anno_json = ( + self.data["path"] + / "annotations" + / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json") + ) # annotations + pkg = "pycocotools" if self.is_coco else "lvis" + LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...") try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements("pycocotools>=2.0.6") - from pycocotools.coco import COCO # noqa - from pycocotools.cocoeval import COCOeval # noqa - - for x in anno_json, pred_json: + for x in pred_json, anno_json: assert x.is_file(), f"{x} file not found" - anno = COCO(str(anno_json)) # init annotations api - pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - eval = COCOeval(anno, pred, "bbox") + check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3") if self.is_coco: - eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + eval = COCOeval(anno, pred, "bbox") + else: + from lvis import LVIS, LVISEval + + anno = LVIS(str(anno_json)) # init annotations api + pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path) + eval = LVISEval(anno, pred, "bbox") + eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval eval.evaluate() eval.accumulate() eval.summarize() - stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50 + if self.is_lvis: + eval.print_results() # explicitly call print_results + # update mAP50-95 and mAP50 + stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = ( + eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]] + ) except Exception as e: - LOGGER.warning(f"pycocotools unable to run: {e}") + LOGGER.warning(f"{pkg} unable to run: {e}") return stats diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py index 1bd38d3be22b..18accac49997 100644 --- a/ultralytics/models/yolo/model.py +++ b/ultralytics/models/yolo/model.py @@ -83,6 +83,7 @@ def task_map(self): "model": WorldModel, "validator": yolo.detect.DetectionValidator, "predictor": yolo.detect.DetectionPredictor, + "trainer": yolo.world.WorldTrainer, } } diff --git a/ultralytics/models/yolo/world/__init__.py b/ultralytics/models/yolo/world/__init__.py new file mode 100644 index 000000000000..1d401999cdf6 --- /dev/null +++ b/ultralytics/models/yolo/world/__init__.py @@ -0,0 +1,5 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from .train import WorldTrainer + +__all__ = ["WorldTrainer"] diff --git a/ultralytics/models/yolo/world/train.py b/ultralytics/models/yolo/world/train.py new file mode 100644 index 000000000000..38cd4cf608c6 --- /dev/null +++ b/ultralytics/models/yolo/world/train.py @@ -0,0 +1,91 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.models import yolo +from ultralytics.nn.tasks import WorldModel +from ultralytics.utils import DEFAULT_CFG, RANK +from ultralytics.data import build_yolo_dataset +from ultralytics.utils.torch_utils import de_parallel +from ultralytics.utils.checks import check_requirements +import itertools + +try: + import clip +except ImportError: + check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + + +def on_pretrain_routine_end(trainer): + """Callback.""" + if RANK in (-1, 0): + # NOTE: for evaluation + names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())] + de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False) + device = next(trainer.model.parameters()).device + text_model, _ = clip.load("ViT-B/32", device=device) + for p in text_model.parameters(): + p.requires_grad_(False) + trainer.text_model = text_model + + +class WorldTrainer(yolo.detect.DetectionTrainer): + """ + A class to fine-tune a world model on a close-set dataset. + + Example: + ```python + from ultralytics.models.yolo.world import WorldModel + + args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3) + trainer = WorldTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a WorldTrainer object with given arguments.""" + if overrides is None: + overrides = {} + super().__init__(cfg, overrides, _callbacks) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return WorldModel initialized with specified config and weights.""" + # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`. + # NOTE: Following the official config, nc hard-coded to 80 for now. + model = WorldModel( + cfg["yaml_file"] if isinstance(cfg, dict) else cfg, + ch=3, + nc=min(self.data["nc"], 80), + verbose=verbose and RANK == -1, + ) + if weights: + model.load(weights) + self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end) + + return model + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return build_yolo_dataset( + self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train" + ) + + def preprocess_batch(self, batch): + """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed.""" + batch = super().preprocess_batch(batch) + + # NOTE: add text features + texts = list(itertools.chain(*batch["texts"])) + text_token = clip.tokenize(texts).to(batch["img"].device) + txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32 + txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) + batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1]) + return batch diff --git a/ultralytics/models/yolo/world/train_world.py b/ultralytics/models/yolo/world/train_world.py new file mode 100644 index 000000000000..d844e3c8c678 --- /dev/null +++ b/ultralytics/models/yolo/world/train_world.py @@ -0,0 +1,108 @@ +from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset +from ultralytics.data.utils import check_det_dataset +from ultralytics.models.yolo.world import WorldTrainer +from ultralytics.utils.torch_utils import de_parallel +from ultralytics.utils import DEFAULT_CFG + + +class WorldTrainerFromScratch(WorldTrainer): + """ + A class extending the WorldTrainer class for training a world model from scratch on open-set dataset. + + Example: + ```python + from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch + from ultralytics import YOLOWorld + + data = dict( + train=dict( + yolo_data=["Objects365.yaml"], + grounding_data=[ + dict( + img_path="../datasets/flickr30k/images", + json_file="../datasets/flickr30k/final_flickr_separateGT_train.json", + ), + dict( + img_path="../datasets/GQA/images", + json_file="../datasets/GQA/final_mixed_train_no_coco.json", + ), + ], + ), + val=dict(yolo_data=["lvis.yaml"]), + ) + + model = YOLOWorld("yolov8s-worldv2.yaml") + model.train(data=data, trainer=WorldTrainerFromScratch) + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a WorldTrainer object with given arguments.""" + if overrides is None: + overrides = {} + super().__init__(cfg, overrides, _callbacks) + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (List[str] | str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + if mode == "train": + dataset = [ + build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True) + if isinstance(im_path, str) + else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs) + for im_path in img_path + ] + return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0] + else: + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) + + def get_dataset(self): + """ + Get train, val path from data dict if it exists. + + Returns None if data format is not recognized. + """ + final_data = dict() + data_yaml = self.args.data + assert data_yaml.get("train", False) # object365.yaml + assert data_yaml.get("val", False) # lvis.yaml + data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()} + assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}." + val_split = "minival" if "lvis" in data["val"][0]["val"] else "val" + for d in data["val"]: + if d.get("minival") is None: # for lvis dataset + continue + d["minival"] = str(d["path"] / d["minival"]) + for s in ["train", "val"]: + final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]] + # save grounding data if there's one + grounding_data = data_yaml[s].get("grounding_data") + if grounding_data is None: + continue + grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data + for g in grounding_data: + assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}" + final_data[s] += grounding_data + # NOTE: to make training work properly, set `nc` and `names` + final_data["nc"] = data["val"][0]["nc"] + final_data["names"] = data["val"][0]["names"] + self.data = final_data + return final_data["train"], final_data["val"][0] + + def plot_training_labels(self): + """DO NOT plot labels.""" + pass + + def final_eval(self): + """Performs final evaluation and validation for object detection YOLO-World model.""" + val = self.args.data["val"]["yolo_data"][0] + self.validator.args.data = val + self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val" + return super().final_eval() diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index ca991f6e0905..a19e83b7f81e 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module): def __init__(self): """Initializes ContrastiveHead with specified region-text similarity parameters.""" super().__init__() - self.bias = nn.Parameter(torch.zeros([])) + # NOTE: use -10.0 to keep the init cls loss consistency with other losses + self.bias = nn.Parameter(torch.tensor([-10.0])) self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log()) def forward(self, x, w): @@ -542,7 +543,8 @@ def __init__(self, embed_dims: int): """Initialize ContrastiveHead with region-text similarity parameters.""" super().__init__() self.norm = nn.BatchNorm2d(embed_dims) - self.bias = nn.Parameter(torch.zeros([])) + # NOTE: use -10.0 to keep the init cls loss consistency with other losses + self.bias = nn.Parameter(torch.tensor([-10.0])) # use -1.0 is more stable self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 9cd794e4d6d3..13b4c7f44acf 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -250,6 +250,15 @@ def forward(self, x, text): y = torch.cat((dbox, cls.sigmoid()), 1) return y if self.export else (y, x) + def bias_init(self): + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + class RTDETRDecoder(nn.Module): """ diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index f116ed2cf57a..9b746d7a6e21 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -564,28 +564,28 @@ def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True): self.clip_model = None # CLIP model placeholder super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) - def set_classes(self, text): - """Perform a forward pass with optional profiling, visualization, and embedding extraction.""" + def set_classes(self, text, batch=80, cache_clip_model=True): + """Set classes in advance so that model could do offline-inference without clip model.""" try: import clip except ImportError: - check_requirements("git+https://github.com/openai/CLIP.git") + check_requirements("git+https://github.com/ultralytics/CLIP.git") import clip - if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute + if ( + not getattr(self, "clip_model", None) and cache_clip_model + ): # for backwards compatibility of models lacking clip_model attribute self.clip_model = clip.load("ViT-B/32")[0] - device = next(self.clip_model.parameters()).device + model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0] + device = next(model.parameters()).device text_token = clip.tokenize(text).to(device) - txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32) + txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)] + txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0) txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) - self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach() + self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]) self.model[-1].nc = len(text) - def init_criterion(self): - """Initialize the loss criterion for the model.""" - raise NotImplementedError - - def predict(self, x, profile=False, visualize=False, augment=False, embed=None): + def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None): """ Perform a forward pass through the model. @@ -593,13 +593,14 @@ def predict(self, x, profile=False, visualize=False, augment=False, embed=None): x (torch.Tensor): The input tensor. profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. + txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None. augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): Model's output tensor. """ - txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype) + txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype) if len(txt_feats) != len(x): txt_feats = txt_feats.repeat(len(x), 1, 1) ori_txt_feats = txt_feats.clone() @@ -627,6 +628,21 @@ def predict(self, x, profile=False, visualize=False, augment=False, embed=None): return torch.unbind(torch.cat(embeddings, 1), dim=0) return x + def loss(self, batch, preds=None): + """ + Compute loss. + + Args: + batch (dict): Batch to compute loss on. + preds (torch.Tensor | List[torch.Tensor]): Predictions. + """ + if not hasattr(self, "criterion"): + self.criterion = self.init_criterion() + + if preds is None: + preds = self.forward(batch["img"], txt_feats=batch["txt_feats"]) + return self.criterion(preds, batch) + class Ensemble(nn.ModuleList): """Ensemble of models.""" diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 360a292ab9d0..7484858946bc 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -157,7 +157,7 @@ def __init__(self, model): # model must be de-paralleled self.hyp = h self.stride = m.stride # model strides self.nc = m.nc # number of classes - self.no = m.no + self.no = m.nc + m.reg_max * 4 self.reg_max = m.reg_max self.device = device