-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Spectral Normalization failure in Torch backend. #19527
Comments
The issue appears to be the use of the
It is unclear to me what the desired behavior is in this case. Please determine when your want to update the wrapped kernel, and to what value, and then you can implement it outside of the loop. |
Thanks for responding, @fchollet . The desired use case is that this will be used in a cGAN that's an RNN/CNN hybrid used for nowcasting of meteorological fields to replace the one that was hooked into the ConvRNN API. It's based off the pseudo-code for Deep Generative Modelling of Radar (DGMR) from DeepMind: https://github.com/google-deepmind/deepmind-research/tree/master/nowcasting Looking around the internet, DGMR has been implemented in various ways, with the crux of the method using a ConvGRU layer. My ConvGRU2D is heavily inspired by those implementations, although the Cambier van Nooten implementation doesn't do SpecNorm at all and looks like what I think you have in mind, as it matches the implementation of ConvLSTM stylistically. It also uses the old ConvRNN API! DeepMind doesn't provide the explicit source code, but some of the implementations indicate that the SpecNorm is applied to all three convolutions, every time, within the GRU. While the Nature article is light on the theoretical justification of why all three are SpecNorm'd every time, my reductive understanding is, "Make GAN stable now," as Miyato et al. (2018) intended. The OCF Pytorch implementation uses Pytorch's native SpecNorm wrapper, for example. In the original paper (Ravuri et al. 2021; linked above), they state (p. 688, in the methods section):
The extent to which SpecNorm is applied when and where as it pertains to the GRU appears to be under a certain amount of interpretation, and I don't have a good answer for you besides, "that's what I think they did, and other people think that, too." The residual blocks they use (D, G, and L blocks) plus the regular convolutions are all confirmed to use SpecNorm, so I would extrapolate that understanding to include the GRU convolutions as well. I suppose my question for you is -- why would I want to reserve the spectral normalization until after the recurrent portion is finished? I don't think I understand this conceptually. I'm not even sure how I would implement it. [35] references Zhang et al. (2019): Self-attention generative adversarial networks in ICLR vol. 36, 7354-7363, for what it's worth. |
The behavior is basically undefined in this case. You need to answer: what value of the kernel do you intend to use at each step, exactly? The interpretation that makes the most sense, from a theoretical perspective, is to perform normalization on the conv kernels before applying the recurrent loop. Applying it inside the loop and using a different kernel value at every step of the loop makes no sense (by the last step you'd get a degenerate kernel value), and even if you get it to work I don't expect it to train well (the normalization op uses a trainable weight, the gradients of which are going to end up very nearly random). Now, if you ask me, the whole idea of using SpectralNormalization here may not be sound. You should only add complexity to a DL system if you are facing a known problem for which this is demonstrably the right solution -- in this case, if you have training stability issues, and if you have run a with/without experiment to show that SN solves the problem. You should not do it "because someone else does this in a paper", since 90% of the time that prior implementation was either buggy or wasn't the subject of an ablation study demonstrating the importance of the trick. If you have training stability issues, there's N different things I'd try before using SpectralNormalization. You could clip gradients, you could use layer normalization, you could normalize gradients, etc. |
You raise some interesting points. DGMR has been shown to be very diffusive, so perhaps that empirical evidence echoes what you're saying about the degenerative kernel. On the other hand, Cambier van Nooten's paper addressed that diffusive nature specifically by adding in a second input field beyond precip and tweaking the loss function. I wasn't in a position where I could try a bunch of different things because my schedule was (and still is) very tight. I don't really have the luxury of trying things incrementally. Having done a lot of reading on the subject of GAN stabilization (https://arxiv.org/abs/2009.02773 and its associated web note https://blog.ml.cmu.edu/2022/01/21/why-spectral-normalization-stabilizes-gans-analysis-and-improvements/) there are a lot of very persuasive arguments as to why it's the go-to stabilizer, as well as other studies that followed Miyato et al. I can't go through the whole lineage of DGMR which probably can be traced back to Goodfellow's original GAN paper, but it does come from DVD-GAN and BigGAN which also used Spectral Normalization... although as I look at a particular implementation here, the GRU doesn't get SpecNorm'd. The residual blocks, however, do. It's difficult to argue with its lineage here. I can see not implementing SpecNorm on the GRU based on what you've said (thank you for that discussion), but it's hard to justify not using it basically everywhere else in an adversarial setting since, at least according to the literature, it handles a lot of the gradient issues in both directions (exploding, vanishing). |
If you want to use it in this context, then my recommendation is to normalize the conv kernel before applying the recurrent loop (so it is only done once and you use the same kernel value at every step of the loop, and the trainable weight for the normalization is updated via a gradient that reflects its contribution across all steps). The way I'd do it is probably to implement a |
I don't think I possess the coding chops to be able to do that. This does raise a follow-on question I had. You said that since SpecNorm has its own trainable parameters, they'll constantly get updated during the time loop. Does that mean my current implementation -- and the other implementations I've seen -- that initializes the three separate Conv2D layers is going through this process of updating the weights each pass through the time steps? |
The kernel gets updated at each iteration. Let's call the successive kernel values The kernel is trainable, so we need to compute its gradients. The gradients here get really messy. The gradient of the loss with respect to |
I think I follow. You've sold me on leaving the convolutions within the GRU alone. But I guess I'm still a little unclear as to the best practice. If I'm reading the ConvLSTM implementation correctly (no guarantee here!), then you have one kernel and one bias that you split multiple ways before doing the low-level convolution directly. In my implementation, those kernels stay separate in the three Conv2D layers. I apologize if this is too far in the weeds (just trying to understand), but is there a functional difference as to who/what gets updated and when? Are the two methods equivalent, accounting for the fact that mine is more simplistic than the Keras implementation? |
Haven't looked at the details, but I assume they're equivalent? Using a conv in a LSTM is perfectly fine as long as you're not trying to call |
Now I understand. The SpecNorm was calling the assign, under the hood, every time step. That's the no-no. |
Have I written custom code?
Yes
OS platform
Fedora 39
Tensorflow, Torch, Keras Versions
Tensorflow: 2.16.1
Torch: 2.2.2
Keras: 3.2.1
Python version
3.11.6
Motivated by the fact that the ConvRNN API is now gone (#19360), I took it upon myself to build my own ConvGRU2D layer modeled after the Keras ConvLSTM. See gist here.
That is the full ConvGRU2DCell layer plus the ConvGRU2D that subclasses RNN. One of the options is spectral normalization. In my steps to move away from Tensorflow, I've replaced the Spectral Normalization from TF-Addons with the Keras implementation.
Well, the ConvGRU2D demo will run with both Torch and Tensorflow backend (commented out lines at top of gist). However, if one engages the Spectral Normalization (
use_spectral_normalization = True
), then the behavior is different. Tensorflow will run. Torch gives this error:I have tried this a variety of ways, including re-ordering the input arrays to match what Torch wants natively: [B T C H W]. At least, I think that's what it wants. Either way, I couldn't get it to work, as it kept failing with this particular error. Am I doing something wrong, or is this a bug?
The text was updated successfully, but these errors were encountered: