Skip to content

Commit

Permalink
ruff check and format
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed May 12, 2024
1 parent 6f4f7cf commit 09c3dac
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 51 deletions.
84 changes: 39 additions & 45 deletions einops/experimental/einsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- mergedlinear
- multiprojection
"""

from typing import Iterable

import torch
Expand All @@ -25,21 +26,21 @@ def _split_into_groups(pattern: str) -> list[list[str]]:
result: list[list[str]] = []
pattern_rest: str = pattern
# check there is space before `(` and after `)` for proper style
pattern_with_edges = f' {pattern} '
msg = f'please add spaces before and after parenthesis in {pattern=}'
assert pattern_with_edges.count(' (') == pattern_with_edges.count('('), msg
assert pattern_with_edges.count(') ') == pattern_with_edges.count(')'), msg
pattern_with_edges = f" {pattern} "
msg = f"please add spaces before and after parenthesis in {pattern=}"
assert pattern_with_edges.count(" (") == pattern_with_edges.count("("), msg
assert pattern_with_edges.count(") ") == pattern_with_edges.count(")"), msg

while True:
if pattern_rest.startswith('('):
i = pattern_rest.index(')')
group, pattern_rest = pattern_rest[1:i], pattern_rest[i + 1:]
assert '(' not in group, 'unbalanced brackets'
if pattern_rest.startswith("("):
i = pattern_rest.index(")")
group, pattern_rest = pattern_rest[1:i], pattern_rest[i + 1 :]
assert "(" not in group, "unbalanced brackets"
result.append(group.split())
elif '(' in pattern_rest:
i = pattern_rest.index('(')
elif "(" in pattern_rest:
i = pattern_rest.index("(")
ungrouped, pattern_rest = pattern_rest[:i], pattern_rest[i:]
assert ')' not in ungrouped, 'unbalanced brackets'
assert ")" not in ungrouped, "unbalanced brackets"
result.extend([[x] for x in pattern_rest.split()])
else:
# no more brackets, just parse the end
Expand All @@ -49,27 +50,25 @@ def _split_into_groups(pattern: str) -> list[list[str]]:


def _join_groups_to_pattern(groups: list[list[str]]) -> str:
result = ''
result = ""
for group in groups:
if len(group) == 1:
result += f'{group[0]} '
result += f"{group[0]} "
else:
result += '(' + ' '.join(group) + ') '
result += "(" + " ".join(group) + ") "
return result.strip()


def _assert_good_identifier(axis_label: str) -> None:
valid, reason = ParsedExpression.check_axis_name_return_reason(
axis_label, allow_underscore=False
)
assert valid, f'Bad {axis_label=}, {reason}'
valid, reason = ParsedExpression.check_axis_name_return_reason(axis_label, allow_underscore=False)
assert valid, f"Bad {axis_label=}, {reason}"


def _get_name_for_anon_axis(disallowed_axes: Iterable[str], axis_len: int) -> str:
prefix = 'c'
prefix = "c"
while True:
prefix += '_'
axis_name = f'{prefix}{axis_len}'
prefix += "_"
axis_name = f"{prefix}{axis_len}"
if axis_name not in disallowed_axes:
return axis_name

Expand All @@ -81,7 +80,7 @@ def _process_input_pattern(input_pattern) -> tuple[Rearrange, list[str], int]:
"""
groups = _split_into_groups(input_pattern)

all_identifiers = [el.partition('=')[0] for group in groups for el in group if not str.isnumeric(el)]
all_identifiers = [el.partition("=")[0] for group in groups for el in group if not str.isnumeric(el)]
assert len(all_identifiers) == len(set(all_identifiers)), f"duplicate names in {input_pattern=}"

