Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Scan convert Scalar Types to TensorTypes #512

Open
ricardoV94 opened this issue Jul 8, 2021 · 7 comments
Open

Make Scan convert Scalar Types to TensorTypes #512

ricardoV94 opened this issue Jul 8, 2021 · 7 comments
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed important Scan Involves the `Scan` `Op`

Comments

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 8, 2021

I have been trying to get Scan to work within a gradient expression without success. I wouldn't be surprised if I am using Scan incorrectly, so let me know :)

def add_five(a):
    return a + 5

def add_five_scan(a):    
    def step(last_count):
        return last_count + 1.0

    counter, _ = aesara.scan(
        fn=step,
        sequences=None,
        outputs_info=[a],
        n_steps=5,
    )

    return counter[-1]

a = at.scalar('a')
y = add_five(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()

a = at.scalar('a')
y = add_five_scan(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()
class UnaryOp(ScalarOp):
    nin = 1

    def grad(self, inp, grads):
        (x,) = inp
        (gz,) = grads
        return [gz * add_five(x)]
    
unary_op = Elemwise(
    UnaryOp(upgrade_to_float_no_complex, "unary_op"),
    name="Elemwise{unary_op,no_inplace}"
)

class UnaryOpScan(ScalarOp):
    nin = 1

    def grad(self, inp, grads):
        (x,) = inp
        (gz,) = grads
        return [gz * add_five_scan(x)]
    
unary_op_scan = Elemwise(
    UnaryOpScan(upgrade_to_float_no_complex, "unary_op_scan"),
    name="Elemwise{unary_op_scan,no_inplace}"
)
x = at.scalar('x')
out = unary_op(x)
grad = aesara.grad(out, x)
print(grad.eval({x: 3}))
# 8.0

x = at.scalar('x')
out = unary_op_scan(x)
grad = aesara.grad(out, x)  # <-- Raises AttributeError
grad.eval({x: 3})
---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-16-e2874236308b> in <module>()
      1 x = at.scalar('x')
      2 out = unary_op_scan(x)
----> 3 grad = aesara.grad(out, x)
      4 grad.eval({x: 3})

11 frames

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    628             assert g.type.dtype in aesara.tensor.type.float_dtypes
    629 
--> 630     rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
    631 
    632     for i in range(len(rval)):

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1439         return grad_dict[var]
   1440 
-> 1441     rval = [access_grad_cache(elem) for elem in wrt]
   1442 
   1443     return rval

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in <listcomp>(.0)
   1439         return grad_dict[var]
   1440 
-> 1441     rval = [access_grad_cache(elem) for elem in wrt]
   1442 
   1443     return rval

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in access_grad_cache(var)
   1392                     for idx in node_to_idx[node]:
   1393 
-> 1394                         term = access_term_cache(node)[idx]
   1395 
   1396                         if not isinstance(term, Variable):

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in access_term_cache(node)
   1219                             )
   1220 
-> 1221                 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1222 
   1223                 if input_grads is None:

/usr/local/lib/python3.7/dist-packages/aesara/tensor/elemwise.py in L_op(self, inputs, outs, ograds)
    549 
    550         # Compute grad with respect to broadcasted input
--> 551         rval = self._bgrad(inputs, outs, ograds)
    552 
    553         # TODO: make sure that zeros are clearly identifiable

/usr/local/lib/python3.7/dist-packages/aesara/tensor/elemwise.py in _bgrad(self, inputs, outputs, ograds)
    608             ).outputs
    609             scalar_igrads = self.scalar_op.L_op(
--> 610                 scalar_inputs, scalar_outputs, scalar_ograds
    611             )
    612             for igrad in scalar_igrads:

/usr/local/lib/python3.7/dist-packages/aesara/scalar/basic.py in L_op(self, inputs, outputs, output_gradients)
   1141 
   1142     def L_op(self, inputs, outputs, output_gradients):
-> 1143         return self.grad(inputs, output_gradients)
   1144 
   1145     def __eq__(self, other):

<ipython-input-14-e66194d4e28c> in grad(self, inp, grads)
     18         (x,) = inp
     19         (gz,) = grads
---> 20         return [gz * add_five_scan(x)]
     21 
     22 unary_op_scan = Elemwise(

<ipython-input-3-34c2b27606c2> in add_five_scan(a)
     10         sequences=None,
     11         outputs_info=[a],
---> 12         n_steps=5,
     13     )
     14 

/usr/local/lib/python3.7/dist-packages/aesara/scan/basic.py in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
   1058     info["strict"] = strict
   1059 
-> 1060     local_op = Scan(inner_inputs, new_outs, info)
   1061 
   1062     ##

/usr/local/lib/python3.7/dist-packages/aesara/scan/op.py in __init__(self, inputs, outputs, info, typeConstructor)
    177             self.output_types.append(
    178                 typeConstructor(
--> 179                     broadcastable=(False,) + o.type.broadcastable, dtype=o.type.dtype
    180                 )
    181             )

AttributeError: 'Scalar' object has no attribute 'broadcastable'
@brandonwillard brandonwillard added important question Further information is requested bug Something isn't working and removed question Further information is requested labels Jul 8, 2021
@brandonwillard
Copy link
Member

brandonwillard commented Jul 8, 2021

I'm guessing that this is a real bug in Scan, because Scalars don't have a broadcastable attribute and it looks like the scan code is implicitly requiring that.

At the very least, Scan/scan should not accept Types it doesn't actually support.

@brandonwillard
Copy link
Member

Also, don't forget to make your MWE code runnable; there are a few imports missing.

@brandonwillard
Copy link
Member

brandonwillard commented Jul 8, 2021

Here's a complete MWE that illustrates the Scan-only scalar input situation more directly:

import numpy as np

import aesara
import aesara.tensor as at


def add_five_scan(a):
    def step(last_count):
        return last_count + 1.0

    counter, _ = aesara.scan(
        fn=step,
        sequences=None,
        outputs_info=[a],
        n_steps=5,
    )

    return counter[-1]



a = at.scalar('a')
y = add_five_scan(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()


from aesara.scalar import float64
a = float64('a')
y = add_five_scan(a)

The problem is that the first a is actually a TensorType "scalar" and the second is a Scalar Type "scalar" (i.e. these are actually two different Types).

@ricardoV94
Copy link
Contributor Author

Is there a reason why scan can't / shouldn't be made to work with these scalar types?

@brandonwillard
Copy link
Member

No, it definitely should work with Scalar Types (e.g. using TensorFromScalar).

@brandonwillard brandonwillard changed the title Cannot use Scan inside a gradient expression Make Scan convert Scalar Types to TensorTypes Jul 19, 2021
@brandonwillard brandonwillard added the help wanted Extra attention is needed label Jul 19, 2021
@brandonwillard brandonwillard added the Scan Involves the `Scan` `Op` label Sep 12, 2021
@ricardoV94 ricardoV94 self-assigned this Oct 22, 2021
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Oct 22, 2021

I investigated this a little bit further. It is pretty straightforward to allow Scan to accept pure scalars by changing this line:

if not isinstance(actual_arg, Variable):

to:

if not isinstance(actual_arg, TensorVariable):

However this is not of much help when it comes to having Scans inside a gradient expression.
The bigger problem is that the Elemwise gradient expects the gradient graph to be entirely scalar, and then tries to recursively convert it to a "broadcastable" tensor version, which does not really make sense for scan graphs...

def transform(r):
# From a graph of ScalarOps, make a graph of Broadcast ops.
if isinstance(r.type, (NullType, DisconnectedType)):
return r
if r in scalar_inputs:
return inputs[scalar_inputs.index(r)]
if r in scalar_outputs:
return outputs[scalar_outputs.index(r)]
if r in scalar_ograds:
return ograds[scalar_ograds.index(r)]
node = r.owner
if node is None:
# the gradient contains a constant, translate it as
# an equivalent TensorType of size 1 and proper number of
# dimensions
res = aesara.tensor.basic.constant(
np.asarray(r.data), dtype=r.type.dtype
)
return DimShuffle((), ["x"] * nd)(res)
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
return new_r
ret = []
for scalar_igrad, ipt in zip(scalar_igrads, inputs):
if scalar_igrad is None:
# undefined gradient
ret.append(None)
continue
ret.append(transform(scalar_igrad))
return ret

I tried a couple of hacks to accommodate Scan graphs (such as manually bypassing ScalarFromTensors, and Rebroadcasts) as well as not trying to convert nodes that are already Elemwise... but couldn't find anything that worked.

Perhaps we would need a more bare-bones scalar Scan, that can be safely "Elemwised" for these situations? I have no idea if that makes sense...

@brandonwillard
Copy link
Member

Yes, there are two issues in the original example:

  1. Elemwise.[R|L]_op don't work with Scalar Ops that have gradients that aren't comprised of exclusively scalar Ops (this is largely due to the implementation of Elemwise._bgrad), and
  2. Scan doesn't accept or return Scalar typed arguments.

For your example, the latter is easily fixed with something like

class UnaryOpScan(ScalarOp):
    nin = 1

    def grad(self, inp, grads):
        (x,) = inp
        (gz,) = grads
        x_at = at.tensor_from_scalar(x)
        add_5_res = at.scalar_from_tensor(add_five_scan(x_at))
        return [gz * add_5_res]

We could update Scan so that it handles Scalar-typed inputs, but that's a minor convenience.

The real problem appears to have little to do with Scan, so either this issue needs to be updated, or a new one needs to be opened for the Elemwise._bgrad issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed important Scan Involves the `Scan` `Op`
Projects
None yet
Development

No branches or pull requests

2 participants