Skip to content

Commit

Permalink
Add rgb-mean and rgb-std arguments (#1546)
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques committed May 5, 2023
1 parent 641c830 commit 48a0113
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def load_data(traindir, valdir, args):
traindir,
presets.ClassificationPresetTrain(
crop_size=train_crop_size,
mean=args.rgb_mean,
std=args.rgb_std,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
Expand All @@ -289,6 +291,8 @@ def load_data(traindir, valdir, args):
else:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size,
mean=args.rgb_mean,
std=args.rgb_std,
resize_size=val_resize_size,
interpolation=interpolation,
)
Expand Down Expand Up @@ -1212,6 +1216,26 @@ def new_func(*args, **kwargs):
"Note: Will be read from the checkpoint if not specified"
),
)
@click.option(
"--rgb-mean",
nargs=3,
default=(0.485, 0.456, 0.406),
type=float,
help=(
"RGB mean values used to shift input RGB values; "
"Note: Will use ImageNet values if not specified."
),
)
@click.option(
"--rgb-std",
default=(0.229, 0.224, 0.225),
nargs=3,
type=float,
help=(
"RGB standard-deviation values used to normalize input RGB values; "
"Note: Will use ImageNet values if not specified."
),
)
@click.pass_context
def cli(ctx, **kwargs):
"""
Expand Down

0 comments on commit 48a0113

Please sign in to comment.