Skip to content

Commit

Permalink
Add git info to cls, seg checkpoints (#10217)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Nov 19, 2022
1 parent 9286336 commit 0307954
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
3 changes: 2 additions & 1 deletion classify/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from models.experimental import attempt_load
from models.yolo import ClassificationModel, DetectionModel
from utils.dataloaders import create_classification_dataloader
from utils.general import (DATASETS_DIR, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_status,
from utils.general import (DATASETS_DIR, GIT, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_status,
check_requirements, colorstr, download, increment_path, init_seeds, print_args, yaml_save)
from utils.loggers import GenericLogger
from utils.plots import imshow_cls
Expand Down Expand Up @@ -237,6 +237,7 @@ def train(opt, device):
'updates': ema.updates,
'optimizer': None, # optimizer.state_dict(),
'opt': vars(opt),
'git': GIT, # {remote, branch, commit} if a git repo
'date': datetime.now().isoformat()}

# Save last, best and delete
Expand Down
9 changes: 2 additions & 7 deletions segment/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.downloads import attempt_download, is_url
from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_status,
from utils.general import (GIT, LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_status,
check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run,
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
labels_to_image_weights, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save)
Expand Down Expand Up @@ -390,6 +390,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'opt': vars(opt),
'git': GIT, # {remote, branch, commit} if a git repo
'date': datetime.now().isoformat()}

# Save last, best and delete
Expand Down Expand Up @@ -498,12 +499,6 @@ def parse_opt(known=False):
parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory')
parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP')

# Weights & Biases arguments
# parser.add_argument('--entity', default=None, help='W&B: Entity')
# parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
# parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
# parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')

return parser.parse_known_args()[0] if known else parser.parse_args()


Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'opt': vars(opt),
'git': GIT,
'git': GIT, # {remote, branch, commit} if a git repo
'date': datetime.now().isoformat()}

# Save last, best and delete
Expand Down

0 comments on commit 0307954

Please sign in to comment.