Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dynamically get cuda toolkit version #1053

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 77 additions & 20 deletions engine/commands/engine_init_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// clang-format on
#include "utils/cuda_toolkit_utils.h"
#include "utils/engine_matcher_utils.h"
#if defined(_WIN32) || defined(__linux__)
#include "utils/file_manager_utils.h"
#endif

namespace commands {

Expand Down Expand Up @@ -60,21 +63,22 @@ bool EngineInitCmd::Exec() const {
variants.push_back(asset_name);
}

auto cuda_version = system_info_utils::GetCudaVersion();
LOG_INFO << "engineName_: " << engineName_;
LOG_INFO << "CUDA version: " << cuda_version;
std::string matched_variant = "";
auto cuda_driver_version = system_info_utils::GetCudaVersion();
LOG_INFO << "Engine: " << engineName_
<< ", CUDA driver version: " << cuda_driver_version;

std::string matched_variant{""};
if (engineName_ == "cortex.tensorrt-llm") {
matched_variant = engine_matcher_utils::ValidateTensorrtLlm(
variants, system_info.os, cuda_version);
variants, system_info.os, cuda_driver_version);
} else if (engineName_ == "cortex.onnx") {
matched_variant = engine_matcher_utils::ValidateOnnx(
variants, system_info.os, system_info.arch);
} else if (engineName_ == "cortex.llamacpp") {
auto suitable_avx = engine_matcher_utils::GetSuitableAvxVariant();
matched_variant = engine_matcher_utils::Validate(
variants, system_info.os, system_info.arch, suitable_avx,
cuda_version);
cuda_driver_version);
}
LOG_INFO << "Matched variant: " << matched_variant;
if (matched_variant.empty()) {
Expand Down Expand Up @@ -105,17 +109,46 @@ bool EngineInitCmd::Exec() const {
}}};

DownloadService download_service;
download_service.AddDownloadTask(downloadTask, [](const std::string&
absolute_path,
bool unused) {
download_service.AddDownloadTask(downloadTask, [this](
const std::string&
absolute_path,
bool unused) {
// try to unzip the downloaded file
std::filesystem::path downloadedEnginePath{absolute_path};
LOG_INFO << "Downloaded engine path: "
<< downloadedEnginePath.string();

archive_utils::ExtractArchive(
downloadedEnginePath.string(),
downloadedEnginePath.parent_path().parent_path().string());
std::filesystem::path extract_path =
downloadedEnginePath.parent_path().parent_path();

archive_utils::ExtractArchive(downloadedEnginePath.string(),
extract_path.string());
#if defined(_WIN32) || defined(__linux__)
// FIXME: hacky try to copy the file. Remove this when we are able to set the library path
auto engine_path = extract_path / engineName_;
LOG_INFO << "Source path: " << engine_path.string();
auto executable_path =
file_manager_utils::GetExecutableFolderContainerPath();
for (const auto& entry :
std::filesystem::recursive_directory_iterator(engine_path)) {
if (entry.is_regular_file() &&
entry.path().extension() != ".gz") {
std::filesystem::path relative_path =
std::filesystem::relative(entry.path(), engine_path);
std::filesystem::path destFile =
executable_path / relative_path;

std::filesystem::create_directories(destFile.parent_path());
std::filesystem::copy_file(
entry.path(), destFile,
std::filesystem::copy_options::overwrite_existing);

std::cout << "Copied: " << entry.path().filename().string()
<< " to " << destFile.string() << std::endl;
}
}
std::cout << "DLL copying completed successfully." << std::endl;
#endif

// remove the downloaded file
// TODO(any) Could not delete file on Windows because it is currently hold by httplib(?)
Expand All @@ -128,23 +161,47 @@ bool EngineInitCmd::Exec() const {
LOG_INFO << "Finished!";
});
if (system_info.os == "mac" || engineName_ == "cortex.onnx") {
return false;
// mac and onnx engine does not require cuda toolkit
return true;
namchuai marked this conversation as resolved.
Show resolved Hide resolved
}

// download cuda toolkit
const std::string jan_host = "https://catalog.jan.ai";
const std::string cuda_toolkit_file_name = "cuda.tar.gz";
const std::string download_id = "cuda";

auto gpu_driver_version = system_info_utils::GetDriverVersion();
// TODO: we don't have API to retrieve list of cuda toolkit dependencies atm because we hosting it at jan
// will have better logic after https://github.com/janhq/cortex/issues/1046 finished
// for now, assume that we have only 11.7 and 12.4
auto suitable_toolkit_version = "";
if (engineName_ == "cortex.tensorrt-llm") {
// for tensorrt-llm, we need to download cuda toolkit v12.4
suitable_toolkit_version = "12.4";
} else {
// llamacpp
auto cuda_driver_semver =
semantic_version_utils::SplitVersion(cuda_driver_version);
if (cuda_driver_semver.major == 11) {
suitable_toolkit_version = "11.7";
} else if (cuda_driver_semver.major == 12) {
suitable_toolkit_version = "12.4";
}
}

auto cuda_runtime_version =
cuda_toolkit_utils::GetCompatibleCudaToolkitVersion(
gpu_driver_version, system_info.os, engineName_);
// compare cuda driver version with cuda toolkit version
// cuda driver version should be greater than toolkit version to ensure compatibility
if (semantic_version_utils::CompareSemanticVersion(
cuda_driver_version, suitable_toolkit_version) < 0) {
LOG_ERROR << "Your Cuda driver version " << cuda_driver_version
<< " is not compatible with cuda toolkit version "
<< suitable_toolkit_version;
return false;
}

std::ostringstream cuda_toolkit_path;
cuda_toolkit_path << "dist/cuda-dependencies/" << 11.7 << "/"
<< system_info.os << "/"
<< cuda_toolkit_file_name;
cuda_toolkit_path << "dist/cuda-dependencies/"
<< cuda_driver_version << "/" << system_info.os
<< "/" << cuda_toolkit_file_name;

LOG_DEBUG << "Cuda toolkit download url: " << jan_host
<< cuda_toolkit_path.str();
Expand Down
20 changes: 15 additions & 5 deletions engine/utils/engine_matcher_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <trantor/utils/Logger.h>
#include <algorithm>
#include <iostream>
#include <iterator>
#include <regex>
#include <string>
Expand Down Expand Up @@ -93,9 +93,19 @@ inline std::string GetSuitableCudaVariant(
bestMatchMinor = variantMinor;
}
}
} else if (cuda_version.empty() && selectedVariant.empty()) {
// If no CUDA version is provided, select the variant without any CUDA in the name
selectedVariant = variant;
}
}

