Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Torch's Conv1d strides and ConvTranspose1d #145

Merged
merged 9 commits into from
Sep 30, 2024

Conversation

fcaspe
Copy link
Contributor

@fcaspe fcaspe commented Sep 24, 2024

Hi Jatin!

This library is really awesome, I have been using it for low latency inference of big convolutional autoencoders, so I have implemented the 1d Transposed Convolution and convolutional strides.

  • ConvTranspose1d is implemented with RTNeural's Conv1D class, but a different loading function has to be called, RTNeural::torch_helpers::loadConvTranspose1D. See torch_convtranspose1d_test.cpp for an example.

  • Conv1d strides are implemented using a .skip() method that performs a single stride step. This just updates the circular buffer of the Conv1D layer with the new input we jump over. For example, if strides=2 is required, then .skip() has to be called every time after a .forward() call is made. See torch_conv1d_stride_test.cpp for an example.

I know these new functionalities are not fully incorporated into the library. For instance, strides are still missing in Conv1DT and the non-streaming versions of Conv1D and Conv2D. Let me know what you think about these additions and I will be happy to improve them so that hopefully they can be integrated into the library!

Best,
Franco

Copy link
Owner

@jatinchowdhury18 jatinchowdhury18 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello! These changes look great.

I think the "ConvTranspose" changes are good to go. I guess it would be cool to add some tests to double-check that it works correctly with the compile-time API as well, but I guess that's not 100% necessary.

For the Conv1D strides, the "skip" implementation looks correct. I would love to take a shot at implementing that in the compile-time implementations of the Conv1D layer as well... I have a rough idea how that should work. I'm also thinking about making a "wrapper" with a counter to keep track of whether forward() or skip() should be called.

Do you mind if I push any changes back to your branch?


TEST(TestTorchConvTranspose1D, modelOutputMatchesPythonImplementationForDoubles)
{
testTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
testTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>(
testTorchConvTranspose1DModel<double,4,15,5,3,1,1,3>(


TEST(TestTorchConvTranspose1D, streaming_modelOutputMatchesPythonImplementationForDoubles)
{
testStreamingTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
testStreamingTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>(
testStreamingTorchConvTranspose1DModel<double,4,15,5,3,1,1,3>(

@jatinchowdhury18
Copy link
Owner

Also linking issue #144 for visibility.

@fcaspe
Copy link
Contributor Author

fcaspe commented Sep 25, 2024

Sounds good! The stride counter idea sounds great! Feel free to push changes and I'll also take a look at that!

@codecov-commenter
Copy link

codecov-commenter commented Sep 27, 2024

Codecov Report

Attention: Patch coverage is 95.00000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 94.76%. Comparing base (a8ae9c5) to head (717df48).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
RTNeural/torch_helpers.h 89.47% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #145      +/-   ##
==========================================
- Coverage   95.70%   94.76%   -0.95%     
==========================================
  Files          58       40      -18     
  Lines        3892     2578    -1314     
==========================================
- Hits         3725     2443    -1282     
+ Misses        167      135      -32     
Flag Coverage Δ
94.76% <95.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@jatinchowdhury18
Copy link
Owner

Alright, I think I've made all the changes that I want to make... Still need to do a pass for cleanup and documentation, but if @fcaspe wants to have another look and make sure I didn't mess things up too bad, that would be much appreciated :).

@fcaspe
Copy link
Contributor Author

fcaspe commented Sep 30, 2024

Just reviewed. The 'StridedConv1d` class is a good idea. The examples look good and cleaner than mine. I also tried on the conv models I am developing and they are working ok with this new version!

@jatinchowdhury18
Copy link
Owner

@fcaspe Awesome! I'm going to go ahead and merge this PR.

@jatinchowdhury18 jatinchowdhury18 merged commit 32b8664 into jatinchowdhury18:main Sep 30, 2024
21 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants