Skip to content

Commit

Permalink
Merge pull request #3988 from google:nnx-lora-param
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648503216
  • Loading branch information
Flax Authors committed Jul 1, 2024
2 parents afaa721 + 24c44f6 commit b5503da
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion flax/nnx/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@
default_kernel_init = initializers.lecun_normal()


class LoRAParam(variables.Variable[A]):
class LoRAParam(variables.Param[A]):
pass



class LoRA(Module):
"""A standalone LoRA layer.
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/tests/nn/lora_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __call__(self, x):
def test_lora_param_type(self):
rngs = nnx.Rngs(0)
model = nnx.LoRA(3, 4, 2, lora_param_type=nnx.LoRAParam, rngs=rngs)
_, params, lora_params = nnx.split(model, nnx.Param, nnx.LoRAParam)
_, lora_params, params = nnx.split(model, nnx.LoRAParam, nnx.Param)
assert params == {}
assert ('lora_a' in lora_params) and ('lora_b' in lora_params)
np.testing.assert_allclose(lora_params.lora_a.value, model.lora_a.value)
Expand Down

0 comments on commit b5503da

Please sign in to comment.