diff --git a/NAM/activations.cpp b/NAM/activations.cpp index c4b6324..96b37fa 100644 --- a/NAM/activations.cpp +++ b/NAM/activations.cpp @@ -6,12 +6,8 @@ activations::ActivationHardTanh _HARD_TANH = activations::ActivationHardTanh(); activations::ActivationReLU _RELU = activations::ActivationReLU(); activations::ActivationSigmoid _SIGMOID = activations::ActivationSigmoid(); -std::unordered_map activations::Activation::_activations = { - {"Tanh", &_TANH}, - {"Hardtanh", &_HARD_TANH}, - {"Fasttanh", &_FAST_TANH}, - {"ReLU", &_RELU}, - {"Sigmoid", &_SIGMOID}}; +std::unordered_map activations::Activation::_activations = + {{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH}, {"ReLU", &_RELU}, {"Sigmoid", &_SIGMOID}}; activations::Activation* tanh_bak = nullptr; diff --git a/NAM/activations.h b/NAM/activations.h index 1ee1c2e..6c10fe2 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -9,27 +9,27 @@ namespace activations { inline float relu(float x) { -return x > 0.0f ? x : 0.0f; + return x > 0.0f ? x : 0.0f; }; inline float sigmoid(float x) { -return 1.0f / (1.0f + expf(-x)); + return 1.0f / (1.0f + expf(-x)); }; inline float hard_tanh(float x) { -const float t = x < -1 ? -1 : x; -return t > 1 ? 1 : t; + const float t = x < -1 ? -1 : x; + return t > 1 ? 1 : t; } inline float fast_tanh(const float x) { -const float ax = fabsf(x); -const float x2 = x * x; + const float ax = fabsf(x); + const float x2 = x * x; -return (x * (2.45550750702956f + 2.45550750702956f * ax + (0.893229853513558f + 0.821226666969744f * ax) * x2) - / (2.44506634652299f + (2.44506634652299f + x2) * fabsf(x + 0.814642734961073f * x * ax))); + return (x * (2.45550750702956f + 2.45550750702956f * ax + (0.893229853513558f + 0.821226666969744f * ax) * x2) + / (2.44506634652299f + (2.44506634652299f + x2) * fabsf(x + 0.814642734961073f * x * ax))); } class Activation @@ -37,14 +37,8 @@ class Activation public: Activation() = default; virtual ~Activation() = default; - virtual void apply(Eigen::MatrixXf& matrix) - { - apply(matrix.data(), matrix.rows() * matrix.cols()); - } - virtual void apply(Eigen::Block block) - { - apply(block.data(), block.rows() * block.cols()); - } + virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); } + virtual void apply(Eigen::Block block) { apply(block.data(), block.rows() * block.cols()); } virtual void apply(Eigen::Block block) { apply(block.data(), block.rows() * block.cols()); @@ -55,68 +49,68 @@ class Activation static void enable_fast_tanh(); static void disable_fast_tanh(); - protected: - static std::unordered_map _activations; +protected: + static std::unordered_map _activations; }; class ActivationTanh : public Activation { - public: - void apply(float *data, long size) override +public: + void apply(float* data, long size) override + { + for (long pos = 0; pos < size; pos++) { - for (long pos = 0; pos < size; pos++) - { - data[pos] = std::tanh(data[pos]); - } + data[pos] = std::tanh(data[pos]); } + } }; class ActivationHardTanh : public Activation { - public: - void apply(float* data, long size) override +public: + void apply(float* data, long size) override + { + for (long pos = 0; pos < size; pos++) { - for (long pos = 0; pos < size; pos++) - { - data[pos] = hard_tanh(data[pos]); - } + data[pos] = hard_tanh(data[pos]); } + } }; class ActivationFastTanh : public Activation { - public: - void apply(float* data, long size) override +public: + void apply(float* data, long size) override + { + for (long pos = 0; pos < size; pos++) { - for (long pos = 0; pos < size; pos++) - { - data[pos] = fast_tanh(data[pos]); - } + data[pos] = fast_tanh(data[pos]); } + } }; class ActivationReLU : public Activation { - public: - void apply(float* data, long size) override +public: + void apply(float* data, long size) override + { + for (long pos = 0; pos < size; pos++) { - for (long pos = 0; pos < size; pos++) - { - data[pos] = relu(data[pos]); - } + data[pos] = relu(data[pos]); } + } }; class ActivationSigmoid : public Activation { - public: - void apply(float* data, long size) override +public: + void apply(float* data, long size) override + { + for (long pos = 0; pos < size; pos++) { - for (long pos = 0; pos < size; pos++) - { - data[pos] = sigmoid(data[pos]); - } + data[pos] = sigmoid(data[pos]); } + } }; }; // namespace activations \ No newline at end of file diff --git a/NAM/convnet.h b/NAM/convnet.h index b77e581..49a3e6f 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -47,7 +47,7 @@ class ConvNetBlock private: BatchNorm batchnorm; bool _batchnorm; - activations::Activation *activation; + activations::Activation* activation; }; class _Head diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 3e97665..6d10231 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -292,4 +292,3 @@ Eigen::MatrixXf Conv1x1::process(const Eigen::MatrixXf& input) const else return this->_weight * input; } - diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index 10777e5..f115ecd 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -9,37 +9,43 @@ #include "convnet.h" #include "wavenet.h" -struct Version { +struct Version +{ int major; int minor; int patch; }; -Version ParseVersion(const std::string& versionStr) { +Version ParseVersion(const std::string& versionStr) +{ Version version; - + // Split the version string into major, minor, and patch components std::stringstream ss(versionStr); std::string majorStr, minorStr, patchStr; std::getline(ss, majorStr, '.'); std::getline(ss, minorStr, '.'); std::getline(ss, patchStr); - + // Parse the components as integers and assign them to the version struct - try { + try + { version.major = std::stoi(majorStr); version.minor = std::stoi(minorStr); version.patch = std::stoi(patchStr); } - catch (const std::invalid_argument&) { + catch (const std::invalid_argument&) + { throw std::invalid_argument("Invalid version string: " + versionStr); } - catch (const std::out_of_range&) { + catch (const std::out_of_range&) + { throw std::out_of_range("Version string out of range: " + versionStr); } - + // Validate the semver components - if (version.major < 0 || version.minor < 0 || version.patch < 0) { + if (version.major < 0 || version.minor < 0 || version.patch < 0) + { throw std::invalid_argument("Negative version component: " + versionStr); } return version; diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index b844e1d..cfeff30 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -29,7 +29,7 @@ void wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::Matrix this->_conv.process_(input, this->_z, i_start, ncols, 0); // Mix-in condition this->_z += this->_input_mixin.process(condition); - + this->_activation->apply(this->_z); if (this->_gated) diff --git a/NAM/wavenet.h b/NAM/wavenet.h index ee28e8f..75f8548 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -48,7 +48,7 @@ class _Layer // The internal state Eigen::MatrixXf _z; - activations::Activation *_activation; + activations::Activation* _activation; const bool _gated; }; @@ -152,7 +152,7 @@ class _Head int _channels; std::vector _layers; Conv1x1 _head; - activations::Activation *_activation; + activations::Activation* _activation; // Stores the outputs of the convs *except* the last one, which goes in // The array `outputs` provided to .process_()