Skip to content

Commit

Permalink
[pytorch] Avoid toByteBuffer() crash for large tensor (#2780)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Sep 14, 2023
1 parent 20423a6 commit 2eb77e5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
*/
package ai.djl.pytorch.integration;

import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

import org.testng.Assert;
import org.testng.annotations.Test;
Expand All @@ -34,4 +37,12 @@ public void testStringTensor() {
Assert.assertThrows(UnsupportedOperationException.class, () -> arr.get(0));
}
}

@Test
public void testLargeTensor() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array = manager.zeros(new Shape(10 * 2850, 18944), DataType.FLOAT32);
Assert.assertThrows(EngineException.class, array::toByteArray);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,13 @@ JNIEXPORT jbyteArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDataPtr
// sparse and mkldnn are required to be converted to dense to access data ptr
auto tensor = (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn()) ? tensor_ptr->to_dense() : *tensor_ptr;
tensor = (tensor.is_contiguous()) ? tensor : tensor.contiguous();
jbyteArray result = env->NewByteArray(tensor.nbytes());
env->SetByteArrayRegion(result, 0, tensor.nbytes(), static_cast<const jbyte*>(tensor.data_ptr()));
size_t nbytes = tensor.nbytes();
if (nbytes > 0x7fffffff) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "toByteBuffer() is not supported for large tensor");
return env->NewByteArray(0);
}
jbyteArray result = env->NewByteArray(nbytes);
env->SetByteArrayRegion(result, 0, nbytes, static_cast<const jbyte*>(tensor.data_ptr()));
return result;
API_END_RETURN()
}
Expand Down

0 comments on commit 2eb77e5

Please sign in to comment.