How to break kernel_fn of serial into multiple kernel_fns of Dense and Activation #179
-
Hi, I've written a code that creates a MLP using Can I do the same thing, so using the kernel_fns of the layers to reproduce the output I get with the kernel_fn of the serial object? I tried, but I get an error. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the detailed repro! Here's a fix for the last cell in your colab ### Test Cell
...
kernel_0_ouptut = layers_functions["kernel_fns"][0](x1.reshape(1,16), None)
print(kernel_0_ouptut)
kernel_1_ouptut = layers_functions["kernel_fns"][1](kernel_0_ouptut)
print(kernel_1_ouptut.nngp) The change lies in passing the whole [Also, on a minor note, it's slightly more efficient to pass |
Beta Was this translation helpful? Give feedback.
Thanks for the detailed repro! Here's a fix for the last cell in your colab
The change lies in passing the whole
Kernel
dataclass as the input to thekernel_fn
as opposed toKernel.nngp
which is just an array. When it's only an array, it misses the necessary metadata and actually interpreted as if you passed an input (akin tox1
), and this raises an error when it's passed to the ReLU function, the infinite width limit of which requires it to be proceeded with a Gaussian linear …