Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Pytorch 2.3.0 #1498

Merged
merged 18 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch/src/gen/java/org/bytedeco/pytorch/Half.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ public class Half extends Pointer {
// #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
// #else
public Half(float value) { super((Pointer)null); allocate(value); }
private native void allocate(float value);
public native @Name("operator float") float asFloat();
@Namespace @Name("javacpp::allocate_Half") private native void allocate(float value);
public float asFloat() { return _asFloat(this); }
private static native @Namespace @Name("javacpp::cast_Half_to_float") float _asFloat(Half h);
// #endif

// #if defined(__CUDACC__) || defined(__HIPCC__)
Expand Down
39 changes: 0 additions & 39 deletions pytorch/src/gen/java/org/bytedeco/pytorch/VectorReader.java

This file was deleted.

3 changes: 1 addition & 2 deletions pytorch/src/gen/java/org/bytedeco/pytorch/global/torch.java
Original file line number Diff line number Diff line change
Expand Up @@ -79474,9 +79474,8 @@ scalar_t sf(scalar_t x, scalar_t y)
@ByVal(nullValue = "torch::jit::TypeResolver(nullptr)") TypeResolver type_resolver,
@ByVal(nullValue = "c10::ArrayRef<at::Tensor>{}") TensorArrayRef tensor_table,
TypeParser type_parser/*=torch::jit::Unpickler::defaultTypeParser*/);
// Targeting ../VectorReader.java


// #ifndef C10_MOBILE
// #endif
// namespace jit
// namespace torch
Expand Down
53 changes: 31 additions & 22 deletions pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@
"torch/csrc/jit/serialization/storage_context.h",

