Skip to content

Commit

Permalink
Added support for :bool dtype to numo method
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Aug 19, 2024
1 parent b509c33 commit 9dcf3ce
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 7 additions & 3 deletions lib/torch/tensor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,13 @@ def new

# TODO read directly from memory
def numo
cls = Torch._dtype_to_numo[dtype]
raise Error, "Cannot convert #{dtype} to Numo" unless cls
cls.from_string(_data_str).reshape(*shape)
if dtype == :bool
Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
else
cls = Torch._dtype_to_numo[dtype]
raise Error, "Cannot convert #{dtype} to Numo" unless cls
cls.from_string(_data_str).reshape(*shape)
end
end

def requires_grad=(requires_grad)
Expand Down
4 changes: 4 additions & 0 deletions test/numo_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ def test_numo
x = Torch.tensor([[1, 2, 3], [4, 5, 6]])
assert x.numo.is_a?(Numo::Int64)
assert_equal x.to_a, x.numo.to_a

x = Torch.tensor([[true, false], [false, true]])
assert x.numo.is_a?(Numo::Bit)
assert_equal [[1, 0], [0, 1]], x.numo.to_a
end

def test_from_numo
Expand Down

0 comments on commit 9dcf3ce

Please sign in to comment.