// If no CUDA version is provided, select the variant without any CUDA in the name
if (selectedVariant.empty()) {
LOG_WARN
<< "No suitable CUDA variant found, selecting a variant without CUDA";
for (const auto& variant : variants) {
if (variant.find("cuda") == std::string::npos) {
selectedVariant = variant;
LOG_INFO << "Found variant without CUDA: " << selectedVariant << "\n";
break;
}
}
}

Expand Down Expand Up @@ -177,4 +187,4 @@ inline std::string Validate(const std::vector<std::string>& variants,

return cuda_compatible;
}
} // namespace engine_matcher_utils
} // namespace engine_matcher_utils
67 changes: 47 additions & 20 deletions engine/utils/semantic_version_utils.h
Original file line number Diff line number Diff line change
@@ -1,34 +1,61 @@
#include <trantor/utils/Logger.h>
#include <sstream>
#include <vector>

namespace semantic_version_utils {
inline std::vector<int> SplitVersion(const std::string& version) {
std::vector<int> parts;
std::stringstream ss(version);
std::string part;
struct SemVer {
int major;
int minor;
int patch;
};

while (std::getline(ss, part, '.')) {
parts.push_back(std::stoi(part));
inline SemVer SplitVersion(const std::string& version) {
if (version.empty()) {
LOG_WARN << "Passed in version is empty!";
}
SemVer semVer = {0, 0, 0}; // default value
std::stringstream ss(version);
std::string part;

while (parts.size() < 3) {
parts.push_back(0);
int index = 0;
while (std::getline(ss, part, '.') && index < 3) {
int value = std::stoi(part);
switch (index) {
case 0:
semVer.major = value;
break;
case 1:
semVer.minor = value;
break;
case 2:
semVer.patch = value;
break;
}
++index;
}

return parts;
return semVer;
}

inline int CompareSemanticVersion(const std::string& version1,
const std::string& version2) {
std::vector<int> v1 = SplitVersion(version1);
std::vector<int> v2 = SplitVersion(version2);

for (size_t i = 0; i < 3; ++i) {
if (v1[i] < v2[i])
return -1;
if (v1[i] > v2[i])
return 1;
}
SemVer v1 = SplitVersion(version1);
SemVer v2 = SplitVersion(version2);

if (v1.major < v2.major)
return -1;
if (v1.major > v2.major)
return 1;

if (v1.minor < v2.minor)
return -1;
if (v1.minor > v2.minor)
return 1;

if (v1.patch < v2.patch)
return -1;
if (v1.patch > v2.patch)
return 1;

return 0;
}
} // namespace semantic_version_utils
} // namespace semantic_version_utils
Loading