Skip to content

Commit

Permalink
* Overload Tensor.create() factory methods for TensorFlow with han…
Browse files Browse the repository at this point in the history
…dy `long... shape` (issue bytedeco/javacpp#301)
  • Loading branch information
saudet committed May 16, 2019
1 parent 5b4f3b5 commit 06c1b76
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

* Overload `Tensor.create()` factory methods for TensorFlow with handy `long... shape` ([issue bytedeco/javacpp#301](https://github.com/bytedeco/javacpp/issues/301))
* Add build for `linux-arm64` to presets for OpenBLAS ([pull #726](https://github.com/bytedeco/javacpp-presets/pull/726))
* Bundle complete binary packages of CPython itself for convenience ([issue #712](https://github.com/bytedeco/javacpp-presets/issues/712))
* Fix and refine mapping for `HoughLines`, `HoughLinesP`, and `HoughCircles` ([issue #717](https://github.com/bytedeco/javacpp-presets/issues/717))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ private native void allocate(CollectiveExecutor col_exec, @Const DeviceMgr dev_m
@Cast("tensorflow::int64") long step_id, @Const Tensor input, Tensor output);

public native CollectiveExecutor col_exec(); public native CollectiveContext col_exec(CollectiveExecutor setter); // Not owned
@MemberGetter public native @Const DeviceMgr dev_mgr(); // Not owned
public native @Const DeviceMgr dev_mgr(); public native CollectiveContext dev_mgr(DeviceMgr setter); // Not owned
public native OpKernelContext op_ctx(); public native CollectiveContext op_ctx(OpKernelContext setter); // Not owned
public native OpKernelContext.Params op_params(); public native CollectiveContext op_params(OpKernelContext.Params setter); // Not owned
@MemberGetter public native @Const @ByRef CollectiveParams col_params();
@MemberGetter public native @StdString BytePointer exec_key();
@MemberGetter public native @Cast("const tensorflow::int64") long step_id();
@MemberGetter public native @Const Tensor input(); // Not owned
public native @Const Tensor input(); public native CollectiveContext input(Tensor setter); // Not owned
public native Tensor output(); public native CollectiveContext output(Tensor setter); // Not owned
public native Device device(); public native CollectiveContext device(Device setter); // The device for which this instance labors
@MemberGetter public native @StdString BytePointer device_name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static class InstantiateOptions extends Pointer {
// between a set of libraries (e.g. by allowing a
// `FunctionLibraryDefinition` to store an `outer_scope` pointer
// and implementing name resolution across libraries).
@MemberGetter public native @Const FunctionLibraryDefinition overlay_lib();
public native @Const FunctionLibraryDefinition overlay_lib(); public native InstantiateOptions overlay_lib(FunctionLibraryDefinition setter);

// This interface is EXPERIMENTAL and subject to change.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static class Params extends Pointer {
public native @Cast("bool") boolean record_tensor_accesses(); public native Params record_tensor_accesses(boolean setter);

// Array indexed by output number for this node
@MemberGetter public native @Const AllocatorAttributes output_attr_array();
public native @Const AllocatorAttributes output_attr_array(); public native Params output_attr_array(AllocatorAttributes setter);

// Shared resources accessible by this op kernel invocation.
public native ResourceMgr resource_manager(); public native Params resource_manager(ResourceMgr setter);
Expand Down Expand Up @@ -94,13 +94,13 @@ public static class Params extends Pointer {
public native CancellationManager cancellation_manager(); public native Params cancellation_manager(CancellationManager setter);

// Inputs to this op kernel.
@MemberGetter public native @Const TensorValueVector inputs();
public native @Const TensorValueVector inputs(); public native Params inputs(TensorValueVector setter);
public native @Cast("bool") boolean is_input_dead(); public native Params is_input_dead(boolean setter);

@MemberGetter public native @Const AllocatorAttributesVector input_alloc_attrs();
public native @Const AllocatorAttributesVector input_alloc_attrs(); public native Params input_alloc_attrs(AllocatorAttributesVector setter);

// Device contexts.
@MemberGetter public native @Const DeviceContextInlinedVector input_device_contexts();
public native @Const DeviceContextInlinedVector input_device_contexts(); public native Params input_device_contexts(DeviceContextInlinedVector setter);
public native DeviceContext op_device_context(); public native Params op_device_context(DeviceContext setter);

// Control-flow op supports.
Expand All @@ -122,7 +122,7 @@ public static class Params extends Pointer {
@MemberGetter public static native int kNoReservation();
public static final int kNoReservation = kNoReservation();
// Values in [0,...) represent reservations for the indexed output.
@MemberGetter public native @Const IntPointer forward_from_array();
public native @Const IntPointer forward_from_array(); public native Params forward_from_array(IntPointer setter);
}

// params must outlive the OpKernelContext.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ public class SessionState extends Pointer {

public native @Cast("tensorflow::int64") long GetNewId();

@MemberGetter public static native @Cast("const char*") BytePointer kTensorHandleResourceTypeName();
public static native @Cast("const char*") BytePointer kTensorHandleResourceTypeName(); public static native void kTensorHandleResourceTypeName(BytePointer setter);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class TF_ImportGraphDefResults extends Pointer {

public native @StdVector TF_Output return_tensors(); public native TF_ImportGraphDefResults return_tensors(TF_Output setter);
public native @Cast("TF_Operation**") @StdVector PointerPointer return_nodes(); public native TF_ImportGraphDefResults return_nodes(PointerPointer setter);
@MemberGetter public native @Cast("const char**") @StdVector PointerPointer missing_unused_key_names();
public native @Cast("const char**") @StdVector PointerPointer missing_unused_key_names(); public native TF_ImportGraphDefResults missing_unused_key_names(PointerPointer setter);
public native @StdVector IntPointer missing_unused_key_indexes(); public native TF_ImportGraphDefResults missing_unused_key_indexes(IntPointer setter);

// Backing memory for missing_unused_key_names values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ public class TF_WhileParams extends Pointer {

// Unique null-terminated name for this while loop. This is used as a prefix
// for created operations.
@MemberGetter public native @Cast("const char*") BytePointer name();
public native @Cast("const char*") BytePointer name(); public native TF_WhileParams name(BytePointer setter);
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static class SliceInfo extends Pointer {

public native @ByRef TensorSlice slice(); public native SliceInfo slice(TensorSlice setter);
public native @StdString BytePointer tag(); public native SliceInfo tag(BytePointer setter);
@MemberGetter public native @Const FloatPointer data();
public native @Const FloatPointer data(); public native SliceInfo data(FloatPointer setter);
public native @Cast("tensorflow::int64") long num_floats(); public native SliceInfo num_floats(long setter);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ public abstract class AbstractTensor extends Pointer implements Indexable {
static { Loader.load(); }
public AbstractTensor(Pointer p) { super(p); }

public static Tensor create(float[] data, long... shape) { return create(data, new TensorShape(shape)); }
public static Tensor create(double[] data, long... shape) { return create(data, new TensorShape(shape)); }
public static Tensor create(int[] data, long... shape) { return create(data, new TensorShape(shape)); }
public static Tensor create(short[] data, long... shape) { return create(data, new TensorShape(shape)); }
public static Tensor create(byte[] data, long... shape) { return create(data, new TensorShape(shape)); }
public static Tensor create(long[] data, long... shape) { return create(data, new TensorShape(shape)); }
public static Tensor create(String[] data, long... shape) { return create(data, new TensorShape(shape)); }

public static Tensor create(float[] data, TensorShape shape) { Tensor t = new Tensor(DT_FLOAT, shape); FloatBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(double[] data, TensorShape shape) { Tensor t = new Tensor(DT_DOUBLE, shape); DoubleBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(int[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT32, shape); IntBuffer b = t.createBuffer(); b.put(data); return t; }
Expand Down

0 comments on commit 06c1b76

Please sign in to comment.