Skip to content

Commit

Permalink
Codebase standards + update TF layers (#318)
Browse files Browse the repository at this point in the history
* add gitignore rule for .personal files
* resolve all complains from current ruff
* add ruff's W rule, fix ruff's complains
* resolve ruff's deprecation
* apply ruff formatting to whole codebase
* add pre-commit-config
* add ruff checks to github CI
* improve comments in test
* more ruff config
* EinMix: minor improvements in documentation
* update layers to be TF 2.16 - compatible
* increase timeout for cell execution in notebook
* skip tests for previos TF versions
  • Loading branch information
arogozhnikov authored Apr 28, 2024
1 parent 2ab5c2a commit cff63f2
Show file tree
Hide file tree
Showing 34 changed files with 592 additions and 486 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: pip
- name: Check for ruff compliance
run: |
pip install ruff && ruff check . && ruff format . --check
- name: Run tests
run: |
python test.py ${{ matrix.frameworks }}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ logo
# code of internal checks
_uncommited_explorations
trash_notebooks
*.personal*

# oneflow's output trash
log
Expand Down
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.2
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi, jupyter ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]
47 changes: 24 additions & 23 deletions docs/source_examples/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,37 @@
It will convert pytorch.ipynb to html page docs/pytorch-examples.html
"""

import nbformat
import markdown

from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import HtmlFormatter

notebook = nbformat.read('Pytorch.ipynb', as_version=nbformat.NO_CONVERT)
notebook = nbformat.read("Pytorch.ipynb", as_version=nbformat.NO_CONVERT)

content = ''
cache = ''
content = ""
cache = ""

for cell in notebook['cells']:
if cell['cell_type'] == 'code':
source = cell['source']
if source.startswith('#left') or source.startswith('#right'):
trimmed_source = source[source.index('\n') + 1:]
for cell in notebook["cells"]:
if cell["cell_type"] == "code":
source = cell["source"]
if source.startswith("#left") or source.startswith("#right"):
trimmed_source = source[source.index("\n") + 1 :]
cache += "<div>{}</div>".format(highlight(trimmed_source, PythonLexer(), HtmlFormatter()))
if source.startswith('#right'):
if source.startswith("#right"):
content += "<div class='leftright-wrapper'><div class='leftright-cells'>{}</div></div> ".format(cache)
cache = ''
cache = ""

elif cell['cell_type'] == 'markdown':
content += "<div class='markdown-cell'>{}</div>".format(markdown.markdown(cell['source']))
elif cell["cell_type"] == "markdown":
content += "<div class='markdown-cell'>{}</div>".format(markdown.markdown(cell["source"]))
else:
raise RuntimeError('not expected type of cell' + cell['cell_type'])
raise RuntimeError("not expected type of cell" + cell["cell_type"])

styles = HtmlFormatter().get_style_defs('.highlight')
styles = HtmlFormatter().get_style_defs(".highlight")

styles += '''
styles += """
body {
padding: 50px 10px;
}
Expand All @@ -56,9 +57,9 @@
text-align: center;
padding: 10px 0px 0px;
}
'''
"""

meta_tags = '''
meta_tags = """
<meta property="og:title" content="Writing better code with pytorch and einops">
<meta property="og:description" content="Learning by example: rewriting and fixing popular code fragments">
<meta property="og:image" content="http://arogozhnikov.github.io/images/einops/einops_video.gif">
Expand All @@ -70,18 +71,18 @@
<meta property="og:site_name" content="Writing better code with pytorch and einops">
<meta name="twitter:image:alt" content="Learning by example: rewriting and fixing popular code fragments">
'''
"""

