Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor optimizations & fixes to support ESMFold #199

Merged
merged 4 commits into from
Aug 23, 2022

Conversation

nikitos9000
Copy link
Contributor

Hi OpenFold team! Thanks for your great implementation, here are some suggestions to the codebase to support upcoming ESMFold release:

  1. Constant tensors like default_frames, group_idx, atom_mask and lit_positions in StructureModule are lazily initialized as buffers to allow for flawless CPU<->GPU model conversion;
  2. Vectorized ops for row_mul and row_vec_mul + disable mixed precision for them, as precision loss may occur;
  3. Constant quaternions are cached to not be recreated on each call.

@gahdritz
Copy link
Collaborator

Thanks!

Do those autocast fixes work during DeepSpeed training, where an APEX-based autocast framework is used instead of the native torch one? The reason those operations are spelled out like that manually in the first place is to avoid automatic casting of all kinds.

@nikitos9000
Copy link
Contributor Author

@gahdritz Yeah this won't work with APEX amp, right, only the native torch. I guess I can just roll back this change.

@gahdritz gahdritz merged commit 4b41059 into aqlaboratory:main Aug 23, 2022
@ebetica ebetica deleted the upstream_updates branch October 2, 2023 21:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants