Skip to content

Commit

Permalink
Refactor argparser printing to print_args() (ultralytics#4850)
Browse files Browse the repository at this point in the history
* Refactor argparser printing to `print_args()`

* Cleanup
  • Loading branch information
glenn-jocher committed Sep 18, 2021
1 parent 0e642a3 commit 87cec3d
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 18 deletions.
12 changes: 6 additions & 6 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_imshow, check_requirements, check_suffix, colorstr, is_ascii, \
non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, \
save_one_box
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
increment_path, is_ascii, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors
from utils.torch_utils import select_device, load_classifier, time_sync
from utils.torch_utils import load_classifier, select_device, time_sync


@torch.no_grad()
Expand Down Expand Up @@ -279,11 +279,11 @@ def parse_opt():
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(FILE.stem, opt)
return opt


def main(opt):
print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop'))
run(**vars(opt))

Expand Down
5 changes: 3 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from models.yolo import Detect
from utils.activations import SiLU
from utils.datasets import LoadImages
from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging, url2file
from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, print_args, \
set_logging, url2file
from utils.torch_utils import select_device


Expand Down Expand Up @@ -322,12 +323,12 @@ def parse_opt():
default=['torchscript', 'onnx'],
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
opt = parser.parse_args()
print_args(FILE.stem, opt)
return opt


def main(opt):
set_logging()
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
run(**vars(opt))


Expand Down
6 changes: 3 additions & 3 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from tensorflow import keras

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
from models.experimental import MixConv2d, CrossConv, attempt_load
from models.experimental import CrossConv, MixConv2d, attempt_load
from models.yolo import Detect
from utils.general import colorstr, make_divisible, set_logging
from utils.general import make_divisible, print_args, set_logging
from utils.activations import SiLU

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -434,12 +434,12 @@ def parse_opt():
parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(FILE.stem, opt)
return opt


def main(opt):
set_logging()
print(colorstr('tf.py: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
run(**vars(opt))


Expand Down
7 changes: 3 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
check_file, check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
check_file, check_yaml, check_suffix, print_args, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolve
Expand Down Expand Up @@ -470,9 +470,8 @@ def parse_opt(known=False):

def main(opt, callbacks=Callbacks()):
# Checks
set_logging(RANK)
if RANK in [-1, 0]:
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
print_args(FILE.stem, opt)
check_git_status()
check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=['thop'])

Expand Down Expand Up @@ -508,7 +507,7 @@ def main(opt, callbacks=Callbacks()):
if not opt.evolve:
train(opt.hyp, opt, device, callbacks)
if WORLD_SIZE > 1 and RANK == 0:
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
_ = LOGGER.info('Destroying process group... ', end=''), dist.destroy_process_group(), LOGGER.info('Done.')

# Evolve hyperparameters (optional)
else:
Expand Down
5 changes: 5 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def set_logging(rank=-1, verbose=True):
level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)


def print_args(name, opt):
# Print argparser arguments
print(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))


def init_seeds(seed=0):
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
Expand Down
6 changes: 3 additions & 3 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_img_size, check_requirements, \
check_suffix, check_yaml, box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, \
increment_path, colorstr
increment_path, colorstr, print_args
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import output_to_target, plot_images, plot_val_study
from utils.torch_utils import select_device, time_sync
Expand Down Expand Up @@ -295,7 +295,7 @@ def run(data,


def parse_opt():
parser = argparse.ArgumentParser(prog='val.py')
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--batch-size', type=int, default=32, help='batch size')
Expand All @@ -319,12 +319,12 @@ def parse_opt():
opt.save_json |= opt.data.endswith('coco.yaml')
opt.save_txt |= opt.save_hybrid
opt.data = check_yaml(opt.data) # check YAML
print_args(FILE.stem, opt)
return opt


def main(opt):
set_logging()
print(colorstr('val: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=('tensorboard', 'thop'))

if opt.task in ('train', 'val', 'test'): # run normally
Expand Down

0 comments on commit 87cec3d

Please sign in to comment.