diff --git a/flake8_annotations/checker.py b/flake8_annotations/checker.py index 594495d..54c1777 100644 --- a/flake8_annotations/checker.py +++ b/flake8_annotations/checker.py @@ -1,9 +1,14 @@ from functools import lru_cache -from pathlib import Path -from typing import Generator, List, Tuple +from typing import Generator, List from flake8_annotations import ( - Argument, Function, FunctionVisitor, PY_GTE_38, __version__, enums, error_codes + Argument, + Function, + FunctionVisitor, + PY_GTE_38, + __version__, + enums, + error_codes, ) # Check if we can use the stdlib ast module instead of typed_ast @@ -20,11 +25,10 @@ class TypeHintChecker: name = "flake8-annotations" version = __version__ - def __init__(self, tree: ast.Module, filename: str): - # Unfortunately no way that I can find around requesting the ast-parsed tree from flake8 - # Removing tree unregisters the plugin, and per the documentation the alternative is - # requesting per-line information - self.tree, self.lines = self.load_file(Path(filename)) + def __init__(self, lines: List[str]): + # Request `lines` here and join to allow for correct handling of input from stdin + self.lines = lines + self.tree = self.get_typed_tree("".join(lines)) # flake8 doesn't strip newlines def run(self) -> Generator[error_codes.Error, None, None]: """ @@ -68,11 +72,8 @@ def run(self) -> Generator[error_codes.Error, None, None]: yield classify_error(function, arg).to_flake8() @staticmethod - def load_file(src_filepath: Path) -> Tuple[ast.Module, List[str]]: - """Parse the provided Python file and return an (typed AST, source) tuple.""" - with src_filepath.open("r", encoding="utf-8") as f: - src = f.read() - + def get_typed_tree(src: str) -> ast.Module: + """Parse the provided source into a typed AST.""" if PY_GTE_38: # Built-in ast requires a flag to parse type comments tree = ast.parse(src, type_comments=True) @@ -80,9 +81,7 @@ def load_file(src_filepath: Path) -> Tuple[ast.Module, List[str]]: # typed-ast will implicitly parse type comments tree = ast.parse(src) - lines = src.splitlines() - - return tree, lines + return tree def classify_error(function: Function, arg: Argument) -> error_codes.Error: diff --git a/testing/helpers.py b/testing/helpers.py index f3f7a71..56f5455 100644 --- a/testing/helpers.py +++ b/testing/helpers.py @@ -19,17 +19,15 @@ def parse_source(src: str) -> Tuple[ast.Module, List[str]]: # typed-ast will implicitly parse type comments tree = ast.parse(src) - lines = src.splitlines() + lines = src.splitlines(keepends=True) return tree, lines def check_source(src: str) -> Generator[Error, None, None]: """Helper for generating linting errors from the provided source code.""" - # Because TypeHintChecker is expecting a filename to initialize, rather than change this logic - # we can use this file as a dummy, then update its tree & lines attributes as parsed from source - checker_instance = TypeHintChecker(None, __file__) - checker_instance.tree, checker_instance.lines = parse_source(src) + _, lines = parse_source(src) + checker_instance = TypeHintChecker(lines) return checker_instance.run() diff --git a/testing/test_cases/variable_formatting_test_cases.py b/testing/test_cases/variable_formatting_test_cases.py index cb7b693..eb10ad0 100644 --- a/testing/test_cases/variable_formatting_test_cases.py +++ b/testing/test_cases/variable_formatting_test_cases.py @@ -20,9 +20,9 @@ def foo(some_arg, *some_args, **some_kwargs) -> int: "protected_function": FormatTestCase( src=dedent( """\ - def _foo(some_arg, *some_args, **some_kwargs) -> int: - pass - """ + def _foo(some_arg, *some_args, **some_kwargs) -> int: + pass + """ ), ), "private_function": FormatTestCase(