From 8e18115434ad232f8846f742d0ec1418253a8e0d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 19 Nov 2022 03:21:56 +0100 Subject: [PATCH] Add git info to cls, seg checkpoints --- classify/train.py | 3 ++- segment/train.py | 9 ++------- train.py | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/classify/train.py b/classify/train.py index 4422ca26b0ae..5faef08e876c 100644 --- a/classify/train.py +++ b/classify/train.py @@ -40,7 +40,7 @@ from models.experimental import attempt_load from models.yolo import ClassificationModel, DetectionModel from utils.dataloaders import create_classification_dataloader -from utils.general import (DATASETS_DIR, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_status, +from utils.general import (DATASETS_DIR, GIT, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_status, check_requirements, colorstr, download, increment_path, init_seeds, print_args, yaml_save) from utils.loggers import GenericLogger from utils.plots import imshow_cls @@ -237,6 +237,7 @@ def train(opt, device): 'updates': ema.updates, 'optimizer': None, # optimizer.state_dict(), 'opt': vars(opt), + 'git': GIT, # {remote, branch, commit} if a git repo 'date': datetime.now().isoformat()} # Save last, best and delete diff --git a/segment/train.py b/segment/train.py index 2a0793d1aa3e..5d9ed78f527c 100644 --- a/segment/train.py +++ b/segment/train.py @@ -46,7 +46,7 @@ from utils.autobatch import check_train_batch_size from utils.callbacks import Callbacks from utils.downloads import attempt_download, is_url -from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_status, +from utils.general import (GIT, LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save) @@ -390,6 +390,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio 'updates': ema.updates, 'optimizer': optimizer.state_dict(), 'opt': vars(opt), + 'git': GIT, # {remote, branch, commit} if a git repo 'date': datetime.now().isoformat()} # Save last, best and delete @@ -498,12 +499,6 @@ def parse_opt(known=False): parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory') parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP') - # Weights & Biases arguments - # parser.add_argument('--entity', default=None, help='W&B: Entity') - # parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option') - # parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval') - # parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') - return parser.parse_known_args()[0] if known else parser.parse_args() diff --git a/train.py b/train.py index 6fa33f47d100..1ea5c5bbeddd 100644 --- a/train.py +++ b/train.py @@ -376,7 +376,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio 'updates': ema.updates, 'optimizer': optimizer.state_dict(), 'opt': vars(opt), - 'git': GIT, + 'git': GIT, # {remote, branch, commit} if a git repo 'date': datetime.now().isoformat()} # Save last, best and delete