Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bringing together results on ddp on a single machine #702

Closed
brucemuller opened this issue Jan 17, 2020 · 38 comments · Fixed by #1017 or #2434
Closed

Bringing together results on ddp on a single machine #702

brucemuller opened this issue Jan 17, 2020 · 38 comments · Fixed by #1017 or #2434
Assignees
Labels
feature Is an improvement or enhancement
Milestone

Comments

@brucemuller
Copy link

I'd like to understand how to use ddp properly with multiple GPUs on a single machine as I'm unsure of how to bring results together using this method.

I'm using TensorBoard for logging

The problem seems to be that my code (below) runs on each of the three GPUs (with a third of the data each), but the variables like "overall_correct" only exist for each of the three processes so only a third of the data gets logged. For example, my overall performance on a single GPU is 82% but with the above process on 3 GPUs it is a third of that. I know this is a kindof silly thing but can someone explain how I should bring together the required validation/training statistics from the sub-processes using pytorch lightning?

My process is roughly:

model = MyModel(hparams)
tt_logger = TestTubeLogger(save_dir="path",name=expname)
trainer = Trainer(logger = tt_logger , gpus=3, distributed_backend='ddp' )
trainer.fit(model)

class MyModel(LightningModule):

     def __init__(self, hparams):
          super(MyModel, self).__init__() 
          self.hparams = hparams
          self.resnet = ResNetEncoder(self.hparams)
          self.loss_meter_training = averageMeter()
          self.overall_correct = 0.0

    def training_step(self, batch, batch_i):   
        ...
        self.loss_meter_training.update(float(total_loss))
        return {'loss': total_loss}

    def validation_step(self, batch, batch_nb):
        ...
        if something:
            self.overall_correct += 0.0
        return {'val_loss': total_loss}

    def validation_end(self, outputs):

        self.logger.experiment.add_scalar('epoch losses/training/total', self.loss_meter_training.avg, self.epoch_nb)
        self.logger.experiment.add_scalar('metrics/validation performance', self.overall_correct/20000, self.epoch_nb)
        self.loss_meter_validation.reset() 
        self.overall_correct = 0.0

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer 

    @pl.data_loader
    def tng_dataloader(self):

        t_loader = PairLoader('dummy', self.hparams , split='training')
        dist_sampler = torch.utils.data.distributed.DistributedSampler(t_loader)
        trainloader = data.DataLoader(t_loader,batch_size=self.hparams.batch_size, sampler=dist_sampler, num_workers = 12)  
    
        return trainloader
    
    @pl.data_loader
    def val_dataloader(self):
        
        v_loader = PairLoader('dummy', self.hparams , split='validation')
        dist_sampler = torch.utils.data.distributed.DistributedSampler(v_loader)
        trainloader = data.DataLoader(v_loader,batch_size=self.hparams.batch_size, sampler=dist_sampler, num_workers = 12) 
        
        return trainloader
@brucemuller brucemuller added the question Further information is requested label Jan 17, 2020
@matthew-z
Copy link
Contributor

matthew-z commented Jan 17, 2020

I don't know if pytorch-lightning has done anything special about it.

With pure PyTorch, you may use dist.all_gather to sync the validation score among workers.

For example, if you have 2 workers and each of them evaluated 2 examples, then you can use dist.all_gather to get the 4 scores and then compute the mean validation score.

@brucemuller
Copy link
Author

@matthew-z thanks for your reply. That seems promising. I have also found this issue requesting similar behaviour: #243
I'm not sure how all_gather works yet. Should it be like:

    def validation_end(self, outputs):

        self.logger.experiment.add_scalar('epoch losses/training/total', self.loss_meter_training.avg, self.epoch_nb)
        self.logger.experiment.add_scalar('metrics/validation performance', self.overall_correct/20000, self.epoch_nb)
        self.loss_meter_validation.reset() 
        
        dist.all_gather(output_list, self.overall_correct)
        self.overall_correct = 0.0

If so, where would we put output_list when using pytorch lightning? Thanks for your help! I think it is something many new users will be confused about when it comes to ddp use

@matthew-z
Copy link
Contributor

I didn't try it with PL, but in general you may try to sync data as follow with DDP:

def gather_list_and_concat(list_of_nums):
    tensor = torch.Tensor(list_of_nums).cuda()
    gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
    dist.all_gather(gather_t, tensor)
    return torch.cat(gather_t)

>>> res = gather_list_and_concat([1, 2])   # in node 1
>>> res = gather_list_and_concat([3, 4])  # in node 2
>>> res
torch.Tensor([1,2,3,4])

@sneiman
Copy link
Contributor

sneiman commented Jan 19, 2020

Hi -
I have been working with the same issue - how to get the benefit of ddp with distributed sampling and still have accurate statistics, expecting to use reduce, or gather from pytorch distributed to pull things together. I did some investigation and found that this was not necessary. The key fact I found is that pl is not duplicating all the tensors in the model to each sub-process. They are in fact shared by all the sub-processes. I did not look at the source, but I expect only the parameters and buffers are being duped.

So no need to keep it in sync - they are all sharing the same tensor. You can create these in the model itself. I put them in a class that tracks a few other things, and keep a reference to the instance in the model. Works the same.

There are still some challenges. You need to index the data by the batch number - but each sub process's batch numbers start with 0. And they do not all have exactly the same number of batches, as the sampler has to adapt total batches when things don't divide evenly. What I did was to add a dimension to capture the process rank - always 0 based. I also sized the storage on the batch count / number of process, rounding down - and not worrying about the few missed elements of data.

Of course, DistributedSampler will hand one or more sub processes 1 or 2 more batches than you have allowed room for - obviously don't try to index a position in the tensor beyond the max index-1.
Computing means, std devation etc is no different, as torch doesn't care how many dimensions a tensor has when computing these.

The result statistics for me have been essentially identical to running on a single gpu - except for the time which has gone from 15 minutes/epoch to 85 secs/epoch.

Hope this helps,

seth

@matthew-z
Copy link
Contributor

I just read the source code, and actually PL didn't do much modification to the original DDP except redirecting forward to train/validation/test_step.

In DDP, I don't think different sub-processes share the same tensor. Try to put a breakpoint in train_step and run PyCharm debugger, and you will find that the loss values are different in sub-processes, so do the validation outputs. The logger only records the rank-0 process.

@sneiman
Copy link
Contributor

sneiman commented Jan 19, 2020

Sorry - my explanation was obviously insufficent. This definitely works, and - I just have to explain it better. I don't believe the logger() behavior is indicative of what happens with the tensor.

Try this - in the init() of your model class, create test tensor with 1 value per gpu you are using. Also create a flag so that this is only called once per process:

    self.firstEpochFlag   = True
    self.testData         = torch.Tensor([0, 0, 0])  # 3 elements - 1 for each gpu used in this test

In on_epoch_start(), on the first epoch, write into self.test.data, and print it:

if firstEpochFlag:
    self.testData[self.trainer.proc_rank] = self.trainer.proc_rank+1
    print(f"testData {self.testData} {self.trainer.proc_rank}")
    self.firstEpochFlag = False

With 3 gpus you will see something like this:

testData tensor([0., 0., 3.]) 2 cpu
testData tensor([0., 2., 3.]) 1 cpu
testData tensor([1., 2., 3.]) 0 cpu

As you can see, the the self.testData tensor is collecting ALL the values and ALL the values are seen by EACH process. And that tensor is on the cpu. If there was one in each sub process we would see:

testData tensor([0., 0., 3.]) 2 cuda:2
testData tensor([0., 2., 0.]) 1 cuda:1
testData tensor([1., 0., 0.]) 0 cuda:0

I can only conclude that there is only one copy of it and all the sub processes can see it, or it is coalesced by ddp for us.

The key issue here is you can use this to produce a process rank indexed tensor of result values that you store yourself - at the end of training_step() - to use to calculate statistics.

This works. Before I did this, I had the same problem you had - loss values only recorded for the 0 gpu, and therefore of questionable value. With this, the statistics make sense, and are essentially identical to what I get if I use on a single gpu.

I hope this is useful to you,

seth

FWIW - there are some behaviors here that bear thinking hard about, As already noted, the printed testData contents only make sense if there is only one copy - or ddp merges this tensor for us. Interestingly though, there must be 3 copies of firstEpochFlag - or the printing would only have been done once as it is set to False the first time it is seen. The choices made in DDP about which elements of the model class are copied to sub processes must be pretty well thought through.

