diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index a2275bf..10777e5 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include "dsp.h" @@ -7,20 +9,56 @@ #include "convnet.h" #include "wavenet.h" -void verify_config_version(const std::string version) +struct Version { + int major; + int minor; + int patch; +}; + +Version ParseVersion(const std::string& versionStr) { + Version version; + + // Split the version string into major, minor, and patch components + std::stringstream ss(versionStr); + std::string majorStr, minorStr, patchStr; + std::getline(ss, majorStr, '.'); + std::getline(ss, minorStr, '.'); + std::getline(ss, patchStr); + + // Parse the components as integers and assign them to the version struct + try { + version.major = std::stoi(majorStr); + version.minor = std::stoi(minorStr); + version.patch = std::stoi(patchStr); + } + catch (const std::invalid_argument&) { + throw std::invalid_argument("Invalid version string: " + versionStr); + } + catch (const std::out_of_range&) { + throw std::out_of_range("Version string out of range: " + versionStr); + } + + // Validate the semver components + if (version.major < 0 || version.minor < 0 || version.patch < 0) { + throw std::invalid_argument("Negative version component: " + versionStr); + } + return version; +} + +void verify_config_version(const std::string versionStr) { - const std::unordered_set supported_versions({"0.5.0", "0.5.1"}); - if (supported_versions.find(version) == supported_versions.end()) + Version version = ParseVersion(versionStr); + if (version.major != 0 || version.minor != 5) { std::stringstream ss; - ss << "Model config is an unsupported version " << version + ss << "Model config is an unsupported version " << versionStr << ". Try either converting the model to a more recent version, or " "update your version of the NAM plugin."; throw std::runtime_error(ss.str()); } } -std::vector _get_weights(nlohmann::json const& j, const std::filesystem::path config_path) +std::vector GetWeights(nlohmann::json const& j, const std::filesystem::path config_path) { if (j.find("weights") != j.end()) { @@ -51,7 +89,7 @@ std::unique_ptr get_dsp(const std::filesystem::path config_filename) auto architecture = j["architecture"]; nlohmann::json config = j["config"]; - std::vector params = _get_weights(j, config_filename); + std::vector params = GetWeights(j, config_filename); bool haveLoudness = false; double loudness = TARGET_DSP_LOUDNESS; if (j.find("metadata") != j.end())