Skip to content

Commit

Permalink
Clone native tensor to avoid having buffers with corrupted data
Browse files Browse the repository at this point in the history
fixes #25

Co-authored-by: Sören Brunk <soeren@brunk.io>
  • Loading branch information
davoclavo and sbrunk committed Jun 15, 2023
1 parent 3fb3c5d commit 0f470d7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ object Tensor:
Array(data.length.toLong),
NativeConverters.tensorOptions(inputDType, layout, CPU, requiresGrad)
)
.clone()
).to(device = device)
case data: U =>
val dtype = scalaToDType(data)
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/scala/torch/TensorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -884,4 +884,17 @@ class TensorSuite extends TensorCheckSuite {
),
expectedTensor = Tensor(Seq(0.0, 0.0, 0.0, 0.0, Double.NaN))
)
test("Tensor creation properly handling buffers") {
val value = 100L
val data = Seq.fill(10000)(value)
val tensors = 1.to(1000).map { _ =>
Tensor(data)
}
assert(
tensors.forall { t =>
t.min().item == value &&
t.max().item == value
}
)
}
}

0 comments on commit 0f470d7

Please sign in to comment.