From dcc83e7c2d7283b3c67461fadb445209a9656458 Mon Sep 17 00:00:00 2001 From: jxtps Date: Sun, 27 Mar 2022 00:09:26 -0700 Subject: [PATCH] * Add `long[] pytorch.Tensor.shape()` method for convenience (pull #1161) --- CHANGELOG.md | 1 + .../java/org/bytedeco/pytorch/AbstractTensor.java | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20501ae34b9..8bea6b664aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ + * Add `long[] pytorch.Tensor.shape()` method for convenience ([pull #1161](https://github.com/bytedeco/javacpp-presets/pull/1161)) * Enable DNNL codegen as BYOC backend in presets for TVM * Allow passing raw pointer as deleter to `from_blob()`, etc functions of PyTorch ([discussion #1160](https://github.com/bytedeco/javacpp-presets/discussions/1160)) * Include `cudnn_backend.h` header file in presets for CUDA ([issue #1158](https://github.com/bytedeco/javacpp-presets/issues/1158)) diff --git a/pytorch/src/main/java/org/bytedeco/pytorch/AbstractTensor.java b/pytorch/src/main/java/org/bytedeco/pytorch/AbstractTensor.java index c647dfbc5be..e7d7e704549 100644 --- a/pytorch/src/main/java/org/bytedeco/pytorch/AbstractTensor.java +++ b/pytorch/src/main/java/org/bytedeco/pytorch/AbstractTensor.java @@ -65,6 +65,20 @@ public static Tensor create(byte[] data, boolean signed, long... shape) { public abstract long nbytes(); public abstract Pointer data_ptr(); + /** + * Convenience method, similar to {@code sizes().vec().get()}. + * + * Returns a new {@code long[]} with each call since e.g. transpose_() and squeeze_() can change the shape of the tensor, + * and the caller could otherwise modify the contents, surprising subsequent callers. + * + * Please memoize externally if you're concerned about performance. + */ + public long[] shape() { + long[] out = new long[(int) ndimension()]; + for (int i = 0; i < out.length; i++) out[i] = size(i); + return out; + } + /** Returns {@code createBuffer(0)}. */ public B createBuffer() { return (B)createBuffer(0);