Skip to content

Commit

Permalink
TFRecord to support S3 index URIs
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Jun 11, 2024
1 parent e7526de commit f0cfd19
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
21 changes: 18 additions & 3 deletions dali/operators/reader/loader/indexed_file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "dali/util/uri.h"
#include "dali/util/file.h"
#include "dali/util/odirect_file.h"
#include "dali/core/call_at_exit.h"

namespace dali {

Expand Down Expand Up @@ -192,13 +193,27 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>, true> {
DALI_ENFORCE(index_uris.size() == paths_.size(),
"Number of index files needs to match the number of data files");
for (size_t i = 0; i < index_uris.size(); ++i) {
std::ifstream fin(index_uris[i]);
DALI_ENFORCE(fin.good(), "Failed to open file " + index_uris[i]);
const auto& path = index_uris[i];
auto uri = URI::Parse(path);
bool local_file = !uri.valid() || uri.scheme() == "file";
FileStream::Options opts;
opts.read_ahead = read_ahead_;
opts.use_mmap = local_file && !copy_read_data_;
opts.use_odirect = local_file && use_o_direct_;

auto index_file = FileStream::Open(path, opts);
auto index_file_cleanup = AtScopeExit([&index_file] {
if (index_file)
index_file->Close();
});

FileStreamBuf<> stream_buf(index_file.get());
std::istream fin(&stream_buf);
DALI_ENFORCE(fin.good(), "Failed to open file " + path);
int64 pos, size;
while (fin >> pos >> size) {
indices_.emplace_back(pos, size, i);
}
fin.close();
}
}

Expand Down
29 changes: 29 additions & 0 deletions dali/util/file.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define DALI_UTIL_FILE_H_

#include <cstdio>
#include <streambuf>
#include <memory>
#include <string>
#include <optional>
Expand Down Expand Up @@ -107,6 +108,34 @@ class DLL_PUBLIC FileStream : public InputStream {
std::string path_;
};

/**
* @brief Custom streambuf implementation that reads from FileStream.
* @remarks It is useful to be used together with std::istream
*/
template <size_t BufferSize = (1 << 10)>
class FileStreamBuf : public std::streambuf {
public:
explicit FileStreamBuf(FileStream *reader) : reader_(reader) {
setg(buffer_, buffer_, buffer_); // Initialize get area pointers
}

protected:
// Override underflow to provide more data when the get area is exhausted
int_type underflow() override {
if (gptr() == egptr()) { // If get area is exhausted
size_t nbytes = reader_->Read(buffer_, BufferSize);
if (nbytes == 0)
return traits_type::eof();
setg(buffer_, buffer_, buffer_ + nbytes);
}
return traits_type::to_int_type(*gptr());
}

private:
FileStream *reader_;
char buffer_[BufferSize];
};

} // namespace dali

#endif // DALI_UTIL_FILE_H_
5 changes: 4 additions & 1 deletion dali/util/s3_client_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ struct S3ClientManager {
static S3ClientManager& Instance() {
static std::once_flag once;
std::call_once(once, []() {
RunInitOrShutdown([&](int) { Aws::InitAPI(Aws::SDKOptions{}); });
Aws::SDKOptions options;
options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Debug;
RunInitOrShutdown([&](int) { Aws::InitAPI(options); });
});
// We want RunInitOrShutdown s_thread_pool_ to outlive s_manager_
static S3ClientManager s_manager_;
Expand All @@ -59,6 +61,7 @@ struct S3ClientManager {
auto endpoint_url_ptr = std::getenv("AWS_ENDPOINT_URL");
if (endpoint_url_ptr) {
config.endpointOverride = std::string(endpoint_url_ptr);
config.scheme = Aws::Http::Scheme::HTTP;
}
client_ = std::make_unique<Aws::S3::S3Client>(std::move(config));
}
Expand Down

0 comments on commit f0cfd19

Please sign in to comment.