Skip to content

Commit

Permalink
Allow 'status' and 'error' to be excluded from error response
Browse files Browse the repository at this point in the history
Update `ErrorAttributeOptions` to allow the `status` and `error`
fields to be excluded from the response without throwing a
NullPointerException.

Fixes gh-30011
  • Loading branch information
philwebb committed Jun 17, 2024
1 parent 1f698d8 commit 60b7e6c
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,7 @@
import org.springframework.http.HttpStatus;
import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.RequestPredicate;
Expand Down Expand Up @@ -90,6 +91,8 @@ public class DefaultErrorWebExceptionHandler extends AbstractErrorWebExceptionHa
SERIES_VIEWS = Collections.unmodifiableMap(views);
}

private static final ErrorAttributeOptions ONLY_STATUS = ErrorAttributeOptions.of(Include.STATUS);

private final ErrorProperties errorProperties;

/**
Expand Down Expand Up @@ -117,13 +120,13 @@ protected RouterFunction<ServerResponse> getRoutingFunction(ErrorAttributes erro
* @return a {@code Publisher} of the HTTP response
*/
protected Mono<ServerResponse> renderErrorView(ServerRequest request) {
Map<String, Object> error = getErrorAttributes(request, getErrorAttributeOptions(request, MediaType.TEXT_HTML));
int errorStatus = getHttpStatus(error);
ServerResponse.BodyBuilder responseBody = ServerResponse.status(errorStatus).contentType(TEXT_HTML_UTF8);
return Flux.just(getData(errorStatus).toArray(new String[] {}))
.flatMap((viewName) -> renderErrorView(viewName, responseBody, error))
int status = getHttpStatus(getErrorAttributes(request, ONLY_STATUS));
Map<String, Object> errorAttributes = getErrorAttributes(request, MediaType.TEXT_HTML);
ServerResponse.BodyBuilder responseBody = ServerResponse.status(status).contentType(TEXT_HTML_UTF8);
return Flux.just(getData(status).toArray(new String[] {}))
.flatMap((viewName) -> renderErrorView(viewName, responseBody, errorAttributes))
.switchIfEmpty(this.errorProperties.getWhitelabel().isEnabled()
? renderDefaultErrorView(responseBody, error) : Mono.error(getError(request)))
? renderDefaultErrorView(responseBody, errorAttributes) : Mono.error(getError(request)))
.next();
}

