Skip to content

Commit

Permalink
* Add methods overloaded with PointerPointer for MKL-DNN (issue by…
Browse files Browse the repository at this point in the history
  • Loading branch information
saudet committed Aug 23, 2018
1 parent d6dc945 commit 6a8d613
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

* Add methods overloaded with `PointerPointer` for MKL-DNN ([issue bytedeco/javacpp#251](https://github.com/bytedeco/javacpp/issues/251))
* Bundle native resources (header files and import libraries) of MKL-DNN
* Make MSBuild compile more efficiently on multiple processors ([pull #599](https://github.com/bytedeco/javacpp-presets/pull/599))
* Add samples for Clang ([pull #598](https://github.com/bytedeco/javacpp-presets/pull/598))
Expand Down
102 changes: 98 additions & 4 deletions mkl-dnn/src/main/java/org/bytedeco/javacpp/mkldnn.java
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,10 @@ inputs and outputs memory (bytes) */
@ByPtrPtr mkldnn_primitive_desc_iterator iterator,
const_mkldnn_op_desc_t op_desc, mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_iterator_create(
@Cast("mkldnn_primitive_desc_iterator_t*") PointerPointer iterator,
const_mkldnn_op_desc_t op_desc, mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);

/** Creates a primitive descriptor \p iterator for given \p op_desc, \p attr,
* \p engine, and optionally a hint primitive descriptor from forward
Expand All @@ -1665,6 +1669,11 @@ inputs and outputs memory (bytes) */
const_mkldnn_op_desc_t op_desc, @Const mkldnn_primitive_attr attr,
mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_iterator_create_v2(
@Cast("mkldnn_primitive_desc_iterator_t*") PointerPointer iterator,
const_mkldnn_op_desc_t op_desc, @Const mkldnn_primitive_attr attr,
mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);

/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no
* more primitive descriptors are available */
Expand All @@ -1691,6 +1700,10 @@ public static native mkldnn_primitive_desc mkldnn_primitive_desc_iterator_fetch(
@ByPtrPtr mkldnn_primitive_desc primitive_desc,
const_mkldnn_op_desc_t op_desc, mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer primitive_desc,
const_mkldnn_op_desc_t op_desc, mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);

/** Creates a \p primitive_desc using \p op_desc, \p attr, \p engine, and
* optionally a hint primitive descriptor from forward propagation. The call is
Expand All @@ -1701,11 +1714,19 @@ public static native mkldnn_primitive_desc mkldnn_primitive_desc_iterator_fetch(
const_mkldnn_op_desc_t op_desc, @Const mkldnn_primitive_attr attr,
mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_create_v2(
@Cast("mkldnn_primitive_desc_t*") PointerPointer primitive_desc,
const_mkldnn_op_desc_t op_desc, @Const mkldnn_primitive_attr attr,
mkldnn_engine engine,
@Const mkldnn_primitive_desc hint_forward_primitive_desc);

/** Makes a copy of a \p primitive_desc. */
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_clone(
@ByPtrPtr mkldnn_primitive_desc primitive_desc,
@Const mkldnn_primitive_desc existing_primitive_desc);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_clone(
@Cast("mkldnn_primitive_desc_t*") PointerPointer primitive_desc,
@Const mkldnn_primitive_desc existing_primitive_desc);

/** Returns a constant reference to the attribute of a \p primitive_desc.
*
Expand All @@ -1718,6 +1739,9 @@ public static native mkldnn_primitive_desc mkldnn_primitive_desc_iterator_fetch(
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_get_attr(
@Const mkldnn_primitive_desc primitive_desc,
@Const @ByPtrPtr mkldnn_primitive_attr attr);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_get_attr(
@Const mkldnn_primitive_desc primitive_desc,
@Cast("const_mkldnn_primitive_attr_t*") PointerPointer attr);

/** Deletes a \p primitive_desc. */
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_desc_destroy(
Expand Down Expand Up @@ -1757,6 +1781,11 @@ public static native int mkldnn_primitive_desc_query_s32(
@Const mkldnn_primitive_desc primitive_desc,
@Const mkldnn_primitive_at_t inputs,
@Const @ByPtrPtr mkldnn_primitive outputs);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_create(
@Cast("mkldnn_primitive_t*") PointerPointer primitive,
@Const mkldnn_primitive_desc primitive_desc,
@Const mkldnn_primitive_at_t inputs,
@Cast("const_mkldnn_primitive_t*") PointerPointer outputs);

/** Retrieves a reference to the \p primitive_desc descriptor of given \p
* primitive.
Expand All @@ -1767,6 +1796,9 @@ public static native int mkldnn_primitive_desc_query_s32(
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_get_primitive_desc(
@Const mkldnn_primitive primitive,
@Const @ByPtrPtr mkldnn_primitive_desc primitive_desc);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_get_primitive_desc(
@Const mkldnn_primitive primitive,
@Cast("const_mkldnn_primitive_desc_t*") PointerPointer primitive_desc);

/** For a \p primitive, returns \p input at the \p index position. */
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_get_input_at(
Expand All @@ -1777,6 +1809,9 @@ public static native int mkldnn_primitive_desc_query_s32(
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_get_output(
@Const mkldnn_primitive primitive, @Cast("size_t") long index,
@Const @ByPtrPtr mkldnn_primitive output);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_get_output(
@Const mkldnn_primitive primitive, @Cast("size_t") long index,
@Cast("const_mkldnn_primitive_t*") PointerPointer output);

/** Deletes a \p primitive. */
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_destroy(
Expand Down Expand Up @@ -1804,11 +1839,16 @@ public static native int mkldnn_primitive_desc_query_s32(
*/
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_attr_create(
@ByPtrPtr mkldnn_primitive_attr attr);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_attr_create(
@Cast("mkldnn_primitive_attr_t*") PointerPointer attr);

/** Makes a copy of an \p existing_attr. */
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_attr_clone(
@ByPtrPtr mkldnn_primitive_attr attr,
@Const mkldnn_primitive_attr existing_attr);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_attr_clone(
@Cast("mkldnn_primitive_attr_t*") PointerPointer attr,
@Const mkldnn_primitive_attr existing_attr);

/** Deletes an \p attr. */
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_attr_destroy(
Expand Down Expand Up @@ -1914,6 +1954,8 @@ public static native int mkldnn_primitive_desc_query_s32(
*/
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_attr_get_post_ops(
@Const mkldnn_primitive_attr attr, @Const @ByPtrPtr mkldnn_post_ops post_ops);
public static native @Cast("mkldnn_status_t") int mkldnn_primitive_attr_get_post_ops(
@Const mkldnn_primitive_attr attr, @Cast("const_mkldnn_post_ops_t*") PointerPointer post_ops);

/** Sets configured \p post_ops to an attribute \p attr for future use (when
* primitive descriptor is being created.
Expand All @@ -1933,6 +1975,7 @@ public static native int mkldnn_primitive_desc_query_s32(

/** Creates an empty sequence of post operations \p post_ops. */
public static native @Cast("mkldnn_status_t") int mkldnn_post_ops_create(@ByPtrPtr mkldnn_post_ops post_ops);
public static native @Cast("mkldnn_status_t") int mkldnn_post_ops_create(@Cast("mkldnn_post_ops_t*") PointerPointer post_ops);

/** Deletes a \p post_ops sequence. */
public static native @Cast("mkldnn_status_t") int mkldnn_post_ops_destroy(mkldnn_post_ops post_ops);
Expand Down Expand Up @@ -2037,6 +2080,9 @@ public static native int mkldnn_primitive_desc_query_s32(
public static native @Cast("mkldnn_status_t") int mkldnn_memory_primitive_desc_create(
@ByPtrPtr mkldnn_primitive_desc memory_primitive_desc,
@Const mkldnn_memory_desc_t memory_desc, mkldnn_engine engine);
public static native @Cast("mkldnn_status_t") int mkldnn_memory_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer memory_primitive_desc,
@Const mkldnn_memory_desc_t memory_desc, mkldnn_engine engine);

/** Creates a \p view_primitive_desc for a given \p memory_primitive_desc, with
* \p dims sizes and \p offset offsets. May fail if layout used does not allow
Expand All @@ -2046,13 +2092,25 @@ public static native int mkldnn_primitive_desc_query_s32(
@Const mkldnn_primitive_desc memory_primitive_desc,
@Const IntPointer dims, @Const IntPointer offsets);
public static native @Cast("mkldnn_status_t") int mkldnn_view_primitive_desc_create(
@ByPtrPtr mkldnn_primitive_desc view_primitive_desc,
@Cast("mkldnn_primitive_desc_t*") PointerPointer view_primitive_desc,
@Const mkldnn_primitive_desc memory_primitive_desc,
@Const IntBuffer dims, @Const IntBuffer offsets);
public static native @Cast("mkldnn_status_t") int mkldnn_view_primitive_desc_create(
@ByPtrPtr mkldnn_primitive_desc view_primitive_desc,
@Const mkldnn_primitive_desc memory_primitive_desc,
@Const int[] dims, @Const int[] offsets);
public static native @Cast("mkldnn_status_t") int mkldnn_view_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer view_primitive_desc,
@Const mkldnn_primitive_desc memory_primitive_desc,
@Const IntPointer dims, @Const IntPointer offsets);
public static native @Cast("mkldnn_status_t") int mkldnn_view_primitive_desc_create(
@ByPtrPtr mkldnn_primitive_desc view_primitive_desc,
@Const mkldnn_primitive_desc memory_primitive_desc,
@Const IntBuffer dims, @Const IntBuffer offsets);
public static native @Cast("mkldnn_status_t") int mkldnn_view_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer view_primitive_desc,
@Const mkldnn_primitive_desc memory_primitive_desc,
@Const int[] dims, @Const int[] offsets);

/** Compares two descriptors of memory primitives.
* @return 1 if the descriptors are the same.
Expand Down Expand Up @@ -2095,6 +2153,10 @@ public static native int mkldnn_memory_primitive_desc_equal(
@ByPtrPtr mkldnn_primitive_desc reorder_primitive_desc,
@Const mkldnn_primitive_desc input,
@Const mkldnn_primitive_desc output);
public static native @Cast("mkldnn_status_t") int mkldnn_reorder_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer reorder_primitive_desc,
@Const mkldnn_primitive_desc input,
@Const mkldnn_primitive_desc output);

/** Initializes a \p reorder_primitive_desc using an \p attr attribute and
* descriptors of \p input and \p output memory primitives. */
Expand All @@ -2103,6 +2165,11 @@ public static native int mkldnn_memory_primitive_desc_equal(
@Const mkldnn_primitive_desc input,
@Const mkldnn_primitive_desc output,
@Const mkldnn_primitive_attr attr);
public static native @Cast("mkldnn_status_t") int mkldnn_reorder_primitive_desc_create_v2(
@Cast("mkldnn_primitive_desc_t*") PointerPointer reorder_primitive_desc,
@Const mkldnn_primitive_desc input,
@Const mkldnn_primitive_desc output,
@Const mkldnn_primitive_attr attr);

/** \} */

Expand All @@ -2119,6 +2186,10 @@ public static native int mkldnn_memory_primitive_desc_equal(
@ByPtrPtr mkldnn_primitive_desc concat_primitive_desc,
@Const mkldnn_memory_desc_t output_desc, int n, int concat_dimension,
@Const @ByPtrPtr mkldnn_primitive_desc input_pds);
public static native @Cast("mkldnn_status_t") int mkldnn_concat_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer concat_primitive_desc,
@Const mkldnn_memory_desc_t output_desc, int n, int concat_dimension,
@Cast("const_mkldnn_primitive_desc_t*") PointerPointer input_pds);

// #if 0
// #endif
Expand All @@ -2139,13 +2210,25 @@ public static native int mkldnn_memory_primitive_desc_equal(
@Const mkldnn_memory_desc_t output_desc, int n, @Const FloatPointer scales,
@Const @ByPtrPtr mkldnn_primitive_desc input_pds);
public static native @Cast("mkldnn_status_t") int mkldnn_sum_primitive_desc_create(
@ByPtrPtr mkldnn_primitive_desc sum_primitive_desc,
@Cast("mkldnn_primitive_desc_t*") PointerPointer sum_primitive_desc,
@Const mkldnn_memory_desc_t output_desc, int n, @Const FloatBuffer scales,
@Const @ByPtrPtr mkldnn_primitive_desc input_pds);
@Cast("const_mkldnn_primitive_desc_t*") PointerPointer input_pds);
public static native @Cast("mkldnn_status_t") int mkldnn_sum_primitive_desc_create(
@ByPtrPtr mkldnn_primitive_desc sum_primitive_desc,
@Const mkldnn_memory_desc_t output_desc, int n, @Const float[] scales,
@Const @ByPtrPtr mkldnn_primitive_desc input_pds);
public static native @Cast("mkldnn_status_t") int mkldnn_sum_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer sum_primitive_desc,
@Const mkldnn_memory_desc_t output_desc, int n, @Const FloatPointer scales,
@Cast("const_mkldnn_primitive_desc_t*") PointerPointer input_pds);
public static native @Cast("mkldnn_status_t") int mkldnn_sum_primitive_desc_create(
@ByPtrPtr mkldnn_primitive_desc sum_primitive_desc,
@Const mkldnn_memory_desc_t output_desc, int n, @Const FloatBuffer scales,
@Const @ByPtrPtr mkldnn_primitive_desc input_pds);
public static native @Cast("mkldnn_status_t") int mkldnn_sum_primitive_desc_create(
@Cast("mkldnn_primitive_desc_t*") PointerPointer sum_primitive_desc,
@Const mkldnn_memory_desc_t output_desc, int n, @Const float[] scales,
@Cast("const_mkldnn_primitive_desc_t*") PointerPointer input_pds);

/** \} */

Expand Down Expand Up @@ -2966,6 +3049,8 @@ public static native int mkldnn_rnn_cell_get_states_count(
/** Creates an \p engine of particular \p kind and \p index. */
public static native @Cast("mkldnn_status_t") int mkldnn_engine_create(@ByPtrPtr mkldnn_engine engine,
@Cast("mkldnn_engine_kind_t") int kind, @Cast("size_t") long index);
public static native @Cast("mkldnn_status_t") int mkldnn_engine_create(@Cast("mkldnn_engine_t*") PointerPointer engine,
@Cast("mkldnn_engine_kind_t") int kind, @Cast("size_t") long index);

/** Returns the kind of an \p engine. */
public static native @Cast("mkldnn_status_t") int mkldnn_engine_get_kind(mkldnn_engine engine,
Expand All @@ -2986,24 +3071,33 @@ public static native int mkldnn_rnn_cell_get_states_count(
/** Creates an execution \p stream of \p stream_kind. */
public static native @Cast("mkldnn_status_t") int mkldnn_stream_create(@ByPtrPtr mkldnn_stream stream,
@Cast("mkldnn_stream_kind_t") int stream_kind);
public static native @Cast("mkldnn_status_t") int mkldnn_stream_create(@Cast("mkldnn_stream_t*") PointerPointer stream,
@Cast("mkldnn_stream_kind_t") int stream_kind);

/** Submits \p primitives to an execution \p stream. The number of primitives
* is \p n. All or none of the primitives can be lazy. In case of an error,
* returns the offending \p error_primitive if it is not \c NULL. */
public static native @Cast("mkldnn_status_t") int mkldnn_stream_submit(mkldnn_stream stream,
@Cast("size_t") long n, @ByPtrPtr mkldnn_primitive primitives,
@ByPtrPtr mkldnn_primitive error_primitive);
public static native @Cast("mkldnn_status_t") int mkldnn_stream_submit(mkldnn_stream stream,
@Cast("size_t") long n, @Cast("mkldnn_primitive_t*") PointerPointer primitives,
@Cast("mkldnn_primitive_t*") PointerPointer error_primitive);

/** Waits for all primitives in the execution \p stream to finish. Returns
* immediately if \p block is zero. In case of an error, returns
* the offending \p error_primitive if it is not \c NULL. */
public static native @Cast("mkldnn_status_t") int mkldnn_stream_wait(mkldnn_stream stream,
int block, @ByPtrPtr mkldnn_primitive error_primitive);
public static native @Cast("mkldnn_status_t") int mkldnn_stream_wait(mkldnn_stream stream,
int block, @Cast("mkldnn_primitive_t*") PointerPointer error_primitive);

/** Reruns all the primitives within the \p stream. In case of an error,
* returns the offending \p error_primitive if it is not \c NULL. */
public static native @Cast("mkldnn_status_t") int mkldnn_stream_rerun(mkldnn_stream stream,
@ByPtrPtr mkldnn_primitive error_primitive);
public static native @Cast("mkldnn_status_t") int mkldnn_stream_rerun(mkldnn_stream stream,
@Cast("mkldnn_primitive_t*") PointerPointer error_primitive);

/** Destroys an execution \p stream. */
public static native @Cast("mkldnn_status_t") int mkldnn_stream_destroy(mkldnn_stream stream);
Expand Down Expand Up @@ -3374,7 +3468,7 @@ public static native void wrap_c_api(@Cast("mkldnn_status_t") int status,
@StdString BytePointer message);
public static native void wrap_c_api(@Cast("mkldnn_status_t") int status,
@StdString String message,
@ByPtrPtr mkldnn_primitive error_primitive/*=0*/);
@Cast("mkldnn_primitive_t*") PointerPointer error_primitive/*=0*/);
public static native void wrap_c_api(@Cast("mkldnn_status_t") int status,
@StdString String message);
}
Expand Down
Loading

0 comments on commit 6a8d613

Please sign in to comment.