Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix]Sync train api #115

Merged
merged 1 commit into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions mmrazor/apis/mmcls/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def train_mmcls_model(model,
train_dataset = dataset[0]
dataset[0] = split_dataset(train_dataset)

sampler_cfg = cfg.data.get('sampler', None)

# Difference from mmclassification.
# Build multi dataloaders according the splited datasets.
data_loaders = list()
Expand All @@ -79,7 +81,8 @@ def train_mmcls_model(model,
num_gpus=len(cfg.gpu_ids),
dist=distributed,
round_up=True,
seed=cfg.seed) for item_ds in dset
seed=cfg.seed,
sampler_cfg=sampler_cfg) for item_ds in dset
]
else:
data_loader = build_dataloader(
Expand All @@ -90,7 +93,8 @@ def train_mmcls_model(model,
num_gpus=len(cfg.gpu_ids),
dist=distributed,
round_up=True,
seed=cfg.seed)
seed=cfg.seed,
sampler_cfg=sampler_cfg)

data_loaders.append(data_loader)

Expand Down Expand Up @@ -120,6 +124,10 @@ def train_mmcls_model(model,
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
elif device == 'cpu':
warnings.warn(
'The argument `device` is deprecated. To use cpu to train, '
'please refers to https://mmclassification.readthedocs.io/en'
'/latest/getting_started.html#train-a-model')
model = model.cpu()
else:
raise ValueError(F'unsupported device name {device}.')
Expand Down
7 changes: 6 additions & 1 deletion mmrazor/apis/mmdet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def train_mmdet_model(model,
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu

runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
'type']
data_loader = [
build_dataloader(
ds,
Expand All @@ -84,7 +86,10 @@ def train_mmdet_model(model,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
seed=cfg.seed,
runner_type=runner_type,
persistent_workers=cfg.data.get('persistent_workers', False))
for ds in dataset
]

# put model on gpus
Expand Down