Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix module initialization with other dtypes and simplify module registration #47

Merged
merged 2 commits into from
Jul 30, 2023

Conversation

sbrunk
Copy link
Owner

@sbrunk sbrunk commented Jul 30, 2023

Fixes and improves initializing modules with other parameter types.

import torch.*

// Default:
nn.Linear(10, 10) // Linear[Float32]

// Override explicitly:
nn.Linear[BFloat16](10, 10) // Linear[BFloat16]

// Override default via context parameter:
import Default.float64
nn.Linear(10, 10) // Linear[Float64]

This is currently inconsistent with tensor creation ops like torch.ones where we use a default parameter for the dtype, meaning it's easier to override at runtime, but we have a fixed default. Both designs have tradeoffs and we need to test what works well in practice. Perhaps we can event find a way to combine both approaches.

Type of all modules fixed at creation time

To make container classes like Sequential work, we currently fix the dtype on module initialization even for parameterless modules like Softmax:

final class Softmax[D <: DType: Default](dim: Int) extends TensorModule[D]:
  def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

Instead of a generic apply method:

final class Softmax(dim: Int):
  def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

If we're able to find a better way, we might be able to change this back in the future.

DType conversion of modules

Module type conversion will need rethinking the module design taking into account (im)mutability of modules and things like mixed precision training and perhaps quantization. For now the type is fixed on module creation.

Other changes

  • Simplify module registration after improvements in new presets.
  • Remove broken copy and to(dtype) methods from module.

Remove broken copy and to(dtype) methods from module.
Module type conversion will need rethinking the module design so for now it's fixed on creation.
@sbrunk sbrunk force-pushed the fix-module-initialization branch from 20770c4 to b2b5a2f Compare July 30, 2023 08:47
@sbrunk sbrunk added the enhancement New feature or request label Jul 30, 2023
@sbrunk sbrunk merged commit 131ba89 into main Jul 30, 2023
7 checks passed
@sbrunk sbrunk deleted the fix-module-initialization branch July 30, 2023 09:02
sbrunk added a commit to davoclavo/storch that referenced this pull request Jul 30, 2023
Move type parameter in all modules from apply to constructor for consistency and compat with Sequential.
See sbrunk#47 for details about the current design.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant