Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How many and which GPUs are needed for training #1

Open
Liuyang829 opened this issue Jul 8, 2021 · 7 comments
Open

How many and which GPUs are needed for training #1

Liuyang829 opened this issue Jul 8, 2021 · 7 comments

Comments

@Liuyang829
Copy link

Thanks for your great work! But when I want to reproduce your code, I meet some troubles on CUDA out of memory. I'm very curious about what GPU and how many GPUs your experiment was implemented on when traning on ResNet50 and ResNet18

@yikaiw
Copy link
Owner

yikaiw commented Jul 8, 2021

Thanks for your recognition. We use 8 V100 GPUs for training ResNet50 and ResNet18. Actually, 4 GTX1080 GPUs are enough for ResNet18.

@Liuyang829
Copy link
Author

I try to use 2 RTX3090 GPUs to train ResNet18. It seems that it needs nearly 1.5 hour for one epoch, which is over 7 days for 120 epochs. It is such a long time that our group cannot afford this. I also want to know that how long does it cost when you train on both ResNet18 and ResNet50? Thank you!

@yikaiw
Copy link
Owner

yikaiw commented Jul 8, 2021

On 8 V100 GPUs, ResNet50 needs 4 days, and ResNet18 only needs 1 day. On 4 GTX1080 GPUs, ResNet18 needs about 2 days. Note that ImageNet data should be stored in the Solid State Disk (SSD), which largely speeds up the training (about twice).

@Liuyang829
Copy link
Author

Thank you very much!

@Liuyang829 Liuyang829 reopened this Sep 27, 2021
@Liuyang829
Copy link
Author

I still have some problems with dataloader in code. Why don't you apply transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) as usual in other ImageNet classification task?
When I was modifying RS-Net code, I met a problem that when I train to the 8th epoch or 10th epoch, the crossentropy loss would turned to nan. I am wondering if the lack of data normalization result in this.

And could you please release the code of Tested at New Resolutions in your ablation study

Thank you very much!

@yikaiw
Copy link
Owner

yikaiw commented Sep 30, 2021

Hi, thanks for your interest.

Since we find the performance without normalization already achieves SOTA, we do not apply the normalization. The provided code won't obtain the nan cross-entropy loss. If you modify the code and meet the nan loss in an epoch during training, you could probably reduce the initial learning rate.

The code for testing at new resolutions is provided as below, which basically calibrates BNs according to the new reslution:

from __future__ import print_function

import os, sys, argparse
import warnings, random, shutil, time
from tqdm import tqdm
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import models.imagenet as customized_models
from utils import AverageMeter, mkdir_p
from utils.dataloaders import *
from tensorboardX import SummaryWriter

warnings.filterwarnings('ignore')

default_model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

customized_models_names = sorted(name for name in customized_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(customized_models.__dict__[name]))

for name in customized_models.__dict__:
    if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]):
        models.__dict__[name] = customized_models.__dict__[name]

model_names = default_model_names + customized_models_names

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('-d', '--data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')

parser.add_argument('--lr-decay', type=str, default='step',
                    help='mode for learning rate decay')
parser.add_argument('--step', type=int, default=30,
                    help='interval for learning rate decay in step mode')
parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 85, 95, 105],
                    help='decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1,
                    help='LR is multiplied by gamma on schedule.')
parser.add_argument('--warmup', action='store_true',
                    help='set lower initial learning rate to warm up the training')

parser.add_argument('--cardinality', type=int, default=32, help='ResNeXt model cardinality (group).')
parser.add_argument('--base-width', type=int, default=4, help='ResNeXt model base width (number of channels in each group).')
parser.add_argument('--groups', type=int, default=3, help='ShuffleNet model groups')
parser.add_argument('--extent', type=int, default=0, help='GENet model spatial extent ratio')
parser.add_argument('--theta', dest='theta', action='store_true', help='GENet model parameterising the gather function')
parser.add_argument('--excite', dest='excite', action='store_true', help='GENet model combining the excite operator')

parser.add_argument('--sizes', type=int, nargs='+', default=[224, 192, 160, 128, 96],
                    help='input resolutions.')
parser.add_argument('--delta_size', type=int, default=0)
parser.add_argument('--cal-bn', action='store_true')
parser.add_argument('--kd', action='store_true',
                    help='build losses of knowledge distillation across scales')
parser.add_argument('-t', '--kd-type', metavar='KD_TYPE', default='topdown',
                    choices=['topdown', 'direct'])
parser.add_argument('--save-dicts', action='store_true')

args = parser.parse_args()
n_sizes = len(args.sizes)
assert args.delta_size >= 0
for i in range(n_sizes):
    args.sizes[i] += args.delta_size


def main():
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting from checkpoints.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch.startswith('resnext'):
            model = models.__dict__[args.arch](
                    baseWidth=args.base_width,
                    cardinality=args.cardinality,
                )
        elif args.arch.startswith('shufflenetv1'):
            model = models.__dict__[args.arch](
                    groups=args.groups
                )
        elif args.arch.startswith('ge_resnet'):
            model = models.__dict__[args.arch](
                    extent=args.extent,
                    theta=args.theta,
                    excite=args.excite
                )
        elif args.arch.startswith('parallel') or args.arch.startswith('meta'):
            model = models.__dict__[args.arch](num_parallel=n_sizes)
        else:
            model = models.__dict__[args.arch]()
    
    if args.kd:
        alpha = nn.Parameter(torch.ones(n_sizes, requires_grad=True))
        model.register_parameter('alpha', alpha)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        state_dict = checkpoint['state_dict']
        if args.save_dicts:
            d = dict(state_dict)
            for key in d.keys():
                d[key] = d[key].cpu().numpy()
            np.save('dict.npy', d)
            print('dict saved')
            return
        if args.delta_size != 0 and args.cal_bn:
            state_dict = cal_bn(args.delta_size, state_dict)
        if not args.kd and 'module.alpha' in state_dict:
            del state_dict['module.alpha']
        model.load_state_dict(state_dict, strict=False)
        print('# parameters:', sum(param.numel() for param in model.parameters()))
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
        return

    cudnn.benchmark = True

    get_train_loader, get_val_loader = get_pytorch_train_loader, get_pytorch_val_loader
    val_loader, val_loader_len = get_val_loader(args.data, args.batch_size, args.sizes, workers=args.workers)
    validate(val_loader, val_loader_len, model, criterion)
    summary()

    return


def cal_bn(delta_size, state_dict):
    statistics = ['weight', 'bias', 'running_mean', 'running_var']
    state_dict_copy = state_dict.copy()
    alpha = delta_size / 32
    for key in state_dict.keys():
        for s in statistics:
            for i in range(1, 5):
                if 'bn_%d' % i in key and s in key:
                    key_ = key.replace('bn_%d' % i, 'bn_%d' % (i - 1))
                    state_dict_copy[key] = state_dict[key_] * alpha + state_dict[key] * (1 - alpha)
    return state_dict_copy


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        res.append(correct)

        return res


def validate(val_loader, val_loader_len, model, criterion):
    top1 = [AverageMeter() for _ in range(n_sizes)]
    top5 = [AverageMeter() for _ in range(n_sizes)]
    top1_res = [[]] * n_sizes
    times = []

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in tqdm(enumerate(val_loader), total=val_loader_len):
        target = target.cuda(non_blocking=True)

        with torch.no_grad():
            # compute output
            t = time.time()
            output = model(input)
            times.append(time.time() - t)
            # print('current step time:', times[-1], flush=True)
         
            for j in range(n_sizes):
                # measure accuracy and record loss
                acc1, acc5, correct = accuracy(output[j], target, topk=(1, 5))
                top1[j].update(acc1.item(), input[0].size(0))
                top5[j].update(acc5.item(), input[0].size(0))
                correct1 = correct[:1].cpu().numpy().tolist()
                top1_res[j] = top1_res[j] + correct1[0]
    print('mean step time:', np.mean(times))

    for j, size in enumerate(args.sizes):
        top1_avg, top5_avg = top1[j].avg, top5[j].avg
        print('size%03d: top1 %.2f, top5 %.2f' % (size, top1_avg, top5_avg))

    with open('top1_val_resnet18_shared_kd.bin','wb') as fp:
        pickle.dump(top1_res,fp)

    return [round(t.avg, 1) for t in top1], [round(t.avg, 1) for t in top5]


def summary():
    with open('top1_val_resnet18_shared_kd.bin','rb') as fp:
        top1_res = pickle.load(fp)

    K = len(top1_res)
    N = len(top1_res[0])

    for i in range(K):
        for j in range(K):
            if i==j:
                continue
            cor_i = np.array(top1_res[i]).astype(np.float32)
            cor_j = np.array(top1_res[j]).astype(np.float32)
            cor_i = 1.0 - cor_i
            _ij = np.multiply(cor_i,cor_j)
            print("(%d,%d) = %f" % (i, j, _ij.sum() / N))


if __name__ == '__main__':
    main()

@Liuyang829
Copy link
Author

Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants