Skip to content

Commit

Permalink
Vertically align the norm outputs for the log_norms in the constituen…
Browse files Browse the repository at this point in the history
…cy parser --log_norms. Also, count the zeros (or close to zero) terms from the tensors as a signal of how many neurons are dying
  • Loading branch information
AngledLuffa committed Aug 30, 2023
1 parent 5c84a29 commit e4ce3c9
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion stanza/models/constituency/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,15 @@ def get_norms(self):
lines.append("reduce_linear:")
for c_idx, c_open in enumerate(self.constituent_opens):
lines.append(" %s weight %.6g bias %.6g" % (c_open, torch.norm(self.reduce_linear_weight[c_idx]).item(), torch.norm(self.reduce_linear_bias[c_idx]).item()))
max_name_len = max(len(name) for name, param in self.named_parameters() if param.requires_grad and name not in skip)
max_norm_len = max(len("%.6g" % torch.norm(param).item()) for name, param in self.named_parameters() if param.requires_grad and name not in skip)
format_string = "%-" + str(max_name_len) + "s norm %" + str(max_norm_len) + "s zeros %d / %d"
print(format_string)
for name, param in self.named_parameters():
if param.requires_grad and name not in skip:
lines.append("%s %.6g" % (name, torch.norm(param).item()))
zeros = torch.sum(param.abs() < 0.000001).item()
norm = "%.6g" % torch.norm(param).item()
lines.append(format_string % (name, norm, zeros, param.nelement()))
return lines

def log_norms(self):
Expand Down

0 comments on commit e4ce3c9

Please sign in to comment.