Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Hardtanh activation function #14

Merged
merged 1 commit into from
Mar 31, 2023

Conversation

mikeoliphant
Copy link
Contributor

@mikeoliphant mikeoliphant commented Mar 30, 2023

This adds support for the "Hardtanh" activation function in WaveNet models. It will have no impact on current models.

My (admittedly limited so far) testing indicates that using a hard tanh activation function (basically clamp to -1/1) results in the same ESR as using a regular tanh. But it is much faster to compute.

It seems to perform about the same as the fast tanh function that is currently disabled, but might be able to be further optimized with SSE magic. It also has the benefit of being available as an activation fuction in pytorch.

If we get it implemented here, now - once it has been out in the wild for a bit we can switch to using it in training (maybe by default, maybe an option, maybe only on lite/feather?)

If you want to test training a model with it, just swap out "Tanh" for "Hardtanh" (lowercase "t") in the WaveNet layer config.

@sdatkinson
Copy link
Owner

My (admittedly limited so far) testing indicates that using a hard tanh activation function (basically clamp to -1/1) results in the same ESR as using a regular tanh. But it is much faster to compute.

Very exciting! I'll have to do some looking into this as well 🙂

@sdatkinson sdatkinson merged commit 2e5e1b2 into sdatkinson:main Mar 31, 2023
const long j_start, const long j_end) {
for (long j = j_start; j < j_end; j++)
for (long i = i_start; i < i_end; i++)
x(i, j) = hard_tanh_(x(i, j));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would perform better if you used unaryExpr instead? Like so:
x = x.unaryExpr([](float in) {return hard_tanh(in);});

Seems to do better according to: https://godbolt.org/z/zfePz8MTa

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you don't want to update the entire matrix, so some x.middleCols() are needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would perform better if you used unaryExpr instead? Like so:
x = x.unaryExpr([](float in) {return hard_tanh(in);});

Agreed - I just copy/pasted a duplicate of the tanh function structure, but didn't worry about it for now as that code path isn't used by WaveNet - it just uses the full matrix overload and I've already optimized those (for both tanh and hardtanh).

My plan is to refactor the activation code in general soon, but I wanted to get the hardtanh in there asap so we can start think about targeting it with training models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I tried using unaryExpr(), and if I recall it performed slightly worse than just rolling over the data directly.

That Compiler Explorer tool you linked to is pretty cool.

@mikeoliphant mikeoliphant deleted the hard_tanh branch April 2, 2023 20:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants