Skip to content

Commit

Permalink
refactor and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amoldavsky committed May 17, 2021
1 parent c6a1cdc commit c35ff95
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.*;

@RestController
@ConditionalOnExpression("${seldon.api.model.enabled:false}")
Expand All @@ -32,19 +30,51 @@ public class ModelPredictionController {

@Autowired SeldonPredictionService predictionService;

/**
* Will access a POST or a GET request with either a query parameter or a FORM parameter.
*
* Examples:
* GET -> /predict?json={ ... }
* curl -s \
* localhost:9000/predict?json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
*
* POST FORM -> /predict
* curl -s -X POST \
* -d 'json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
* localhost:9000/predict
*
* @param json
* @return
* @deprecated
*/
@Deprecated
@RequestMapping(
value = "/predict",
method = {RequestMethod.GET, RequestMethod.POST},
produces = "application/json; charset=utf-8")
public ResponseEntity<String> predict(@RequestParam("json") String json) {
produces = MediaType.APPLICATION_JSON_UTF8_VALUE
)
public ResponseEntity<String> predictLegacy(@RequestParam("json") String json) {
return this.predict(json);
}

@RequestMapping(
value = "/predict",
method = {RequestMethod.POST},
consumes = {
MediaType.APPLICATION_JSON_VALUE,
MediaType.APPLICATION_JSON_UTF8_VALUE
},
produces = MediaType.APPLICATION_JSON_VALUE
)
public ResponseEntity<String> predict(@RequestBody String jsonStr) {
SeldonMessage request;
try {
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, json);
ProtoBufUtils.updateMessageBuilderFromJson(builder, jsonStr);
request = builder.build();
} catch (InvalidProtocolBufferException e) {
logger.error("Bad request", e);
throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, json);
throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, jsonStr);
}

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void setup() {
@LocalServerPort private int port;

@Test
public void testPredict() throws Exception {
public void testPredictLegacyGetQuery() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
MvcResult res =
mvc.perform(
Expand All @@ -54,6 +54,51 @@ public void testPredict() throws Exception {
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyPostQuery() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyPostForm() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_FORM_URLENCODED))
.andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredict() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON))
.andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testFeedback() throws Exception {
final String predictJson = readFile("src/test/resources/feedback.json", StandardCharsets.UTF_8);
Expand Down

0 comments on commit c35ff95

Please sign in to comment.