From 3c149064df80fe7785ae1cc04f6ca2606c094316 Mon Sep 17 00:00:00 2001 From: snarayan21 Date: Wed, 31 Jan 2024 20:36:26 -0800 Subject: [PATCH] Fix lp layernorm weight (#2954) * lp layernorm weight fix * Update composer/algorithms/low_precision_layernorm/low_precision_layernorm.py Co-authored-by: Mihir Patel --------- Co-authored-by: Brian <23239305+b-chu@users.noreply.github.com> Co-authored-by: Mihir Patel --- .../low_precision_layernorm/low_precision_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py index 112de20803..9324289351 100644 --- a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py +++ b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py @@ -143,7 +143,7 @@ def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm: lp_layernorm = LPLayerNorm(layer.normalized_shape, layer.eps, layer.elementwise_affine) with torch.no_grad(): - if hasattr(layer, 'weight'): + if layer.weight is None: # pyright: ignore[reportUnnecessaryComparison] lp_layernorm.register_parameter('weight', None) else: lp_layernorm.weight.copy_(layer.weight) # type: ignore