Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson committed May 5, 2023
1 parent a2322ce commit 2aafede
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 68 deletions.
8 changes: 2 additions & 6 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@ activations::ActivationHardTanh _HARD_TANH = activations::ActivationHardTanh();
activations::ActivationReLU _RELU = activations::ActivationReLU();
activations::ActivationSigmoid _SIGMOID = activations::ActivationSigmoid();

std::unordered_map<std::string, activations::Activation*> activations::Activation::_activations = {
{"Tanh", &_TANH},
{"Hardtanh", &_HARD_TANH},
{"Fasttanh", &_FAST_TANH},
{"ReLU", &_RELU},
{"Sigmoid", &_SIGMOID}};
std::unordered_map<std::string, activations::Activation*> activations::Activation::_activations =
{{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH}, {"ReLU", &_RELU}, {"Sigmoid", &_SIGMOID}};

activations::Activation* tanh_bak = nullptr;

Expand Down
90 changes: 42 additions & 48 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,36 @@ namespace activations
{
inline float relu(float x)
{
return x > 0.0f ? x : 0.0f;
return x > 0.0f ? x : 0.0f;
};

inline float sigmoid(float x)
{
return 1.0f / (1.0f + expf(-x));
return 1.0f / (1.0f + expf(-x));
};

inline float hard_tanh(float x)
{
const float t = x < -1 ? -1 : x;
return t > 1 ? 1 : t;
const float t = x < -1 ? -1 : x;
return t > 1 ? 1 : t;
}

inline float fast_tanh(const float x)
{
const float ax = fabsf(x);
const float x2 = x * x;
const float ax = fabsf(x);
const float x2 = x * x;

return (x * (2.45550750702956f + 2.45550750702956f * ax + (0.893229853513558f + 0.821226666969744f * ax) * x2)
/ (2.44506634652299f + (2.44506634652299f + x2) * fabsf(x + 0.814642734961073f * x * ax)));
return (x * (2.45550750702956f + 2.45550750702956f * ax + (0.893229853513558f + 0.821226666969744f * ax) * x2)
/ (2.44506634652299f + (2.44506634652299f + x2) * fabsf(x + 0.814642734961073f * x * ax)));
}

class Activation
{
public:
Activation() = default;
virtual ~Activation() = default;
virtual void apply(Eigen::MatrixXf& matrix)
{
apply(matrix.data(), matrix.rows() * matrix.cols());
}
virtual void apply(Eigen::Block<Eigen::MatrixXf> block)
{
apply(block.data(), block.rows() * block.cols());
}
virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); }
virtual void apply(Eigen::Block<Eigen::MatrixXf> block) { apply(block.data(), block.rows() * block.cols()); }
virtual void apply(Eigen::Block<Eigen::MatrixXf, -1, -1, true> block)
{
apply(block.data(), block.rows() * block.cols());
Expand All @@ -55,68 +49,68 @@ class Activation
static void enable_fast_tanh();
static void disable_fast_tanh();

protected:
static std::unordered_map<std::string, Activation *> _activations;
protected:
static std::unordered_map<std::string, Activation*> _activations;
};

class ActivationTanh : public Activation
{
public:
void apply(float *data, long size) override
public:
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = std::tanh(data[pos]);
}
data[pos] = std::tanh(data[pos]);
}
}
};

class ActivationHardTanh : public Activation
{
public:
void apply(float* data, long size) override
public:
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = hard_tanh(data[pos]);
}
data[pos] = hard_tanh(data[pos]);
}
}
};

class ActivationFastTanh : public Activation
{
public:
void apply(float* data, long size) override
public:
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = fast_tanh(data[pos]);
}
data[pos] = fast_tanh(data[pos]);
}
}
};

class ActivationReLU : public Activation
{
public:
void apply(float* data, long size) override
public:
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = relu(data[pos]);
}
data[pos] = relu(data[pos]);
}
}
};

class ActivationSigmoid : public Activation
{
public:
void apply(float* data, long size) override
public:
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = sigmoid(data[pos]);
}
data[pos] = sigmoid(data[pos]);
}
}
};

}; // namespace activations
2 changes: 1 addition & 1 deletion NAM/convnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ConvNetBlock
private:
BatchNorm batchnorm;
bool _batchnorm;
activations::Activation *activation;
activations::Activation* activation;
};

class _Head
Expand Down
1 change: 0 additions & 1 deletion NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,3 @@ Eigen::MatrixXf Conv1x1::process(const Eigen::MatrixXf& input) const
else
return this->_weight * input;
}

24 changes: 15 additions & 9 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,43 @@
#include "convnet.h"
#include "wavenet.h"

struct Version {
struct Version
{
int major;
int minor;
int patch;
};

Version ParseVersion(const std::string& versionStr) {
Version ParseVersion(const std::string& versionStr)
{
Version version;

// Split the version string into major, minor, and patch components
std::stringstream ss(versionStr);
std::string majorStr, minorStr, patchStr;
std::getline(ss, majorStr, '.');
std::getline(ss, minorStr, '.');
std::getline(ss, patchStr);

// Parse the components as integers and assign them to the version struct
try {
try
{
version.major = std::stoi(majorStr);
version.minor = std::stoi(minorStr);
version.patch = std::stoi(patchStr);
}
catch (const std::invalid_argument&) {
catch (const std::invalid_argument&)
{
throw std::invalid_argument("Invalid version string: " + versionStr);
}
catch (const std::out_of_range&) {
catch (const std::out_of_range&)
{
throw std::out_of_range("Version string out of range: " + versionStr);
}

// Validate the semver components
if (version.major < 0 || version.minor < 0 || version.patch < 0) {
if (version.major < 0 || version.minor < 0 || version.patch < 0)
{
throw std::invalid_argument("Negative version component: " + versionStr);
}
return version;
Expand Down
2 changes: 1 addition & 1 deletion NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::Matrix
this->_conv.process_(input, this->_z, i_start, ncols, 0);
// Mix-in condition
this->_z += this->_input_mixin.process(condition);

this->_activation->apply(this->_z);

if (this->_gated)
Expand Down
4 changes: 2 additions & 2 deletions NAM/wavenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class _Layer
// The internal state
Eigen::MatrixXf _z;

activations::Activation *_activation;
activations::Activation* _activation;
const bool _gated;
};

Expand Down Expand Up @@ -152,7 +152,7 @@ class _Head
int _channels;
std::vector<Conv1x1> _layers;
Conv1x1 _head;
activations::Activation *_activation;
activations::Activation* _activation;

// Stores the outputs of the convs *except* the last one, which goes in
// The array `outputs` provided to .process_()
Expand Down

0 comments on commit 2aafede

Please sign in to comment.