-
Notifications
You must be signed in to change notification settings - Fork 741
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Map
torch::data::datasets::ChunkDataReader
and related data load…
…ing classes from PyTorch (issue #1215)
- Loading branch information
Showing
29 changed files
with
1,211 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.pytorch.*; | ||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
public class TestChunkData { | ||
public static void main(String[] args) throws Exception { | ||
try (PointerScope scope = new PointerScope()) { | ||
long batch_size = 10; | ||
long prefetch_count = 1; | ||
ChunkDataReader data_reader = new ChunkDataReader() { | ||
public ExampleVector read_chunk(long chunk_index) { | ||
return new ExampleVector( | ||
new Example(Tensor.create(100.0), Tensor.create(200.0)), | ||
new Example(Tensor.create(300.0), Tensor.create(400.0))); | ||
} | ||
public long chunk_count() { return 1; } | ||
public void reset() { } | ||
}; | ||
RandomSampler sampler = new RandomSampler(0); | ||
ChunkMapDataset data_set = new ChunkSharedBatchDataset( | ||
new ChunkDataset(data_reader, sampler, sampler, | ||
new ChunkDatasetOptions(prefetch_count, batch_size))).map(new ExampleStack()); | ||
ChunkRandomDataLoader data_loader = new ChunkRandomDataLoader( | ||
data_set, new DataLoaderOptions(batch_size)); | ||
for (int epoch = 1; epoch <= 10; ++epoch) { | ||
for (ExampleIterator it = data_loader.begin(); !it.equals(data_loader.end()); it = it.increment()) { | ||
Example batch = it.access(); | ||
System.out.println(batch.data().createIndexer() + " " + batch.target().createIndexer()); | ||
} | ||
} | ||
} | ||
} | ||
} |
39 changes: 39 additions & 0 deletions
39
pytorch/src/gen/java/org/bytedeco/pytorch/ChunkBatchDataset.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
// namespace detail | ||
|
||
/** A dataset that can yield data only in batches. */ | ||
@Name("torch::data::datasets::BatchDataset<torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>,c10::optional<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e::BatchType>,size_t>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class ChunkBatchDataset extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public ChunkBatchDataset(Pointer p) { super(p); } | ||
|
||
@MemberGetter public static native @Cast("const bool") boolean is_stateful(); | ||
public static final boolean is_stateful = is_stateful(); | ||
|
||
/** Returns a batch of data given an index. */ | ||
public native @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long request); | ||
|
||
/** Returns the size of the dataset, or an empty optional if it is unsized. */ | ||
public native @ByVal SizeTOptional size(); | ||
|
||
/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */ | ||
|
||
/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */ | ||
|
||
} |
38 changes: 38 additions & 0 deletions
38
pytorch/src/gen/java/org/bytedeco/pytorch/ChunkBatchSharedBatchDataset.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
@Name("torch::data::datasets::BatchDataset<torch::data::datasets::SharedBatchDataset<torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler> >,c10::optional<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e::BatchType>,size_t>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class ChunkBatchSharedBatchDataset extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public ChunkBatchSharedBatchDataset(Pointer p) { super(p); } | ||
|
||
@MemberGetter public static native @Cast("const bool") boolean is_stateful(); | ||
public static final boolean is_stateful = is_stateful(); | ||
|
||
/** Returns a batch of data given an index. */ | ||
public native @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long request); | ||
|
||
/** Returns the size of the dataset, or an empty optional if it is unsized. */ | ||
public native @ByVal SizeTOptional size(); | ||
|
||
/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */ | ||
public native @ByVal ChunkMapDataset map(@ByVal ExampleStack transform); | ||
|
||
/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */ | ||
|
||
} |
53 changes: 53 additions & 0 deletions
53
pytorch/src/gen/java/org/bytedeco/pytorch/ChunkDataReader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
|
||
/** Interface for chunk reader, which performs data chunking and reading of | ||
* entire chunks. | ||
* | ||
* A chunk could be an entire file, such as an audio data file or an image, | ||
* or part of a file in the case of a large text-file split based on seek | ||
* positions. */ | ||
@Name("torch::data::datasets::ChunkDataReader<torch::data::Example<>,std::vector<torch::data::Example<> > >") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class ChunkDataReader extends Pointer { | ||
static { Loader.load(); } | ||
/** Default native constructor. */ | ||
public ChunkDataReader() { super((Pointer)null); allocate(); } | ||
/** Native array allocator. Access with {@link Pointer#position(long)}. */ | ||
public ChunkDataReader(long size) { super((Pointer)null); allocateArray(size); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public ChunkDataReader(Pointer p) { super(p); } | ||
private native void allocate(); | ||
private native void allocateArray(long size); | ||
@Override public ChunkDataReader position(long position) { | ||
return (ChunkDataReader)super.position(position); | ||
} | ||
@Override public ChunkDataReader getPointer(long i) { | ||
return new ChunkDataReader((Pointer)this).offsetAddress(i); | ||
} | ||
|
||
|
||
|
||
/** Read an entire chunk. */ | ||
@Virtual(true) public native @ByVal @Cast("torch::data::datasets::ChunkDataReader<torch::data::Example<>,std::vector<torch::data::Example<> > >::ChunkType*") ExampleVector read_chunk(@Cast("size_t") long chunk_index); | ||
|
||
/** Returns the number of chunks available in this reader. */ | ||
@Virtual(true) public native @Cast("size_t") long chunk_count(); | ||
|
||
/** This will clear any internal state associate with this reader. */ | ||
@Virtual(true) public native void reset(); | ||
} |
75 changes: 75 additions & 0 deletions
75
pytorch/src/gen/java/org/bytedeco/pytorch/ChunkDataset.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
|
||
/** A stateful dataset that support hierarchical sampling and prefetching of | ||
* entre chunks. | ||
* | ||
* Unlike regular dataset, chunk dataset require two samplers to operate and | ||
* keeps an internal state. {@code ChunkSampler} selects, which chunk to load next, | ||
* while the {@code ExampleSampler} determins the order of Examples that are returned | ||
* in each {@code get_batch} call. The hierarchical sampling approach used here is | ||
* inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf */ | ||
@Name("torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class ChunkDataset extends ChunkStatefulDataset { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public ChunkDataset(Pointer p) { super(p); } | ||
|
||
|
||
public ChunkDataset( | ||
ChunkDataReader chunk_reader, | ||
RandomSampler chunk_sampler, | ||
RandomSampler example_sampler, | ||
ChunkDatasetOptions options) { super((Pointer)null); allocate(chunk_reader, chunk_sampler, example_sampler, options, null); } | ||
public ChunkDataset( | ||
ChunkDataReader chunk_reader, | ||
RandomSampler chunk_sampler, | ||
RandomSampler example_sampler, | ||
ChunkDatasetOptions options, | ||
Pointer preprocessing_policy) { super((Pointer)null); allocate(chunk_reader, chunk_sampler, example_sampler, options, preprocessing_policy); } | ||
private native void allocate( | ||
@ByVal @Cast("JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e*") ChunkDataReader chunk_reader, | ||
@ByVal RandomSampler chunk_sampler, | ||
@ByVal RandomSampler example_sampler, | ||
@ByVal ChunkDatasetOptions options, | ||
@ByVal(nullValue = "std::function<void(std::vector<torch::data::Example<>>&)>()") @Cast("std::function<void(std::vector<torch::data::Example<>>&)>*") Pointer preprocessing_policy); | ||
|
||
/** Default get_batch method of BatchDataset. This method returns | ||
* Example batches created from the preloaded chunks. The implemenation | ||
* is dataset agnostic and does not need overriding in different chunk | ||
* datasets. */ | ||
public native @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long batch_size); | ||
|
||
/** Helper method around get_batch as {@code batch_size} is not strictly necessary */ | ||
public native @ByVal ExampleVectorOptional get_batch(); | ||
|
||
/** This will clear any internal state and starts the internal prefetching | ||
* mechanism for the chunk dataset. */ | ||
public native void reset(); | ||
|
||
/** size is not used for chunk dataset. */ | ||
public native @ByVal SizeTOptional size(); | ||
|
||
// provide a references to chunk sampler. Used mainly in distributed data | ||
// loading to set the epoch number for the sampler. | ||
public native @Cast("torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>::ChunkSamplerType*") @ByRef RandomSampler chunk_sampler(); | ||
|
||
public native void save(@ByRef OutputArchive archive); | ||
|
||
public native void load(@ByRef InputArchive archive); | ||
} |
47 changes: 47 additions & 0 deletions
47
pytorch/src/gen/java/org/bytedeco/pytorch/ChunkDatasetOptions.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
// namespace detail | ||
|
||
/** Options to configure a {@code ChunkDataset}. */ | ||
@Namespace("torch::data::datasets") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class ChunkDatasetOptions extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public ChunkDatasetOptions(Pointer p) { super(p); } | ||
|
||
|
||
public ChunkDatasetOptions( | ||
@Cast("size_t") long preloader_count, | ||
@Cast("size_t") long batch_size, | ||
@Cast("size_t") long cache_size/*=2048*/, | ||
@Cast("size_t") long cross_chunk_shuffle_count/*=1*/) { super((Pointer)null); allocate(preloader_count, batch_size, cache_size, cross_chunk_shuffle_count); } | ||
private native void allocate( | ||
@Cast("size_t") long preloader_count, | ||
@Cast("size_t") long batch_size, | ||
@Cast("size_t") long cache_size/*=2048*/, | ||
@Cast("size_t") long cross_chunk_shuffle_count/*=1*/); | ||
public ChunkDatasetOptions( | ||
@Cast("size_t") long preloader_count, | ||
@Cast("size_t") long batch_size) { super((Pointer)null); allocate(preloader_count, batch_size); } | ||
private native void allocate( | ||
@Cast("size_t") long preloader_count, | ||
@Cast("size_t") long batch_size); | ||
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer preloader_count(); | ||
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer batch_size(); | ||
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer cache_size(); | ||
public native @Cast("size_t*") @ByRef @NoException(true) SizeTPointer cross_chunk_shuffle_count(); | ||
} |
37 changes: 37 additions & 0 deletions
37
pytorch/src/gen/java/org/bytedeco/pytorch/ChunkMapBatchDataset.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
@Name("torch::data::datasets::BatchDataset<torch::data::datasets::MapDataset<torch::data::datasets::SharedBatchDataset<torch::data::datasets::ChunkDataset<JavaCPP_torch_0003a_0003adata_0003a_0003adatasets_0003a_0003aChunkDataReader_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_0002cstd_0003a_0003avector_0003ctorch_0003a_0003adata_0003a_0003aExample_0003c_0003e_00020_0003e_00020_0003e,torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler> >,torch::data::transforms::Stack<torch::data::Example<> > >,std::vector<torch::data::Example<> >,at::ArrayRef<size_t> >") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class ChunkMapBatchDataset extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public ChunkMapBatchDataset(Pointer p) { super(p); } | ||
|
||
@MemberGetter public static native @Cast("const bool") boolean is_stateful(); | ||
public static final boolean is_stateful = is_stateful(); | ||
|
||
/** Returns a batch of data given an index. */ | ||
public native @ByVal ExampleVector get_batch(@ByVal SizeTArrayRef request); | ||
|
||
/** Returns the size of the dataset, or an empty optional if it is unsized. */ | ||
public native @ByVal SizeTOptional size(); | ||
|
||
/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */ | ||
|
||
/** Creates a {@code MapDataset} that applies the given {@code transform} to this dataset. */ | ||
|
||
} |
Oops, something went wrong.