diff --git a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java index 35cc9e506..2a134265d 100644 --- a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java +++ b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java @@ -19,6 +19,7 @@ import lombok.EqualsAndHashCode; import lombok.Value; import org.openrewrite.*; +import org.openrewrite.internal.ListUtils; import org.openrewrite.internal.lang.Nullable; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.VariableNameUtils; @@ -33,6 +34,7 @@ import java.util.*; import java.util.regex.Pattern; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static java.util.Collections.emptyList; @@ -224,6 +226,25 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) { name, type, null)); + JavaType.FullyQualified fqType = TypeUtils.asFullyQualified(type); + if (fqType != null && !fqType.getTypeParameters().isEmpty() && !(instanceOf.getClazz() instanceof J.ParameterizedType)) { + TypedTree oldTypeTree = (TypedTree) instanceOf.getClazz(); + + // Each type parameter is turned into a wildcard, i.e. `List` -> `List` or `Map.Entry` -> `Map.Entry` + List wildcardsList = IntStream.range(0, fqType.getTypeParameters().size()) + .mapToObj(i -> new J.Wildcard(randomId(), Space.EMPTY, Markers.EMPTY, null, null)) + .collect(Collectors.toList()); + + J.ParameterizedType newTypeTree = new J.ParameterizedType( + randomId(), + oldTypeTree.getPrefix(), + Markers.EMPTY, + oldTypeTree.withPrefix(Space.EMPTY), + null, + oldTypeTree.getType() + ).withTypeParameters(wildcardsList); + result = result.withClazz(newTypeTree); + } // update entry in replacements to share the pattern variable name for (Map.Entry entry : replacements.entrySet()) { diff --git a/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java b/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java index c0b668e0c..9ea524856 100644 --- a/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java @@ -133,6 +133,64 @@ void test(Object o) { ); } + @Test + void genericsWithoutParameters() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + List routesList = (List) routes; + if (routesList.isEmpty()) { + return Collections.emptyList(); + } + if (routesList.stream() + .anyMatch(route -> !(route instanceof Map))) { + return Collections.emptyList(); + } + return routesList.stream() + .map(route -> (Map) route) + .collect(Collectors.toList()); + } + return Collections.emptyList(); + } + } + """, + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List routesList) { + if (routesList.isEmpty()) { + return Collections.emptyList(); + } + if (routesList.stream() + .anyMatch(route -> !(route instanceof Map))) { + return Collections.emptyList(); + } + return routesList.stream() + .map(route -> (Map) route) + .collect(Collectors.toList()); + } + return Collections.emptyList(); + } + } + """ + ) + ); + } + @Test void primitiveArray() { rewriteRun( @@ -245,7 +303,7 @@ void test(Object o) { public class A { void test(Object o) { Map.Entry entry = null; - if (o instanceof Map.Entry entry1) { + if (o instanceof Map.Entry entry1) { entry = entry1; } System.out.println(entry); @@ -700,7 +758,7 @@ Object test(Object o) { import java.util.List; public class A { Object test(Object o) { - return o instanceof List l ? l.get(0) : o.toString(); + return o instanceof List l ? l.get(0) : o.toString(); } } """ @@ -725,7 +783,7 @@ Object test(Object o) { import java.util.List; public class A { Object test(Object o) { - return o instanceof List l ? l.get(0) : o.toString(); + return o instanceof List l ? l.get(0) : o.toString(); } } """