Skip to content

Commit

Permalink
Merge pull request #223 from ar-nowaczynski/unclamped_fape_per_sample
Browse files Browse the repository at this point in the history
Set clamped vs unclamped FAPE for each sample in batch independently
  • Loading branch information
gahdritz committed Sep 30, 2022
2 parents a1f77ad + 43941b8 commit 55c5623
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
5 changes: 0 additions & 5 deletions openfold/data/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,6 @@ def _prep_batch_properties_probs(self):
stage_cfg = self.config[self.stage]

max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)

if(stage_cfg.uniform_recycling):
recycling_probs = [
Expand Down
15 changes: 15 additions & 0 deletions openfold/data/feature_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ def np_example_to_features(
cfg[mode],
)

if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)

return {k: v for k, v in features.items()}


Expand Down

0 comments on commit 55c5623

Please sign in to comment.