From f47d026ee95cf85864942d4a2977338e599c0ed7 Mon Sep 17 00:00:00 2001 From: pengzhao-intel Date: Wed, 30 May 2018 10:18:58 +0800 Subject: [PATCH] reorder the input format to output format for the in-place add in case they're different. --- src/operator/nn/mkldnn/mkldnn_base.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 1bd1581dbc2d..4622c0cee05c 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -128,8 +128,14 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { if (res.first == CopyBack) { const_cast(arr).CopyFrom(*res.second); } else if (res.first == AddBack) { - auto mem = arr.GetMKLDNNData(res.second->get_primitive_desc()); - CHECK(mem != nullptr); + const mkldnn::memory *mem = arr.GetMKLDNNData(); + auto arr_desc = mem->get_primitive_desc(); + auto res_desc = res.second->get_primitive_desc(); + if (arr_desc == res_desc) { + mem = arr.GetMKLDNNData(res_desc); + } else { + mem = arr.GetMKLDNNDataReorder(res_desc); + } // We have to allocate new memory for the sum result. auto sum_res = TmpMemMgr::Get()->Alloc( res.second->get_primitive_desc());