Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-551] Test CreateMKLDNNMem/CommitOutput (#11308)
Browse files Browse the repository at this point in the history
* refactor copyfrom

* add boilerplate

* rename to MKLDNNCopy

* write to temp memory

* reorder mkldnn / views

* return memory from GetMKLDNNData

* add kaddto to unit test

* move orig output before creatingnewmem

* coerce memory if shape does not fit

* use MKLDNNCopy in commit

* uncomment addto test

* switch order of mkldnnsum params

* improving logging

* wait to read after copying arr

* remove extra white spaces

* remove extra white space

* remove unused var

* reorder output

* do not write to views

* remove shape check in test

* use input pdesc

* remove unused var

* fix merge

* put inplace in separate loop

* use two mem

* use sum_pd when calling CreateMKLDNNData

* reorder sum shapes if needed

* comment out getsumpd

* use MKLDNNCopy helper to reshape mem

* remove getsumpd

* use output mem for createmem

* remove todo

* waittoread output

* do not attempt to shape output

* use correct arr as input

* revert commit change to ps-lite

* revert change to tvm

* fix lint

* add comment to test

* reduce calls to get_primitive_desc

* skip tests that reorder2default

* push_back to inputs

* skip if view/mkldnn

* add noop test

* pass input ptr for write in place

* allow empty
  • Loading branch information
azai91 authored and piiswrong committed Jun 26, 2018
1 parent 0538ad9 commit e4bf646
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 96 deletions.
72 changes: 2 additions & 70 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ const mkldnn::memory *NDArray::GetMKLDNNData(
if (mem->get_primitive_desc() == desc
|| (desc1.data.format == GetDefaultFormat(desc1)
&& desc2.data.format == GetDefaultFormat(desc2))) {
return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
return GetMKLDNNExact(mem, desc);
} else {
return nullptr;
}
Expand Down Expand Up @@ -638,82 +638,14 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {

CHECK(mem.get_primitive_desc().get_size() == shape().Size() * GetTypeSize(dtype_))
<< "The size of NDArray doesn't match the requested MKLDNN memory desc";
MKLDNNStream *stream = MKLDNNStream::Get();
// If this array uses MKLDNN layout, we have to make sure it's not a view.
// Otherwise, we'll have to change the layout inside the array.

if (IsMKLDNNData() && IsView())
ptr_->Reorder2Default();

const mkldnn::memory *this_mem = GetMKLDNNData();
mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc();
mkldnn::memory::desc from_desc = from_pd.desc();
mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
mkldnn::memory::desc this_desc = this_pd.desc();
mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc);
if (IsView()) {
// Sliced array must use the default layout.
CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format);
}
// It's possible that the memory and the NDArray don't have the same shape.
if (!same_shape(this_desc, from_desc)
// If the source memory uses the default layout, we can reshape directly.
&& from_def_format == from_desc.data.format) {
// In this case, we can simply create a new MKLDNN memory for the required
// shape.
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
} else if (!same_shape(this_desc, from_desc)) {
// In this case, the source memory stores data in a customized layout. We
// need to reorganize the data in memory before we can reshape.
mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format);
mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
// Now we can reshape it
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
} else if (from_pd == this_pd) {
// If the layout is the same, we can just copy data.
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
} else {
// If both are not using the default layouts. There isn't much we can do,
// other than reorder data layout directly.
if (this_def_format != this_desc.data.format
&& from_def_format != from_desc.data.format) {
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
} else if (this_def_format == this_desc.data.format) {
// If the dest mem uses the default memory layout, we can simply use
// the default format of the source memory to improve perf of reorder.
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd,
from_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
} else {
// If the src mem uses the default memory layout, we can use
// the default format of the source memory to improve perf.
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd,
this_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
}
}
MKLDNNCopy(mem, this_mem);
}

mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) {
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ enum OutDataOp {
};

typedef std::pair<OutDataOp, mkldnn::memory *> mkldnn_output_t;
void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem);

/*
* These two functions try to create MKLDNN memory in an NDArray based on `req'.
Expand Down
103 changes: 89 additions & 14 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,75 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) {
}
}

void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) {
MKLDNNStream *stream = MKLDNNStream::Get();

mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc();
mkldnn::memory::desc from_desc = from_pd.desc();
mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
mkldnn::memory::desc this_desc = this_pd.desc();
mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc);
// It's possible that the memory and the NDArray don't have the same shape.
if (!same_shape(this_desc, from_desc)
// If the source memory uses the default layout, we can reshape directly.
&& from_def_format == from_desc.data.format) {
// In this case, we can simply create a new MKLDNN memory for the required
// shape.
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
} else if (!same_shape(this_desc, from_desc)) {
// In this case, the source memory stores data in a customized layout. We
// need to reorganize the data in memory before we can reshape.
mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format);
mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
// Now we can reshape it
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
} else if (from_pd == this_pd) {
// If the layout is the same, we can just copy data.
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
} else {
// If both are not using the default layouts. There isn't much we can do,
// other than reorder data layout directly.
if (this_def_format != this_desc.data.format
&& from_def_format != from_desc.data.format) {
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
} else if (this_def_format == this_desc.data.format) {
// If the dest mem uses the default memory layout, we can simply use
// the default format of the source memory to improve perf of reorder.
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd,
from_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
} else {
// If the src mem uses the default memory layout, we can use
// the default format of the source memory to improve perf.
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd,
this_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
}
}
}

bool CanWriteTo(const NDArray &out_arr,
const NDArray &in_arr,
const mkldnn::memory::primitive_desc &desc) {
Expand All @@ -94,22 +163,25 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr,
if (kAddTo == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::AddBack, tmp);
} else if (req == kWriteInplace && in_arr != nullptr && CanWriteTo(out_arr, *in_arr, desc)) {
} else if (kWriteInplace == req && in_arr != nullptr && CanWriteTo(out_arr, *in_arr, desc)) {
mkldnn::memory *mem = const_cast<NDArray &>(out_arr).CreateMKLDNNData(desc);
// mem is nullptr if out_arr is view and desc is MKLDNN format.
// need to Reorder2Default before calling CreateMKLDNNMem
CHECK(mem != nullptr);
return mkldnn_output_t(OutDataOp::Noop, mem);
} else if (req == kWriteInplace) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
}
mkldnn::memory *mem = const_cast<NDArray &>(out_arr).CreateMKLDNNData(desc);
if (nullptr == mem) {
} else if (kWriteInplace == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
} else if (kWriteTo == req) {
mkldnn::memory *mem = const_cast<NDArray &>(out_arr).CreateMKLDNNData(desc);
if (nullptr == mem) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
}
return mkldnn_output_t(OutDataOp::Noop, mem);
}
return mkldnn_output_t(OutDataOp::Noop, mem);
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::Noop, tmp);
}

mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr,
Expand Down Expand Up @@ -141,13 +213,16 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) {
if (res.first == CopyBack) {
const_cast<NDArray &>(arr).CopyFrom(*res.second);
} else if (res.first == AddBack) {
auto res_memory = res.second;
auto target_pd = arr.GetMKLDNNData()->get_primitive_desc();
auto mem = arr.GetMKLDNNData(res.second->get_primitive_desc());
CHECK(mem != nullptr);
// We have to allocate new memory for the sum result.
auto sum_res = TmpMemMgr::Get()->Alloc(
res.second->get_primitive_desc());
op::MKLDNNSum(*res.second, *mem, *sum_res);
const_cast<NDArray &>(arr).CopyFrom(*sum_res);
if (mem == nullptr) {
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
MKLDNNCopy(*res_memory, tmp_memory);
res_memory = tmp_memory;
mem = arr.GetMKLDNNData();
}
op::MKLDNNSum(*mem, *res_memory, *mem);
}
}

Expand Down
20 changes: 16 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,22 @@ void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
std::vector<mkldnn::primitive::at> inputs;
input_pds[0] = arr1.get_primitive_desc();
input_pds[1] = arr2.get_primitive_desc();
CHECK(input_pds[0] == input_pds[1]);
inputs.push_back(arr1);
inputs.push_back(arr2);
// TODO(zhengda) I need to reorder memory here.
CHECK(input_pds[0] == input_pds[0]);
const mkldnn::memory *in_mem1 = &arr1;
const mkldnn::memory *in_mem2 = &arr2;
auto output_pd = out.get_primitive_desc();
if (input_pds[0] != output_pd) {
auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd);
auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd);
mxnet::MKLDNNCopy(arr1, tmp_memory1);
mxnet::MKLDNNCopy(arr2, tmp_memory2);
input_pds[0] = tmp_memory1->get_primitive_desc();
input_pds[1] = tmp_memory2->get_primitive_desc();
in_mem1 = tmp_memory1;
in_mem2 = tmp_memory2;
}
inputs.push_back(*in_mem1);
inputs.push_back(*in_mem2);
mkldnn::sum::primitive_desc sum_pd(scales, input_pds);
MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out));
}
Expand Down
Loading

0 comments on commit e4bf646

Please sign in to comment.