-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Scan
convert Scalar
Type
s to TensorType
s
#512
Comments
I'm guessing that this is a real bug in At the very least, |
Also, don't forget to make your MWE code runnable; there are a few imports missing. |
Here's a complete MWE that illustrates the import numpy as np
import aesara
import aesara.tensor as at
def add_five_scan(a):
def step(last_count):
return last_count + 1.0
counter, _ = aesara.scan(
fn=step,
sequences=None,
outputs_info=[a],
n_steps=5,
)
return counter[-1]
a = at.scalar('a')
y = add_five_scan(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()
from aesara.scalar import float64
a = float64('a')
y = add_five_scan(a) The problem is that the first |
Is there a reason why scan can't / shouldn't be made to work with these scalar types? |
No, it definitely should work with |
Scan
convert Scalar
Type
s to TensorType
s
I investigated this a little bit further. It is pretty straightforward to allow Scan to accept pure scalars by changing this line: Line 596 in 7c4871c
to: if not isinstance(actual_arg, TensorVariable): However this is not of much help when it comes to having Scans inside a gradient expression. aesara/aesara/tensor/elemwise.py Lines 622 to 653 in 6f68579
I tried a couple of hacks to accommodate Scan graphs (such as manually bypassing ScalarFromTensors, and Rebroadcasts) as well as not trying to convert nodes that are already Elemwise... but couldn't find anything that worked. Perhaps we would need a more bare-bones scalar Scan, that can be safely "Elemwised" for these situations? I have no idea if that makes sense... |
Yes, there are two issues in the original example:
For your example, the latter is easily fixed with something like class UnaryOpScan(ScalarOp):
nin = 1
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
x_at = at.tensor_from_scalar(x)
add_5_res = at.scalar_from_tensor(add_five_scan(x_at))
return [gz * add_5_res] We could update The real problem appears to have little to do with |
I have been trying to get
Scan
to work within a gradient expression without success. I wouldn't be surprised if I am usingScan
incorrectly, so let me know :)The text was updated successfully, but these errors were encountered: