diff --git a/libcst/metadata/tests/test_type_inference_provider.py b/libcst/metadata/tests/test_type_inference_provider.py index f6c977519..e7cad72aa 100644 --- a/libcst/metadata/tests/test_type_inference_provider.py +++ b/libcst/metadata/tests/test_type_inference_provider.py @@ -66,3 +66,10 @@ def test_simple_class_types(self, source_path: Path, data_path: Path) -> None: cache={TypeInferenceProvider: data}, ) _test_simple_class_helper(self, wrapper) + + def test_with_empty_cache(self) -> None: + tip = TypeInferenceProvider({}) + self.assertEqual(tip.lookup, {}) + + tip = TypeInferenceProvider(PyreData()) + self.assertEqual(tip.lookup, {}) diff --git a/libcst/metadata/type_inference_provider.py b/libcst/metadata/type_inference_provider.py index 9975d0235..4924738e7 100644 --- a/libcst/metadata/type_inference_provider.py +++ b/libcst/metadata/type_inference_provider.py @@ -32,7 +32,7 @@ class InferredType(TypedDict): annotation: str -class PyreData(TypedDict): +class PyreData(TypedDict, total=False): types: Sequence[InferredType] @@ -75,7 +75,8 @@ def gen_cache( def __init__(self, cache: PyreData) -> None: super().__init__(cache) lookup: Dict[CodeRange, str] = {} - for item in cache["types"]: + cache_types = cache.get("types", []) + for item in cache_types: location = item["location"] start = location["start"] end = location["stop"]