Skip to content

Commit

Permalink
Attempt to fix build break from torch version 1.12.0 released June 28…
Browse files Browse the repository at this point in the history
…, 2022.
  • Loading branch information
WilBrady committed Jun 29, 2022
1 parent 9be2b60 commit b481353
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
7 changes: 6 additions & 1 deletion orttraining/orttraining/eager/opgen/opgen/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Licensed under the MIT License.

import io
from typing import TextIO, List, Union
from typing import List, TextIO, Union

from opgen.lexer import Token


Expand Down Expand Up @@ -355,6 +356,10 @@ class StreamType(ConcreteType):
pass


class SymIntType(ConcreteType):
pass


# region Decls


Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/eager/opgen/opgen/atenops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def __init__(self, dY, X):
"aten::gelu": Gelu("self"),
"aten::max": ReduceMax("self", keepdims=0),
"aten::min": ReduceMin("self", keepdims=0),
"aten::_cat": Concat("tensors", "dim"),
"aten::fill_.Scalar": ConstantOfShape("self", value="value"),
"aten::ne.Scalar": MakeTorchFallback(),
"aten::ne.Scalar_out": MakeTorchFallback(),
Expand All @@ -147,6 +146,7 @@ def __init__(self, dY, X):
# This is done to make sure it is backward and future compatible
if version.parse(torch.__version__) < version.parse(TORCH_API_CHANGE_VERSION):
hand_implemented["aten::gelu_backward"] = GeluGrad("grad", "self")
hand_implemented["aten::_cat"] = Concat("tensors", "dim")
else:
hand_implemented["aten::gelu_backward"] = GeluGrad("grad_output", "self")

Expand Down
6 changes: 4 additions & 2 deletions orttraining/orttraining/eager/opgen/opgen/parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from opgen.lexer import *
from typing import List, Optional, Tuple, Union

from opgen.ast import *
from typing import List, Tuple, Union, Optional
from opgen.lexer import *


class UnexpectedTokenError(RuntimeError):
Expand Down Expand Up @@ -291,6 +292,7 @@ def _parse_torch_base_type(self) -> Type:
"Storage": StorageType,
"ConstQuantizerPtr": ConstQuantizerPtrType,
"Stream": StreamType,
"SymInt": SymIntType,
}
identifier = self._expect_token(TokenKind.IDENTIFIER)
base_type_parser = base_type_parsers.get(identifier.value)
Expand Down

0 comments on commit b481353

Please sign in to comment.