Skip to content

Commit

Permalink
* Map torch::data::datasets::ChunkDataReader and related data load…
Browse files Browse the repository at this point in the history
…ing classes from PyTorch (issue #1215)
  • Loading branch information
saudet committed Dec 11, 2022
1 parent bd67bb6 commit fa4dfdc
Show file tree
Hide file tree
Showing 29 changed files with 1,211 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

* Map `torch::data::datasets::ChunkDataReader` and related data loading classes from PyTorch ([issue #1215](https://github.com/bytedeco/javacpp-presets/issues/1215))
* Add missing predefined `AVChannelLayout` in presets for FFmpeg ([issue #1286](https://github.com/bytedeco/javacpp-presets/issues/1286))
* Map `c10::impl::GenericDict` as returned by `c10::IValue::toGenericDict()` in presets for PyTorch
* Introduce `linux-armhf` and `linux-x86` builds to presets for TensorFlow Lite ([pull #1268](https://github.com/bytedeco/javacpp-presets/pull/1268))
Expand Down
4 changes: 4 additions & 0 deletions pytorch/cppbuild.sh
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ sedinplace "s/var.startswith(('BUILD_', 'USE_', 'CMAKE_'))/var.startswith(('BUIL
sedinplace 's/TensorIndex(c10::nullopt_t)/TensorIndex(c10::nullopt_t none = None)/g' aten/src/ATen/TensorIndexing.h

# add missing declarations
sedinplace '/using ExampleType = ExampleType_;/a\
using BatchType = ChunkType;\
using DataType = ExampleType;\
' torch/csrc/api/include/torch/data/datasets/chunk.h
sedinplace '/^};/a\
TORCH_API std::ostream& operator<<(std::ostream& stream, const nn::Module& module);\
' torch/csrc/api/include/torch/nn/module.h
Expand Down
33 changes: 33 additions & 0 deletions pytorch/samples/TestChunkData.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import static org.bytedeco.pytorch.global.torch.*;

public class TestChunkData {
public static void main(String[] args) throws Exception {
try (PointerScope scope = new PointerScope()) {
long batch_size = 10;
long prefetch_count = 1;
ChunkDataReader data_reader = new ChunkDataReader() {
public ExampleVector read_chunk(long chunk_index) {
return new ExampleVector(
new Example(Tensor.create(100.0), Tensor.create(200.0)),
new Example(Tensor.create(300.0), Tensor.create(400.0)));
}
public long chunk_count() { return 1; }
public void reset() { }
};
RandomSampler sampler = new RandomSampler(0);
ChunkMapDataset data_set = new ChunkSharedBatchDataset(
new ChunkDataset(data_reader, sampler, sampler,
new ChunkDatasetOptions(prefetch_count, batch_size))).map(new ExampleStack());
ChunkRandomDataLoader data_loader = new ChunkRandomDataLoader(
data_set, new DataLoaderOptions(batch_size));
for (int epoch = 1; epoch <= 10; ++epoch) {
for (ExampleIterator it = data_loader.begin(); !it.equals(data_loader.end()); it = it.increment()) {
Example batch = it.access();
System.out.println(batch.data().createIndexer() + " " + batch.target().createIndexer());
}
}
}
}
}
39 changes: 39 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/ChunkBatchDataset.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;
// namespace detail

/** A dataset that can yield data only in batches. */
@Name("torch::data::datasets::BatchDataset<torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>,c10::optional<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e::BatchType>,size_t>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ChunkBatchDataset extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ChunkBatchDataset(Pointer p) { super(p); }

@MemberGetter public static native @Cast("const bool") boolean is_stateful();
public static final boolean is_stateful = is_stateful();

/** Returns a batch of data given an index. */
public native @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long request);

/** Returns the size of the dataset, or an empty optional if it is unsized. */
public native @ByVal SizeTOptional size();

/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */

/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;

@Name("torch::data::datasets::BatchDataset<torch::data::datasets::SharedBatchDataset<torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler> >,c10::optional<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e::BatchType>,size_t>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ChunkBatchSharedBatchDataset extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ChunkBatchSharedBatchDataset(Pointer p) { super(p); }

@MemberGetter public static native @Cast("const bool") boolean is_stateful();
public static final boolean is_stateful = is_stateful();

/** Returns a batch of data given an index. */
public native @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long request);

/** Returns the size of the dataset, or an empty optional if it is unsized. */
public native @ByVal SizeTOptional size();

/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */
public native @ByVal ChunkMapDataset map(@ByVal ExampleStack transform);

/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */

}
53 changes: 53 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/ChunkDataReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;


/** Interface for chunk reader, which performs data chunking and reading of
* entire chunks.
*
* A chunk could be an entire file, such as an audio data file or an image,
* or part of a file in the case of a large text-file split based on seek
* positions. */
@Name("torch::data::datasets::ChunkDataReader<torch::data::Example<>,std::vector<torch::data::Example<> > >") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ChunkDataReader extends Pointer {
static { Loader.load(); }
/** Default native constructor. */
public ChunkDataReader() { super((Pointer)null); allocate(); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public ChunkDataReader(long size) { super((Pointer)null); allocateArray(size); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ChunkDataReader(Pointer p) { super(p); }
private native void allocate();
private native void allocateArray(long size);
@Override public ChunkDataReader position(long position) {
return (ChunkDataReader)super.position(position);
}
@Override public ChunkDataReader getPointer(long i) {
return new ChunkDataReader((Pointer)this).offsetAddress(i);
}



/** Read an entire chunk. */
@Virtual(true) public native @ByVal @Cast("torch::data::datasets::ChunkDataReader<torch::data::Example<>,std::vector<torch::data::Example<> > >::ChunkType*") ExampleVector read_chunk(@Cast("size_t") long chunk_index);

/** Returns the number of chunks available in this reader. */
@Virtual(true) public native @Cast("size_t") long chunk_count();

/** This will clear any internal state associate with this reader. */
@Virtual(true) public native void reset();
}
75 changes: 75 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/ChunkDataset.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;


/** A stateful dataset that support hierarchical sampling and prefetching of
* entre chunks.
*
* Unlike regular dataset, chunk dataset require two samplers to operate and
* keeps an internal state. {@code ChunkSampler} selects, which chunk to load next,
* while the {@code ExampleSampler} determins the order of Examples that are returned
* in each {@code get_batch} call. The hierarchical sampling approach used here is
* inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf */
@Name("torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ChunkDataset extends ChunkStatefulDataset {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ChunkDataset(Pointer p) { super(p); }


public ChunkDataset(
ChunkDataReader chunk_reader,
RandomSampler chunk_sampler,
RandomSampler example_sampler,
ChunkDatasetOptions options) { super((Pointer)null); allocate(chunk_reader, chunk_sampler, example_sampler, options, null); }
public ChunkDataset(
ChunkDataReader chunk_reader,
RandomSampler chunk_sampler,
RandomSampler example_sampler,
ChunkDatasetOptions options,
Pointer preprocessing_policy) { super((Pointer)null); allocate(chunk_reader, chunk_sampler, example_sampler, options, preprocessing_policy); }
private native void allocate(
@ByVal @Cast("JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e*") ChunkDataReader chunk_reader,
@ByVal RandomSampler chunk_sampler,
@ByVal RandomSampler example_sampler,
@ByVal ChunkDatasetOptions options,
@ByVal(nullValue = "std::function<void(std::vector<torch::data::Example<>>&)>()") @Cast("std::function<void(std::vector<torch::data::Example<>>&)>*") Pointer preprocessing_policy);

/** Default get_batch method of BatchDataset. This method returns
* Example batches created from the preloaded chunks. The implemenation
* is dataset agnostic and does not need overriding in different chunk
* datasets. */
public native @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long batch_size);

/** Helper method around get_batch as {@code batch_size} is not strictly necessary */
public native @ByVal ExampleVectorOptional get_batch();

/** This will clear any internal state and starts the internal prefetching
* mechanism for the chunk dataset. */
public native void reset();

/** size is not used for chunk dataset. */
public native @ByVal SizeTOptional size();

// provide a references to chunk sampler. Used mainly in distributed data
// loading to set the epoch number for the sampler.
public native @Cast("torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>::ChunkSamplerType*") @ByRef RandomSampler chunk_sampler();

public native void save(@ByRef OutputArchive archive);

public native void load(@ByRef InputArchive archive);
}
47 changes: 47 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/ChunkDatasetOptions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;
// namespace detail

/** Options to configure a {@code ChunkDataset}. */
@Namespace("torch::data::datasets") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ChunkDatasetOptions extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ChunkDatasetOptions(Pointer p) { super(p); }


public ChunkDatasetOptions(
@Cast("size_t") long preloader_count,
@Cast("size_t") long batch_size,
@Cast("size_t") long cache_size/*=2048*/,
@Cast("size_t") long cross_chunk_shuffle_count/*=1*/) { super((Pointer)null); allocate(preloader_count, batch_size, cache_size, cross_chunk_shuffle_count); }
private native void allocate(
@Cast("size_t") long preloader_count,
@Cast("size_t") long batch_size,
@Cast("size_t") long cache_size/*=2048*/,
@Cast("size_t") long cross_chunk_shuffle_count/*=1*/);
public ChunkDatasetOptions(
@Cast("size_t") long preloader_count,
@Cast("size_t") long batch_size) { super((Pointer)null); allocate(preloader_count, batch_size); }
private native void allocate(
@Cast("size_t") long preloader_count,
@Cast("size_t") long batch_size);
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer preloader_count();
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer batch_size();
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer cache_size();
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer cross_chunk_shuffle_count();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;

@Name("torch::data::datasets::BatchDataset<torch::data::datasets::MapDataset<torch::data::datasets::SharedBatchDataset<torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler> >,torch::data::transforms::Stack<torch::data::Example<> > >,std::vector<torch::data::Example<> >,at::ArrayRef<size_t> >") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class ChunkMapBatchDataset extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ChunkMapBatchDataset(Pointer p) { super(p); }

@MemberGetter public static native @Cast("const bool") boolean is_stateful();
public static final boolean is_stateful = is_stateful();

/** Returns a batch of data given an index. */
public native @ByVal ExampleVector get_batch(@ByVal SizeTArrayRef request);

/** Returns the size of the dataset, or an empty optional if it is unsized. */
public native @ByVal SizeTOptional size();

/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */

/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */

}
Loading

0 comments on commit fa4dfdc

Please sign in to comment.