Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OutputVector class and a more complex example to Tensorflow #563

Merged
merged 2 commits into from
May 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

* Add more samples for TensorFlow including a complete training example ([pull #563](https://github.com/bytedeco/javacpp-presets/pull/563))
* Add helper for `PIX`, `FPIX`, and `DPIX` of Leptonica, facilitating access to image data of Tesseract ([issue #517](https://github.com/bytedeco/javacpp-presets/issues/517))
* Add presets for the NVBLAS, NVGRAPH, NVRTC, and NVML modules of CUDA ([issue deeplearning4j/nd4j#2895](https://github.com/deeplearning4j/nd4j/issues/2895))
* Link OpenBLAS with `-Wl,-z,noexecstack` on `linux-armhf` as required by the JDK ([issue deeplearning4j/libnd4j#700](https://github.com/deeplearning4j/libnd4j/issues/700))
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/samples/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>org.bytedeco.javacpp-presets</groupId>
<artifactId>tensorflow-samples</artifactId>
<version>1.8.0-1.4.2-SNAPSHOT</version>
<name>JavaCPP Presets Samples for TensorFlow</name>

<properties>
<maven.compiler.target>1.7</maven.compiler.target>
<maven.compiler.source>1.7</maven.compiler.source>
</properties>

<dependencies>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tensorflow-platform</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.bytedeco.javacpp.samples.tensorflow;

import static org.bytedeco.javacpp.tensorflow.Const;
import static org.bytedeco.javacpp.tensorflow.InitMain;
import static org.bytedeco.javacpp.tensorflow.TF_CHECK_OK;

import java.nio.IntBuffer;

import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.tensorflow.AddN;
import org.bytedeco.javacpp.tensorflow.GraphDef;
import org.bytedeco.javacpp.tensorflow.Input;
import org.bytedeco.javacpp.tensorflow.InputList;
import org.bytedeco.javacpp.tensorflow.L2Loss;
import org.bytedeco.javacpp.tensorflow.Output;
import org.bytedeco.javacpp.tensorflow.OutputVector;
import org.bytedeco.javacpp.tensorflow.Scope;
import org.bytedeco.javacpp.tensorflow.Session;
import org.bytedeco.javacpp.tensorflow.SessionOptions;
import org.bytedeco.javacpp.tensorflow.StringTensorPairVector;
import org.bytedeco.javacpp.tensorflow.StringVector;
import org.bytedeco.javacpp.tensorflow.Tensor;
import org.bytedeco.javacpp.tensorflow.TensorShape;
import org.bytedeco.javacpp.tensorflow.TensorVector;

/**
* Showcase the usage of OutputVector and the AddN operator.
*
* @author Nico Hezel
*/
public class AddNExample {

public static void main(String[] args) {

// Load all javacpp-preset classes and native libraries
Loader.load(org.bytedeco.javacpp.tensorflow.class);

// Platform-specific initialization routine
InitMain("trainer", (int[])null, null);

// Create a new empty graph
Scope scope = Scope.NewRootScope();

// (2,1) matrix of ones, sixes and tens
TensorShape shape = new TensorShape(2, 1);
Output ones = Const(scope.WithOpName("ones"), 1, shape);
Output sixes = Const(scope.WithOpName("sixes"), 6, shape);
Output tens = Const(scope.WithOpName("tens"), 10, shape);

// Adding all matrices element-wise
OutputVector ov = new OutputVector(ones, sixes, tens);
InputList inputList = new InputList(ov);
AddN add = new AddN(scope.WithOpName("add"), inputList);

// Build a graph definition object
GraphDef def = new GraphDef();
TF_CHECK_OK(scope.ToGraphDef(def));

// Creates a session.
SessionOptions options = new SessionOptions();
try(final Session session = new Session(options)) {

// Create the graph to be used for the session.
TF_CHECK_OK(session.Create(def));

// Input and output of a single session run.
StringTensorPairVector input_feed = new StringTensorPairVector();
StringVector output_tensor_name = new StringVector("add:0");
StringVector target_tensor_name = new StringVector();
TensorVector outputs = new TensorVector();

// Run the session once
TF_CHECK_OK(session.Run(input_feed, output_tensor_name, target_tensor_name, outputs));

// Print the add-output
for (Tensor output : outputs.get()) {
IntBuffer y_flat = output.createBuffer();
for (int i = 0; i < output.NumElements(); i++)
System.out.println(y_flat.get(i));
}
}
}
}
Loading