Skip to content

Commit

Permalink
Don't require patch version (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson authored May 5, 2023
1 parent ee1db2b commit a2322ce
Showing 1 changed file with 44 additions and 6 deletions.
50 changes: 44 additions & 6 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <unordered_set>

#include "dsp.h"
Expand All @@ -7,20 +9,56 @@
#include "convnet.h"
#include "wavenet.h"

void verify_config_version(const std::string version)
struct Version {
int major;
int minor;
int patch;
};

Version ParseVersion(const std::string& versionStr) {
Version version;

// Split the version string into major, minor, and patch components
std::stringstream ss(versionStr);
std::string majorStr, minorStr, patchStr;
std::getline(ss, majorStr, '.');
std::getline(ss, minorStr, '.');
std::getline(ss, patchStr);

// Parse the components as integers and assign them to the version struct
try {
version.major = std::stoi(majorStr);
version.minor = std::stoi(minorStr);
version.patch = std::stoi(patchStr);
}
catch (const std::invalid_argument&) {
throw std::invalid_argument("Invalid version string: " + versionStr);
}
catch (const std::out_of_range&) {
throw std::out_of_range("Version string out of range: " + versionStr);
}

// Validate the semver components
if (version.major < 0 || version.minor < 0 || version.patch < 0) {
throw std::invalid_argument("Negative version component: " + versionStr);
}
return version;
}

void verify_config_version(const std::string versionStr)
{
const std::unordered_set<std::string> supported_versions({"0.5.0", "0.5.1"});
if (supported_versions.find(version) == supported_versions.end())
Version version = ParseVersion(versionStr);
if (version.major != 0 || version.minor != 5)
{
std::stringstream ss;
ss << "Model config is an unsupported version " << version
ss << "Model config is an unsupported version " << versionStr
<< ". Try either converting the model to a more recent version, or "
"update your version of the NAM plugin.";
throw std::runtime_error(ss.str());
}
}

std::vector<float> _get_weights(nlohmann::json const& j, const std::filesystem::path config_path)
std::vector<float> GetWeights(nlohmann::json const& j, const std::filesystem::path config_path)
{
if (j.find("weights") != j.end())
{
Expand Down Expand Up @@ -51,7 +89,7 @@ std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename)

auto architecture = j["architecture"];
nlohmann::json config = j["config"];
std::vector<float> params = _get_weights(j, config_filename);
std::vector<float> params = GetWeights(j, config_filename);
bool haveLoudness = false;
double loudness = TARGET_DSP_LOUDNESS;
if (j.find("metadata") != j.end())
Expand Down

0 comments on commit a2322ce

Please sign in to comment.