Skip to content

Commit

Permalink
Avoid nested constructor binding if there are no request parameters
Browse files Browse the repository at this point in the history
Closes gh-31821
  • Loading branch information
rstoyanchev committed Dec 13, 2023
1 parent 0970b1d commit ec0ec7a
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ private Object createObject(ResolvableType objectType, String nestedPath, ValueR
Class<?> paramType = paramTypes[i];
Object value = valueResolver.resolveValue(paramPath, paramType);

if (value == null && shouldConstructArgument(param)) {
if (value == null && shouldConstructArgument(param) && hasValuesFor(paramPath, valueResolver)) {
ResolvableType type = ResolvableType.forMethodParameter(param);
args[i] = createObject(type, paramPath + ".", valueResolver);
}
Expand Down Expand Up @@ -1022,6 +1022,15 @@ protected boolean shouldConstructArgument(MethodParameter param) {
type.getPackageName().startsWith("java."));
}

private boolean hasValuesFor(String paramPath, ValueResolver resolver) {
for (String name : resolver.getNames()) {
if (name.startsWith(paramPath + ".")) {
return true;
}
}
return false;
}

private void validateConstructorArgument(
Class<?> constructorClass, String nestedPath, String name, @Nullable Object value) {

Expand Down Expand Up @@ -1293,7 +1302,6 @@ public interface NameResolver {
* Strategy for {@link #construct constructor binding} to look up the values
* to bind to a given constructor parameter.
*/
@FunctionalInterface
public interface ValueResolver {

/**
Expand All @@ -1305,6 +1313,13 @@ public interface ValueResolver {
*/
@Nullable
Object resolveValue(String name, Class<?> type);

/**
* Return the names of all property values.
* @since 6.1.2
*/
Set<String> getNames();

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import java.beans.ConstructorProperties;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import jakarta.validation.constraints.NotNull;
import org.junit.jupiter.api.Test;

import org.springframework.core.ResolvableType;
import org.springframework.format.support.DefaultFormattingConversionService;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -76,6 +78,17 @@ void dataClassBindingWithMissingParameter() {
assertThat(bindingResult.getFieldValue("param3")).isNull();
}

@Test // gh-31821
void dataClassBindingWithNestedOptionalParameterWithMissingParameter() {
MapValueResolver valueResolver = new MapValueResolver(Map.of("param1", "value1"));
DataBinder binder = initDataBinder(NestedDataClass.class);
binder.construct(valueResolver);

NestedDataClass dataClass = getTarget(binder);
assertThat(dataClass.param1()).isEqualTo("value1");
assertThat(dataClass.nestedParam2()).isNull();
}

@Test
void dataClassBindingWithConversionError() {
MapValueResolver valueResolver = new MapValueResolver(Map.of("param1", "value1", "param2", "x"));
Expand All @@ -90,7 +103,7 @@ void dataClassBindingWithConversionError() {
}

@SuppressWarnings("SameParameterValue")
private static DataBinder initDataBinder(Class<DataClass> targetType) {
private static DataBinder initDataBinder(Class<?> targetType) {
DataBinder binder = new DataBinder(null);
binder.setTargetType(ResolvableType.forClass(targetType));
binder.setConversionService(new DefaultFormattingConversionService());
Expand Down Expand Up @@ -137,17 +150,45 @@ public int param3() {
}


private static class NestedDataClass {

private final String param1;

@Nullable
private final DataClass nestedParam2;

public NestedDataClass(String param1, @Nullable DataClass nestedParam2) {
this.param1 = param1;
this.nestedParam2 = nestedParam2;
}

public String param1() {
return this.param1;
}

@Nullable
public DataClass nestedParam2() {
return this.nestedParam2;
}
}


private static class MapValueResolver implements DataBinder.ValueResolver {

private final Map<String, Object> values;
private final Map<String, Object> map;

private MapValueResolver(Map<String, Object> values) {
this.values = values;
private MapValueResolver(Map<String, Object> map) {
this.map = map;
}

@Override
public Object resolveValue(String name, Class<?> type) {
return values.get(name);
return map.get(name);
}

@Override
public Set<String> getNames() {
return this.map.keySet();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
package org.springframework.web.bind;

import java.lang.reflect.Array;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -213,6 +216,9 @@ protected static class ServletRequestValueResolver implements ValueResolver {

private final WebDataBinder dataBinder;

@Nullable
private Set<String> parameterNames;

protected ServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) {
this.request = request;
this.dataBinder = dataBinder;
Expand Down Expand Up @@ -261,6 +267,23 @@ else if (isFormDataPost(this.request)) {
}
return null;
}

@Override
public Set<String> getNames() {
if (this.parameterNames == null) {
this.parameterNames = initParameterNames(this.request);
}
return this.parameterNames;
}

protected Set<String> initParameterNames(ServletRequest request) {
Set<String> set = new LinkedHashSet<>();
Enumeration<String> enumeration = request.getParameterNames();
while (enumeration.hasMoreElements()) {
set.add(enumeration.nextElement());
}
return set;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -164,6 +165,11 @@ private record MapValueResolver(Map<String, Object> map) implements ValueResolve
public Object resolveValue(String name, Class<?> type) {
return this.map.get(name);
}

@Override
public Set<String> getNames() {
return this.map.keySet();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.web.servlet.mvc.method.annotation;

import java.util.Map;
import java.util.Set;

import jakarta.servlet.ServletRequest;

Expand Down Expand Up @@ -121,6 +122,16 @@ protected Object getRequestParameter(String name, Class<?> type) {
}
return value;
}

@Override
protected Set<String> initParameterNames(ServletRequest request) {
Set<String> set = super.initParameterNames(request);
Map<String, String> uriVars = getUriVars(getRequest());
if (uriVars != null) {
set.addAll(uriVars.keySet());
}
return set;
}
}

}

0 comments on commit ec0ec7a

Please sign in to comment.