Skip to content

Commit

Permalink
perf: 一些逻辑优化
Browse files Browse the repository at this point in the history
  • Loading branch information
MistEO committed Oct 1, 2024
1 parent 7295e48 commit 0bbfca9
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions source/MaaFramework/Resource/ONNXResMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <filesystem>
#include <ranges>
#include <unordered_set>

#ifdef _WIN32
#include "Utils/SafeWindows.hpp"
Expand Down Expand Up @@ -62,34 +63,21 @@ bool ONNXResMgr::use_gpu(int device_id)
}
options_ = {};

auto all_providers = Ort::GetAvailableProviders();

auto all_providers_vec = Ort::GetAvailableProviders();
std::unordered_set<std::string> all_providers(
std::make_move_iterator(all_providers_vec.begin()),
std::make_move_iterator(all_providers_vec.end()));
LogInfo << VAR(all_providers);

bool support_cuda = false;
[[maybe_unused]] bool support_dml = false;
[[maybe_unused]] bool support_coreml = false;
for (const auto& provider : all_providers) {
if (provider == "CUDAExecutionProvider") {
support_cuda = true;
}
if (provider == "DmlExecutionProvider") {
support_dml = true;
}
if (provider == "CoreMLExecutionProvider") {
support_coreml = true;
}
}

if (support_cuda) {
if (all_providers.contains("CUDAExecutionProvider")) {
OrtCUDAProviderOptions cuda_options {};
cuda_options.device_id = device_id;
options_.AppendExecutionProvider_CUDA(cuda_options);

LogInfo << "Using CUDA execution provider with device_id " << device_id;
}
#ifdef MAA_WITH_DML
else if (support_dml) {
else if (all_providers.contains("DmlExecutionProvider")) {
auto status = OrtSessionOptionsAppendExecutionProvider_DML(options_, device_id);
if (!Ort::Status(status).IsOK()) {
LogError << "Failed to append DML execution provider";
Expand All @@ -99,7 +87,7 @@ bool ONNXResMgr::use_gpu(int device_id)
}
#endif
#ifdef MAA_WITH_COREML
else if (support_coreml) {
else if (all_providers.contains("CoreMLExecutionProvider")) {
auto status = OrtSessionOptionsAppendExecutionProvider_CoreML((OrtSessionOptions*)options_, 0);
if (!Ort::Status(status).IsOK()) {
LogError << "Failed to append CoreML execution provider";
Expand Down

0 comments on commit 0bbfca9

Please sign in to comment.