Skip to content

Commit

Permalink
support float16 tensor (#45)
Browse files Browse the repository at this point in the history
* support float16 tensor

* use tensor.data_ptr directly

* add fp16 test

* fix

* add MAG240M Example

* fix bugs

* add fp16 test

* fix bugs

* update

* update distribute training

* update parameters

* add readme
  • Loading branch information
Dalong committed Jul 2, 2022
1 parent a1e2413 commit 1cd9007
Show file tree
Hide file tree
Showing 13 changed files with 601 additions and 27 deletions.
13 changes: 8 additions & 5 deletions csrc/include/qvf/dist_tensor_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <infinity/queues/QueuePairFactory.h>
#include <infinity/requests/RequestToken.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <chrono>
#include <deque>
#include <thread>
Expand Down Expand Up @@ -98,15 +99,17 @@ class DistTensorClient {
{tensor_shape[0], tensor_shape[1]}, tensor_option);
}

void register_float32_tensor(torch::Tensor& float_tensor) {
void register_float_tensor(torch::Tensor& float_tensor) {
QUIVER_FEATURE_ASSERT(
float_tensor.dim() == 2,
"Only support 2-dimensional tensor, But got %d-dimensional tensor\n",
float_tensor.dim());
uint64_t size_in_bytes = 4 * float_tensor.numel();

uint64_t size_in_bytes = float_tensor.element_size() * float_tensor.numel();

tensor_buffer = new infinity::memory::Buffer(
context, float_tensor.data_ptr<float>(), size_in_bytes);
context, float_tensor.data_ptr(), size_in_bytes);

tensor_token = tensor_buffer->createRegionToken();
}

Expand Down Expand Up @@ -134,12 +137,12 @@ class DistTensorClient {
torch::Tensor& local_offsets,
torch::Tensor& remote_offsets) {
QUIVER_FEATURE_ASSERT(
reinterpret_cast<uint64_t>(res_tensor.data_ptr<float>()) ==
reinterpret_cast<uint64_t>(res_tensor.data_ptr()) ==
tensor_buffer->getAddress(),
"Result Tensor is not created from registered buffer");

pipes[server_rank]->read(tensor_buffer, local_offsets, remote_offsets,
res_tensor.size(1) * 4);
res_tensor.size(1) * res_tensor.element_size());
}

void collect_inner(CollectionTask collection_task) {
Expand Down
6 changes: 4 additions & 2 deletions csrc/include/qvf/dist_tensor_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include <torch/extension.h>
#include <ATen/ATen.h>

namespace qvf {
class DistTensorServer {
Expand Down Expand Up @@ -51,9 +52,10 @@ class DistTensorServer {

void serve_tensor(torch::Tensor& data) {
std::cout << "Registering Buffer, Please Wait..." << std::endl;
uint64_t size_in_bytes = data.numel() * 4;
uint64_t size_in_bytes = data.numel() * data.element_size();

feature_buffer = new infinity::memory::Buffer(
context, data.data_ptr<float>(), size_in_bytes);
context, data.data_ptr(), size_in_bytes);
bufferToken = feature_buffer->createRegionToken();
server_thread = std::thread(run, qpFactory, bufferToken,
qp_per_pipe * (world_size - 1));
Expand Down
4 changes: 2 additions & 2 deletions csrc/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ void register_DistTensorClient(pybind11::module& m) {
.def("create_registered_float32_tensor",
&qvf::DistTensorClient::create_registered_float32_tensor,
py::call_guard<py::gil_scoped_release>())
.def("register_float32_tensor",
&qvf::DistTensorClient::register_float32_tensor,
.def("register_float_tensor",
&qvf::DistTensorClient::register_float_tensor,
py::call_guard<py::gil_scoped_release>())
.def("create_registered_float32_tensor_cuda",
&qvf::DistTensorClient::create_registered_float32_tensor_cuda,
Expand Down
45 changes: 45 additions & 0 deletions examples/mag240m/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Introduction

Distributed training setting on MAG240M dataset is almost the same as the [official example in DGL](https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb_lsc/MAG240M) except that we use `Quiver-Feature` for distributed feature collection.

Our implementation is much faster than DGL's offical example while achieved similar accuracy.

# Data Preprocess & Partition

First, please run [preprocess.py](./preprocess.py) to generate `graph.dgl` and `full.npy`, you can check [DGL's official guide](https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb_lsc/MAG240M) for more details.

Then we use [Range Partition](../../docs/partition_methods.md) to partition feature data, it is very easy to understand, you can check [process_quiver.py](./process_quiver.py) for more details.

![](../../docs/imgs/range_partition.png)


# Running Training Script

On each machine, please run:

python3 distributed_training.py \
--rootdir . \
--graph-path ./graph.dgl \
--feature-partition-path ./feature_part.pt \
--server_world_size 2
--server_rank 0

Remember to:

- Set shm size limit as large as your physical memory size. You can set by:

sudo mount -o remount,size=300G /dev/shm

- Set `MASTER_IP` as your master node's IP


The validation accuracy is 0.680. We do not have ground truth test labels so we do not report test accuracy.

# Performance

With 2 machines and 1 GPU per machine, we need 2 minutes 10 seconds to train and 15 seconds to validate for each epoch. This is 3x faster than [DGL's performance result](https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb_lsc/MAG240M).


# Hardware configurations

We have 2 machines, each have 377G memory and they are connected by 100Gbps IB. Running training script will consume around 256GB memory.
18 changes: 18 additions & 0 deletions examples/mag240m/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
PORT_NUMBER = 3344
MASTER_IP = "155.198.152.17"
#MASTER_IP = "127.0.0.1"
HLPER_PORT = 5678
NODE_COUNT = 1200000
FEATURE_DIM = 128
FEATURE_TYPE_SIZE = 4
SAMPLE_NUM = 80000
ITER_NUM = 10
POST_LIST_SIZE = 128
QP_NUM = 8
TX_DEPTH = 2048
CTX_POLL_BATCH = TX_DEPTH // POST_LIST_SIZE
TEST_TLB_OPTIMIZATION = True

# For MAG240M Training
SAMPLE_PARAM = [15, 25]
BATCH_SIZE = 1024
Loading

0 comments on commit 1cd9007

Please sign in to comment.