Skip to content

Commit

Permalink
Fix manually mapped functions in presets for PyTorch (pull #1391)
Browse files Browse the repository at this point in the history
  • Loading branch information
HGuillemet authored Aug 8, 2023
1 parent 251f317 commit bdba6a2
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.bytedeco.pytorch.functions;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FunctionPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
Expand All @@ -25,5 +26,5 @@ protected NamedModuleApplyFunction() {

private native void allocate();

public native void call(@Const @StdString @ByRef String name, @ByRef Module m);
public native void call(@Const @StdString BytePointer name, @ByRef Module m);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.bytedeco.pytorch.functions;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FunctionPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
Expand All @@ -25,5 +26,5 @@ protected NamedSharedModuleApplyFunction() {

private native void allocate();

public native void call(@Const @StdString @ByRef String name, @ByRef @SharedPtr @Cast({"", "std::shared_ptr<torch::nn::Module>"}) Module m);
public native void call(@Const @StdString BytePointer name, @ByRef @SharedPtr @Cast({"", "std::shared_ptr<torch::nn::Module>"}) Module m);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ protected PickleWriter() {

private native void allocate();

public native void call(@Cast("const char *") BytePointer buf, @Cast("size_t") long nbytes);
public native void call(@Cast("const char*") BytePointer buf, @Cast("size_t") long nbytes);
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package org.bytedeco.pytorch.functions;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FunctionPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.annotation.Cast;
import org.bytedeco.javacpp.annotation.Const;
import org.bytedeco.javacpp.annotation.Properties;
import org.bytedeco.javacpp.annotation.StdString;

Expand All @@ -26,5 +27,6 @@ protected StringConsumer() {

private native void allocate();

public native void call(@Cast({"", "const std::string&"}) @StdString String s);
// std::function<void(const std::string&)>
public native void call(@Const @StdString BytePointer s);
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.bytedeco.pytorch.functions;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FunctionPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.annotation.Cast;
import org.bytedeco.javacpp.annotation.Properties;
import org.bytedeco.javacpp.annotation.StdString;

Expand All @@ -25,5 +27,6 @@ protected StringSupplier() {

private native void allocate();

public native @StdString String call();
// Without the cast, the function returns a std::basic_string<char>& and the cast from StringAdapter returns a reference to a variable in the stack.
public native @StdString @Cast({"", "char*"}) BytePointer call();
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ protected TensorIdGetter() {
private native void allocate();

// std::function<std::string(const at::Tensor&)>
public native @StdString String call(@Const @ByRef Tensor tensor);
// Without the cast, the function returns a std::basic_string<char>& and the cast from StringAdapter returns a reference to a variable in the stack.
public native @StdString @Cast({"", "char*"}) BytePointer call(@Const @ByRef Tensor tensor);
}

0 comments on commit bdba6a2

Please sign in to comment.