Skip to content

Commit

Permalink
updated unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amoldavsky committed May 18, 2021
1 parent 0e6dd8d commit 23fa285
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 16 deletions.
2 changes: 1 addition & 1 deletion incubating/wrappers/java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xs
<groupId>io.seldon.wrapper</groupId>
<artifactId>seldon-core-wrapper</artifactId>
<packaging>jar</packaging>
<version>0.4.1</version>
<version>0.4.2</version>
<name>Seldon Core Java Wrapper</name>
<url>http://maven.apache.org</url>
<description>Wrapper for seldon-core Java prediction models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class ModelPredictionController {
@Autowired SeldonPredictionService predictionService;

/**
* Will access a POST or a GET request with either a query parameter or a FORM parameter.
* Will accept a POST or a GET request with either a query parameter or a FORM parameter.
*
* Examples:
* GET -> /predict?json={ ... }
Expand All @@ -58,6 +58,22 @@ public ResponseEntity<String> predictLegacy(@RequestParam("json") String json) {
return this.predict(json);
}

/**
* Will accept a POST with a proper JSON body.
*
* Examples:
* POST -> /predict
* curl -s -X POST \
* -d '{"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
* localhost:9000/predict
*
* curl -s -X POST \
* -d '{"jsonData": {"foo": "bar"}' \
* localhost:9000/predict
*
* @param jsonStr
* @return
*/
@RequestMapping(
value = "/predict",
method = {RequestMethod.POST},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package io.seldon.wrapper.api;

import static io.seldon.wrapper.util.TestUtils.readFile;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.assertNotNull;
import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

import io.seldon.protos.PredictionProtos;
import java.nio.charset.StandardCharsets;
import org.junit.Assert;
import org.junit.Before;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.runner.RunWith;
import org.junit.runners.MethodSorters;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -47,52 +49,63 @@ public void setup() {

@Test
public void testPredictLegacyGetQuery() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.get("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyPostQuery() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.queryParam("json", predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyPostForm() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_FORM_URLENCODED))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyButNotPredict() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
Expand All @@ -101,6 +114,7 @@ public void testPredictLegacyButNotPredict() throws Exception {
.content(predictJson)
.contentType(MediaType.APPLICATION_FORM_URLENCODED))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
Expand All @@ -111,7 +125,9 @@ public void testPredictLegacyButNotPredict() throws Exception {

@Test
public void testPredictButNotPredictLegacy() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
Expand All @@ -129,14 +145,17 @@ public void testPredictButNotPredictLegacy() throws Exception {

@Test
public void testPredict() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
Expand All @@ -145,13 +164,16 @@ public void testPredict() throws Exception {
@Test
public void testPredictWithUTF8Header() throws Exception {
final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
Expand All @@ -160,16 +182,62 @@ public void testPredictWithUTF8Header() throws Exception {
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
}

@Test
public void testPredictWithDefaultData() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8)
)
.andDo(print())
.andExpect(status().isOk())
.andExpect(jsonPath("$.data", is(notNullValue())))
.andReturn();

// if we get back a header of "application/json" then we are hitting the legacy predict
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
}

@Test
public void testPredictWithJsonData_UTF8Header() throws Exception {
final String predictJson = TestMessages.JSON_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andExpect(status().isOk())
.andExpect(jsonPath("$.jsonData", is(notNullValue())))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);

// if we get back a header of "application/json" then we are hitting the legacy predict
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
}

@Test
public void testFeedback() throws Exception {
final String predictJson = readFile("src/test/resources/feedback.json", StandardCharsets.UTF_8);
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.get("/send-feedback")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.seldon.wrapper.api;

import static io.seldon.wrapper.util.TestUtils.readFileFromAbsolutePathOrResources;

final public class TestMessages {

/**
* All possible fields based on the SeldonMessage Proto:
* https://docs.seldon.io/projects/seldon-core/en/v1.6.0/reference/apis/prediction.html
*/
public static final String TF_DATA = readFile("requests/defaultData.json");
public static final String DEFAULT_DATA = TF_DATA;
public static final String JSON_DATA = readFile("requests/jsonData.json");
// TODO: add binData
// TODO: add strData
// TODO: add customData

private static String readFile(String file) {
return readFileFromAbsolutePathOrResources(file);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
import io.seldon.protos.PredictionProtos.DefaultData;
import io.seldon.protos.PredictionProtos.SeldonMessage;
import io.seldon.protos.PredictionProtos.Tensor;
import io.seldon.wrapper.pb.ProtoBufUtils;
import org.springframework.stereotype.Component;

@Component
public class TestPredictionService implements SeldonPredictionService {
@Override
public SeldonMessage predict(SeldonMessage payload) {
return SeldonMessage.newBuilder()
.setData(DefaultData.newBuilder().setTensor(Tensor.newBuilder().addShape(1).addValues(1.0)))
.build();
// echo payload back
return payload.toBuilder().build();
// SeldonMessage.Builder builder = SeldonMessage.newBuilder();
// ProtoBufUtils.updateMessageBuilderFromJson(builder, payload.getData());
// request = builder.build();
// SeldonMessage response = SeldonMessage.newBuilder();
// ProtoBufUtils.updateMessageBuilderFromJson(response, payload);
// return response;
// return SeldonMessage.newBuilder()
// .setData(DefaultData.newBuilder().setTensor(Tensor.newBuilder().addShape(1).addValues(1.0)))
// .build();
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
package io.seldon.wrapper.util;

import java.io.IOException;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;

public class TestUtils {

private static final ClassLoader classLoader = TestUtils.class.getClassLoader();

public static String readFile(String path, Charset encoding) throws IOException {
byte[] encoded = Files.readAllBytes(Paths.get(path));
return new String(encoded, encoding);
}

/**
* Will load file from either an absolute path of a relative path from "target/test-classes"
* @param file file path (ex: "requests/jsonData.json", "/dev/null")
* @return
*/
public static String readFileFromAbsolutePathOrResources(String file) {
try {
InputStream is = getInputStreamFromAbsolutePathOrResources(file, classLoader);
byte[] bytes = is.readAllBytes();
return new String(bytes, StandardCharsets.UTF_8);
} catch(Throwable t) {
System.out.println(t);
t.printStackTrace();
// nothing
}
return null;
}

public static InputStream getInputStreamFromAbsolutePathOrResources(String file, ClassLoader classLoader) {
InputStream is = null;

// try loading assuming an absolute path
try {
is = new FileInputStream(file);
} catch ( FileNotFoundException fne ) {
// Nothing
}
if( is == null ) {
is = classLoader.getResourceAsStream(file);
}

return is;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"data": {
"ndarray": [
[
1,
2
]
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"jsonData": {
"data": {
"subject": "helpful message",
"body": "nothing strange, good, rewarding"
},
"foo": "bar"
}
}

0 comments on commit 23fa285

Please sign in to comment.