From 0f470d7ee25f31b34030fdd6b3510e262404001c Mon Sep 17 00:00:00 2001 From: David Gomez-Urquiza Date: Thu, 15 Jun 2023 15:43:21 -0600 Subject: [PATCH] Clone native tensor to avoid having buffers with corrupted data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fixes #25 Co-authored-by: Sören Brunk --- core/src/main/scala/torch/Tensor.scala | 1 + core/src/test/scala/torch/TensorSuite.scala | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index 26e2f0fd..2b1ca339 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -797,6 +797,7 @@ object Tensor: Array(data.length.toLong), NativeConverters.tensorOptions(inputDType, layout, CPU, requiresGrad) ) + .clone() ).to(device = device) case data: U => val dtype = scalaToDType(data) diff --git a/core/src/test/scala/torch/TensorSuite.scala b/core/src/test/scala/torch/TensorSuite.scala index 40e21c53..e51c71da 100644 --- a/core/src/test/scala/torch/TensorSuite.scala +++ b/core/src/test/scala/torch/TensorSuite.scala @@ -884,4 +884,17 @@ class TensorSuite extends TensorCheckSuite { ), expectedTensor = Tensor(Seq(0.0, 0.0, 0.0, 0.0, Double.NaN)) ) + test("Tensor creation properly handling buffers") { + val value = 100L + val data = Seq.fill(10000)(value) + val tensors = 1.to(1000).map { _ => + Tensor(data) + } + assert( + tensors.forall { t => + t.min().item == value && + t.max().item == value + } + ) + } }