Again, hope this helps. I can show you my entire model class if you like - there is a lot in it ...

seth

@sneiman
Copy link
Contributor

sneiman commented Jan 20, 2020

I have verified that there is one testData tensor per proc - so ddp must be coalescing them

@williamFalcon
Copy link
Contributor

Lightning only routes forward to training_step, validation_step. BUT, we should be calling dist.all_gather with the outputs out of training_step and validation_step if we want full batch metrics.

so, maybe someone wants to submit a PR for this?
@sneiman @brucemuller @matthew-z

@sneiman
Copy link
Contributor

sneiman commented Jan 24, 2020

My solution works for me - and has the benefit of allowing me compute any statistics I like on the results. Of course, my model code has to know a little about what only doing certain things on proc 0 ...

@jmarsil
Copy link

jmarsil commented Feb 5, 2020

@sneiman I'm facing a similar issue. Could you show the entire model class?

Aside: When I try to pass a dictionary of tensors into progress_bar or even log, I've been getting this consistent error. (Training using 'ddp' on 4 gpus). Have a follow up questions, why can't we pass dictionaries of tensors like 'dp'?

Here is the error:

Screen Shot 2020-02-05 at 10 12 53 AM

Let me know your thoughts!

@sneiman
Copy link
Contributor

sneiman commented Feb 5, 2020

If I am not mistaken, your error as above is because in v.item(), v has more than one value - so item() cannot convert it to a python scalar. I do not use most of the metrics and progress bar features of pytorch-lightning, but I feel pretty certain that is your problem here.

I will post more of my code for you to take a look at to see how I collect information, in the next post. The overall idea is pretty simple ... declare a multi-dimensional tensor to hold the information you want to collect, with one dimension per gpu. then write into it during running each gpu's data in its own portion of the tensor. pl will reduce the to coalescel. you still have to mindful of multiprocessing. example to come.

@jmarsil
Copy link

jmarsil commented Feb 5, 2020

Ahh thanks for doing that. Right, that would make sense. Since each GPU has its own set of processes, it probably passes a tensor with len(n_gpus) to the callbacks step correct? In the mean time I will try to implement something & update you with how that goes.

@sneiman
Copy link
Contributor

sneiman commented Feb 5, 2020

Let's see if this will help:

Within my model I have an instance of a class (cleverly named 'data') which holds the data loaders, and does the work of breaking the data in train, validation and test sets. Because statistics and history change along with data format, I keep all my history and statistics calculations in the same class. It gets tweaked whenever I change the data, or use the model in a way that affects statistics or data collected.

In this class I create my history tensors, and related statistic variables:


class data():
    def __init__(self, trn_ratio=.6, tot_batch_size=16, num_gpus=1, back_end='dp', pct_to_use=1.0):

        self.tot_batch_size     = tot_batch_size
        self.num_gpus           = num_gpus
        self.gpu_batch_size     = self.tot_batch_size//self.num_gpus if back_end=='dp' else self.tot_batch_size

        # req'd transforms
        self.basictransform     = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        # load basic data set and generate subsets
        # get len and calc end index for each subset
        ds_name                 = f'GUAC_faces_square_w125_dataset_classes_{num_classes}'
        guac_dataset            = GUAC('./data/GUAC', ds_name, transform=self.basictransform, target_transform=None, labelmask=None)

        # create end points for each subset
        dst_end                 = int(len(guac_dataset)*pct_to_use)
        trn_end                 = int(trn_ratio*dst_end)
        val_end                 = int(((1.0-trn_ratio)/2.0)*dst_end)+trn_end

        # create subsets, get batch cnts for each susbset, max idx for each subset

        self.trn_batch_cnt      = len(self.trn_dataset)//self.gpu_batch_size
        self.val_batch_cnt      = len(self.val_dataset)//self.gpu_batch_size
        self.tst_batch_cnt      = len(self.tst_dataset)//self.gpu_batch_size

        self.max_trn_idx        = self.trn_batch_cnt//self.num_gpus
        self.max_val_idx        = self.val_batch_cnt//self.num_gpus
        self.max_tst_idx        = self.tst_batch_cnt//self.num_gpus

        ..... skipping stuff you don't need to see
        # HERE is what you are looking for ... I included the rest because the code refers to some of the calculated values above

        # just showing some of it and for training only to keep things simple ...
        # per batch loss, accuracy history - cleared every epoch
        # epoch history by batch
        self.trn_batch_loss     = torch.Tensor(self.num_gpus, self.max_trn_idx, self.gpu_batch_size)   .fill_(0.0)
        self.trn_batch_acc      = torch.Tensor(self.num_gpus, self.max_trn_idx, self.gpu_batch_size, 2).fill_(0.0)

        self.trn_cw_cor         = torch.Tensor(num_classes).fill_(0.0)
        self.trn_ch_cor         = torch.Tensor(num_classes).fill_(0.0)

        self.trn_avg_loss       = 0.0
        self.trn_avg_acc        = 0.0
        self.trn_avg_acc_w      = 0.0
        self.trn_avg_acc_h      = 0.0
        self.trn_acc_std_w      = 0.0
        self.trn_acc_std_h      = 0.0

    # this function is called at the end of every training step to collect the info
    # loss tensor, accuracy in width, accuracy in height, the gpu as a index into history, AND idx - one entry per batch
    # notice that I copy the information, so that I am not keeping tensors around across iterations
    # without this, and a few other things, you create a memory leak
    # whypr
    def cap_trn_batch(self, loss, acc_w, acc_h, gpu, idx):
        if idx<self.max_trn_idx:
            for i in range(len(acc_w)):
                self.trn_batch_loss[gpu][idx][i]   = loss.item()
                self.trn_batch_acc[gpu][idx][i,0]  = acc_w[i].item()
                self.trn_batch_acc[gpu][idx][i,1]  = acc_h[i].item()

    # function to clear the tensors - just fill them so we don't reallocate
    def clear_trn_results(self):
        self.trn_batch_loss.fill_(0.0)
        self.trn_batch_acc. fill_(0.0)

    # called once an epoch to generate statistics
    def calc_epoch_stats(self):

        self.trn_avg_loss   = self.trn_batch_loss.        abs().mean()
        self.trn_avg_acc    = self.trn_batch_acc.         abs().mean()
        self.trn_avg_acc_w  = self.trn_batch_acc[:,:,:,0].mean()
        self.trn_avg_acc_h  = self.trn_batch_acc[:,:,:,1].mean()
        self.trn_acc_std_w  = self.trn_batch_acc[:,:,:,0].std()
        self.trn_acc_std_h  = self.trn_batch_acc[:,:,:,1].std()

Ok - thats it ... lets look at training_step:

   # called once for each training batch
    def training_step(self, batch, batch_nb):

        # fwd, loss, acc: classes and ypr
        imgs, lab_w, lab_h  = batch[0], batch[1][:,0].long(), batch[1][:,1].long()
        lab_y               = batch[1][:,2]
        lab_p               = batch[1][:,3]
        lab_r               = batch[1][:,4]
        out_w, out_h, p     = self.forward(imgs)

        trn_loss_w          = self.crit_w(out_w,  lab_w)
        trn_loss_h          = self.crit_h(out_h,  lab_h)*1.5
        trn_loss_y          = self.crit_p(p[:,0], lab_y)*.75
        trn_loss_p          = self.crit_p(p[:,1], lab_p)*1.5
        trn_loss_r          = self.crit_p(p[:,2], lab_r)
        trn_loss            = trn_loss_w+trn_loss_h+trn_loss_y+trn_loss_p+trn_loss_r

        acc_w               = lab_w==out_w.argmax(1)
        acc_h               = lab_h==out_h.argmax(1)

        # track acc by class
        for i in range(len(acc_w)):
            self.data.trn_cw_cor[lab_w[i]] += acc_w[i].item()
            self.data.trn_ch_cor[lab_h[i]] += acc_h[i].item()

        # capture data - adjust gpu arg to be zero based to store in data arrays
        dev                 = batch[0].device.index
        self.data.cap_trn_batch(trn_loss.detach(), acc_w.detach(), acc_h.detach(), dev-self.base_gpu_num, batch_nb)



        del imgs, lab_w, lab_h, lab_y, lab_p, lab_r, out_w, out_h, p, acc_w, acc_h, trn_loss_w, trn_loss_h, trn_loss_y, trn_loss_p, trn_loss_r
        return {'loss': trn_loss}

I think you can see the idea here. Create history TENSORS as part of your model - being in an instance that is part of your model is just as good. Write functions to put data in them, clear them and to calculate statistics. Put the data in them after every step. Calculate and use the results as you need to - generally once an epoch. BE CAREFUL to detach those tensors before you make the cap call - or else you will have a memory leak. I also found that I needed to delete them to avoid memory leaks.

Because you are using distributed processing you need to force the processes to all finish their calculations before you clear the data. I do it at the end of on_epoch_start

        .......

        # barrier() causes all the processes to wait for each other
        # only needed for ddp because other approaches have a single process - 
        # I think this would also be needed for ddp2, but I dont use it so have not allowed for it
        if self.back_end == 'ddp': torch.distributed.barrier()
        self.data.clear_all_results()

I think you can get it from here ....

@jmarsil
Copy link

jmarsil commented Feb 5, 2020

Thanks for this @sneiman , the fix is making a lot more sense now. I will implement a version of this for my auto segmentation pipeline & can share when completed.

Question regarding that torch.distributed.barrier() call. Is this called when you activate validation_end? If it is then theoretically you could log training metrics & save when you call that? That could work, but only if you would need a validation step.

It would be great if lightning could do something like this automatically when 'ddp' backend is called! Think engineering an internal solution like you presented would save a lot of headaches! Thoughts @williamFalcon ?

@sneiman
Copy link
Contributor

sneiman commented Feb 5, 2020 via email

@jmarsil
Copy link

jmarsil commented Feb 6, 2020

Interesting, thanks for the explanation! Working to develop a solution for 2D/3D segmentation. Have you come into hurdles with your ddp script actually training? Mine is seeming to freeze 1% into the first epoch.

@sneiman
Copy link
Contributor

sneiman commented Feb 6, 2020 via email

@williamFalcon
Copy link
Contributor

@brucemuller mind submitting a PR for this?

Ideally we bring results back for training_step and validation_step?

This has been on my radar for a bit but haven't had time to fix. WOuld love a PR!

@brucemuller
Copy link
Author

@williamFalcon Using @matthew-z code I think worked for me:

def gather_list_and_concat(list_of_nums):
tensor = torch.Tensor(list_of_nums).cuda()
gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(gather_t, tensor)
return torch.cat(gather_t)

res = gather_list_and_concat([1, 2]) # in node 1
res = gather_list_and_concat([3, 4]) # in node 2
res
torch.Tensor([1,2,3,4])

I'm not sure I have time for PR atm but would like to. If I create a fork would someone like to help make the changes (@sneiman @jmarsil) ? I haven't followed everything in this thread but would dist.all_gather be sufficient to keep pl light enough?

Side note: I used ddp on 4 GPUs recently and it was actually slower than using 1 GPU. Does anyone know how I can start understanding why or if there's any common reasons?

@matthew-z
Copy link
Contributor

What do you mean by "slow"?

The speed (e.g, n iter/sec) showed in the progress bar is the speed of each node. n is usually lower in multi-gpu mode, as the real speed should be n * world_size

If you mean something like "it takes more time to train an epoch" Well, you need to figure out what they are waiting for. E.g., dataloading, param syncing or a very slow node.

@williamFalcon
Copy link
Contributor

@brucemuller that could happen if you’re not using distributed sampler. we updated the docs with more details. check them out?

https://pytorch-lightning.readthedocs.io/en/latest/multi_gpu.html

@sneiman
Copy link
Contributor

sneiman commented Feb 13, 2020

@williamFalcon Using @matthew-z code I think worked for me:

def gather_list_and_concat(list_of_nums):
tensor = torch.Tensor(list_of_nums).cuda()
gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(gather_t, tensor)
return torch.cat(gather_t)

res = gather_list_and_concat([1, 2]) # in node 1
res = gather_list_and_concat([3, 4]) # in node 2
res
torch.Tensor([1,2,3,4])