github_ribbon = '''
github_ribbon = """
<a href="https://github.com/arogozhnikov/einops" class="github-corner" aria-label="View source on GitHub">
<svg width="80" height="80" viewBox="0 0 250 250" style="fill:#151513; color:#fff; position: absolute; top: 0; border: 0; right: 0;" aria-hidden="true">
<path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path><path d="M128.3,109.0 C113.8,99.7 119.0,89.6 119.0,89.6 C122.0,82.7 120.5,78.6 120.5,78.6 C119.2,72.0 123.4,76.3 123.4,76.3 C127.3,80.9 125.5,87.3 125.5,87.3 C122.9,97.6 130.6,101.9 134.4,103.2" fill="currentColor" style="transform-origin: 130px 106px;" class="octo-arm"></path>
<path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z" fill="currentColor" class="octo-body"></path>
</svg></a>
<style>.github-corner:hover .octo-arm{animation:octocat-wave 560ms ease-in-out}@keyframes octocat-wave{0%,100%{transform:rotate(0)}20%,60%{transform:rotate(-25deg)}40%,80%{transform:rotate(10deg)}}@media (max-width:500px){.github-corner:hover .octo-arm{animation:none}.github-corner .octo-arm{animation:octocat-wave 560ms ease-in-out}}</style>
'''
"""

result = f'''
result = f"""
<!DOCTYPE html>
<html lang="en">
<head>
Expand All @@ -96,7 +97,7 @@
</body>
</html>
'''
"""

with open('../pytorch-examples.html', 'w') as f:
with open("../pytorch-examples.html", "w") as f:
f.write(result)
21 changes: 10 additions & 11 deletions docs/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,41 @@

from PIL.Image import fromarray
from IPython import get_ipython
from IPython.display import display_html


def display_np_arrays_as_images():
def np_to_png(a):
if 2 <= len(a.shape) <= 3:
return fromarray(np.array(np.clip(a, 0, 1) * 255, dtype='uint8'))._repr_png_()
return fromarray(np.array(np.clip(a, 0, 1) * 255, dtype="uint8"))._repr_png_()
else:
return fromarray(np.zeros([1, 1], dtype='uint8'))._repr_png_()
return fromarray(np.zeros([1, 1], dtype="uint8"))._repr_png_()

def np_to_text(obj, p, cycle):
if len(obj.shape) < 2:
print(repr(obj))
if 2 <= len(obj.shape) <= 3:
pass
else:
print('<array of shape {}>'.format(obj.shape))

get_ipython().display_formatter.formatters['image/png'].for_type(np.ndarray, np_to_png)
get_ipython().display_formatter.formatters['text/plain'].for_type(np.ndarray, np_to_text)
print("<array of shape {}>".format(obj.shape))

get_ipython().display_formatter.formatters["image/png"].for_type(np.ndarray, np_to_png)
get_ipython().display_formatter.formatters["text/plain"].for_type(np.ndarray, np_to_text)

from IPython.display import display_html

_style_inline = """<style>
.einops-answer {
color: transparent;
padding: 5px 15px;
background-color: #def;
}
.einops-answer:hover { color: blue; }
.einops-answer:hover { color: blue; }
</style>
"""


def guess(x):
display_html(
_style_inline
+ "<h4>Answer is: <span class='einops-answer'>{x}</span> (hover to see)</h4>".format(x=tuple(x)),
raw=True)
_style_inline + "<h4>Answer is: <span class='einops-answer'>{x}</span> (hover to see)</h4>".format(x=tuple(x)),
raw=True,
)
16 changes: 9 additions & 7 deletions einops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
__author__ = 'Alex Rogozhnikov'
__version__ = '0.7.0'
# imports can use EinopsError class
# ruff: noqa: E402

__author__ = "Alex Rogozhnikov"
__version__ = "0.7.0"


class EinopsError(RuntimeError):
""" Runtime error thrown by einops """
"""Runtime error thrown by einops"""

pass


__all__ = ['rearrange', 'reduce', 'repeat', 'einsum',
'pack', 'unpack',
'parse_shape', 'asnumpy', 'EinopsError']
__all__ = ["rearrange", "reduce", "repeat", "einsum", "pack", "unpack", "parse_shape", "asnumpy", "EinopsError"]

from .einops import rearrange, reduce, repeat, einsum, parse_shape, asnumpy
from .packing import pack, unpack
from .packing import pack, unpack
9 changes: 5 additions & 4 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def shape(self, x):
try:
hash(shape)
return shape
except:
except BaseException:
# unhashable symbols in shape. Wrap tuple to be hashable.
return HashableTuple(shape)

Expand Down Expand Up @@ -661,12 +661,13 @@ def einsum(self, pattern, *x):
def shape(self, x):
return tuple(x.shape)


class TinygradBackend(AbstractBackend):
framework_name = "tinygrad"

def __init__(self):
import tinygrad

self.tinygrad = tinygrad

def is_appropriate_type(self, tensor):
Expand Down Expand Up @@ -709,6 +710,6 @@ def concat(self, tensors, axis: int):

def is_float_type(self, x):
return self.tinygrad.dtypes.is_float(x.dtype)

def einsum(self, pattern, *x):
return self.tinygrad.Tensor.einsum(pattern, *x)
return self.tinygrad.Tensor.einsum(pattern, *x)
3 changes: 2 additions & 1 deletion einops/_torch_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
a number of additional moves is needed.
Design of main operations (dynamic resolution by lookup) is unlikely
to be implemented by torch.jit.script,
to be implemented by torch.jit.script,
but torch.compile seems to work with operations just fine.
"""

import warnings
from typing import Dict, List, Tuple

Expand Down
37 changes: 21 additions & 16 deletions einops/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i
recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=tensor.ndim)
return _apply_recipe_array_api(
xp,
recipe=recipe, tensor=tensor, reduction_type=reduction, axes_lengths=hashable_axes_lengths,
recipe=recipe,
tensor=tensor,
reduction_type=reduction,
axes_lengths=hashable_axes_lengths,
)
except EinopsError as e:
message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
Expand All @@ -28,7 +31,6 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i
raise EinopsError(message + "\n {}".format(e))



def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
return reduce(tensor, pattern, reduction="repeat", **axes_lengths)

Expand All @@ -39,44 +41,45 @@ def rearrange(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:

def asnumpy(tensor: Tensor):
import numpy as np

return np.from_dlpack(tensor)


Shape = Tuple


def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]:
n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, 'pack')
n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, "pack")
xp = tensors[0].__array_namespace__()

reshaped_tensors: List[Tensor] = []
packed_shapes: List[Shape] = []
for i, tensor in enumerate(tensors):
shape = tensor.shape
if len(shape) < min_axes:
raise EinopsError(f'packed tensor #{i} (enumeration starts with 0) has shape {shape}, '
f'while pattern {pattern} assumes at least {min_axes} axes')
raise EinopsError(
f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
f"while pattern {pattern} assumes at least {min_axes} axes"
)
axis_after_packed_axes = len(shape) - n_axes_after
packed_shapes.append(shape[n_axes_before:axis_after_packed_axes])
reshaped_tensors.append(xp.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])))

return xp.concat(reshaped_tensors, axis=n_axes_before), packed_shapes



def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]:
xp = tensor.__array_namespace__()
n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname='unpack')
n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname="unpack")

# backend = get_backend(tensor)
input_shape = tensor.shape
if len(input_shape) != n_axes_before + 1 + n_axes_after:
raise EinopsError(f'unpack(..., {pattern}) received input of wrong dim with shape {input_shape}')
raise EinopsError(f"unpack(..., {pattern}) received input of wrong dim with shape {input_shape}")

unpacked_axis: int = n_axes_before

lengths_of_composed_axes: List[int] = [
-1 if -1 in p_shape else prod(p_shape)
for p_shape in packed_shapes
]
lengths_of_composed_axes: List[int] = [-1 if -1 in p_shape else prod(p_shape) for p_shape in packed_shapes]

n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes)
if n_unknown_composed_axes > 1:
Expand All @@ -102,18 +105,20 @@ def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Ten
split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j]

shape_start = input_shape[:unpacked_axis]
shape_end = input_shape[unpacked_axis + 1:]
shape_end = input_shape[unpacked_axis + 1 :]
slice_filler = (slice(None, None),) * unpacked_axis
try:
return [
xp.reshape(
# shortest way slice arbitrary axis
tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]), ...)],
(*shape_start, *element_shape, *shape_end)
(*shape_start, *element_shape, *shape_end),
)
for i, element_shape in enumerate(packed_shapes)
]
except BaseException:
# this hits if there is an error during reshapes, which means passed shapes were incorrect
raise RuntimeError(f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
f' into requested {packed_shapes}')
raise RuntimeError(
f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
f" into requested {packed_shapes}"
)
14 changes: 5 additions & 9 deletions einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if typing.TYPE_CHECKING:
# for docstrings in pycharm
import numpy as np
import numpy as np # noqa E401

from . import EinopsError
from ._backends import get_backend
Expand Down Expand Up @@ -814,23 +814,19 @@ def _compactify_pattern_for_einsum(pattern: str) -> str:


@typing.overload
def einsum(tensor: Tensor, pattern: str, /) -> Tensor:
...
def einsum(tensor: Tensor, pattern: str, /) -> Tensor: ...


@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str, /) -> Tensor:
...
def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str, /) -> Tensor: ...


@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str, /) -> Tensor:
...
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str, /) -> Tensor: ...


@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str, /) -> Tensor:
...
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str, /) -> Tensor: ...


def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor:
Expand Down
Loading

0 comments on commit cff63f2

Please sign in to comment.