"datasets.h",
"pytorch_adapters.h"
"pytorch_adapters.h",
"platform_unification.h"
},
exclude = {"openblas_config.h", "cblas.h", "lapacke_config.h", "lapacke_mangling.h", "lapack.h", "lapacke.h", "lapacke_utils.h"},
link = {"c10", "torch_cpu", "torch"},
Expand Down Expand Up @@ -354,12 +355,12 @@ public void map(InfoMap infoMap) {
//.put(new Info("c10::DataPtr&&", "at::DataPtr&&").valueTypes("@Cast({\"\", \"c10::DataPtr&&\"}) @StdMove DataPtr").pointerTypes("DataPtr")) // DataPtr::operator= deleted
.put(new Info("c10::DataPtr", "at::DataPtr").valueTypes("@StdMove DataPtr").pointerTypes("DataPtr"))
.put(new Info("c10::StorageImpl::UniqueStorageShareExternalPointer(at::DataPtr&&, size_t)",
"c10::Storage::UniqueStorageShareExternalPointer(at::DataPtr&&, size_t)").javaText(
"c10::Storage::UniqueStorageShareExternalPointer(at::DataPtr&&, size_t)").javaText(
"public native void UniqueStorageShareExternalPointer(@Cast({\"\", \"c10::DataPtr&&\"}) @StdMove DataPtr data_ptr, @Cast(\"size_t\") long size_bytes);"
))
))
.put(new Info("c10::GetStorageImplCreate", "c10::SetStorageImplCreate",
"c10::intrusive_ptr<c10::StorageImpl> (*)(c10::StorageImpl::use_byte_size_t, c10::SymInt, c10::DataPtr, c10::Allocator*, bool)").skip())
;
;
//// Enumerations
infoMap
.put(new Info("c10::ScalarType", "at::ScalarType", "torch::Dtype").enumerate().valueTypes("ScalarType").pointerTypes("@Cast(\"c10::ScalarType*\") BytePointer"))
Expand Down Expand Up @@ -1113,21 +1114,21 @@ public void map(InfoMap infoMap) {
"c10::complex<c10::Half>::operator -=(c10::Half)",
"c10::complex<c10::Half>::operator *=(c10::Half)",
"c10::complex<c10::Half>::operator /=(c10::Half)"
).skip())
).skip())
.put(new Info("c10::complex<c10::Half>::complex(const c10::Half&, const c10::Half&)").javaText( // Second argument not optional + add specific functions
"public HalfComplex(Half re, Half im) { super((Pointer)null); allocate(re, im); }\n" +
"private native void allocate(@Const @ByRef Half re, @Const @ByRef(nullValue = \"c10::Half()\") Half im);\n" +
"public HalfComplex(@Const @ByRef FloatComplex value) { super((Pointer)null); allocate(value); }\n" +
"private native void allocate(@Const @ByRef FloatComplex value);\n" +
"\n" +
"// Conversion operator\n" +
"public native @ByVal @Name(\"operator c10::complex<float>\") FloatComplex asFloatComplex();\n" +
"\n" +
"public native @ByRef @Name(\"operator +=\") HalfComplex addPut(@Const @ByRef HalfComplex other);\n" +
"\n" +
"public native @ByRef @Name(\"operator -=\") HalfComplex subtractPut(@Const @ByRef HalfComplex other);\n" +
"\n" +
"public native @ByRef @Name(\"operator *=\") HalfComplex multiplyPut(@Const @ByRef HalfComplex other);"
"public HalfComplex(Half re, Half im) { super((Pointer)null); allocate(re, im); }\n" +
"private native void allocate(@Const @ByRef Half re, @Const @ByRef(nullValue = \"c10::Half()\") Half im);\n" +
"public HalfComplex(@Const @ByRef FloatComplex value) { super((Pointer)null); allocate(value); }\n" +
"private native void allocate(@Const @ByRef FloatComplex value);\n" +
"\n" +
"// Conversion operator\n" +
"public native @ByVal @Name(\"operator c10::complex<float>\") FloatComplex asFloatComplex();\n" +
"\n" +
"public native @ByRef @Name(\"operator +=\") HalfComplex addPut(@Const @ByRef HalfComplex other);\n" +
"\n" +
"public native @ByRef @Name(\"operator -=\") HalfComplex subtractPut(@Const @ByRef HalfComplex other);\n" +
"\n" +
"public native @ByRef @Name(\"operator *=\") HalfComplex multiplyPut(@Const @ByRef HalfComplex other);"
)
)
;
Expand Down Expand Up @@ -1202,7 +1203,7 @@ public void map(InfoMap infoMap) {
.put(new Info("torch::jit::Wrap<torch::jit::Block>").pointerTypes("BlockWrap"))
.put(new Info("torch::jit::Wrap<torch::jit::Node>").pointerTypes("JitNodeWrap"))
.put(new Info("torch::jit::Wrap<torch::jit::Value>").pointerTypes("ValueWrap"))
;
;


//// Data loader
Expand Down Expand Up @@ -1786,7 +1787,7 @@ public void map(InfoMap infoMap) {
.put(new Info("torch::optim::" + opt + "Options", "torch::optim::" + opt + "ParamState")) // Help qualification
.put(new Info("torch::optim::OptimizerCloneableOptions<torch::optim::" + opt + "Options>").pointerTypes("OptimizerCloneable" + opt + "Options"))
.put(new Info("torch::optim::OptimizerCloneableParamState<torch::optim::" + opt + "ParamState>").pointerTypes("OptimizerCloneable" + opt + "ParamState"))
;
;
new PointerInfo("torch::optim::" + opt + "Options").makeUnique(infoMap);
new PointerInfo("torch::optim::OptimizerCloneableParamState<torch::optim::" + opt + "ParamState>").javaBaseName("OptimizerCloneable" + opt + "AdagradParamState").makeUnique(infoMap);
new PointerInfo("torch::optim::OptimizerCloneableOptions<torch::optim::" + opt + "Options>").javaBaseName("OptimizerCloneable" + opt + "Options").makeUnique(infoMap);
Expand Down Expand Up @@ -2342,6 +2343,7 @@ We need either to put an annotation info on each member, or javaName("@NoOffset
"torch::jit::Suspend",
"torch::jit::TokenTrie",
"torch::jit::TaggedRange",
"torch::jit::VectorReader",
"torch::jit::WithCurrentScope",
"torch::jit::WithInsertPoint",
"torch::jit::variable_tensor_list",
Expand Down Expand Up @@ -2432,8 +2434,7 @@ We need either to put an annotation info on each member, or javaName("@NoOffset
"std::enable_shared_from_this<torch::jit::tracer::TracingState>", "std::enable_shared_from_this<TracingState>",
"std::enable_shared_from_this<torch::nn::Module>", "std::enable_shared_from_this<Module>"
).pointerTypes("Pointer").cast())
.put(new Info("MTLCommandBuffer_t", "DispatchQueue_t").valueTypes("Pointer").pointerTypes("PointerPointer").skip());

.put(new Info("MTLCommandBuffer_t", "DispatchQueue_t").valueTypes("Pointer").pointerTypes("PointerPointer").skip());


///// Special cases needing javaText
Expand Down Expand Up @@ -2539,6 +2540,14 @@ We need either to put an annotation info on each member, or javaName("@NoOffset

infoMap.put(new Info("c10::VaryingShape<c10::Stride>::merge").skip()); // https://github.com/pytorch/pytorch/issues/123248, waiting for the fix in 2.3.1 or 2.4

//// Different C++ API between platforms
infoMap
.put(new Info("c10::Half::Half(float)").annotations("@Namespace @Name(\"javacpp::allocate_Half\")"))
.put(new Info("c10::Half::operator float()").javaText(
"public float asFloat() { return _asFloat(this); }\n" +
"private static native @Namespace @Name(\"javacpp::cast_Half_to_float\") float _asFloat(Half h);"
))
;
}

private static String template(String t, String... args) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/* Add the necessary C++ calls to present a unified C++ API whatever the platform */

namespace javacpp {

inline c10::Half *allocate_Half(jfloat value) {
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
return new c10::Half((float16_t) value);
saudet marked this conversation as resolved.
Show resolved Hide resolved
#else
return new c10::Half((float) value);
#endif
}

inline float cast_Half_to_float(c10::Half *h) {
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
return (float) (float16_t) *h;
#else
return (float) *h;
#endif
}

}
Loading