Skip to content

Commit

Permalink
feat: add native async model average algorithm (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangraying authored and NOBLES5E committed Sep 15, 2021
1 parent 22fc369 commit 1cc64c9
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 60 deletions.
107 changes: 51 additions & 56 deletions bagua-core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bagua-core/bagua-core-internal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ derivative = "2.2.0"
oneshot = "0.1"
cpp = "0.5"
sized-object-pool = "0.2"
dynamic-pool = "0.2"
once_cell = "1.7"
ndarray = "0.15.3"
serde = { version = "1", features = ["derive"] }
Expand Down
17 changes: 17 additions & 0 deletions bagua-core/bagua-core-internal/kernels/bagua_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ __global__ void divide_inplace_f16(__half *x, float D_, int N) {
}
}

__global__ void async_model_average(float *tensor, const float *reduced_tensor_copy,
const float *tensor_copy, const float nranks, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
// tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i];
// if (tensor[i] != tensor[i]) {
// printf("nan encountered!");
// }
atomicAdd(&tensor[i], reduced_tensor_copy[i] / nranks - tensor_copy[i]);
}
}

template<typename T>
size_t array_min_max_size(
Expand Down Expand Up @@ -610,6 +620,13 @@ void average_inplace_f16_host(__half *x, __half *y, int N, cudaStream_t stream)
average_inplace_f16<<<DIVUP(N, 1024), 1024, 0, stream>>>(x, y, N);
CUDACHECK(cudaGetLastError());
}

void async_model_average_host(float *tensor, const float *reduced_tensor_copy,
const float *tensor_copy, const float nranks, const int N, cudaStream_t stream) {
async_model_average<<<DIVUP(N, 1024), 1024, 0, stream>>>(tensor, reduced_tensor_copy, tensor_copy, nranks, N);
CUDACHECK(cudaGetLastError());
}

//// decentralize, recvbuf should get the average of sendbuf and peer's sendbuf
//ncclResult_t ncclPeerAverage(void *sendbuf, void *recvbuf, size_t sendcount,
// int peer_rank, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
Expand Down
Loading

0 comments on commit 1cc64c9

Please sign in to comment.