From 43d53d9ecc4d7f10c5dbce177784f943ba2aa5b6 Mon Sep 17 00:00:00 2001 From: Mike Oliphant Date: Wed, 29 Mar 2023 18:29:13 -0700 Subject: [PATCH] Added Hardtanh activation function --- dsp/dsp.cpp | 26 ++++++++++++++++++++++++++ dsp/dsp.h | 8 ++++++++ dsp/wavenet.cpp | 4 +++- 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/dsp/dsp.cpp b/dsp/dsp.cpp index c83736c..b97052f 100644 --- a/dsp/dsp.cpp +++ b/dsp/dsp.cpp @@ -205,6 +205,11 @@ inline float fast_tanh_(const float x) { (2.44506634652299f + x2) * fabs(x + 0.814642734961073f * x * ax))); } +inline float hard_tanh_(const float x) { + const float t = x < -1 ? -1 : x; + return t > 1 ? 1 : t; +} + void tanh_(Eigen::MatrixXf &x, const long i_start, const long i_end, const long j_start, const long j_end) { for (long j = j_start; j < j_end; j++) @@ -227,6 +232,27 @@ void tanh_(Eigen::MatrixXf &x) { } } +void hard_tanh_(Eigen::MatrixXf& x, const long i_start, const long i_end, + const long j_start, const long j_end) { + for (long j = j_start; j < j_end; j++) + for (long i = i_start; i < i_end; i++) + x(i, j) = hard_tanh_(x(i, j)); +} + +void hard_tanh_(Eigen::MatrixXf& x, const long j_start, const long j_end) { + hard_tanh_(x, 0, x.rows(), j_start, j_end); +} + +void hard_tanh_(Eigen::MatrixXf& x) { + float* ptr = x.data(); + + long size = x.rows() * x.cols(); + + for (long pos = 0; pos < size; pos++) { + ptr[pos] = hard_tanh_(ptr[pos]); + } +} + void Conv1D::set_params_(std::vector::iterator ¶ms) { if (this->_weight.size() > 0) { const long out_channels = this->_weight[0].rows(); diff --git a/dsp/dsp.h b/dsp/dsp.h index 359b572..1e5efeb 100644 --- a/dsp/dsp.h +++ b/dsp/dsp.h @@ -153,6 +153,14 @@ void tanh_(Eigen::MatrixXf &x, const long i_start, const long i_end); void tanh_(Eigen::MatrixXf &x); +// In-place Hardtanh on (N,M) array +void hard_tanh_(Eigen::MatrixXf& x, const long i_start, const long i_end, + const long j_start, const long j_end); +// Subset of the columns +void hard_tanh_(Eigen::MatrixXf& x, const long i_start, const long i_end); + +void hard_tanh_(Eigen::MatrixXf& x); + class Conv1D { public: Conv1D() { this->_dilation = 1; }; diff --git a/dsp/wavenet.cpp b/dsp/wavenet.cpp index f72a7fe..e466b78 100644 --- a/dsp/wavenet.cpp +++ b/dsp/wavenet.cpp @@ -30,7 +30,9 @@ void wavenet::_Layer::process_(const Eigen::MatrixXf &input, this->_conv.process_(input, this->_z, i_start, ncols, 0); // Mix-in condition this->_z += this->_input_mixin.process(condition); - if (this->_activation == "Tanh") + if (this->_activation == "Hardtanh") + hard_tanh_(this->_z); + else if (this->_activation == "Tanh") tanh_(this->_z); else if (this->_activation == "ReLU") relu_(this->_z, 0, channels, 0, this->_z.cols());