Skip to content

Commit

Permalink
fix: Revert previous fix for unfold method in torch frontend and push…
Browse files Browse the repository at this point in the history
… a better fix for shape mismatch issues with the native fw
  • Loading branch information
hmahmood24 committed Jul 2, 2024
1 parent 4396c35 commit 1399769
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,17 @@ def unfold(self, dimension, size, step):
slicing[dimension] = slice(i, i + size)
slices.append(self.ivy_array[tuple(slicing)])
stacked = torch_frontend.stack(slices, dim=dimension)
new_shape = list(self.shape)
num_slices = (self.shape[dimension] - size) // step + 1
new_shape[dimension] = num_slices
if dimension == -1:
new_shape.insert(dimension, size)
else:
new_shape.insert(dimension + 1, size)
reshaped = stacked.reshape(new_shape)
dims = list(range(len(stacked.shape)))
dims[-2], dims[-1] = dims[-1], dims[-2]
return stacked.permute(*dims)
return reshaped.permute(*dims)

def long(self, memory_format=None):
self.ivy_array = ivy.astype(self.ivy_array, ivy.int64, copy=False)
Expand Down

0 comments on commit 1399769

Please sign in to comment.