Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add linear
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed May 25, 2018
1 parent 7643e21 commit de2a823
Showing 1 changed file with 79 additions and 21 deletions.
100 changes: 79 additions & 21 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ namespace storage {

#if MXNET_USE_CUDA
/*!
* \brief Storage manager with a memory pool on gpu.
* \brief Storage manager with a memory pool on gpu. Memory chunks are reused based on exact size
* match.
*/
class GPUPooledStorageManager final : public StorageManager {
public:
Expand Down Expand Up @@ -131,9 +132,9 @@ void GPUPooledStorageManager::Free(Storage::Handle handle) {
}

void GPUPooledStorageManager::ReleaseAll() {
Storage::Handle handle;
for (auto&& i : memory_pool_) {
for (auto&& j : i.second) {
Storage::Handle handle;
handle.dptr = j;
handle.size = i.first;
DirectFreeNoLock(handle);
Expand All @@ -144,6 +145,17 @@ void GPUPooledStorageManager::ReleaseAll() {

/*!
* \brief Storage manager with a memory pool, with rounded size, on gpu.
*
* This GPU mem pool uses a mixture of nearest pow2 (exponential) rounding and
* nearest multiple (linear) rounding to help alleviate the memory allocation stress
* in which the default naive exact-size-match pool falls short, such as in variable-length
* input/output cases like RNN workloads.
*
* \param cutoff the cutoff at which rounding is switched from exponential to linear. It's set
* through MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF environment variable. Must be between 20 (1 MB)
* and 34 (16 GB).
* Suppose the cutoff is X, the memory size buckets look like this:
* exp2(0), exp2(1), ..., exp2(X), 2*exp2(X), 3*exp2(X), ...
*/
class GPUPooledRoundedStorageManager final : public StorageManager {
public:
Expand All @@ -152,11 +164,28 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
*/
GPUPooledRoundedStorageManager() {
reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5);
min_chunk_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_LOG2_MIN_CHUNK", 5);
if (min_chunk_ < 5) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_LOG2_MIN_CHUNK cannot be set to a value smaller than 5. " \
<< "Got " << min_chunk_ << ".";
min_chunk_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_MIN_CHUNK", 4096);
cut_off_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF", 24);
if (min_chunk_ < 32) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_MIN_CHUNK cannot be set to a value smaller than 32. " \
<< "Got: " << min_chunk_ << ".";
}
if (min_chunk_ != 1ul << log2_round_up(min_chunk_)) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_MIN_CHUNK must be a power of 2. Got: " << min_chunk_ << ".";
} else {
min_chunk_ = log2_round_up(min_chunk_);
}
if (cut_off_ < 20 || cut_off_ > LOG2_MAX_MEM) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF cannot be set to a value " \
<< "smaller than 20 or greater than " << LOG2_MAX_MEM << ". Got: " \
<< cut_off_ << ".";
}
if (cut_off_ < min_chunk_) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF cannot be set to a value " \
<< "smaller than log2 of MXNET_GPU_MEM_POOL_MIN_CHUNK. Got: " \
<< cut_off_ << " vs " << min_chunk_ << ".";
}
memory_pool_ = std::vector<std::vector<void*>>((1ul << (LOG2_MAX_MEM - cut_off_)) + cut_off_);
}
/*!
* \brief Default destructor.
Expand All @@ -169,7 +198,7 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
void Free(Storage::Handle handle) override;

void DirectFree(Storage::Handle handle) override {
handle.size = 1ul << log2_round_up(handle.size);
handle.size = get_size(get_bucket(handle.size));
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
DirectFreeNoLock(handle);
}
Expand Down Expand Up @@ -219,16 +248,43 @@ class GPUPooledRoundedStorageManager final : public StorageManager {

#if defined(__clang__) || defined(__GNUC__) || defined(__WINDOWS__)
inline int log2_round_up(size_t s) {
int fls = clz(s); // find last set
// must be bigger than min_chunk_ (which is at least 32 for nccl scatter)
return std::max(static_cast<int>(min_chunk_), (addr_width-fls) + ((ctz(s) < fls - 1)?1:0));
int result = addr_width - 1 - clz(s);
return result + ((ctz(s) < result)?1:0);
}
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
// (2048, 10) -> 2
// (2049, 10) -> 3
int ffs = ctz(s); // find first set
return (s >> divisor_log2) + (ffs < divisor_log2 ? 1 : 0);
}
#else
inline int log2_round_up(size_t s) {
return std::max(static_cast<int>(min_chunk_),
static_cast<int>(std::ceil(std::log2(s))));
return static_cast<int>(std::ceil(std::log2(s)));
}
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
// (2048, 10) -> 2
// (2049, 10) -> 3
int divisor = std::pow(2, divisor_log2);
return s / divisor + (s % divisor ? 1 : 0);
}
#endif // defined(__clang__) || defined(__GNUC__) || defined(__WINDOWS__)
inline int get_bucket(size_t s) {
int log_size = log2_round_up(s);
if (log_size > static_cast<int>(cut_off_))
return div_pow2_round_up(s, cut_off_) - 1 + cut_off_;
else
return std::max(log_size, static_cast<int>(min_chunk_));
}

inline size_t get_size(int bucket) {
if (bucket <= static_cast<int>(cut_off_))
return 1ul << bucket;
else
return (bucket - cut_off_ + 1) * (1ul << cut_off_);
}

void DirectFreeNoLock(Storage::Handle handle) {
cudaError_t err = cudaFree(handle.dptr);
// ignore unloading error, as memory has already been recycled
Expand All @@ -242,20 +298,21 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
void ReleaseAll();
// number of devices
const int NDEV = 32;
const size_t LOG2_MAX_MEM = 34;
static const int addr_width = sizeof(size_t) * 8;
// used memory
size_t used_memory_ = 0, min_chunk_;
size_t used_memory_ = 0, min_chunk_, cut_off_;
// percentage of reserved memory
int reserve_;
// memory pool
std::array<std::vector<void*>, addr_width> memory_pool_;
std::vector<std::vector<void*>> memory_pool_;
DISALLOW_COPY_AND_ASSIGN(GPUPooledRoundedStorageManager);
}; // class GPUPooledRoundedStorageManager

void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
int log2_size = log2_round_up(handle->size);
size_t size = 1ul << log2_size;
auto&& reuse_pool = memory_pool_[log2_size];
int bucket = get_bucket(handle->size);
size_t size = get_size(bucket);
auto&& reuse_pool = memory_pool_[bucket];
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
if (reuse_pool.size() == 0) {
size_t free, total;
Expand All @@ -278,17 +335,18 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
}

void GPUPooledRoundedStorageManager::Free(Storage::Handle handle) {
int log2_size = log2_round_up(handle.size);
auto&& reuse_pool = memory_pool_[log2_size];
int bucket = get_bucket(handle.size);
auto&& reuse_pool = memory_pool_[bucket];
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
reuse_pool.push_back(handle.dptr);
}

void GPUPooledRoundedStorageManager::ReleaseAll() {
Storage::Handle handle;
for (size_t i = 0; i < memory_pool_.size(); i++) {
handle.size = 1ul << i;
int size = get_size(i);
for (auto& j : memory_pool_[i]) {
Storage::Handle handle;
handle.size = size;
handle.dptr = j;
DirectFreeNoLock(handle);
}
Expand Down

0 comments on commit de2a823

Please sign in to comment.