batch_axes = []
Expand All @@ -90,8 +89,8 @@ def _process_input_pattern(input_pattern) -> tuple[Rearrange, list[str], int]:
for group in groups:
named_group = []
for axis in group:
if '=' in axis:
axis_name, _, axis_len_str = axis.partition('=')
if "=" in axis:
axis_name, _, axis_len_str = axis.partition("=")
axis_len = int(axis_len_str)
assert axis_len > 0, axis
_assert_good_identifier(axis_name)
Expand All @@ -100,7 +99,7 @@ def _process_input_pattern(input_pattern) -> tuple[Rearrange, list[str], int]:

elif str.isnumeric(axis):
axis_len = int(axis)
assert axis_len > 0, f'{axis_len=}'
assert axis_len > 0, f"{axis_len=}"
axis_name = _get_name_for_anon_axis(all_identifiers, axis_len)
input_axes2size[axis_name] = axis_len
named_group.append(axis_name)
Expand All @@ -112,7 +111,7 @@ def _process_input_pattern(input_pattern) -> tuple[Rearrange, list[str], int]:
named_groups.append(named_group)

init_reordering_pattern = _join_groups_to_pattern(named_groups)
init_reordering_pattern += ' -> ' + _join_groups_to_pattern([[x] for x in batch_axes] + [list(input_axes2size)])
init_reordering_pattern += " -> " + _join_groups_to_pattern([[x] for x in batch_axes] + [list(input_axes2size)])
total_input_size = _product(list(input_axes2size.values()))

return Rearrange(init_reordering_pattern, **input_axes2size), batch_axes, total_input_size
Expand All @@ -121,29 +120,29 @@ def _process_input_pattern(input_pattern) -> tuple[Rearrange, list[str], int]:
def _process_output_pattern(output_pattern, batch_axes) -> tuple[Rearrange, int]:
groups = _split_into_groups(output_pattern)

all_identifiers = [el.partition('=')[0] for group in groups for el in group if not str.isnumeric(el)]
all_identifiers = [el.partition("=")[0] for group in groups for el in group if not str.isnumeric(el)]
assert len(all_identifiers) == len(set(all_identifiers)), f"duplicate names in {output_pattern=}"

output_axis2size = {}
named_groups = []
for group in groups:
named_group = []
for axis in group:
assert '=' not in axis, f'wrong identifier {axis=}, no names in outputs'
assert "=" not in axis, f"wrong identifier {axis=}, no names in outputs"
if str.isnumeric(axis):
axis_len = int(axis)
assert axis_len > 0, f'{axis_len=}'
assert axis_len > 0, f"{axis_len=}"
axis_name = _get_name_for_anon_axis(all_identifiers, axis_len)
output_axis2size[axis_name] = axis_len
named_group.append(axis_name)
else:
assert axis in batch_axes, f'unknown axis in output, allowed only {batch_axes=}'
assert axis in batch_axes, f"unknown axis in output, allowed only {batch_axes=}"
named_group.append(axis)

named_groups.append(named_group)

reordering_pattern = _join_groups_to_pattern([[x] for x in batch_axes] + [[*output_axis2size]])
reordering_pattern += ' -> ' + _join_groups_to_pattern(named_groups)
reordering_pattern += " -> " + _join_groups_to_pattern(named_groups)
total_output_size = _product(list(output_axis2size.values()))
return Rearrange(reordering_pattern, **output_axis2size), total_output_size

Expand All @@ -161,8 +160,7 @@ def __init__(self, input_pattern: str):
# raise RuntimeError("no support for ellipsis so far")
# self._required_identifiers = parsed.identifiers

self._in_rearrange, self.batch_axes, self._total_input_size = \
_process_input_pattern(input_pattern)
self._in_rearrange, self.batch_axes, self._total_input_size = _process_input_pattern(input_pattern)
self._out_rearranges = ModuleList([])
self._out_sizes = []
self.linear = None # set after create_weights
Expand All @@ -171,11 +169,10 @@ def __init__(self, input_pattern: str):
self.bias = Parameter(torch.empty([0]))
self.bias_mask = Parameter(torch.empty([0], dtype=torch.bool), requires_grad=False)

def add_output(self, pattern: str, init: str = 'xavier_normal', bias: bool = True) -> int:
""" returns index in output list """
def add_output(self, pattern: str, init: str = "xavier_normal", bias: bool = True) -> int:
"""returns index in output list"""
idx = len(self.outputs)
out_rearrange, out_total_size = \
_process_output_pattern(pattern, batch_axes=self.batch_axes)
out_rearrange, out_total_size = _process_output_pattern(pattern, batch_axes=self.batch_axes)
self.outputs.append((pattern, init, bias))
self._out_sizes.append(out_total_size)
self._out_rearranges.append(out_rearrange)
Expand All @@ -184,12 +181,12 @@ def add_output(self, pattern: str, init: str = 'xavier_normal', bias: bool = Tru
b = self.bias.new_zeros(out_total_size)
b_mask = self.bias_mask.new_full(size=(out_total_size,), fill_value=int(bias), dtype=torch.bool)

if init == 'xavier_normal':
if init == "xavier_normal":
torch.nn.init.xavier_normal_(W) # bias is zero
elif init == 'zeros':
elif init == "zeros":
torch.nn.init.zeros_(W) # bias is zero
else:
raise ValueError(f'Unknown {init=}')
raise ValueError(f"Unknown {init=}")

with torch.no_grad():
self.weight = Parameter(torch.concatenate([self.weight, W]))
Expand All @@ -201,13 +198,10 @@ def add_output(self, pattern: str, init: str = 'xavier_normal', bias: bool = Tru
def __repr__(self):
output = f"EinSplit({self.input_pattern})"
for i, (pattern, init, bias, *_) in self.outputs:
output += f'\n + output {i}: {pattern}; {bias=}, {init=}'
output += f"\n + output {i}: {pattern}; {bias=}, {init=}"
return output

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
merged = F.linear(self._in_rearrange(x), self.weight, self.bias * self.bias_mask)
split = torch.split(merged, self._out_sizes, dim=-1)
return [
rearr_out(x) for rearr_out, x in zip(self._out_rearranges, split)
]

return [rearr_out(x) for rearr_out, x in zip(self._out_rearranges, split)]
13 changes: 7 additions & 6 deletions tests/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ def test_torch_einsplit():

import torch
from einops.experimental.einsplit import EinSplit

b = 2
s = 3
c1 = 5
c2 = 7
c_out1 = 9
c_out2 = 11
mod = EinSplit(f'b s {c1=} {c2=}')
out1_idx = mod.add_output(f'b s {c_out1}', init='xavier_normal')
out2_idx = mod.add_output(f'b s {c_out2}', init='xavier_normal')
out3_idx = mod.add_output(f'(3 b 7) s {c_out2}', init='zeros')
mod = EinSplit(f"b s {c1=} {c2=}")
out1_idx = mod.add_output(f"b s {c_out1}", init="xavier_normal")
out2_idx = mod.add_output(f"b s {c_out2}", init="xavier_normal")
out3_idx = mod.add_output(f"(3 b 7) s {c_out2}", init="zeros")
assert (out1_idx, out2_idx, out3_idx) == (0, 1, 2)

optim = torch.optim.Adam(mod.parameters(), lr=1e-2)
Expand All @@ -37,8 +38,8 @@ def test_torch_einsplit():
out3_norms.append(out3.norm().item())

if iteration % 10 == 0:
print(f'{iteration:>5} {loss:6.2f}')
print(f"{iteration:>5} {loss:6.2f}")

assert out3_norms[0] == out3_norms[-1] == 0
assert out1_norms[0] > 2 * out1_norms[-1]
assert out2_norms[0] > 2 * out2_norms[-1]
assert out2_norms[0] > 2 * out2_norms[-1]

0 comments on commit 09c3dac

Please sign in to comment.