-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fe9d81d
commit 6f4f7cf
Showing
2 changed files
with
257 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
""" | ||
torch-only version for efficient production of multiple outputs from the same input, | ||
while making all rearranges. | ||
python 3.9+ because of typing. | ||
implementation is a bit fragile, pin exact version if using. | ||
Name isn't great, other names under consideration: | ||
- multilinear (... confusion with multilinearity), | ||
- mergedlinear | ||
- multiprojection | ||
""" | ||
from typing import Iterable | ||
|
||
import torch | ||
from torch.functional import F | ||
from torch.nn import ModuleList, Parameter | ||
|
||
from einops.einops import _product | ||
from einops.layers.torch import Rearrange | ||
from einops.parsing import ParsedExpression | ||
|
||
|
||
def _split_into_groups(pattern: str) -> list[list[str]]: | ||
# does not differentiate composed and non-composed ellipsis | ||
result: list[list[str]] = [] | ||
pattern_rest: str = pattern | ||
# check there is space before `(` and after `)` for proper style | ||
pattern_with_edges = f' {pattern} ' | ||
msg = f'please add spaces before and after parenthesis in {pattern=}' | ||
assert pattern_with_edges.count(' (') == pattern_with_edges.count('('), msg | ||
assert pattern_with_edges.count(') ') == pattern_with_edges.count(')'), msg | ||
|
||
while True: | ||
if pattern_rest.startswith('('): | ||
i = pattern_rest.index(')') | ||
group, pattern_rest = pattern_rest[1:i], pattern_rest[i + 1:] | ||
assert '(' not in group, 'unbalanced brackets' | ||
result.append(group.split()) | ||
elif '(' in pattern_rest: | ||
i = pattern_rest.index('(') | ||
ungrouped, pattern_rest = pattern_rest[:i], pattern_rest[i:] | ||
assert ')' not in ungrouped, 'unbalanced brackets' | ||
result.extend([[x] for x in pattern_rest.split()]) | ||
else: | ||
# no more brackets, just parse the end | ||
result.extend([[x] for x in pattern_rest.split()]) | ||
break | ||
return result | ||
|
||
|
||
def _join_groups_to_pattern(groups: list[list[str]]) -> str: | ||
result = '' | ||
for group in groups: | ||
if len(group) == 1: | ||
result += f'{group[0]} ' | ||
else: | ||
result += '(' + ' '.join(group) + ') ' | ||
return result.strip() | ||
|
||
|
||
def _assert_good_identifier(axis_label: str) -> None: | ||
valid, reason = ParsedExpression.check_axis_name_return_reason( | ||
axis_label, allow_underscore=False | ||
) | ||
assert valid, f'Bad {axis_label=}, {reason}' | ||
|
||
|
||
def _get_name_for_anon_axis(disallowed_axes: Iterable[str], axis_len: int) -> str: | ||
prefix = 'c' | ||
while True: | ||
prefix += '_' | ||
axis_name = f'{prefix}{axis_len}' | ||
if axis_name not in disallowed_axes: | ||
return axis_name | ||
|
||
|
||
def _process_input_pattern(input_pattern) -> tuple[Rearrange, list[str], int]: | ||
""" | ||
examples of input patterns: 'a 1 (2 3 b) ()' 'c d e f 9' | ||
does not support ellipsis, and tagging of variables, like c=4 | ||
""" | ||
groups = _split_into_groups(input_pattern) | ||
|
||
all_identifiers = [el.partition('=')[0] for group in groups for el in group if not str.isnumeric(el)] | ||
assert len(all_identifiers) == len(set(all_identifiers)), f"duplicate names in {input_pattern=}" | ||
|
||
batch_axes = [] | ||
input_axes2size = {} | ||
named_groups = [] | ||
for group in groups: | ||
named_group = [] | ||
for axis in group: | ||
if '=' in axis: | ||
axis_name, _, axis_len_str = axis.partition('=') | ||
axis_len = int(axis_len_str) | ||
assert axis_len > 0, axis | ||
_assert_good_identifier(axis_name) | ||
input_axes2size[axis_name] = axis_len | ||
named_group.append(axis_name) | ||
|
||
elif str.isnumeric(axis): | ||
axis_len = int(axis) | ||
assert axis_len > 0, f'{axis_len=}' | ||
axis_name = _get_name_for_anon_axis(all_identifiers, axis_len) | ||
input_axes2size[axis_name] = axis_len | ||
named_group.append(axis_name) | ||
else: | ||
_assert_good_identifier(axis) | ||
batch_axes.append(axis) | ||
named_group.append(axis) | ||
|
||
named_groups.append(named_group) | ||
|
||
init_reordering_pattern = _join_groups_to_pattern(named_groups) | ||
init_reordering_pattern += ' -> ' + _join_groups_to_pattern([[x] for x in batch_axes] + [list(input_axes2size)]) | ||
total_input_size = _product(list(input_axes2size.values())) | ||
|
||
return Rearrange(init_reordering_pattern, **input_axes2size), batch_axes, total_input_size | ||
|
||
|
||
def _process_output_pattern(output_pattern, batch_axes) -> tuple[Rearrange, int]: | ||
groups = _split_into_groups(output_pattern) | ||
|
||
all_identifiers = [el.partition('=')[0] for group in groups for el in group if not str.isnumeric(el)] | ||
assert len(all_identifiers) == len(set(all_identifiers)), f"duplicate names in {output_pattern=}" | ||
|
||
output_axis2size = {} | ||
named_groups = [] | ||
for group in groups: | ||
named_group = [] | ||
for axis in group: | ||
assert '=' not in axis, f'wrong identifier {axis=}, no names in outputs' | ||
if str.isnumeric(axis): | ||
axis_len = int(axis) | ||
assert axis_len > 0, f'{axis_len=}' | ||
axis_name = _get_name_for_anon_axis(all_identifiers, axis_len) | ||
output_axis2size[axis_name] = axis_len | ||
named_group.append(axis_name) | ||
else: | ||
assert axis in batch_axes, f'unknown axis in output, allowed only {batch_axes=}' | ||
named_group.append(axis) | ||
|
||
named_groups.append(named_group) | ||
|
||
reordering_pattern = _join_groups_to_pattern([[x] for x in batch_axes] + [[*output_axis2size]]) | ||
reordering_pattern += ' -> ' + _join_groups_to_pattern(named_groups) | ||
total_output_size = _product(list(output_axis2size.values())) | ||
return Rearrange(reordering_pattern, **output_axis2size), total_output_size | ||
|
||
|
||
class EinSplit(torch.nn.Module): | ||
def __init__(self, input_pattern: str): | ||
super().__init__() | ||
"""all dimensions should be provided in-place""" | ||
self.input_pattern = input_pattern | ||
self.outputs: list[tuple] = [] | ||
# intermediate parsing results | ||
|
||
# parsed = ParsedExpression(input_pattern) | ||
# if parsed.has_ellipsis: | ||
# raise RuntimeError("no support for ellipsis so far") | ||
# self._required_identifiers = parsed.identifiers | ||
|
||
self._in_rearrange, self.batch_axes, self._total_input_size = \ | ||
_process_input_pattern(input_pattern) | ||
self._out_rearranges = ModuleList([]) | ||
self._out_sizes = [] | ||
self.linear = None # set after create_weights | ||
|
||
self.weight = Parameter(torch.empty([0, self._total_input_size])) | ||
self.bias = Parameter(torch.empty([0])) | ||
self.bias_mask = Parameter(torch.empty([0], dtype=torch.bool), requires_grad=False) | ||
|
||
def add_output(self, pattern: str, init: str = 'xavier_normal', bias: bool = True) -> int: | ||
""" returns index in output list """ | ||
idx = len(self.outputs) | ||
out_rearrange, out_total_size = \ | ||
_process_output_pattern(pattern, batch_axes=self.batch_axes) | ||
self.outputs.append((pattern, init, bias)) | ||
self._out_sizes.append(out_total_size) | ||
self._out_rearranges.append(out_rearrange) | ||
|
||
W = self.weight.new_zeros(out_total_size, self._total_input_size) | ||
b = self.bias.new_zeros(out_total_size) | ||
b_mask = self.bias_mask.new_full(size=(out_total_size,), fill_value=int(bias), dtype=torch.bool) | ||
|
||
if init == 'xavier_normal': | ||
torch.nn.init.xavier_normal_(W) # bias is zero | ||
elif init == 'zeros': | ||
torch.nn.init.zeros_(W) # bias is zero | ||
else: | ||
raise ValueError(f'Unknown {init=}') | ||
|
||
with torch.no_grad(): | ||
self.weight = Parameter(torch.concatenate([self.weight, W])) | ||
self.bias = Parameter(torch.concatenate([self.bias, b])) | ||
self.bias_mask = Parameter(torch.concatenate([self.bias_mask, b_mask]), requires_grad=False) | ||
|
||
return idx | ||
|
||
def __repr__(self): | ||
output = f"EinSplit({self.input_pattern})" | ||
for i, (pattern, init, bias, *_) in self.outputs: | ||
output += f'\n + output {i}: {pattern}; {bias=}, {init=}' | ||
return output | ||
|
||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]: | ||
merged = F.linear(self._in_rearrange(x), self.weight, self.bias * self.bias_mask) | ||
split = torch.split(merged, self._out_sizes, dim=-1) | ||
return [ | ||
rearr_out(x) for rearr_out, x in zip(self._out_rearranges, split) | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
|
||
from . import is_backend_tested | ||
|
||
|
||
def test_torch_einsplit(): | ||
if not is_backend_tested("torch"): | ||
pytest.skip() | ||
|
||
import torch | ||
from einops.experimental.einsplit import EinSplit | ||
b = 2 | ||
s = 3 | ||
c1 = 5 | ||
c2 = 7 | ||
c_out1 = 9 | ||
c_out2 = 11 | ||
mod = EinSplit(f'b s {c1=} {c2=}') | ||
out1_idx = mod.add_output(f'b s {c_out1}', init='xavier_normal') | ||
out2_idx = mod.add_output(f'b s {c_out2}', init='xavier_normal') | ||
out3_idx = mod.add_output(f'(3 b 7) s {c_out2}', init='zeros') | ||
assert (out1_idx, out2_idx, out3_idx) == (0, 1, 2) | ||
|
||
optim = torch.optim.Adam(mod.parameters(), lr=1e-2) | ||
batch = torch.randn(b, s, c1, c2) | ||
out1_norms = [] | ||
out2_norms = [] | ||
out3_norms = [] | ||
for iteration in range(100): | ||
out1, out2, out3 = mod(batch) | ||
loss = out1.norm() + out2.norm() | ||
loss.backward() | ||
optim.step() | ||
optim.zero_grad() | ||
out1_norms.append(out1.norm().item()) | ||
out2_norms.append(out2.norm().item()) | ||
out3_norms.append(out3.norm().item()) | ||
|
||
if iteration % 10 == 0: | ||
print(f'{iteration:>5} {loss:6.2f}') | ||
|
||
assert out3_norms[0] == out3_norms[-1] == 0 | ||
assert out1_norms[0] > 2 * out1_norms[-1] | ||
assert out2_norms[0] > 2 * out2_norms[-1] |