Skip to content

Commit

Permalink
Fix mapping of unique_ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
HGuillemet committed Jan 19, 2024
1 parent b977780 commit e7393d9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@ public class Interpreter extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Interpreter(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public Interpreter(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public Interpreter position(long position) {
return (Interpreter)super.position(position);
}
@Override public Interpreter getPointer(long i) {
return new Interpreter((Pointer)this).offsetAddress(i);
}

// Instantiate an interpreter. All errors associated with reading and
// processing this model will be forwarded to the error_reporter object.
Expand All @@ -32,9 +23,9 @@ public class Interpreter extends Pointer {
// WARNING: Use of this constructor outside of an InterpreterBuilder is not
// recommended.
public Interpreter(ErrorReporter error_reporter/*=tflite::DefaultErrorReporter()*/) { super((Pointer)null); allocate(error_reporter); }
private native void allocate(ErrorReporter error_reporter/*=tflite::DefaultErrorReporter()*/);
@UniquePtr @Name("std::make_unique<tflite::impl::Interpreter>") private native void allocate(ErrorReporter error_reporter/*=tflite::DefaultErrorReporter()*/);
public Interpreter() { super((Pointer)null); allocate(); }
private native void allocate();
@UniquePtr @Name("std::make_unique<tflite::impl::Interpreter>") private native void allocate();

// Interpreters are not copyable as they have non-trivial memory semantics.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public Subgraph(ErrorReporter error_reporter,
@Cast("tflite::resource::ResourceIDMap*") StringIntMap resource_ids,
@Cast("tflite::resource::InitializationStatusMap*") IntResourceBaseMap initialization_status_map,
int subgraph_index/*=kInvalidSubgraphIndex*/) { super((Pointer)null); allocate(error_reporter, external_contexts, subgraphs, resources, resource_ids, initialization_status_map, subgraph_index); }
private native void allocate(ErrorReporter error_reporter,
@UniquePtr @Name("std::make_unique<tflite::Subgraph>") private native void allocate(ErrorReporter error_reporter,
@Cast("TfLiteExternalContext**") PointerPointer external_contexts,
SubgraphVector subgraphs,
@Cast("tflite::resource::ResourceMap*") IntResourceBaseMap resources,
Expand All @@ -38,7 +38,7 @@ public Subgraph(ErrorReporter error_reporter,
@Cast("tflite::resource::ResourceMap*") IntResourceBaseMap resources,
@Cast("tflite::resource::ResourceIDMap*") StringIntMap resource_ids,
@Cast("tflite::resource::InitializationStatusMap*") IntResourceBaseMap initialization_status_map) { super((Pointer)null); allocate(error_reporter, external_contexts, subgraphs, resources, resource_ids, initialization_status_map); }
private native void allocate(ErrorReporter error_reporter,
@UniquePtr @Name("std::make_unique<tflite::Subgraph>") private native void allocate(ErrorReporter error_reporter,
@ByPtrPtr TfLiteExternalContext external_contexts,
SubgraphVector subgraphs,
@Cast("tflite::resource::ResourceMap*") IntResourceBaseMap resources,
Expand All @@ -51,7 +51,7 @@ public Subgraph(ErrorReporter error_reporter,
@Cast("tflite::resource::ResourceIDMap*") StringIntMap resource_ids,
@Cast("tflite::resource::InitializationStatusMap*") IntResourceBaseMap initialization_status_map,
int subgraph_index/*=kInvalidSubgraphIndex*/) { super((Pointer)null); allocate(error_reporter, external_contexts, subgraphs, resources, resource_ids, initialization_status_map, subgraph_index); }
private native void allocate(ErrorReporter error_reporter,
@UniquePtr @Name("std::make_unique<tflite::Subgraph>") private native void allocate(ErrorReporter error_reporter,
@ByPtrPtr TfLiteExternalContext external_contexts,
SubgraphVector subgraphs,
@Cast("tflite::resource::ResourceMap*") IntResourceBaseMap resources,
Expand All @@ -63,7 +63,7 @@ private native void allocate(ErrorReporter error_reporter,

// Subgraphs should be movable but not copyable.
public Subgraph(@StdMove Subgraph arg0) { super((Pointer)null); allocate(arg0); }
private native void allocate(@StdMove Subgraph arg0);
@UniquePtr @Name("std::make_unique<tflite::Subgraph>") private native void allocate(@StdMove Subgraph arg0);


// Provide a list of tensor indexes that are inputs to the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ public void map(InfoMap infoMap) {
.put(new Info("tflite::impl::Interpreter::typed_output_tensor<double>").javaNames("typed_output_tensor_double"))
.put(new Info("tflite::impl::Interpreter::typed_output_tensor<bool>").javaNames("typed_output_tensor_bool"))
.put(new Info("tflite::impl::Interpreter::typed_output_tensor<TfLiteFloat16>").javaNames("typed_input_tensor_float16"))

// Classes passed to some native functions as unique_ptr and that can be allocated Java-side
.put(new Info("tflite::impl::Interpreter::Interpreter").annotations("@UniquePtr", "@Name(\"std::make_unique<tflite::impl::Interpreter>\")"))
.put(new Info("tflite::Subgraph::Subgraph").annotations("@UniquePtr", "@Name(\"std::make_unique<tflite::Subgraph>\")"))
;
}
}

0 comments on commit e7393d9

Please sign in to comment.