Skip to content

Commit

Permalink
Remove finalize_()
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson committed Sep 8, 2024
1 parent 529a1c8 commit cd5a86c
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 17 deletions.
3 changes: 3 additions & 0 deletions NAM/convnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ void nam::convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const
// Copy to required output array (TODO tighten this up)
for (int s = 0; s < num_frames; s++)
output[s] = this->_head_output(s);

// Prepare for next call:
nam::Buffer::_advance_input_buffer_(num_frames);
}

void nam::convnet::ConvNet::_verify_weights(const int channels, const std::vector<int>& dilations, const bool batchnorm,
Expand Down
9 changes: 4 additions & 5 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ void nam::DSP::prewarm()
for (long i = 0; i < _prewarm_samples; i++)
{
this->process(sample_ptr, sample_ptr, 1);
this->finalize_(1);
sample = 0;
}
}
Expand All @@ -58,8 +57,6 @@ void nam::DSP::SetLoudness(const double loudness)
mHasLoudness = true;
}

void nam::DSP::finalize_(const int num_frames) {}

// Buffer =====================================================================

nam::Buffer::Buffer(const int receptive_field, const double expected_sample_rate)
Expand Down Expand Up @@ -128,9 +125,8 @@ void nam::Buffer::_reset_input_buffer()
this->_input_buffer_offset = this->_receptive_field;
}

void nam::Buffer::finalize_(const int num_frames)
void nam::Buffer::_advance_input_buffer_(const int num_frames)
{
this->nam::DSP::finalize_(num_frames);
this->_input_buffer_offset += num_frames;
}

Expand Down Expand Up @@ -163,6 +159,9 @@ void nam::Linear::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_f
auto input = Eigen::Map<const Eigen::VectorXf>(&this->_input_buffer[offset], this->_receptive_field);
output[i] = this->_bias + this->_weight.dot(input);
}

// Prepare for next call:
nam::Buffer::_advance_input_buffer_(num_frames);
}

// NN modules =================================================================
Expand Down
6 changes: 1 addition & 5 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ class DSP
// overridden in subclasses).
// 2. The output level is applied and the result stored to `output`.
virtual void process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames);
// Anything to take care of before next buffer comes in.
// For example:
// * Move the buffer index forward
virtual void finalize_(const int num_frames);
// Expected sample rate, in Hz.
// TODO throw if it doesn't know.
double GetExpectedSampleRate() const { return mExpectedSampleRate; };
Expand Down Expand Up @@ -86,7 +82,6 @@ class Buffer : public DSP
{
public:
Buffer(const int receptive_field, const double expected_sample_rate = -1.0);
void finalize_(const int num_frames);

protected:
// Input buffer
Expand All @@ -97,6 +92,7 @@ class Buffer : public DSP
std::vector<float> _input_buffer;
std::vector<float> _output_buffer;

void _advance_input_buffer_(const int num_frames);
void _set_receptive_field(const int new_receptive_field, const int input_buffer_size);
void _set_receptive_field(const int new_receptive_field);
void _reset_input_buffer();
Expand Down
9 changes: 3 additions & 6 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,6 @@ nam::wavenet::WaveNet::WaveNet(const std::vector<nam::wavenet::LayerArrayParams>
_prewarm_samples += this->_layer_arrays[i].get_receptive_field();
}

void nam::wavenet::WaveNet::finalize_(const int num_frames)
{
this->DSP::finalize_(num_frames);
this->_advance_buffers_(num_frames);
}

void nam::wavenet::WaveNet::set_weights_(std::vector<float>& weights)
{
std::vector<float>::iterator it = weights.begin();
Expand Down Expand Up @@ -347,6 +341,9 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const
float out = this->_head_scale * this->_head_arrays[final_head_array](0, s);
output[s] = out;
}

// Finalize to rpepare for the next call:
this->_advance_buffers_(num_frames);
}

void nam::wavenet::WaveNet::_set_num_frames_(const long num_frames)
Expand Down
1 change: 0 additions & 1 deletion NAM/wavenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ class WaveNet : public DSP
std::vector<float> weights, const double expected_sample_rate = -1.0);
~WaveNet() = default;

void finalize_(const int num_frames) override;
void set_weights_(std::vector<float>& weights);

private:
Expand Down

0 comments on commit cd5a86c

Please sign in to comment.