I'm not sure I have time for PR atm but would like to. If I create a fork would someone like to help make the changes (@sneiman @jmarsil) ? I haven't followed everything in this thread but would dist.all_gather be sufficient to keep pl light enough?

Side note: I used ddp on 4 GPUs recently and it was actually slower than using 1 GPU. Does anyone know how I can start understanding why or if there's any common reasons?

Hi all -

My experience suggests that a reduce_all() call is being applied already. I have not had any time to look for where in the code this is happening, but the reason my approach works is that the history tensors I have been creating are are all identical after every epoch. I am literally doing nothing except:

declaring them when I create the model - in fact, I suspect this approach works because somewhere in pl or pytorch, ALL tensors that are attributes of the model class are identified and reduced
filling them in indexed by a 0 based dev number at every step - the values for each device are stored in their own dim 0 order of the history tensor.
calculating statistics as needed at on_epoch_end and test_end - of course this has to be aware that there each device's statistics are stored at different dim 0 indices in the history tensor

I can imagine doing some things with decorators to make this more a little less tedious - but not a lot more. Is there something i am missing about how this could be better?

@sneiman
Copy link
Contributor

sneiman commented Feb 14, 2020

Wanted to share an issue/solution that has surfaced in trying to bring objects from ddp spawned sub processes back to the master process.

The issue is this: if you create a python multiprocessing Queue or SimpleQueue in the normal way, and attempt to use it in process created by torch.multiprocessing.spawn(), as pytorch-lightning does, you will crash with a SIGSEGV.

The solution is to create the desired queue using the spawn context object for multiprocessing, but spawn the processes using the torch.multiprocessing module. Parroting the typical multiprocessing example:

import torch.multiprocessing as mp

def f(i, q):
    if q and not q.empty():
        print(f"{q.get()}")
        q.put(['goodbye'])

if __name__ == '__main__':
    # q       = mp.SimpleQueue()          # the normal way: DONT do this - leads to SIGSEGV

    # do this instead
    spawn_ctx = mp.get_context('spawn')   # get the spawn context
    q         = spawn_ctx.SimpleQueue()   # use it to create the Queue/SimpleQueue

    q.put(['hello'])
    p         = mp.spawn(f, (q,))         # but spawn it with torch.multiprocessing
    print(f"{q.get()}")

I have not tested, but I suspect the same is true for pipes and any other multiprocessing connection. Also note, I have tested this on Ubuntu 18.04, python 3.6.8, pytorch 1.4, pytorch-lightning 0.6 using a multi-core, multi-gpu machine, but not on a multi-node setup.

This happens with torch.multiprocessing.spawn() because the spawn call uses the same type of context to start the processes it launches. I think this is reasonable behavior, as the start method changes how sharing works between created threads, forks and processes - though I have not looked at that particular underlying code.

You cannot use spawn_ctx.spawn() because the context - which comes from the underlying python multiprocessing module - does not have a spawn method. This method is only part of torch.multiprocessing.

Still working on how to use this to get the stateful model and trainer back from a ddp job. There are a lot of complications because things have to be picklable to get through the queue. But at least there is a tunnel ...

For the curious, SIGSEGV is a kernel notice of a segment violation. Our spawned process's attempt to use q causes the process to use memory that it has no right to. The torch.multiprocessing.spawn call uses its own join and a wrapper to catch the exception and avoid segfault.

@williamFalcon @Borda @brucemuller @matthew-z @jmarsil - I have a collected a few tricks in using ddp - the history tensors above, this one, getting input from a sub-process, and perhaps a few others. Do you think it worthwhile to do a little write up about them?

@williamFalcon
Copy link
Contributor

@brucemuller @mikerossgithub @sneiman try master now? just merged #2434 which does a reduce mean op across all gpus for anything you return from val_epoch_end

@s-rog
Copy link
Contributor

s-rog commented Jul 1, 2020

Possibly related? I got this error on validation end after upgrading to masters from 0.8.1

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 538, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1158, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 370, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 470, in run_training_epoch
    self.run_evaluation(test_mode=False)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 408, in run_evaluation
    eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 346, in _evaluate
    self.reduce_eval_ddp(eval_results)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 363, in reduce_eval_ddp
    self.reduce_eval_ddp(v)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 365, in reduce_eval_ddp
    dist.all_reduce(v, op=dist.reduce_op.SUM)
  File "/opt/conda/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py", line 898, in all_reduce
    work = _default_pg.allreduce([tensor], opts)
RuntimeError: Tensors must be CUDA and dense

@williamFalcon
Copy link
Contributor

what are you returning from your validation_epoch_end?

@s-rog
Copy link
Contributor

s-rog commented Jul 1, 2020

one of the tensors returned for logging is a cpu tensor... now they have to be cuda tensors... intended? #2442

@Borda
Copy link
Member

Borda commented Jul 1, 2020

one of the tensors returned for logging is a cpu tensor... now they have to be cuda tensors... intended? #2442

logging does not like any other then CPU i guess..

@cattaneod
Copy link

cattaneod commented Oct 14, 2020

I have a related question:
I'm trying to train a network for place recognition, thus i need to gather the embeddings (network's outputs) of all validation samples in a single process to create a KDTree.
Without Pytorch Lightning i solved in this way:

emb_list = []
for batch_idx, sample in enumerate(validation_dataloader):
    emb = model(sample)
    dist.barrier()
    out_emb = [torch.zeros_like(emb) for _ in range(world_size)]
    dist.all_gather(out_emb, emb)
    if rank == 0:
        interleaved_out = torch.empty((emb.shape[0]*world_size, emb.shape[1]),
                                                         device=emb.device, dtype=emb.dtype)
        for current_rank in range(world_size):
            interleaved_out[current_rank::world_size] = out_emb[current_rank]
        emb_list.append(interleaved_out.detach().clone())
if rank == 0:
    emb_list = torch.cat(emb_list)
    # Create KDTree and compute recall

The interleaved_out is needed because the distributed sampler distribute the dataset in the following way (supposing to have 3 gpus):
GPU0 will process samples [0, 3, 6 , ...]
GPU1 will process samples [1, 4, 7 , ...]
GPU2 will process samples [2, 5, 8 , ...]

Is there any way to do the same in Pytorch Lightning?

@cattaneod
Copy link

cattaneod commented Oct 15, 2020

I was able to achieve the same in pytorch lightning calling dist.all_gather() inside validation_epoch_end, however in this way i can only use ddp training, and i lose some nice pytorch lightning features.

I think it would be nice to provide one hook that gather all the validation_step outputs on one machine, regardless of the backend.

@williamFalcon
Copy link
Contributor

good point. this is a new feature, mind opening a new GH issue?

@cattaneod
Copy link

good point. this is a new feature, mind opening a new GH issue?

Sure, i just created a new issues #4175

@neergaard
Copy link

Would this also support other objects than tensors from the validation_step and test_step methods, say lists of strings?

@Borda Borda added feature Is an improvement or enhancement and removed question Further information is requested labels Dec 23, 2020
@Arij-Aladel
Copy link

I was able to achieve the same in pytorch lightning calling dist.all_gather() inside validation_epoch_end, however in this way i can only use ddp training, and i lose some nice pytorch lightning features.

I think it would be nice to provide one hook that gather all the validation_step outputs on one machine, regardless of the backend.

@cattaneod tell me please does disy.all_gather() gather all results?

@cattaneod
Copy link

I was able to achieve the same in pytorch lightning calling dist.all_gather() inside validation_epoch_end, however in this way i can only use ddp training, and i lose some nice pytorch lightning features.
I think it would be nice to provide one hook that gather all the validation_step outputs on one machine, regardless of the backend.

@cattaneod tell me please does disy.all_gather() gather all results?

Yes, the code I provided in my previous comment works fine with DistributedDataParallel (i tested only with single machine, multi-GPU tho).
In that code, emb_list at the end contains all the output from all the process

@ZhiyuanChen
Copy link

I didn't try it with PL, but in general you may try to sync data as follow with DDP:

def gather_list_and_concat(list_of_nums):
    tensor = torch.Tensor(list_of_nums).cuda()
    gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
    dist.all_gather(gather_t, tensor)
    return torch.cat(gather_t)

>>> res = gather_list_and_concat([1, 2])   # in node 1
>>> res = gather_list_and_concat([3, 4])  # in node 2
>>> res
torch.Tensor([1,2,3,4])

this works only when tensor on every device comes with same size. an additional all_gather would require to get the length first to get proper result.

@wwx13
Copy link

wwx13 commented Sep 28, 2021

I didn't try it with PL, but in general you may try to sync data as follow with DDP:

def gather_list_and_concat(list_of_nums):
    tensor = torch.Tensor(list_of_nums).cuda()
    gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
    dist.all_gather(gather_t, tensor)
    return torch.cat(gather_t)

>>> res = gather_list_and_concat([1, 2])   # in node 1
>>> res = gather_list_and_concat([3, 4])  # in node 2
>>> res
torch.Tensor([1,2,3,4])

this works only when tensor on every device comes with same size. an additional all_gather would require to get the length first to get proper result.

I have a related question:
I'm trying to train a network for place recognition, thus i need to gather the embeddings (network's outputs) of all validation samples in a single process to create a KDTree.
Without Pytorch Lightning i solved in this way:

emb_list = []
for batch_idx, sample in enumerate(validation_dataloader):
    emb = model(sample)
    dist.barrier()
    out_emb = [torch.zeros_like(emb) for _ in range(world_size)]
    dist.all_gather(out_emb, emb)
    if rank == 0:
        interleaved_out = torch.empty((emb.shape[0]*world_size, emb.shape[1]),
                                                         device=emb.device, dtype=emb.dtype)
        for current_rank in range(world_size):
            interleaved_out[current_rank::world_size] = out_emb[current_rank]
        emb_list.append(interleaved_out.detach().clone())
if rank == 0:
    emb_list = torch.cat(emb_list)
    # Create KDTree and compute recall

The interleaved_out is needed because the distributed sampler distribute the dataset in the following way (supposing to have 3 gpus):
GPU0 will process samples [0, 3, 6 , ...]
GPU1 will process samples [1, 4, 7 , ...]
GPU2 will process samples [2, 5, 8 , ...]

Is there any way to do the same in Pytorch Lightning?

I wonder what if different process dataloader have different num of batch to iter? Will dis.barrier() stuck never stop?

@wwx13
Copy link

wwx13 commented Sep 28, 2021

y

I didn't try it with PL, but in general you may try to sync data as follow with DDP:

def gather_list_and_concat(list_of_nums):
    tensor = torch.Tensor(list_of_nums).cuda()
    gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
    dist.all_gather(gather_t, tensor)
    return torch.cat(gather_t)

>>> res = gather_list_and_concat([1, 2])   # in node 1
>>> res = gather_list_and_concat([3, 4])  # in node 2
>>> res
torch.Tensor([1,2,3,4])

this works only when tensor on every device comes with same size. an additional all_gather would require to get the length first to get proper result.

I have a related question:
I'm trying to train a network for place recognition, thus i need to gather the embeddings (network's outputs) of all validation samples in a single process to create a KDTree.
Without Pytorch Lightning i solved in this way:

emb_list = []
for batch_idx, sample in enumerate(validation_dataloader):
    emb = model(sample)
    dist.barrier()
    out_emb = [torch.zeros_like(emb) for _ in range(world_size)]
    dist.all_gather(out_emb, emb)
    if rank == 0:
        interleaved_out = torch.empty((emb.shape[0]*world_size, emb.shape[1]),
                                                         device=emb.device, dtype=emb.dtype)
        for current_rank in range(world_size):
            interleaved_out[current_rank::world_size] = out_emb[current_rank]
        emb_list.append(interleaved_out.detach().clone())
if rank == 0:
    emb_list = torch.cat(emb_list)
    # Create KDTree and compute recall

The interleaved_out is needed because the distributed sampler distribute the dataset in the following way (supposing to have 3 gpus):
GPU0 will process samples [0, 3, 6 , ...]
GPU1 will process samples [1, 4, 7 , ...]
GPU2 will process samples [2, 5, 8 , ...]
Is there any way to do the same in Pytorch Lightning?

I wonder what if different process dataloader have different num of batch to iter? Will dis.barrier() stuck never stop?

My code shows it will stuck: dist.barrier, how to slove it? Can anyone help m?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet