Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metrics] Automatic reduction of metrics from several validation steps #1249

Closed
Laksh1997 opened this issue Mar 26, 2020 · 8 comments
Closed
Labels
discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@Laksh1997
Copy link

🚀 Feature

As per the slack, it could be cool to implement this. More detail below.

Motivation

To avoid the user having to do this

logits = torch.cat(x['logits'] for x in output)
labels = torch.cat(x['labels'] for x in output) 
and so on ...

Pitch

Something like this:

    def collate_metrics(self, output):
        """
        Function to collate the output from several validation steps
        """
        collated_output = {}
        keys = output[0].keys()
        for key in keys:
            tensor_dim = output[0][key].dim()
            if tensor_dim > 0:
                collated_output[key] = torch.cat([x[key] for x in output])
            elif tensor_dim == 0:
                # Reduce scalars by mean
                collated_output[key] = torch.tensor([x[key] for x in output]).mean()
        return collated_output

Alternatives

Can just add the above to lightning module and use it anyway.

@Laksh1997 Laksh1997 added feature Is an improvement or enhancement help wanted Open to be worked on labels Mar 26, 2020
@awaelchli
Copy link
Member

I think this is cool. Things that come to my mind:

  • concatenate or stack seams reasonable as a default collate, but I would do without the mean, it's too specific.
  • let the user override the collate method
  • if there is a collate, it should not only apply to validation, but also to training_end, test_end, right? Then the question is do we let the user override each of these?

@oplatek
Copy link
Contributor

oplatek commented Mar 26, 2020

I think that collate_metrics function would need at least:

  • filtering - on which keys it should be applied
  • needed to be used on several places - end of epoch, end of training, etc ...

The code will start to be quite messy soon.
I quite like this approach #973 (comment)

@Borda
Copy link
Member

Borda commented Mar 26, 2020

I guess that if add metrics as a class discussed in #973 we may define for each custom reduction method, right?

@Borda Borda added the discussion In a discussion stage label Mar 26, 2020
@Borda Borda added this to the 0.7.3 milestone Mar 26, 2020
@Borda Borda modified the milestones: 0.7.4, 0.7.5 Apr 24, 2020
@Borda Borda modified the milestones: 0.7.6, 0.8.0, 0.7.7 May 13, 2020
@Borda Borda modified the milestones: 0.7.7, 0.8.0, Metrics May 26, 2020
@Borda Borda modified the milestones: 0.8.0, 0.9.0 Jun 9, 2020
@Borda
Copy link
Member

Borda commented Jun 11, 2020

cc: @justusschock @SkafteNicki

@justusschock
Copy link
Member

@Borda I think we will integrate this into an automated metric calculation plan that also has a different collation per metric.

@edenlightning
Copy link
Contributor

@justusschock can we close this issue? was it fixed?

@edenlightning edenlightning changed the title Automatic reduction of metrics from several validation steps [metrics] Automatic reduction of metrics from several validation steps Jun 18, 2020
@justusschock
Copy link
Member

@edenlightning not yet. We haven't yet come to implementing accumulation for metrics. This will be V2 of metrics

@Borda Borda removed this from the 0.9.0 milestone Jun 18, 2020
@Borda Borda added this to the 0.8.x milestone Jun 18, 2020
@Borda Borda modified the milestones: 0.8.x, 0.9.0 Aug 6, 2020
@edenlightning edenlightning modified the milestones: 0.9.0, 0.9.x Aug 18, 2020
@SkafteNicki
Copy link
Member

With PR #3245 merge, this is solved now. Each metric now have a aggregated property that contains the aggregated metric value of data seen so far. In practice you can use it like this in lightning:

def validation_step(self, batch, batch_idx):
    x, y = batch
    ypred = self(x)
    loss = self.loss_fn(ypred, y)
    val = self.metric(ypred, y)
    return loss # no need to return the value of the metric

def validation_epoch_end(self, validation_step_outputs):
    aggregated_metric = self.metric.aggregated
    return aggregated_metric

Closing this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

7 participants