Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

masking compatible with fullgraph compile #91

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

theAdamColton
Copy link
Contributor

this adds some slightly confusing masking code, but improves speed by 3x by making the shape of intermediate tensors non-dynamic. The masked_mean code is equivalent, up to fp precision, with the old code that used tensor indexing

Before, using LFQ with masking was not compatible with torch.compile with fullgraph=True or with dynamic=False. It was compatible with plain torch.compile, but the masked tensor indexing caused graph breaks

I added an example that uses masked sequences, to make sure it works properly

I did a benchmark. I ran the example code that uses masking. This was on a 3090 GPU

  • the previous masked LFQ implementation, using torch.compile(model, fullgraph=False, mode='max-autotune'), had an average model.forward time of 1.18 milliseconds
  • with this commit, using torch.compile(model, fullgraph=True, mode='max-autotune'), the average time is 0.40 milliseconds

The speedup might be worth the extra confusingness in the code

@lucidrains
Copy link
Owner

ah yea, that does look a bit confusing, needs a tiny bit more work

do you think you can try fitting all the logic into one function, masked_mean, where if mask is None, it simply takes a regular .mean()?

@lucidrains
Copy link
Owner

we can reassess after your refactor

@lucidrains
Copy link
Owner

@theAdamColton have you tried the updated LFQ? curious how you got good results on the previous broken one

@theAdamColton
Copy link
Contributor Author

With the previous LFQ i set entropy loss and commit loss to very low weights and it did actually work.

@theAdamColton
Copy link
Contributor Author

I've also been experimenting with the entropy loss from maskgit, it does it slightly different than the current lfq code here. The one there seems to work pretty well

@theAdamColton
Copy link
Contributor Author

Also, this is a different issue, but I think here where the entropy is computed, maybe it should use F.log_softmax to separately compute the log probs from the distances, instead of taking the log of the probs to get the log probs.

@lucidrains
Copy link
Owner

Also, this is a different issue, but I think here where the entropy is computed, maybe it should use F.log_softmax to separately compute the log probs from the distances, instead of taking the log of the probs to get the log probs.

@theAdamColton how is that different? can you show me in code?

@theAdamColton
Copy link
Contributor Author

@lucidrains
for example, instead of

prob = (-distance * inv_temperature).softmax(dim = -1)
per_sample_entropy = (-prob * log(prob)).sum(dim=-1).mean()

this is what I mean:

prob = (-distance * inv_temperature).softmax(dim = -1)
log_prob = F.log_softmax(-distance * inv_temperature, dim = -1)
per_sample_entropy = (-prob * log_prob).sum(dim=-1).mean()

I don't know if it would make a difference, but it's what the maskgit code does. Using log_softmax might fix precision issues

from the pytorch log_softmax doc
"While mathematically equivalent to log(softmax(x)), doing these two operations separately is slower and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly."

@lucidrains
Copy link
Owner

I think the numerical stability is accounted for by the epsilon in the log I have in the file, but do let me know otherwise

@lucidrains
Copy link
Owner

lucidrains commented Dec 9, 2023

anyways, I've put in my hours today, happy Saturday! See if you can get that mask to go into the masked mean fn and I'll review it again

@lucidrains lucidrains force-pushed the master branch 2 times, most recently from d9967be to 34b9e97 Compare May 10, 2024 14:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants