Skip to content

Commit

Permalink
Removed WASM_MEMORY64 macro
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Jul 13, 2024
1 parent 640c5ce commit 96241e5
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 34 deletions.
4 changes: 2 additions & 2 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ else()
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s MEMORY64=1"
)
string(APPEND CMAKE_C_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental")
string(APPEND CMAKE_CXX_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental")
string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental")
string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental")
set(SMEMORY_FLAG "-sMEMORY64")

target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
Expand Down
28 changes: 1 addition & 27 deletions onnxruntime/core/providers/js/js_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,28 +110,17 @@ class JsKernel : public OpKernel {
temp_data_size += sizeof(size_t) * 3;
}
}
#ifdef WASM_MEMORY64
uintptr_t* p_serialized_kernel_context = reinterpret_cast<uintptr_t*>(alloc->Alloc(temp_data_size));
#else
uint32_t* p_serialized_kernel_context = reinterpret_cast<uint32_t*>(alloc->Alloc(temp_data_size));
#endif
if (p_serialized_kernel_context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate memory for serialized kernel context.");
}

#ifdef WASM_MEMORY64
p_serialized_kernel_context[0] = reinterpret_cast<uintptr_t>(context);
p_serialized_kernel_context[1] = static_cast<uintptr_t>(context->InputCount());
p_serialized_kernel_context[2] = static_cast<uintptr_t>(context->OutputCount());
p_serialized_kernel_context[3] = reinterpret_cast<uintptr_t>(custom_data_ptr);
p_serialized_kernel_context[4] = static_cast<uintptr_t>(custom_data_size);
#else
p_serialized_kernel_context[0] = reinterpret_cast<uint32_t>(context);
p_serialized_kernel_context[1] = static_cast<uint32_t>(context->InputCount());
p_serialized_kernel_context[2] = static_cast<uint32_t>(context->OutputCount());
p_serialized_kernel_context[3] = reinterpret_cast<uint32_t>(custom_data_ptr);
p_serialized_kernel_context[4] = static_cast<uint32_t>(custom_data_size);
#endif

size_t index = 5;
for (int i = 0; i < context->InputCount(); i++) {
const auto* input_ptr = context->Input<Tensor>(i);
Expand All @@ -142,21 +131,12 @@ class JsKernel : public OpKernel {
p_serialized_kernel_context[index++] = 0;
continue;
}
#ifdef WASM_MEMORY64
p_serialized_kernel_context[index++] = static_cast<uintptr_t>(input_ptr->GetElementType());
p_serialized_kernel_context[index++] = reinterpret_cast<uintptr_t>(input_ptr->DataRaw());
p_serialized_kernel_context[index++] = static_cast<uintptr_t>(input_ptr->Shape().NumDimensions());
for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) {
p_serialized_kernel_context[index++] = static_cast<uintptr_t>(input_ptr->Shape()[d]);
}
#else
p_serialized_kernel_context[index++] = static_cast<uint32_t>(input_ptr->GetElementType());
p_serialized_kernel_context[index++] = reinterpret_cast<uint32_t>(input_ptr->DataRaw());
p_serialized_kernel_context[index++] = static_cast<uint32_t>(input_ptr->Shape().NumDimensions());
for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) {
p_serialized_kernel_context[index++] = static_cast<uint32_t>(input_ptr->Shape()[d]);
}
#endif
}

#ifndef NDEBUG
Expand Down Expand Up @@ -220,15 +200,9 @@ class JsKernel : public OpKernel {
return status;
}

#ifdef WASM_MEMORY64
intptr_t status_code = EM_ASM_INT(
{ return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); },
this, reinterpret_cast<uintptr_t>(p_serialized_kernel_context));
#else
int status_code = EM_ASM_INT(
{ return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); },
this, reinterpret_cast<uint32_t>(p_serialized_kernel_context));
#endif

LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data="
<< (size_t)(context->Output<Tensor>(0)->DataRaw()) << ".";
Expand Down
5 changes: 0 additions & 5 deletions onnxruntime/wasm/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ enum DataLocation {
};

static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same.");
#ifdef WASM_MEMORY64
static_assert(sizeof(size_t) == 8, "size of size_t should be 8 in this build (wasm64).");
#else
static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32).");
#endif

OrtErrorCode CheckStatus(OrtStatusPtr status) {
if (status) {
Expand Down

0 comments on commit 96241e5

Please sign in to comment.