From b509c3349c3fe68e6e6a2fc27b1bf8eaed3156a1 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 5 Aug 2024 19:11:07 -0700 Subject: [PATCH] Removed non-persistent buffers from state_dict [skip ci] --- lib/torch/nn/module.rb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/torch/nn/module.rb b/lib/torch/nn/module.rb index 96375a1..f9a76f9 100644 --- a/lib/torch/nn/module.rb +++ b/lib/torch/nn/module.rb @@ -407,7 +407,10 @@ def save_to_state_dict(destination, prefix: "") destination[prefix + k] = v end named_buffers.each do |k, v| - destination[prefix + k] = v + # TODO exclude v.nil? + if !@non_persistent_buffers_set.include?(k) + destination[prefix + k] = v + end end end