Skip to content

Commit

Permalink
fixed ssim evaluation concatenation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
z-fabian committed Aug 13, 2020
1 parent 955b6a6 commit 2b442ad
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions fastmri/mri_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def validation_epoch_end(self, val_logs):
# handle aggregation for distributed case with pytorch_lightning metrics
metrics = dict(val_loss=0, nmse=0, ssim=0, psnr=0)
for fname in outputs:
output = torch.cat([out for _, out in sorted(outputs[fname])]).numpy()
target = torch.cat([tgt for _, tgt in sorted(targets[fname])]).numpy()
output = torch.stack([out for _, out in sorted(outputs[fname])], dim=0).numpy()
target = torch.stack([tgt for _, tgt in sorted(targets[fname])], dim=0).numpy()
metrics["nmse"] = metrics["nmse"] + evaluate.nmse(target, output)
metrics["ssim"] = metrics["ssim"] + evaluate.ssim(target, output)
metrics["psnr"] = metrics["psnr"] + evaluate.psnr(target, output)
Expand All @@ -227,13 +227,12 @@ def validation_epoch_end(self, val_logs):

num_examples = torch.tensor(len(outputs)).to(device)
tot_examples = self.TotExamples(num_examples)

log_metrics = {
f"metrics/{metric}": values / tot_examples
for metric, values in metrics.items()
}
print(log_metrics)
metrics = {metric: values / tot_examples for metric, values in metrics.items()}

return dict(log=log_metrics, **metrics)

def test_epoch_end(self, test_logs):
Expand Down

0 comments on commit 2b442ad

Please sign in to comment.