Skip to content

Commit

Permalink
Rename KerasBackend -> TFKerasBackend to avoid confusion with keras 3.
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Apr 28, 2024
1 parent f33113e commit b67cfea
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def einsum(self, pattern, *x):
return self.tf.einsum(pattern, *x)


class KerasBackend(AbstractBackend):
class TFKerasBackend(AbstractBackend):
framework_name = "tensorflow.keras"

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def collect_test_backends(symbolic=False, layers=False) -> List[_backends.Abstra
backend_types = []
else:
backend_types = [
_backends.KerasBackend,
_backends.TFKerasBackend,
]

backend_names_to_test = parse_backends_to_test()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,13 @@ def logsumexp_numpy(x, tuple_of_axes):
y = numpy.sum(y, axis=tuple_of_axes)
return numpy.log(y) + minused

from einops._backends import TorchBackend, ChainerBackend, TensorflowBackend, KerasBackend, NumpyBackend
from einops._backends import TorchBackend, ChainerBackend, TensorflowBackend, TFKerasBackend, NumpyBackend

backend2callback = {
TorchBackend.framework_name: logsumexp_torch,
ChainerBackend.framework_name: logsumexp_chainer,
TensorflowBackend.framework_name: logsumexp_tf,
KerasBackend.framework_name: logsumexp_keras,
TFKerasBackend.framework_name: logsumexp_keras,
NumpyBackend.framework_name: logsumexp_numpy,
}

Expand Down

0 comments on commit b67cfea

Please sign in to comment.