You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Dear Author,
I noticed that in FDS.smooth code,there exists:
"feature[labels == label] = self.calibrate_mean_var(feature[labels == label], ...)"
Does the operation an inplace operation when backwards? features receive gradient from regressor decoder and backward.This inplace operation may have error?
Looking forward for your reply!Thanks!
I found in 2021,some issues also pointed out the potential bug when apllied FDS to their work.I think we could easily use torch.masked_scatter to fix the problem like that: features = torch.masked_scatter(features,(labels == label).unsqueeze(1).repeat(1,features.shape[1]),self.calibrate_mean_var( features[labels == label], self.running_mean_last_epoch[int(label - self.bucket_start)], self.running_var_last_epoch[int(label - self.bucket_start)], self.smoothed_mean_last_epoch[int(label - self.bucket_start)], self.smoothed_var_last_epoch[int(label - self.bucket_start)]))
Is this right?
The text was updated successfully, but these errors were encountered:
Dear Author,
I noticed that in FDS.smooth code,there exists:
"feature[labels == label] = self.calibrate_mean_var(feature[labels == label], ...)"
Does the operation an inplace operation when backwards? features receive gradient from regressor decoder and backward.This inplace operation may have error?
Looking forward for your reply!Thanks!
I found in 2021,some issues also pointed out the potential bug when apllied FDS to their work.I think we could easily use torch.masked_scatter to fix the problem like that:
features = torch.masked_scatter(features,(labels == label).unsqueeze(1).repeat(1,features.shape[1]),self.calibrate_mean_var( features[labels == label], self.running_mean_last_epoch[int(label - self.bucket_start)], self.running_var_last_epoch[int(label - self.bucket_start)], self.smoothed_mean_last_epoch[int(label - self.bucket_start)], self.smoothed_var_last_epoch[int(label - self.bucket_start)]))
Is this right?
The text was updated successfully, but these errors were encountered: