Skip to content

Commit

Permalink
fix: default num_ln_in_parallel_attn to one if not supplied (#2364)
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh authored Aug 6, 2024
1 parent 1768c00 commit a64d407
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,9 @@ def forward(
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
# Falcon2 includes the number of layer norms in the config
# in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)

if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
Expand Down

0 comments on commit a64d407

Please sign in to comment.