diff --git a/README.md b/README.md index 1d6b4e153d82..7763d174f92b 100644 --- a/README.md +++ b/README.md @@ -160,46 +160,31 @@ python train.py --data coco.yaml --cfg yolov5n.yaml --weights '' --batch-size 12 -##
Environments
- -Get started in seconds with our verified environments. Click each icon below for details. - -
- - - - - - - - - - - - - - -
##
Integrations
+ +
+ + + - + - + - +
-|Deci ⭐ NEW|ClearML ⭐ NEW|Roboflow|Weights & Biases -|:-:|:-:|:-:|:-:| -|Automatically compile and quantize YOLOv5 for better inference performance in one click at [Deci](https://bit.ly/yolov5-deci-platform)|Automatically track, visualize and even remotely train YOLOv5 using [ClearML](https://cutt.ly/yolov5-readme-clearml) (open-source!)|Label and export your custom datasets directly to YOLOv5 for training with [Roboflow](https://roboflow.com/?ref=ultralytics) |Automatically track and visualize all your YOLOv5 training runs in the cloud with [Weights & Biases](https://wandb.ai/site?utm_campaign=repo_yolo_readme) +|Comet ⭐ NEW|Deci ⭐ NEW|ClearML ⭐ NEW|Roboflow|Weights & Biases +|:-:|:-:|:-:|:-:|:-:| +|Visualize model metrics and predictions and upload models and datasets in realtime with [Comet](https://www.comet.com/site/?ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration)|Automatically compile and quantize YOLOv5 for better inference performance in one click at [Deci](https://bit.ly/yolov5-deci-platform)|Automatically track, visualize and even remotely train YOLOv5 using [ClearML](https://cutt.ly/yolov5-readme-clearml) (open-source!)|Label and export your custom datasets directly to YOLOv5 for training with [Roboflow](https://roboflow.com/?ref=ultralytics) |Automatically track and visualize all your YOLOv5 training runs in the cloud with [Weights & Biases](https://wandb.ai/site?utm_campaign=repo_yolo_readme) ##
Why YOLOv5
@@ -323,6 +308,28 @@ python export.py --weights yolov5s-cls.pt resnet50.pt efficientnet_b0.pt --inclu +##
Environments
+ +Get started in seconds with our verified environments. Click each icon below for details. + +
+ + + + + + + + + + + + + + +
+ + ##
Contribute
We love your input! We want to make contributing to YOLOv5 as easy and transparent as possible. Please see our [Contributing Guide](CONTRIBUTING.md) to get started, and fill out the [YOLOv5 Survey](https://ultralytics.com/survey?utm_source=github&utm_medium=social&utm_campaign=Survey) to send us feedback on your experiences. Thank you to all our contributors! diff --git a/train.py b/train.py index 29293aa612cf..e16c17c499f0 100644 --- a/train.py +++ b/train.py @@ -52,6 +52,7 @@ init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save) from utils.loggers import Loggers +from utils.loggers.comet.comet_utils import check_comet_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss from utils.metrics import fitness @@ -330,7 +331,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) pbar.set_description(('%11s' * 2 + '%11.4g' * 5) % (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) - callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths) + callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths, list(mloss)) if callbacks.stop_training: return # end batch ------------------------------------------------------------------------------------------------ @@ -465,11 +466,11 @@ def parse_opt(known=False): parser.add_argument('--seed', type=int, default=0, help='Global training seed') parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') - # Weights & Biases arguments - parser.add_argument('--entity', default=None, help='W&B: Entity') - parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option') - parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval') - parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') + # Logger arguments + parser.add_argument('--entity', default=None, help='Entity') + parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='Upload data, "val" option') + parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval') + parser.add_argument('--artifact_alias', type=str, default='latest', help='Version of dataset artifact to use') return parser.parse_known_args()[0] if known else parser.parse_args() @@ -481,8 +482,8 @@ def main(opt, callbacks=Callbacks()): check_git_status() check_requirements() - # Resume - if opt.resume and not (check_wandb_resume(opt) or opt.evolve): # resume from specified or most recent last.pt + # Resume (from specified or most recent last.pt) + if opt.resume and not check_wandb_resume(opt) and not check_comet_resume(opt) or opt.evolve: last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run()) opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml opt_data = opt.data # original dataset diff --git a/tutorial.ipynb b/tutorial.ipynb index 12840063b1f1..957437b2be6d 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -413,7 +413,7 @@ "import utils\n", "display = utils.notebook_init() # checks" ], - "execution_count": 1, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -465,7 +465,7 @@ "!python detect.py --weights yolov5s.pt --img 640 --conf 0.25 --source data/images\n", "# display.Image(filename='runs/detect/exp/zidane.jpg', width=600)" ], - "execution_count": 2, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -535,7 +535,7 @@ "torch.hub.download_url_to_file('https://ultralytics.com/assets/coco2017val.zip', 'tmp.zip') # download (780M - 5000 images)\n", "!unzip -q tmp.zip -d ../datasets && rm tmp.zip # unzip" ], - "execution_count": 3, + "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -566,7 +566,7 @@ "# Validate YOLOv5s on COCO val\n", "!python val.py --weights yolov5s.pt --data coco.yaml --img 640 --half" ], - "execution_count": 4, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -653,11 +653,14 @@ "cell_type": "code", "source": [ "#@title Select YOLOv5 🚀 logger {run: 'auto'}\n", - "logger = 'TensorBoard' #@param ['TensorBoard', 'ClearML', 'W&B']\n", + "logger = 'TensorBoard' #@param ['TensorBoard', 'Comet', 'ClearML', 'W&B']\n", "\n", "if logger == 'TensorBoard':\n", " %load_ext tensorboard\n", " %tensorboard --logdir runs/train\n", + "elif logger == 'Comet':\n", + " %pip install -q comet_ml\n", + " import comet_ml; comet_ml.init()\n", "elif logger == 'ClearML':\n", " %pip install -q clearml && clearml-init\n", "elif logger == 'W&B':\n", @@ -683,7 +686,7 @@ "# Train YOLOv5s on COCO128 for 3 epochs\n", "!python train.py --img 640 --batch 16 --epochs 3 --data coco128.yaml --weights yolov5s.pt --cache" ], - "execution_count": 5, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -857,6 +860,28 @@ "# 4. Visualize" ] }, + { + "cell_type": "markdown", + "source": [ + "## Comet Logging and Visualization 🌟 NEW\n", + "[Comet](https://www.comet.com/site/?ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration) is now fully integrated with YOLOv5. Track and visualize model metrics in real time, save your hyperparameters, datasets, and model checkpoints, and visualize your model predictions with [Comet Custom Panels](https://www.comet.com/docs/v2/guides/comet-dashboard/code-panels/about-panels/?ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration)! Comet makes sure you never lose track of your work and makes it easy to share results and collaborate across teams of all sizes! \n", + "\n", + "Getting started is easy:\n", + "```shell\n", + "pip install comet_ml # 1. install\n", + "export COMET_API_KEY= # 2. paste API key\n", + "python train.py --img 640 --epochs 3 --data coco128.yaml --weights yolov5s.pt # 3. train\n", + "```\n", + "\n", + "To learn more about all of the supported Comet features for this integration, check out the [Comet Tutorial](https://github.com/ultralytics/yolov5/tree/master/utils/loggers/comet). If you'd like to learn more about Comet, head over to our [documentation](https://www.comet.com/docs/v2/?ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration). Get started by trying out the Comet Colab Notebook:\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RG0WOQyxlDlo5Km8GogJpIEJlg_5lyYO?usp=sharing)\n", + "\n", + "\"yolo-ui\"" + ], + "metadata": { + "id": "nWOsI5wJR1o3" + } + }, { "cell_type": "markdown", "source": [ @@ -1096,4 +1121,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 3aee35844f52..f29debb76907 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -17,7 +17,7 @@ from utils.plots import plot_images, plot_labels, plot_results from utils.torch_utils import de_parallel -LOGGERS = ('csv', 'tb', 'wandb', 'clearml') # *.csv, TensorBoard, Weights & Biases, ClearML +LOGGERS = ('csv', 'tb', 'wandb', 'clearml', 'comet') # *.csv, TensorBoard, Weights & Biases, ClearML RANK = int(os.getenv('RANK', -1)) try: @@ -41,6 +41,18 @@ except (ImportError, AssertionError): clearml = None +try: + if RANK not in [0, -1]: + comet_ml = None + else: + import comet_ml + + assert hasattr(comet_ml, '__version__') # verify package import not local dir + from utils.loggers.comet import CometLogger + +except (ModuleNotFoundError, ImportError, AssertionError): + comet_ml = None + class Loggers(): # YOLOv5 Loggers class @@ -80,7 +92,10 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, prefix = colorstr('ClearML: ') s = f"{prefix}run 'pip install clearml' to automatically track, visualize and remotely train YOLOv5 🚀 in ClearML" self.logger.info(s) - + if not comet_ml: + prefix = colorstr('Comet: ') + s = f"{prefix}run 'pip install comet_ml' to automatically track and visualize YOLOv5 🚀 runs in Comet" + self.logger.info(s) # TensorBoard s = self.save_dir if 'tb' in self.include and not self.opt.evolve: @@ -107,6 +122,18 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, else: self.clearml = None + # Comet + if comet_ml and 'comet' in self.include: + if isinstance(self.opt.resume, str) and self.opt.resume.startswith("comet://"): + run_id = self.opt.resume.split("/")[-1] + self.comet_logger = CometLogger(self.opt, self.hyp, run_id=run_id) + + else: + self.comet_logger = CometLogger(self.opt, self.hyp) + + else: + self.comet_logger = None + @property def remote_dataset(self): # Get data_dict if custom dataset artifact link is provided @@ -115,12 +142,18 @@ def remote_dataset(self): data_dict = self.clearml.data_dict if self.wandb: data_dict = self.wandb.data_dict + if self.comet_logger: + data_dict = self.comet_logger.data_dict return data_dict def on_train_start(self): - # Callback runs on train start - pass + if self.comet_logger: + self.comet_logger.on_train_start() + + def on_pretrain_routine_start(self): + if self.comet_logger: + self.comet_logger.on_pretrain_routine_start() def on_pretrain_routine_end(self, labels, names): # Callback runs on pre-train routine end @@ -131,8 +164,11 @@ def on_pretrain_routine_end(self, labels, names): self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) # if self.clearml: # pass # ClearML saves these images automatically using hooks + if self.comet_logger: + self.comet_logger.on_pretrain_routine_end(paths) - def on_train_batch_end(self, model, ni, imgs, targets, paths): + def on_train_batch_end(self, model, ni, imgs, targets, paths, vals): + log_dict = dict(zip(self.keys[0:3], vals)) # Callback runs on train batch end # ni: number integrated batches (since train start) if self.plots: @@ -148,11 +184,21 @@ def on_train_batch_end(self, model, ni, imgs, targets, paths): if self.clearml: self.clearml.log_debug_samples(files, title='Mosaics') + if self.comet_logger: + self.comet_logger.on_train_batch_end(log_dict, step=ni) + def on_train_epoch_end(self, epoch): # Callback runs on train epoch end if self.wandb: self.wandb.current_epoch = epoch + 1 + if self.comet_logger: + self.comet_logger.on_train_epoch_end(epoch) + + def on_val_start(self): + if self.comet_logger: + self.comet_logger.on_val_start() + def on_val_image_end(self, pred, predn, path, names, im): # Callback runs on val image end if self.wandb: @@ -160,7 +206,11 @@ def on_val_image_end(self, pred, predn, path, names, im): if self.clearml: self.clearml.log_image_with_boxes(path, pred, names, im) - def on_val_end(self): + def on_val_batch_end(self, batch_i, im, targets, paths, shapes, out): + if self.comet_logger: + self.comet_logger.on_val_batch_end(batch_i, im, targets, paths, shapes, out) + + def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix): # Callback runs on val end if self.wandb or self.clearml: files = sorted(self.save_dir.glob('val*.jpg')) @@ -169,6 +219,9 @@ def on_val_end(self): if self.clearml: self.clearml.log_debug_samples(files, title='Validation') + if self.comet_logger: + self.comet_logger.on_val_end(nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix) + def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): # Callback runs at the end of each fit (train+val) epoch x = dict(zip(self.keys, vals)) @@ -199,6 +252,9 @@ def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): self.clearml.current_epoch_logged_images = set() # reset epoch image limit self.clearml.current_epoch += 1 + if self.comet_logger: + self.comet_logger.on_fit_epoch_end(x, epoch=epoch) + def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): # Callback runs on model save event if (epoch + 1) % self.opt.save_period == 0 and not final_epoch and self.opt.save_period != -1: @@ -209,6 +265,9 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): model_name='Latest Model', auto_delete_file=False) + if self.comet_logger: + self.comet_logger.on_model_save(last, epoch, final_epoch, best_fitness, fi) + def on_train_end(self, last, best, epoch, results): # Callback runs on training end, i.e. saving best model if self.plots: @@ -237,10 +296,16 @@ def on_train_end(self, last, best, epoch, results): name='Best Model', auto_delete_file=False) + if self.comet_logger: + final_results = dict(zip(self.keys[3:10], results)) + self.comet_logger.on_train_end(files, self.save_dir, last, best, epoch, final_results) + def on_params_update(self, params: dict): # Update hyperparams or configs of the experiment if self.wandb: self.wandb.wandb_run.config.update(params, allow_val_change=True) + if self.comet_logger: + self.comet_logger.on_params_update(params) class GenericLogger: diff --git a/utils/loggers/comet/README.md b/utils/loggers/comet/README.md new file mode 100644 index 000000000000..7b0b8e0e2f09 --- /dev/null +++ b/utils/loggers/comet/README.md @@ -0,0 +1,256 @@ + + +# YOLOv5 with Comet + +This guide will cover how to use YOLOv5 with [Comet](https://www.comet.com/site/?ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration) + +# About Comet + +Comet builds tools that help data scientists, engineers, and team leaders accelerate and optimize machine learning and deep learning models. + +Track and visualize model metrics in real time, save your hyperparameters, datasets, and model checkpoints, and visualize your model predictions with [Comet Custom Panels](https://www.comet.com/examples/comet-example-yolov5?shareable=YcwMiJaZSXfcEXpGOHDD12vA1&ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration)! +Comet makes sure you never lose track of your work and makes it easy to share results and collaborate across teams of all sizes! + +# Getting Started + +## Install Comet + +```shell +pip install comet_ml +``` + +## Configure Comet Credentials + +There are two ways to configure Comet with YOLOv5. + +You can either set your credentials through enviroment variables + +**Environment Variables** + +```shell +export COMET_API_KEY= +export COMET_PROJECT_NAME= # This will default to 'yolov5' +``` + +Or create a `.comet.config` file in your working directory and set your credentials there. + +**Comet Configuration File** + +``` +[comet] +api_key= +project_name= # This will default to 'yolov5' +``` + +## Run the Training Script + +```shell +# Train YOLOv5s on COCO128 for 5 epochs +python train.py --img 640 --batch 16 --epochs 5 --data coco128.yaml --weights yolov5s.pt +``` + +That's it! Comet will automatically log your hyperparameters, command line arguments, training and valiation metrics. You can visualize and analyze your runs in the Comet UI + +yolo-ui + +# Try out an Example! +Check out an example of a [completed run here](https://www.comet.com/examples/comet-example-yolov5/a0e29e0e9b984e4a822db2a62d0cb357?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step&ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration) + +Or better yet, try it out yourself in this Colab Notebook + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RG0WOQyxlDlo5Km8GogJpIEJlg_5lyYO?usp=sharing) + +# Log automatically + +By default, Comet will log the following items + +## Metrics +- Box Loss, Object Loss, Classification Loss for the training and validation data +- mAP_0.5, mAP_0.5:0.95 metrics for the validation data. +- Precision and Recall for the validation data + +## Parameters + +- Model Hyperparameters +- All parameters passed through the command line options + +## Visualizations + +- Confusion Matrix of the model predictions on the validation data +- Plots for the PR and F1 curves across all classes +- Correlogram of the Class Labels + +# Configure Comet Logging + +Comet can be configured to log additional data either through command line flags passed to the training script +or through environment variables. + +```shell +export COMET_MODE=online # Set whether to run Comet in 'online' or 'offline' mode. Defaults to online +export COMET_MODEL_NAME= #Set the name for the saved model. Defaults to yolov5 +export COMET_LOG_CONFUSION_MATRIX=false # Set to disable logging a Comet Confusion Matrix. Defaults to true +export COMET_MAX_IMAGE_UPLOADS= # Controls how many total image predictions to log to Comet. Defaults to 100. +export COMET_LOG_PER_CLASS_METRICS=true # Set to log evaluation metrics for each detected class at the end of training. Defaults to false +export COMET_DEFAULT_CHECKPOINT_FILENAME= # Set this if you would like to resume training from a different checkpoint. Defaults to 'last.pt' +export COMET_LOG_BATCH_LEVEL_METRICS=true # Set this if you would like to log training metrics at the batch level. Defaults to false. +export COMET_LOG_PREDICTIONS=true # Set this to false to disable logging model predictions +``` + +## Logging Checkpoints with Comet + +Logging Models to Comet is disabled by default. To enable it, pass the `save-period` argument to the training script. This will save the +logged checkpoints to Comet based on the interval value provided by `save-period` + +```shell +python train.py \ +--img 640 \ +--batch 16 \ +--epochs 5 \ +--data coco128.yaml \ +--weights yolov5s.pt \ +--save-period 1 +``` + +## Logging Model Predictions + +By default, model predictions (images, ground truth labels and bounding boxes) will be logged to Comet. + +You can control the frequency of logged predictions and the associated images by passing the `bbox_interval` command line argument. Predictions can be visualized using Comet's Object Detection Custom Panel. This frequency corresponds to every Nth batch of data per epoch. In the example below, we are logging every 2nd batch of data for each epoch. + +**Note:** The YOLOv5 validation dataloader will default to a batch size of 32, so you will have to set the logging frequency accordingly. + +Here is an [example project using the Panel](https://www.comet.com/examples/comet-example-yolov5?shareable=YcwMiJaZSXfcEXpGOHDD12vA1&ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration) + + +```shell +python train.py \ +--img 640 \ +--batch 16 \ +--epochs 5 \ +--data coco128.yaml \ +--weights yolov5s.pt \ +--bbox_interval 2 +``` + +### Controlling the number of Prediction Images logged to Comet + +When logging predictions from YOLOv5, Comet will log the images associated with each set of predictions. By default a maximum of 100 validation images are logged. You can increase or decrease this number using the `COMET_MAX_IMAGE_UPLOADS` environment variable. + +```shell +env COMET_MAX_IMAGE_UPLOADS=200 python train.py \ +--img 640 \ +--batch 16 \ +--epochs 5 \ +--data coco128.yaml \ +--weights yolov5s.pt \ +--bbox_interval 1 +``` + +### Logging Class Level Metrics + +Use the `COMET_LOG_PER_CLASS_METRICS` environment variable to log mAP, precision, recall, f1 for each class. + +```shell +env COMET_LOG_PER_CLASS_METRICS=true python train.py \ +--img 640 \ +--batch 16 \ +--epochs 5 \ +--data coco128.yaml \ +--weights yolov5s.pt +``` + +## Uploading a Dataset to Comet Artifacts + +If you would like to store your data using [Comet Artifacts](https://www.comet.com/docs/v2/guides/data-management/using-artifacts/#learn-more?ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration), you can do so using the `upload_dataset` flag. + +The dataset be organized in the way described in the [YOLOv5 documentation](https://docs.ultralytics.com/tutorials/train-custom-datasets/#3-organize-directories). The dataset config `yaml` file must follow the same format as that of the `coco128.yaml` file. + +```shell +python train.py \ +--img 640 \ +--batch 16 \ +--epochs 5 \ +--data coco128.yaml \ +--weights yolov5s.pt \ +--upload_dataset +``` + +You can find the uploaded dataset in the Artifacts tab in your Comet Workspace +artifact-1 + +You can preview the data directly in the Comet UI. +artifact-2 + +Artifacts are versioned and also support adding metadata about the dataset. Comet will automatically log the metadata from your dataset `yaml` file +artifact-3 + +### Using a saved Artifact + +If you would like to use a dataset from Comet Artifacts, set the `path` variable in your dataset `yaml` file to point to the following Artifact resource URL. + +``` +# contents of artifact.yaml file +path: "comet:///:" +``` +Then pass this file to your training script in the following way + +```shell +python train.py \ +--img 640 \ +--batch 16 \ +--epochs 5 \ +--data artifact.yaml \ +--weights yolov5s.pt +``` + +Artifacts also allow you to track the lineage of data as it flows through your Experimentation workflow. Here you can see a graph that shows you all the experiments that have used your uploaded dataset. +artifact-4 + +## Resuming a Training Run + +If your training run is interrupted for any reason, e.g. disrupted internet connection, you can resume the run using the `resume` flag and the Comet Run Path. + +The Run Path has the following format `comet:////`. + +This will restore the run to its state before the interruption, which includes restoring the model from a checkpoint, restoring all hyperparameters and training arguments and downloading Comet dataset Artifacts if they were used in the original run. The resumed run will continue logging to the existing Experiment in the Comet UI + +```shell +python train.py \ +--resume "comet://" +``` + +## Hyperparameter Search with the Comet Optimizer + +YOLOv5 is also integrated with Comet's Optimizer, making is simple to visualie hyperparameter sweeps in the Comet UI. + +### Configuring an Optimizer Sweep + +To configure the Comet Optimizer, you will have to create a JSON file with the information about the sweep. An example file has been provided in `utils/loggers/comet/optimizer_config.json` + +```shell +python utils/loggers/comet/hpo.py \ + --comet_optimizer_config "utils/loggers/comet/optimizer_config.json" +``` + +The `hpo.py` script accepts the same arguments as `train.py`. If you wish to pass additional arguments to your sweep simply add them after +the script. + +```shell +python utils/loggers/comet/hpo.py \ + --comet_optimizer_config "utils/loggers/comet/optimizer_config.json" \ + --save-period 1 \ + --bbox_interval 1 +``` + +### Running a Sweep in Parallel + +```shell +comet optimizer -j utils/loggers/comet/hpo.py \ + utils/loggers/comet/optimizer_config.json" +``` + +### Visualizing Results + +Comet provides a number of ways to visualize the results of your sweep. Take a look at a [project with a completed sweep here](https://www.comet.com/examples/comet-example-yolov5/view/PrlArHGuuhDTKC1UuBmTtOSXD/panels?ref=yolov5&utm_source=yolov5&utm_medium=affilliate&utm_campaign=yolov5_comet_integration) + +hyperparameter-yolo \ No newline at end of file diff --git a/utils/loggers/comet/__init__.py b/utils/loggers/comet/__init__.py new file mode 100644 index 000000000000..b168687dd7b2 --- /dev/null +++ b/utils/loggers/comet/__init__.py @@ -0,0 +1,496 @@ +import glob +import json +import logging +import os +import sys +from pathlib import Path + +logger = logging.getLogger(__name__) + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[3] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH + +try: + import comet_ml + + # Project Configuration + config = comet_ml.config.get_config() + COMET_PROJECT_NAME = config.get_string(os.getenv("COMET_PROJECT_NAME"), "comet.project_name", default="yolov5") +except (ModuleNotFoundError, ImportError): + comet_ml = None + COMET_PROJECT_NAME = None + +import torch +import torchvision.transforms as T +import yaml + +from utils.dataloaders import img2label_paths +from utils.general import check_dataset, scale_coords, xywh2xyxy +from utils.metrics import box_iou + +COMET_PREFIX = "comet://" + +COMET_MODE = os.getenv("COMET_MODE", "online") + +# Model Saving Settings +COMET_MODEL_NAME = os.getenv("COMET_MODEL_NAME", "yolov5") + +# Dataset Artifact Settings +COMET_UPLOAD_DATASET = os.getenv("COMET_UPLOAD_DATASET", "false").lower() == "true" + +# Evaluation Settings +COMET_LOG_CONFUSION_MATRIX = os.getenv("COMET_LOG_CONFUSION_MATRIX", "true").lower() == "true" +COMET_LOG_PREDICTIONS = os.getenv("COMET_LOG_PREDICTIONS", "true").lower() == "true" +COMET_MAX_IMAGE_UPLOADS = int(os.getenv("COMET_MAX_IMAGE_UPLOADS", 100)) + +# Confusion Matrix Settings +CONF_THRES = float(os.getenv("CONF_THRES", 0.001)) +IOU_THRES = float(os.getenv("IOU_THRES", 0.6)) + +# Batch Logging Settings +COMET_LOG_BATCH_METRICS = os.getenv("COMET_LOG_BATCH_METRICS", "false").lower() == "true" +COMET_BATCH_LOGGING_INTERVAL = os.getenv("COMET_BATCH_LOGGING_INTERVAL", 1) +COMET_PREDICTION_LOGGING_INTERVAL = os.getenv("COMET_PREDICTION_LOGGING_INTERVAL", 1) +COMET_LOG_PER_CLASS_METRICS = os.getenv("COMET_LOG_PER_CLASS_METRICS", "false").lower() == "true" + +RANK = int(os.getenv("RANK", -1)) + +to_pil = T.ToPILImage() + + +class CometLogger: + """Log metrics, parameters, source code, models and much more + with Comet + """ + + def __init__(self, opt, hyp, run_id=None, job_type="Training", **experiment_kwargs) -> None: + self.job_type = job_type + self.opt = opt + self.hyp = hyp + + # Comet Flags + self.comet_mode = COMET_MODE + + self.save_model = opt.save_period > -1 + self.model_name = COMET_MODEL_NAME + + # Batch Logging Settings + self.log_batch_metrics = COMET_LOG_BATCH_METRICS + self.comet_log_batch_interval = COMET_BATCH_LOGGING_INTERVAL + + # Dataset Artifact Settings + self.upload_dataset = self.opt.upload_dataset if self.opt.upload_dataset else COMET_UPLOAD_DATASET + self.resume = self.opt.resume + + # Default parameters to pass to Experiment objects + self.default_experiment_kwargs = { + "log_code": False, + "log_env_gpu": True, + "log_env_cpu": True, + "project_name": COMET_PROJECT_NAME,} + self.default_experiment_kwargs.update(experiment_kwargs) + self.experiment = self._get_experiment(self.comet_mode, run_id) + + self.data_dict = self.check_dataset(self.opt.data) + self.class_names = self.data_dict["names"] + self.num_classes = self.data_dict["nc"] + + self.logged_images_count = 0 + self.max_images = COMET_MAX_IMAGE_UPLOADS + + if run_id is None: + self.experiment.log_other("Created from", "YOLOv5") + if not isinstance(self.experiment, comet_ml.OfflineExperiment): + workspace, project_name, experiment_id = self.experiment.url.split("/")[-3:] + self.experiment.log_other( + "Run Path", + f"{workspace}/{project_name}/{experiment_id}", + ) + self.log_parameters(vars(opt)) + self.log_parameters(self.opt.hyp) + self.log_asset_data( + self.opt.hyp, + name="hyperparameters.json", + metadata={"type": "hyp-config-file"}, + ) + self.log_asset( + f"{self.opt.save_dir}/opt.yaml", + metadata={"type": "opt-config-file"}, + ) + + self.comet_log_confusion_matrix = COMET_LOG_CONFUSION_MATRIX + + if hasattr(self.opt, "conf_thres"): + self.conf_thres = self.opt.conf_thres + else: + self.conf_thres = CONF_THRES + if hasattr(self.opt, "iou_thres"): + self.iou_thres = self.opt.iou_thres + else: + self.iou_thres = IOU_THRES + + self.comet_log_predictions = COMET_LOG_PREDICTIONS + if self.opt.bbox_interval == -1: + self.comet_log_prediction_interval = self.opt.epochs // 10 if self.opt.epochs < 10 else 1 + else: + self.comet_log_prediction_interval = self.opt.bbox_interval + + if self.comet_log_predictions: + self.metadata_dict = {} + + self.comet_log_per_class_metrics = COMET_LOG_PER_CLASS_METRICS + + self.experiment.log_others({ + "comet_mode": COMET_MODE, + "comet_max_image_uploads": COMET_MAX_IMAGE_UPLOADS, + "comet_log_per_class_metrics": COMET_LOG_PER_CLASS_METRICS, + "comet_log_batch_metrics": COMET_LOG_BATCH_METRICS, + "comet_log_confusion_matrix": COMET_LOG_CONFUSION_MATRIX, + "comet_model_name": COMET_MODEL_NAME,}) + + # Check if running the Experiment with the Comet Optimizer + if hasattr(self.opt, "comet_optimizer_id"): + self.experiment.log_other("optimizer_id", self.opt.comet_optimizer_id) + self.experiment.log_other("optimizer_objective", self.opt.comet_optimizer_objective) + self.experiment.log_other("optimizer_metric", self.opt.comet_optimizer_metric) + self.experiment.log_other("optimizer_parameters", json.dumps(self.hyp)) + + def _get_experiment(self, mode, experiment_id=None): + if mode == "offline": + if experiment_id is not None: + return comet_ml.ExistingOfflineExperiment( + previous_experiment=experiment_id, + **self.default_experiment_kwargs, + ) + + return comet_ml.OfflineExperiment(**self.default_experiment_kwargs,) + + else: + try: + if experiment_id is not None: + return comet_ml.ExistingExperiment( + previous_experiment=experiment_id, + **self.default_experiment_kwargs, + ) + + return comet_ml.Experiment(**self.default_experiment_kwargs) + + except ValueError: + logger.warning("COMET WARNING: " + "Comet credentials have not been set. " + "Comet will default to offline logging. " + "Please set your credentials to enable online logging.") + return self._get_experiment("offline", experiment_id) + + return + + def log_metrics(self, log_dict, **kwargs): + self.experiment.log_metrics(log_dict, **kwargs) + + def log_parameters(self, log_dict, **kwargs): + self.experiment.log_parameters(log_dict, **kwargs) + + def log_asset(self, asset_path, **kwargs): + self.experiment.log_asset(asset_path, **kwargs) + + def log_asset_data(self, asset, **kwargs): + self.experiment.log_asset_data(asset, **kwargs) + + def log_image(self, img, **kwargs): + self.experiment.log_image(img, **kwargs) + + def log_model(self, path, opt, epoch, fitness_score, best_model=False): + if not self.save_model: + return + + model_metadata = { + "fitness_score": fitness_score[-1], + "epochs_trained": epoch + 1, + "save_period": opt.save_period, + "total_epochs": opt.epochs,} + + model_files = glob.glob(f"{path}/*.pt") + for model_path in model_files: + name = Path(model_path).name + + self.experiment.log_model( + self.model_name, + file_or_folder=model_path, + file_name=name, + metadata=model_metadata, + overwrite=True, + ) + + def check_dataset(self, data_file): + with open(data_file) as f: + data_config = yaml.safe_load(f) + + if data_config['path'].startswith(COMET_PREFIX): + path = data_config['path'].replace(COMET_PREFIX, "") + data_dict = self.download_dataset_artifact(path) + + return data_dict + + self.log_asset(self.opt.data, metadata={"type": "data-config-file"}) + + return check_dataset(data_file) + + def log_predictions(self, image, labelsn, path, shape, predn): + if self.logged_images_count >= self.max_images: + return + detections = predn[predn[:, 4] > self.conf_thres] + iou = box_iou(labelsn[:, 1:], detections[:, :4]) + mask, _ = torch.where(iou > self.iou_thres) + if len(mask) == 0: + return + + filtered_detections = detections[mask] + filtered_labels = labelsn[mask] + + processed_image = (image * 255).to(torch.uint8) + + image_id = path.split("/")[-1].split(".")[0] + image_name = f"{image_id}_curr_epoch_{self.experiment.curr_epoch}" + self.log_image(to_pil(processed_image), name=image_name) + + metadata = [] + for cls, *xyxy in filtered_labels.tolist(): + metadata.append({ + "label": f"{self.class_names[int(cls)]}-gt", + "score": 100, + "box": { + "x": xyxy[0], + "y": xyxy[1], + "x2": xyxy[2], + "y2": xyxy[3]},}) + for *xyxy, conf, cls in filtered_detections.tolist(): + metadata.append({ + "label": f"{self.class_names[int(cls)]}", + "score": conf * 100, + "box": { + "x": xyxy[0], + "y": xyxy[1], + "x2": xyxy[2], + "y2": xyxy[3]},}) + + self.metadata_dict[image_name] = metadata + self.logged_images_count += 1 + + return + + def preprocess_prediction(self, image, labels, shape, pred): + nl, _ = labels.shape[0], pred.shape[0] + + # Predictions + if self.opt.single_cls: + pred[:, 5] = 0 + + predn = pred.clone() + scale_coords(image.shape[1:], predn[:, :4], shape[0], shape[1]) + + labelsn = None + if nl: + tbox = xywh2xyxy(labels[:, 1:5]) # target boxes + scale_coords(image.shape[1:], tbox, shape[0], shape[1]) # native-space labels + labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels + scale_coords(image.shape[1:], predn[:, :4], shape[0], shape[1]) # native-space pred + + return predn, labelsn + + def add_assets_to_artifact(self, artifact, path, asset_path, split): + img_paths = sorted(glob.glob(f"{asset_path}/*")) + label_paths = img2label_paths(img_paths) + + for image_file, label_file in zip(img_paths, label_paths): + image_logical_path, label_logical_path = map(lambda x: os.path.relpath(x, path), [image_file, label_file]) + + try: + artifact.add(image_file, logical_path=image_logical_path, metadata={"split": split}) + artifact.add(label_file, logical_path=label_logical_path, metadata={"split": split}) + except ValueError as e: + logger.error('COMET ERROR: Error adding file to Artifact. Skipping file.') + logger.error(f"COMET ERROR: {e}") + continue + + return artifact + + def upload_dataset_artifact(self): + dataset_name = self.data_dict.get("dataset_name", "yolov5-dataset") + path = str((ROOT / Path(self.data_dict["path"])).resolve()) + + metadata = self.data_dict.copy() + for key in ["train", "val", "test"]: + split_path = metadata.get(key) + if split_path is not None: + metadata[key] = split_path.replace(path, "") + + artifact = comet_ml.Artifact(name=dataset_name, artifact_type="dataset", metadata=metadata) + for key in metadata.keys(): + if key in ["train", "val", "test"]: + if isinstance(self.upload_dataset, str) and (key != self.upload_dataset): + continue + + asset_path = self.data_dict.get(key) + if asset_path is not None: + artifact = self.add_assets_to_artifact(artifact, path, asset_path, key) + + self.experiment.log_artifact(artifact) + + return + + def download_dataset_artifact(self, artifact_path): + logged_artifact = self.experiment.get_artifact(artifact_path) + artifact_save_dir = str(Path(self.opt.save_dir) / logged_artifact.name) + logged_artifact.download(artifact_save_dir) + + metadata = logged_artifact.metadata + data_dict = metadata.copy() + data_dict["path"] = artifact_save_dir + data_dict["names"] = {int(k): v for k, v in metadata.get("names").items()} + + data_dict = self.update_data_paths(data_dict) + return data_dict + + def update_data_paths(self, data_dict): + path = data_dict.get("path", "") + + for split in ["train", "val", "test"]: + if data_dict.get(split): + split_path = data_dict.get(split) + data_dict[split] = (f"{path}/{split_path}" if isinstance(split, str) else [ + f"{path}/{x}" for x in split_path]) + + return data_dict + + def on_pretrain_routine_end(self, paths): + if self.opt.resume: + return + + for path in paths: + self.log_asset(str(path)) + + if self.upload_dataset: + if not self.resume: + self.upload_dataset_artifact() + + return + + def on_train_start(self): + self.log_parameters(self.hyp) + + def on_train_epoch_start(self): + return + + def on_train_epoch_end(self, epoch): + self.experiment.curr_epoch = epoch + + return + + def on_train_batch_start(self): + return + + def on_train_batch_end(self, log_dict, step): + self.experiment.curr_step = step + if self.log_batch_metrics and (step % self.comet_log_batch_interval == 0): + self.log_metrics(log_dict, step=step) + + return + + def on_train_end(self, files, save_dir, last, best, epoch, results): + if self.comet_log_predictions: + curr_epoch = self.experiment.curr_epoch + self.experiment.log_asset_data(self.metadata_dict, "image-metadata.json", epoch=curr_epoch) + + for f in files: + self.log_asset(f, metadata={"epoch": epoch}) + self.log_asset(f"{save_dir}/results.csv", metadata={"epoch": epoch}) + + if not self.opt.evolve: + model_path = str(best if best.exists() else last) + name = Path(model_path).name + if self.save_model: + self.experiment.log_model( + self.model_name, + file_or_folder=model_path, + file_name=name, + overwrite=True, + ) + + # Check if running Experiment with Comet Optimizer + if hasattr(self.opt, 'comet_optimizer_id'): + metric = results.get(self.opt.comet_optimizer_metric) + self.experiment.log_other('optimizer_metric_value', metric) + + self.finish_run() + + def on_val_start(self): + return + + def on_val_batch_start(self): + return + + def on_val_batch_end(self, batch_i, images, targets, paths, shapes, outputs): + if not (self.comet_log_predictions and ((batch_i + 1) % self.comet_log_prediction_interval == 0)): + return + + for si, pred in enumerate(outputs): + if len(pred) == 0: + continue + + image = images[si] + labels = targets[targets[:, 0] == si, 1:] + shape = shapes[si] + path = paths[si] + predn, labelsn = self.preprocess_prediction(image, labels, shape, pred) + if labelsn is not None: + self.log_predictions(image, labelsn, path, shape, predn) + + return + + def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix): + if self.comet_log_per_class_metrics: + if self.num_classes > 1: + for i, c in enumerate(ap_class): + class_name = self.class_names[c] + self.experiment.log_metrics( + { + 'mAP@.5': ap50[i], + 'mAP@.5:.95': ap[i], + 'precision': p[i], + 'recall': r[i], + 'f1': f1[i], + 'true_positives': tp[i], + 'false_positives': fp[i], + 'support': nt[c]}, + prefix=class_name) + + if self.comet_log_confusion_matrix: + epoch = self.experiment.curr_epoch + class_names = list(self.class_names.values()) + class_names.append("background") + num_classes = len(class_names) + + self.experiment.log_confusion_matrix( + matrix=confusion_matrix.matrix, + max_categories=num_classes, + labels=class_names, + epoch=epoch, + column_label='Actual Category', + row_label='Predicted Category', + file_name=f"confusion-matrix-epoch-{epoch}.json", + ) + + def on_fit_epoch_end(self, result, epoch): + self.log_metrics(result, epoch=epoch) + + def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): + if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: + self.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) + + def on_params_update(self, params): + self.log_parameters(params) + + def finish_run(self): + self.experiment.end() diff --git a/utils/loggers/comet/comet_utils.py b/utils/loggers/comet/comet_utils.py new file mode 100644 index 000000000000..3cbd45156b57 --- /dev/null +++ b/utils/loggers/comet/comet_utils.py @@ -0,0 +1,150 @@ +import logging +import os +from urllib.parse import urlparse + +try: + import comet_ml +except (ModuleNotFoundError, ImportError): + comet_ml = None + +import yaml + +logger = logging.getLogger(__name__) + +COMET_PREFIX = "comet://" +COMET_MODEL_NAME = os.getenv("COMET_MODEL_NAME", "yolov5") +COMET_DEFAULT_CHECKPOINT_FILENAME = os.getenv("COMET_DEFAULT_CHECKPOINT_FILENAME", "last.pt") + + +def download_model_checkpoint(opt, experiment): + model_dir = f"{opt.project}/{experiment.name}" + os.makedirs(model_dir, exist_ok=True) + + model_name = COMET_MODEL_NAME + model_asset_list = experiment.get_model_asset_list(model_name) + + if len(model_asset_list) == 0: + logger.error(f"COMET ERROR: No checkpoints found for model name : {model_name}") + return + + model_asset_list = sorted( + model_asset_list, + key=lambda x: x["step"], + reverse=True, + ) + logged_checkpoint_map = {asset["fileName"]: asset["assetId"] for asset in model_asset_list} + + resource_url = urlparse(opt.weights) + checkpoint_filename = resource_url.query + + if checkpoint_filename: + asset_id = logged_checkpoint_map.get(checkpoint_filename) + else: + asset_id = logged_checkpoint_map.get(COMET_DEFAULT_CHECKPOINT_FILENAME) + checkpoint_filename = COMET_DEFAULT_CHECKPOINT_FILENAME + + if asset_id is None: + logger.error(f"COMET ERROR: Checkpoint {checkpoint_filename} not found in the given Experiment") + return + + try: + logger.info(f"COMET INFO: Downloading checkpoint {checkpoint_filename}") + asset_filename = checkpoint_filename + + model_binary = experiment.get_asset(asset_id, return_type="binary", stream=False) + model_download_path = f"{model_dir}/{asset_filename}" + with open(model_download_path, "wb") as f: + f.write(model_binary) + + opt.weights = model_download_path + + except Exception as e: + logger.warning("COMET WARNING: Unable to download checkpoint from Comet") + logger.exception(e) + + +def set_opt_parameters(opt, experiment): + """Update the opts Namespace with parameters + from Comet's ExistingExperiment when resuming a run + + Args: + opt (argparse.Namespace): Namespace of command line options + experiment (comet_ml.APIExperiment): Comet API Experiment object + """ + asset_list = experiment.get_asset_list() + resume_string = opt.resume + + for asset in asset_list: + if asset["fileName"] == "opt.yaml": + asset_id = asset["assetId"] + asset_binary = experiment.get_asset(asset_id, return_type="binary", stream=False) + opt_dict = yaml.safe_load(asset_binary) + for key, value in opt_dict.items(): + setattr(opt, key, value) + opt.resume = resume_string + + # Save hyperparameters to YAML file + # Necessary to pass checks in training script + save_dir = f"{opt.project}/{experiment.name}" + os.makedirs(save_dir, exist_ok=True) + + hyp_yaml_path = f"{save_dir}/hyp.yaml" + with open(hyp_yaml_path, "w") as f: + yaml.dump(opt.hyp, f) + opt.hyp = hyp_yaml_path + + +def check_comet_weights(opt): + """Downloads model weights from Comet and updates the + weights path to point to saved weights location + + Args: + opt (argparse.Namespace): Command Line arguments passed + to YOLOv5 training script + + Returns: + None/bool: Return True if weights are successfully downloaded + else return None + """ + if comet_ml is None: + return + + if isinstance(opt.weights, str): + if opt.weights.startswith(COMET_PREFIX): + api = comet_ml.API() + resource = urlparse(opt.weights) + experiment_path = f"{resource.netloc}{resource.path}" + experiment = api.get(experiment_path) + download_model_checkpoint(opt, experiment) + return True + + return None + + +def check_comet_resume(opt): + """Restores run parameters to its original state based on the model checkpoint + and logged Experiment parameters. + + Args: + opt (argparse.Namespace): Command Line arguments passed + to YOLOv5 training script + + Returns: + None/bool: Return True if the run is restored successfully + else return None + """ + if comet_ml is None: + return + + if isinstance(opt.resume, str): + if opt.resume.startswith(COMET_PREFIX): + api = comet_ml.API() + resource = urlparse(opt.resume) + experiment_path = f"{resource.netloc}{resource.path}" + experiment = api.get(experiment_path) + set_opt_parameters(opt, experiment) + download_model_checkpoint(opt, experiment) + + return True + + return None diff --git a/utils/loggers/comet/hpo.py b/utils/loggers/comet/hpo.py new file mode 100644 index 000000000000..eab4df9978cf --- /dev/null +++ b/utils/loggers/comet/hpo.py @@ -0,0 +1,118 @@ +import argparse +import json +import logging +import os +import sys +from pathlib import Path + +import comet_ml + +logger = logging.getLogger(__name__) + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[3] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH + +from train import parse_opt, train +from utils.callbacks import Callbacks +from utils.general import increment_path +from utils.torch_utils import select_device + +# Project Configuration +config = comet_ml.config.get_config() +COMET_PROJECT_NAME = config.get_string(os.getenv("COMET_PROJECT_NAME"), "comet.project_name", default="yolov5") + + +def get_args(known=False): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path') + parser.add_argument('--cfg', type=str, default='', help='model.yaml path') + parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') + parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path') + parser.add_argument('--epochs', type=int, default=300, help='total training epochs') + parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') + parser.add_argument('--rect', action='store_true', help='rectangular training') + parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') + parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') + parser.add_argument('--noval', action='store_true', help='only validate final epoch') + parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor') + parser.add_argument('--noplots', action='store_true', help='save no plot files') + parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') + parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') + parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') + parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer') + parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') + parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') + parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--quad', action='store_true', help='quad dataloader') + parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler') + parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon') + parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)') + parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2') + parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)') + parser.add_argument('--seed', type=int, default=0, help='Global training seed') + parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') + + # Weights & Biases arguments + parser.add_argument('--entity', default=None, help='W&B: Entity') + parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option') + parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval') + parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') + + # Comet Arguments + parser.add_argument("--comet_optimizer_config", type=str, help="Comet: Path to a Comet Optimizer Config File.") + parser.add_argument("--comet_optimizer_id", type=str, help="Comet: ID of the Comet Optimizer sweep.") + parser.add_argument("--comet_optimizer_objective", type=str, help="Comet: Set to 'minimize' or 'maximize'.") + parser.add_argument("--comet_optimizer_metric", type=str, help="Comet: Metric to Optimize.") + parser.add_argument("--comet_optimizer_workers", + type=int, + default=1, + help="Comet: Number of Parallel Workers to use with the Comet Optimizer.") + + return parser.parse_known_args()[0] if known else parser.parse_args() + + +def run(parameters, opt): + hyp_dict = {k: v for k, v in parameters.items() if k not in ["epochs", "batch_size"]} + + opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) + opt.batch_size = parameters.get("batch_size") + opt.epochs = parameters.get("epochs") + + device = select_device(opt.device, batch_size=opt.batch_size) + train(hyp_dict, opt, device, callbacks=Callbacks()) + + +if __name__ == "__main__": + opt = get_args(known=True) + + opt.weights = str(opt.weights) + opt.cfg = str(opt.cfg) + opt.data = str(opt.data) + opt.project = str(opt.project) + + optimizer_id = os.getenv("COMET_OPTIMIZER_ID") + if optimizer_id is None: + with open(opt.comet_optimizer_config) as f: + optimizer_config = json.load(f) + optimizer = comet_ml.Optimizer(optimizer_config) + else: + optimizer = comet_ml.Optimizer(optimizer_id) + + opt.comet_optimizer_id = optimizer.id + status = optimizer.status() + + opt.comet_optimizer_objective = status["spec"]["objective"] + opt.comet_optimizer_metric = status["spec"]["metric"] + + logger.info("COMET INFO: Starting Hyperparameter Sweep") + for parameter in optimizer.get_parameters(): + run(parameter["parameters"], opt) diff --git a/utils/loggers/comet/optimizer_config.json b/utils/loggers/comet/optimizer_config.json new file mode 100644 index 000000000000..83ddddab6f20 --- /dev/null +++ b/utils/loggers/comet/optimizer_config.json @@ -0,0 +1,209 @@ +{ + "algorithm": "random", + "parameters": { + "anchor_t": { + "type": "discrete", + "values": [ + 2, + 8 + ] + }, + "batch_size": { + "type": "discrete", + "values": [ + 16, + 32, + 64 + ] + }, + "box": { + "type": "discrete", + "values": [ + 0.02, + 0.2 + ] + }, + "cls": { + "type": "discrete", + "values": [ + 0.2 + ] + }, + "cls_pw": { + "type": "discrete", + "values": [ + 0.5 + ] + }, + "copy_paste": { + "type": "discrete", + "values": [ + 1 + ] + }, + "degrees": { + "type": "discrete", + "values": [ + 0, + 45 + ] + }, + "epochs": { + "type": "discrete", + "values": [ + 5 + ] + }, + "fl_gamma": { + "type": "discrete", + "values": [ + 0 + ] + }, + "fliplr": { + "type": "discrete", + "values": [ + 0 + ] + }, + "flipud": { + "type": "discrete", + "values": [ + 0 + ] + }, + "hsv_h": { + "type": "discrete", + "values": [ + 0 + ] + }, + "hsv_s": { + "type": "discrete", + "values": [ + 0 + ] + }, + "hsv_v": { + "type": "discrete", + "values": [ + 0 + ] + }, + "iou_t": { + "type": "discrete", + "values": [ + 0.7 + ] + }, + "lr0": { + "type": "discrete", + "values": [ + 1e-05, + 0.1 + ] + }, + "lrf": { + "type": "discrete", + "values": [ + 0.01, + 1 + ] + }, + "mixup": { + "type": "discrete", + "values": [ + 1 + ] + }, + "momentum": { + "type": "discrete", + "values": [ + 0.6 + ] + }, + "mosaic": { + "type": "discrete", + "values": [ + 0 + ] + }, + "obj": { + "type": "discrete", + "values": [ + 0.2 + ] + }, + "obj_pw": { + "type": "discrete", + "values": [ + 0.5 + ] + }, + "optimizer": { + "type": "categorical", + "values": [ + "SGD", + "Adam", + "AdamW" + ] + }, + "perspective": { + "type": "discrete", + "values": [ + 0 + ] + }, + "scale": { + "type": "discrete", + "values": [ + 0 + ] + }, + "shear": { + "type": "discrete", + "values": [ + 0 + ] + }, + "translate": { + "type": "discrete", + "values": [ + 0 + ] + }, + "warmup_bias_lr": { + "type": "discrete", + "values": [ + 0, + 0.2 + ] + }, + "warmup_epochs": { + "type": "discrete", + "values": [ + 5 + ] + }, + "warmup_momentum": { + "type": "discrete", + "values": [ + 0, + 0.95 + ] + }, + "weight_decay": { + "type": "discrete", + "values": [ + 0, + 0.001 + ] + } + }, + "spec": { + "maxCombo": 0, + "metric": "metrics/mAP_0.5", + "objective": "maximize" + }, + "trials": 1 +} diff --git a/val.py b/val.py index 5427ee7b3619..665d92f9286d 100644 --- a/val.py +++ b/val.py @@ -259,7 +259,7 @@ def run( plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred - callbacks.run('on_val_batch_end') + callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, out) # Compute metrics stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy @@ -289,7 +289,7 @@ def run( # Plots if plots: confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) - callbacks.run('on_val_end') + callbacks.run('on_val_end', nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix) # Save JSON if save_json and len(jdict):