Skip to content

Commit

Permalink
* Include graph_runner.h and shape_refiner.h for TensorFlow
Browse files Browse the repository at this point in the history
  • Loading branch information
saudet committed Sep 11, 2018
1 parent 906a8ae commit a95b465
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* Bundle native resources (header files and import libraries) of MKL-DNN
* Make MSBuild compile more efficiently on multiple processors ([pull #599](https://github.com/bytedeco/javacpp-presets/pull/599))
* Add samples for Clang ([pull #598](https://github.com/bytedeco/javacpp-presets/pull/598))
* Include `python_api.h` and enable Python API for TensorFlow ([issue #602](https://github.com/bytedeco/javacpp-presets/issues/602))
* Include `graph_runner.h`, `shape_refiner`, `python_api.h`, and enable Python API for TensorFlow ([issue #602](https://github.com/bytedeco/javacpp-presets/issues/602))
* Add presets for Spinnaker 1.15.x ([pull #553](https://github.com/bytedeco/javacpp-presets/pull/553)), CPython 3.6.x, ONNX 1.2.2 ([pull #547](https://github.com/bytedeco/javacpp-presets/pull/547))
* Define `std::vector<tensorflow::OpDef>` type to `OpDefVector` for TensorFlow
* Link HDF5 with zlib on Windows also ([issue deeplearning4j/deeplearning4j#6017](https://github.com/deeplearning4j/deeplearning4j/issues/6017))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@
"tensorflow/core/common_runtime/process_function_library_runtime.h",
"tensorflow/core/graph/graph.h",
"tensorflow/core/graph/tensor_id.h",
"tensorflow/core/common_runtime/graph_runner.h",
"tensorflow/core/common_runtime/shape_refiner.h",
"tensorflow/core/framework/node_def_builder.h",
"tensorflow/core/framework/node_def_util.h",
"tensorflow/core/framework/selective_registration.h",
Expand Down Expand Up @@ -315,6 +317,8 @@
"tensorflow/core/common_runtime/process_function_library_runtime.h",
"tensorflow/core/graph/graph.h",
"tensorflow/core/graph/tensor_id.h",
"tensorflow/core/common_runtime/graph_runner.h",
"tensorflow/core/common_runtime/shape_refiner.h",
"tensorflow/core/framework/node_def_builder.h",
"tensorflow/core/framework/node_def_util.h",
"tensorflow/core/framework/selective_registration.h",
Expand Down Expand Up @@ -600,7 +604,7 @@ public void map(InfoMap infoMap) {
.put(new Info("tensorflow::gtl::FlatMap<TF_Session*,tensorflow::string>").pointerTypes("TF_SessionStringMap").define())

// Skip composite op scopes bc: call to implicitly-deleted default constructor of '::tensorflow::CompositeOpScopes'
.put(new Info("tensorflow::CompositeOpScopes").skip())
.put(new Info("tensorflow::CompositeOpScopes", "tensorflow::ExtendedInferenceContext").skip())

// Fixed shape inference
.put(new Info("std::vector<const tensorflow::Tensor*>").pointerTypes("ConstTensorPtrVector").define())
Expand Down
199 changes: 192 additions & 7 deletions tensorflow/src/main/java/org/bytedeco/javacpp/tensorflow.java
Original file line number Diff line number Diff line change
Expand Up @@ -21817,13 +21817,6 @@ private native void allocate(@Const KernelDef kernel_def, @StringPiece String ke
// #include "tensorflow/core/lib/core/status.h"
// #include "tensorflow/core/lib/gtl/inlined_vector.h"
// #include "tensorflow/core/platform/macros.h"

@Namespace("tensorflow") @Opaque public static class ShapeRefiner extends Pointer {
/** Empty constructor. Calls {@code super((Pointer)null)}. */
public ShapeRefiner() { super((Pointer)null); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ShapeRefiner(Pointer p) { super(p); }
}
@Namespace("tensorflow") @Opaque public static class ShapeRefinerTest extends Pointer {
/** Empty constructor. Calls {@code super((Pointer)null)}. */
public ShapeRefinerTest() { super((Pointer)null); }
Expand Down Expand Up @@ -30606,6 +30599,198 @@ public static class Hasher extends Pointer {
// #endif // TENSORFLOW_GRAPH_TENSOR_ID_H_


// Parsed from tensorflow/core/common_runtime/graph_runner.h

/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_
// #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_

// #include <memory>
// #include <string>
// #include <vector>

// #include "tensorflow/core/common_runtime/device.h"
// #include "tensorflow/core/framework/function.h"
// #include "tensorflow/core/framework/tensor.h"
// #include "tensorflow/core/graph/graph.h"
// #include "tensorflow/core/lib/core/status.h"
// #include "tensorflow/core/platform/env.h"

// GraphRunner takes a Graph, some inputs to feed, and some outputs
// to fetch and executes the graph required to feed and fetch the
// inputs and outputs.
//
// This class is only meant for internal use where one needs to
// partially evaluate inexpensive nodes in a graph, such as for shape
// inference or for constant folding. Because of its limited, simple
// use-cases, it executes all computation on the given device (CPU by default)
// and is not meant to be particularly lightweight, fast, or efficient.
@Namespace("tensorflow") @NoOffset public static class GraphRunner extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public GraphRunner(Pointer p) { super(p); }

// REQUIRES: `env` is not nullptr.
public GraphRunner(Env env) { super((Pointer)null); allocate(env); }
private native void allocate(Env env);
// REQUIRES: 'device' is not nullptr. Not owned.
public GraphRunner(Device device) { super((Pointer)null); allocate(device); }
private native void allocate(Device device);

// Function semantics for `inputs`, `output_names` and `outputs`
// matches those from Session::Run().
//
// NOTE: The output tensors share lifetime with the GraphRunner, and could
// be destroyed once the GraphRunner is destroyed.
//
// REQUIRES: `graph`, `env`, and `outputs` are not nullptr.
// `function_library` may be nullptr.
public native @ByVal Status Run(Graph graph, FunctionLibraryRuntime function_library,
@Cast("const tensorflow::GraphRunner::NamedTensorList*") @ByRef StringTensorPairVector inputs,
@Const @ByRef StringVector output_names,
TensorVector outputs);
}

// namespace tensorflow

// #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_


// Parsed from tensorflow/core/common_runtime/shape_refiner.h

/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
// #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_

// #include <vector>

// #include "tensorflow/core/common_runtime/graph_runner.h"
// #include "tensorflow/core/framework/function.pb.h"
// #include "tensorflow/core/framework/shape_inference.h"
// #include "tensorflow/core/graph/graph.h"
// #include "tensorflow/core/lib/core/status.h"
// #include "tensorflow/core/platform/macros.h"


// This class stores extra inference information in addition to
// InferenceContext, such as inference tree for user-defined functions and node
// input and output types.

// ShapeRefiner performs shape inference for TensorFlow Graphs. It is
// responsible for instantiating InferenceContext objects for each
// Node in the Graph, and providing/storing the 'input_tensor' Tensors
// used by Shape Inference functions, when available at graph
// construction time.
@Namespace("tensorflow") @NoOffset public static class ShapeRefiner extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ShapeRefiner(Pointer p) { super(p); }

public ShapeRefiner(int graph_def_version, @Const OpRegistryInterface ops) { super((Pointer)null); allocate(graph_def_version, ops); }
private native void allocate(int graph_def_version, @Const OpRegistryInterface ops);

// Same as ShapeRefiner(versions.producer(), ops)
public ShapeRefiner(@Const @ByRef VersionDef versions, @Const OpRegistryInterface ops) { super((Pointer)null); allocate(versions, ops); }
private native void allocate(@Const @ByRef VersionDef versions, @Const OpRegistryInterface ops);

// Performs validation of 'node' and runs 'node's shape function,
// storing its shape outputs.
//
// All inputs of 'node' must be added to ShapeRefiner prior to
// adding 'node'.
//
// Returns an error if:
// - the shape function for 'node' was not registered.
// - 'node' was added before its inputs.
// - The shape inference function returns an error.
public native @ByVal Status AddNode(@Const Node node);

// Sets 'node's 'output_port' output to have shape 'shape'.
//
// Returns an error if 'node' was not previously added to this
// object, if 'output_port' is invalid, or if 'shape' is
// not compatible with the existing shape of the output.
public native @ByVal Status SetShape(@Const Node node, int output_port,
@ByVal ShapeHandle shape);

// Update the input shapes of node in case the shapes of the fan-ins of 'node'
// have themselves been modified (For example, in case of incremental shape
// refinement). If 'relax' is true, a new shape with the broadest set of
// information will be set as the new input (see InferenceContext::RelaxInput
// for full details and examples). Sets refined to true if any shapes have
// changed (in their string representations). Note that shapes may have been
// updated to newer versions (but with identical string representations) even
// if <*refined> is set to false.
public native @ByVal Status UpdateNode(@Const Node node, @Cast("bool") boolean relax, @Cast("bool*") BoolPointer refined);
public native @ByVal Status UpdateNode(@Const Node node, @Cast("bool") boolean relax, @Cast("bool*") boolean... refined);

// Returns the InferenceContext for 'node', if present.
public native InferenceContext GetContext(@Const Node node);

// Returns the ExtendedInferenceContext for 'node', if present.

// Getters and setters for graph_def_version_.
public native int graph_def_version();
public native void set_graph_def_version(int version);

public native void set_require_shape_inference_fns(@Cast("bool") boolean require_shape_inference_fns);
public native void set_disable_constant_propagation(@Cast("bool") boolean disable);

// Set function library to enable function shape inference.
// Without function library, function inference always yields unknown shapes.
// With this enabled, shape inference can take more time since it descends
// into all function calls. It doesn't do inference once for each function
// definition, but once for each function call.
// The function library must outlive the shape refiner.
public native void set_function_library_for_shape_inference(
@Const FunctionLibraryDefinition lib);

public native @Cast("bool") boolean function_shape_inference_supported();

// Call this to keep nested shapes information for user-defined functions:
// nested inferences will be available on the ExtendedInferenceContext for
// each function node, forming a tree of shape inferences corresponding to the
// tree of nested function calls. By default this setting is disabled, and
// only the shapes for the top-level function node will be reported on the
// InferenceContext for each function node, to reduce memory usage.
//
// This flag has no effect when the function inference is not enabled via
// set_function_library_for_shape_inference.
public native void set_keep_nested_shape_inferences();
}

// namespace tensorflow

// #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_


// Parsed from tensorflow/core/framework/node_def_builder.h

/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Expand Down

0 comments on commit a95b465

Please sign in to comment.