Expand All @@ -144,10 +147,15 @@ private List<String> getData(int errorStatus) {
* @return a {@code Publisher} of the HTTP response
*/
protected Mono<ServerResponse> renderErrorResponse(ServerRequest request) {
Map<String, Object> error = getErrorAttributes(request, getErrorAttributeOptions(request, MediaType.ALL));
return ServerResponse.status(getHttpStatus(error))
int status = getHttpStatus(getErrorAttributes(request, ONLY_STATUS));
Map<String, Object> errorAttributes = getErrorAttributes(request, MediaType.ALL);
return ServerResponse.status(status)
.contentType(MediaType.APPLICATION_JSON)
.body(BodyInserters.fromValue(error));
.body(BodyInserters.fromValue(errorAttributes));
}

private Map<String, Object> getErrorAttributes(ServerRequest request, MediaType mediaType) {
return getErrorAttributes(request, getErrorAttributeOptions(request, mediaType));
}

protected ErrorAttributeOptions getErrorAttributeOptions(ServerRequest request, MediaType mediaType) {
Expand Down Expand Up @@ -215,7 +223,9 @@ protected boolean isIncludeBindingErrors(ServerRequest request, MediaType produc
* @return the error HTTP status
*/
protected int getHttpStatus(Map<String, Object> errorAttributes) {
return (int) errorAttributes.get("status");
Object status = errorAttributes.get("status");
Assert.state(status instanceof Integer, "ErrorAttributes must contain a status integer");
return (int) status;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,9 +25,12 @@
import org.junit.jupiter.api.extension.ExtendWith;
import reactor.core.publisher.Mono;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.context.PropertyPlaceholderAutoConfiguration;
import org.springframework.boot.autoconfigure.mustache.MustacheAutoConfiguration;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.WebProperties;
import org.springframework.boot.autoconfigure.web.reactive.HttpHandlerAutoConfiguration;
import org.springframework.boot.autoconfigure.web.reactive.ReactiveWebServerFactoryAutoConfiguration;
import org.springframework.boot.autoconfigure.web.reactive.WebFluxAutoConfiguration;
Expand All @@ -36,12 +39,17 @@
import org.springframework.boot.test.system.CapturedOutput;
import org.springframework.boot.test.system.OutputCaptureExtension;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.error.ErrorAttributeOptions.Include;
import org.springframework.boot.web.reactive.error.DefaultErrorAttributes;
import org.springframework.boot.web.reactive.error.ErrorAttributes;
import org.springframework.boot.web.reactive.error.ErrorWebExceptionHandler;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.test.web.reactive.server.HttpHandlerConnector.FailureAfterResponseCompletedException;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping;
Expand All @@ -50,6 +58,7 @@
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.result.view.ViewResolver;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
Expand Down Expand Up @@ -573,6 +582,21 @@ void defaultErrorAttributesSubclassWithoutDelegation() {
});
}

@Test
void customErrorWebExceptionHandlerWithoutStatus() {
this.contextRunner.withUserConfiguration(CustomErrorWebExceptionHandlerWithoutStatus.class).run((context) -> {
WebTestClient client = getWebClient(context);
client.get()
.uri("/badRequest")
.exchange()
.expectStatus()
.isBadRequest()
.expectBody()
.jsonPath("status")
.doesNotExist();
});
}

private String getErrorTemplatesLocation() {
String packageName = getClass().getPackage().getName();
return "classpath:/" + packageName.replace('.', '/') + "/templates/";
Expand Down Expand Up @@ -675,4 +699,29 @@ public Map<String, Object> getErrorAttributes(ServerRequest request, ErrorAttrib

}

static class CustomErrorWebExceptionHandlerWithoutStatus {

@Bean
@Order(-1)
ErrorWebExceptionHandler errorWebExceptionHandler(ServerProperties serverProperties,
ErrorAttributes errorAttributes, WebProperties webProperties,
ObjectProvider<ViewResolver> viewResolvers, ServerCodecConfigurer serverCodecConfigurer,
ApplicationContext applicationContext) {
DefaultErrorWebExceptionHandler exceptionHandler = new DefaultErrorWebExceptionHandler(errorAttributes,
webProperties.getResources(), serverProperties.getError(), applicationContext) {

@Override
protected ErrorAttributeOptions getErrorAttributeOptions(ServerRequest request, MediaType mediaType) {
return super.getErrorAttributeOptions(request, mediaType).excluding(Include.STATUS, Include.ERROR);
}

};
exceptionHandler.setViewResolvers(viewResolvers.orderedStream().toList());
exceptionHandler.setMessageWriters(serverCodecConfigurer.getWriters());
exceptionHandler.setMessageReaders(serverCodecConfigurer.getReaders());
return exceptionHandler;
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,6 @@

import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -55,7 +54,7 @@ class DefaultErrorWebExceptionHandlerTests {
@Test
void nonStandardErrorStatusCodeShouldNotFail() {
ErrorAttributes errorAttributes = mock(ErrorAttributes.class);
given(errorAttributes.getErrorAttributes(any(), any())).willReturn(getErrorAttributes());
given(errorAttributes.getErrorAttributes(any(), any())).willReturn(Collections.singletonMap("status", 498));
Resources resourceProperties = new Resources();
ErrorProperties errorProperties = new ErrorProperties();
ApplicationContext context = new AnnotationConfigReactiveWebApplicationContext();
Expand All @@ -67,10 +66,6 @@ void nonStandardErrorStatusCodeShouldNotFail() {
exceptionHandler.handle(exchange, new RuntimeException()).block();
}

private Map<String, Object> getErrorAttributes() {
return Collections.singletonMap("status", 498);
}

private void setupViewResolver(DefaultErrorWebExceptionHandler exceptionHandler) {
View view = mock(View.class);
given(view.render(any(), any(), any())).willReturn(Mono.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.context.PropertyPlaceholderAutoConfiguration;
import org.springframework.boot.autoconfigure.freemarker.FreeMarkerAutoConfiguration;
import org.springframework.boot.autoconfigure.http.HttpMessageConvertersAutoConfiguration;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.servlet.DispatcherServletAutoConfiguration;
import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration;
import org.springframework.boot.autoconfigure.web.servlet.WebMvcAutoConfiguration;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.error.ErrorAttributeOptions.Include;
import org.springframework.boot.web.servlet.error.ErrorAttributes;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -343,6 +348,18 @@ void testIncompatibleMediaType() {
assertThat(entity.getBody()).isNull();
}

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void customErrorControllerWithoutStatusConfiguration() {
load(CustomErrorControllerWithoutStatusConfiguration.class);
RequestEntity request = RequestEntity.post(URI.create(createUrl("/bodyValidation")))
.accept(MediaType.APPLICATION_JSON)
.contentType(MediaType.APPLICATION_JSON)
.body("{}");
ResponseEntity<Map> entity = new TestRestTemplate().exchange(request, Map.class);
assertThat(entity.getBody()).doesNotContainKey("status");
}

private void assertErrorAttributes(Map<?, ?> content, String status, String error, Class<?> exception,
String message, String path) {
assertThat(content.get("status")).as("Wrong status").hasToString(status);
Expand All @@ -363,12 +380,16 @@ private String createUrl(String path) {
}

private void load(String... arguments) {
load(TestConfiguration.class, arguments);
}

private void load(Class<?> configuration, String... arguments) {
List<String> args = new ArrayList<>();
args.add("--server.port=0");
if (arguments != null) {
args.addAll(Arrays.asList(arguments));
}
this.context = SpringApplication.run(TestConfiguration.class, StringUtils.toStringArray(args));
this.context = SpringApplication.run(configuration, StringUtils.toStringArray(args));
}

@Target(ElementType.TYPE)
Expand All @@ -394,11 +415,13 @@ static void main(String[] args) {
@Bean
View error() {
return new AbstractView() {

@Override
protected void renderMergedOutputModel(Map<String, Object> model, HttpServletRequest request,
HttpServletResponse response) throws Exception {
response.getWriter().write("ERROR_BEAN");
}

};
}

Expand Down Expand Up @@ -498,4 +521,23 @@ void setContent(String content) {

}

static class CustomErrorControllerWithoutStatusConfiguration extends TestConfiguration {

@Bean
BasicErrorController basicErrorController(ServerProperties serverProperties, ErrorAttributes errorAttributes,
ObjectProvider<ErrorViewResolver> errorViewResolvers) {
return new BasicErrorController(errorAttributes, serverProperties.getError(),
errorViewResolvers.orderedStream().toList()) {

@Override
protected ErrorAttributeOptions getErrorAttributeOptions(HttpServletRequest request,
MediaType mediaType) {
return super.getErrorAttributeOptions(request, mediaType).excluding(Include.STATUS);
}

};
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Map;
import java.util.Set;

/**
Expand Down Expand Up @@ -79,6 +80,19 @@ public ErrorAttributeOptions excluding(Include... excludes) {
return new ErrorAttributeOptions(Collections.unmodifiableSet(updated));
}

/**
* Remove elements from the given map if they are not included in this set of options.
* @param map the map to update
* @since 3.2.7
*/
public void retainIncluded(Map<String, Object> map) {
for (Include candidate : Include.values()) {
if (!this.includes.contains(candidate)) {
map.remove(candidate.key);
}
}
}

private EnumSet<Include> copyIncludes() {
return (this.includes.isEmpty()) ? EnumSet.noneOf(Include.class) : EnumSet.copyOf(this.includes);
}
Expand All @@ -88,7 +102,7 @@ private EnumSet<Include> copyIncludes() {
* @return an {@code ErrorAttributeOptions}
*/
public static ErrorAttributeOptions defaults() {
return of();
return of(Include.STATUS, Include.ERROR);
}

/**
Expand Down Expand Up @@ -120,22 +134,40 @@ public enum Include {
/**
* Include the exception class name attribute.
*/
EXCEPTION,
EXCEPTION("exception"),

/**
* Include the stack trace attribute.
*/
STACK_TRACE,
STACK_TRACE("trace"),

/**
* Include the message attribute.
*/
MESSAGE,
MESSAGE("message"),

/**
* Include the binding errors attribute.
*/
BINDING_ERRORS
BINDING_ERRORS("errors"),

/**
* Include the HTTP status code.
* @since 3.2.7
*/
STATUS("status"),

/**
* Include the HTTP status code.
* @since 3.2.7
*/
ERROR("error");

private final String key;

Include(String key) {
this.key = key;
}

}

Expand Down
Loading

0 comments on commit 60b7e6c

Please sign in to comment.