Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch] New version of the presets #1360

Merged
merged 71 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
96e3a65
Reorganization, use of new JavaCPP features, more mapping
HGuillemet Apr 16, 2023
9152ff6
Add missing exports in module-info
HGuillemet May 22, 2023
a80ea2e
Add 2 missing includes, reindent
HGuillemet May 23, 2023
809f2c3
Add missing includes
HGuillemet May 23, 2023
42e877c
Remove 3 classes not in API
HGuillemet May 23, 2023
ad9e7d0
Update gen
HGuillemet May 23, 2023
03af32c
Fix Module::apply and JitModule::apply
HGuillemet May 23, 2023
50c2d70
Remove includes needing CUDA installed
HGuillemet May 23, 2023
f11a5c4
Fix windows build.
HGuillemet May 24, 2023
9001e16
Skip some "internal-only" functions
HGuillemet May 24, 2023
309b8f4
Update gen
HGuillemet May 24, 2023
c499367
Fix make_generator for Windows
HGuillemet May 24, 2023
ebb9249
Move cuda-specific to torch_cuda
HGuillemet May 27, 2023
c64ff76
Add nvfuser to preloads
HGuillemet May 27, 2023
8ec1916
Exclude more non-exported symbols
HGuillemet May 27, 2023
534fe1b
gen update
HGuillemet May 27, 2023
5b198c1
cuda gen update
HGuillemet May 27, 2023
7fb3268
Merge 2.0.1 changes from master
HGuillemet May 29, 2023
7527850
Skip EnumHolder::is
HGuillemet May 29, 2023
2c293c5
Add include path for CUDA on Windows
HGuillemet May 30, 2023
d720854
Fix torch_cuda windows linking
HGuillemet May 30, 2023
a9b3d2a
Fix torch_cuda windows linking
HGuillemet May 30, 2023
4f53b93
Change TORCH_CUDA_ARCH_LIST
HGuillemet May 31, 2023
0fc6776
* Upgrade presets for FFmpeg 6.0, HDF5 1.14.1, LLVM 16.0.4, NVIDIA V…
saudet May 28, 2023
0324667
* Add new `SampleJpegEncoder` code for nvJPEG module of CUDA (pull #…
devjeonghwan May 31, 2023
ec64a17
Merge branch 'master' into hg_pytorch
HGuillemet Jun 1, 2023
3e3fe5c
Add rm in deploy-centos to preserve disk space
HGuillemet Jun 2, 2023
b6d0123
Change TORCH_CUDA_ARCH_LIST
HGuillemet Jun 2, 2023
7333646
Change TORCH_CUDA_ARCH_LIST.
HGuillemet Jun 3, 2023
c556443
Merge remote-tracking branch 'origin/master' into hg_pytorch
HGuillemet Jun 7, 2023
1b5b94f
Change version to 2.0.1-new. gen update.
HGuillemet Jun 7, 2023
038b07a
Revert TORCH_CUDA_ARCH_LIST to 5.0+PTX
HGuillemet Jun 9, 2023
97f4aaa
Merge remote-tracking branch 'origin/master' into hg_pytorch
HGuillemet Jun 10, 2023
6c9188e
Update to 1.5.10-SNAPSHOT
HGuillemet Jun 10, 2023
0b20648
Deploy on Ubuntu instead of Centos
HGuillemet Jun 10, 2023
9263f7c
Try to fix CUDA builds on Ubuntu
saudet Jun 12, 2023
657ce64
Fix CUDA builds on Ubuntu some more
saudet Jun 12, 2023
49a73de
Fix incorrect versions
saudet Jun 12, 2023
ddab4a5
Fix workflow for ccache
saudet Jun 12, 2023
12b4523
Revert unnecessary changes to deploy-centos/action.yml
saudet Jun 12, 2023
47a3e16
Load include list from resources in init().
HGuillemet Jun 13, 2023
15012ba
Use C format for list of parsed headers
HGuillemet Jun 14, 2023
ac34d7c
Link jnitorch_cuda with cudart on windows
HGuillemet Jun 15, 2023
c74a931
Fix linking jni torch_cuda with cudart
HGuillemet Jun 16, 2023
808a2c8
Add linking jni torch_cuda with cusparse
HGuillemet Jun 16, 2023
14fed5e
Add linking jni torch_cuda with nvJitLink
HGuillemet Jun 16, 2023
4d5fd3d
Check against parser class name instead of parser class.
HGuillemet Jun 20, 2023
01e2996
Merge branch 'master' into hg_pytorch
HGuillemet Jun 21, 2023
f341fd5
Simplify initIncludes
HGuillemet Jun 21, 2023
58372cb
Revert nvJitLink linking
HGuillemet Jun 26, 2023
3ba2c44
Fix cusolver version in preloads
HGuillemet Jun 26, 2023
fb172f3
Merge branch 'master' into hg_pytorch
HGuillemet Jun 26, 2023
992ed3a
Add GenericDictEntryRef
HGuillemet Jul 7, 2023
ee49627
Cleanup OrderedDict
HGuillemet Jul 7, 2023
34c7ec8
Changes to functions after bytedeco/javacpp@d8b1890
HGuillemet Jul 8, 2023
07dfacd
Add missing gen files for OrderedDict
HGuillemet Jul 8, 2023
8c025db
Remove useless mapping after bytedeco/javacpp@2dacec9
HGuillemet Jul 9, 2023
3c95bef
Update gen after bytedeco/javacpp@ec90945
HGuillemet Jul 10, 2023
c81d251
Add `@NoOffset` on Call
HGuillemet Jul 10, 2023
4ebe97c
Workaround for TransformerActivation.get2 returning a std::function
HGuillemet Jul 11, 2023
43db4eb
Rename c10::variant instances for consistency
HGuillemet Jul 11, 2023
2edaa3b
Fix TensorActivation.get2 now returning a TensorMapper. Change access…
HGuillemet Jul 11, 2023
aba8fd8
Merge remote-tracking branch 'origin/master' into hg_pytorch
HGuillemet Jul 11, 2023
83dce9c
Remove TensorActivation.get2
HGuillemet Jul 12, 2023
8a28331
Merge branch 'master' into hg_pytorch
HGuillemet Jul 22, 2023
eff0259
Update gen after bytedeco/javacpp@8646e97
HGuillemet Jul 22, 2023
f8cd7ec
Update CHANGELOG.md and fix nits
saudet Jul 22, 2023
c14b39a
Add missing at::sqrt(Tensor) and other complex math operators
HGuillemet Jul 23, 2023
2c4ff2d
Add ska::detailv3::log2 masked by last commit
HGuillemet Jul 23, 2023
8c16124
Skip one-element constructor for all ArrayRef instances
HGuillemet Jul 23, 2023
f5cc0be
Add ArrayRef constructor taking a std::vector
HGuillemet Jul 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

* Refactor and improve presets for PyTorch ([pull #1360](https://github.com/bytedeco/javacpp-presets/pull/1360))
* Include `mkl_lapack.h` header file in presets for MKL ([issue #1388](https://github.com/bytedeco/javacpp-presets/issues/1388))
* Map new higher-level C++ API of Triton Inference Server ([pull #1361](https://github.com/bytedeco/javacpp-presets/pull/1361))
* Upgrade presets for OpenCV 4.8.0, DNNL 3.1.1, CPython 3.11.4, NumPy 1.25.1, SciPy 1.11.1, LLVM 16.0.6, TensorFlow Lite 2.13.0, Triton Inference Server 2.34.0, ONNX Runtime 1.15.1, and their dependencies
Expand Down
10 changes: 5 additions & 5 deletions pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,36 @@ We can use [Maven 3](http://maven.apache.org/) to download and install automatic
<modelVersion>4.0.0</modelVersion>
<groupId>org.bytedeco.pytorch</groupId>
<artifactId>simplemnist</artifactId>
<version>1.5.9</version>
<version>1.5.10-SNAPSHOT</version>
<properties>
<exec.mainClass>SimpleMNIST</exec.mainClass>
</properties>
<dependencies>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>pytorch-platform</artifactId>
<version>2.0.1-1.5.9</version>
<version>2.0.1-1.5.10-SNAPSHOT</version>
</dependency>

<!-- Additional dependencies required to use CUDA, cuDNN, and NCCL -->
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>pytorch-platform-gpu</artifactId>
<version>2.0.1-1.5.9</version>
<version>2.0.1-1.5.10-SNAPSHOT</version>
</dependency>

<!-- Additional dependencies to use bundled CUDA, cuDNN, and NCCL -->
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>cuda-platform-redist</artifactId>
<version>12.1-8.9-1.5.9</version>
<version>12.1-8.9-1.5.10-SNAPSHOT</version>
</dependency>

<!-- Additional dependencies to use bundled full version of MKL -->
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>mkl-platform-redist</artifactId>
<version>2023.1-1.5.9</version>
<version>2023.1-1.5.10-SNAPSHOT</version>
</dependency>
</dependencies>
<build>
Expand Down
4 changes: 3 additions & 1 deletion pytorch/src/gen/java/org/bytedeco/pytorch/ASMoutput.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;

@Namespace("torch::profiler::impl::kineto") @Opaque @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ActivityTraceWrapper extends Pointer {
/** Empty constructor. Calls {@code super((Pointer)null)}. */
public ActivityTraceWrapper() { super((Pointer)null); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ActivityTraceWrapper(Pointer p) { super(p); }
}
46 changes: 46 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/ActivityTypeSet.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;

@Name("std::set<torch::profiler::impl::ActivityType>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ActivityTypeSet extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ActivityTypeSet(Pointer p) { super(p); }
public ActivityTypeSet() { allocate(); }
private native void allocate();
public native @Name("operator =") @ByRef ActivityTypeSet put(@ByRef ActivityTypeSet x);

public boolean empty() { return size() == 0; }
public native long size();

public ActivityType front() { try (Iterator it = begin()) { return it.get(); } }
public native void insert(@ByRef ActivityType value);
public native void erase(@ByRef ActivityType value);
public native @ByVal Iterator begin();
public native @ByVal Iterator end();
@NoOffset @Name("iterator") public static class Iterator extends Pointer {
public Iterator(Pointer p) { super(p); }
public Iterator() { }

public native @Name("operator ++") @ByRef Iterator increment();
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
public native @Name("operator *") @ByRef @Const ActivityType get();
}
}

12 changes: 7 additions & 5 deletions pytorch/src/gen/java/org/bytedeco/pytorch/Adagrad.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -33,10 +35,10 @@ public Adagrad(
private native void allocate(
@ByVal OptimizerParamGroupVector param_groups);

public Adagrad(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdagradOptions{}") AdagradOptions defaults) { super((Pointer)null); allocate(params, defaults); }
private native void allocate(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdagradOptions{}") AdagradOptions defaults);
public Adagrad(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params) { super((Pointer)null); allocate(params); }
private native void allocate(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params);
public Adagrad(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdagradOptions{}") AdagradOptions defaults) { super((Pointer)null); allocate(params, defaults); }
private native void allocate(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdagradOptions{}") AdagradOptions defaults);
public Adagrad(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params) { super((Pointer)null); allocate(params); }
private native void allocate(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params);

public native @ByVal Tensor step(@ByVal(nullValue = "torch::optim::Optimizer::LossClosure(nullptr)") LossClosure closure);
public native @ByVal Tensor step();
Expand Down
9 changes: 7 additions & 2 deletions pytorch/src/gen/java/org/bytedeco/pytorch/AdagradOptions.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -33,7 +35,10 @@ public class AdagradOptions extends OptimizerCloneableAdagradOptions {
public native @ByRef @NoException(true) DoublePointer eps();



private static native @Namespace @Cast("bool") @Name("operator ==") boolean equals(
@Const @ByRef AdagradOptions lhs,
@Const @ByRef AdagradOptions rhs);
public boolean equals(AdagradOptions rhs) { return equals(this, rhs); }
public native double get_lr();
public native void set_lr(double lr);
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -38,5 +40,8 @@ public class AdagradParamState extends OptimizerCloneableAdagradParamState {
public native @Cast("int64_t*") @ByRef @NoException(true) LongPointer step();



private static native @Namespace @Cast("bool") @Name("operator ==") boolean equals(
@Const @ByRef AdagradParamState lhs,
@Const @ByRef AdagradParamState rhs);
public boolean equals(AdagradParamState rhs) { return equals(this, rhs); }
}
12 changes: 7 additions & 5 deletions pytorch/src/gen/java/org/bytedeco/pytorch/Adam.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -32,10 +34,10 @@ public Adam(
@ByVal OptimizerParamGroupVector param_groups) { super((Pointer)null); allocate(param_groups); }
private native void allocate(
@ByVal OptimizerParamGroupVector param_groups);
public Adam(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamOptions{}") AdamOptions defaults) { super((Pointer)null); allocate(params, defaults); }
private native void allocate(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamOptions{}") AdamOptions defaults);
public Adam(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params) { super((Pointer)null); allocate(params); }
private native void allocate(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params);
public Adam(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamOptions{}") AdamOptions defaults) { super((Pointer)null); allocate(params, defaults); }
private native void allocate(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamOptions{}") AdamOptions defaults);
public Adam(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params) { super((Pointer)null); allocate(params); }
private native void allocate(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params);

public native @ByVal Tensor step(@ByVal(nullValue = "torch::optim::Optimizer::LossClosure(nullptr)") LossClosure closure);
public native @ByVal Tensor step();
Expand Down
9 changes: 7 additions & 2 deletions pytorch/src/gen/java/org/bytedeco/pytorch/AdamOptions.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -33,7 +35,10 @@ public class AdamOptions extends OptimizerCloneableAdamOptions {
public native @Cast("bool*") @ByRef @NoException(true) BoolPointer amsgrad();



private static native @Namespace @Cast("bool") @Name("operator ==") boolean equals(
@Const @ByRef AdamOptions lhs,
@Const @ByRef AdamOptions rhs);
public boolean equals(AdamOptions rhs) { return equals(this, rhs); }
public native double get_lr();
public native void set_lr(double lr);
}
9 changes: 7 additions & 2 deletions pytorch/src/gen/java/org/bytedeco/pytorch/AdamParamState.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -40,5 +42,8 @@ public class AdamParamState extends OptimizerCloneableAdamParamState {
public native @ByRef @NoException(true) Tensor max_exp_avg_sq();



private static native @Namespace @Cast("bool") @Name("operator ==") boolean equals(
@Const @ByRef AdamParamState lhs,
@Const @ByRef AdamParamState rhs);
public boolean equals(AdamParamState rhs) { return equals(this, rhs); }
}
12 changes: 7 additions & 5 deletions pytorch/src/gen/java/org/bytedeco/pytorch/AdamW.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -32,10 +34,10 @@ public AdamW(
@ByVal OptimizerParamGroupVector param_groups) { super((Pointer)null); allocate(param_groups); }
private native void allocate(
@ByVal OptimizerParamGroupVector param_groups);
public AdamW(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions defaults) { super((Pointer)null); allocate(params, defaults); }
private native void allocate(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions defaults);
public AdamW(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params) { super((Pointer)null); allocate(params); }
private native void allocate(@Cast({"", "std::vector<at::Tensor>"}) @StdMove TensorVector params);
public AdamW(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions defaults) { super((Pointer)null); allocate(params, defaults); }
private native void allocate(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions defaults);
public AdamW(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params) { super((Pointer)null); allocate(params); }
private native void allocate(@Cast({"", "std::vector<torch::Tensor>"}) @StdMove TensorVector params);

public native @ByVal Tensor step(@ByVal(nullValue = "torch::optim::Optimizer::LossClosure(nullptr)") LossClosure closure);
public native @ByVal Tensor step();
Expand Down
9 changes: 7 additions & 2 deletions pytorch/src/gen/java/org/bytedeco/pytorch/AdamWOptions.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -33,7 +35,10 @@ public class AdamWOptions extends OptimizerCloneableAdamWOptions {
public native @Cast("bool*") @ByRef @NoException(true) BoolPointer amsgrad();



private static native @Namespace @Cast("bool") @Name("operator ==") boolean equals(
@Const @ByRef AdamWOptions lhs,
@Const @ByRef AdamWOptions rhs);
public boolean equals(AdamWOptions rhs) { return equals(this, rhs); }
public native double get_lr();
public native void set_lr(double lr);
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Targeted by JavaCPP version 1.5.9: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.10-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.functions.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
Expand Down Expand Up @@ -40,5 +42,8 @@ public class AdamWParamState extends OptimizerCloneableAdamWParamState {
public native @ByRef @NoException(true) Tensor max_exp_avg_sq();



private static native @Namespace @Cast("bool") @Name("operator ==") boolean equals(
@Const @ByRef AdamWParamState lhs,
@Const @ByRef AdamWParamState rhs);
public boolean equals(AdamWParamState rhs) { return equals(this, rhs); }
}
Loading