Skip to content

Commit

Permalink
add the support for multipart/form-data in python service and engine
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuma37 committed Aug 3, 2019
1 parent db458fa commit 72d17b7
Show file tree
Hide file tree
Showing 18 changed files with 1,211 additions and 40 deletions.
4 changes: 4 additions & 0 deletions engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
*******************************************************************************/
package io.seldon.engine.api.rest;

import java.io.IOException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;

import javax.annotation.PostConstruct;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -43,6 +48,8 @@
import io.seldon.engine.tracing.TracingProvider;
import io.seldon.protos.PredictionProtos.Feedback;
import io.seldon.protos.PredictionProtos.SeldonMessage;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;

@RestController
public class RestClientController {
Expand Down Expand Up @@ -127,55 +134,105 @@ String unpause() {
@Timed
@CrossOrigin(origins = "*")
@RequestMapping(value = "/api/v0.1/predictions", method = RequestMethod.POST, consumes = "application/json; charset=utf-8", produces = "application/json; charset=utf-8")
public ResponseEntity<String> predictions(RequestEntity<String> requestEntity)
public ResponseEntity<String> predictions_json(RequestEntity<String> requestEntity)
{
logger.debug("Received predict request");
Scope tracingScope = null;
if (tracingProvider.isActive())
tracingScope = tracingProvider.getTracer().buildSpan("/api/v0.1/predictions").startActive(true);
try
{
SeldonMessage request;
try
{
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, requestEntity.getBody() );
request = builder.build();
}
catch (InvalidProtocolBufferException e)
{
logger.error("Bad request",e);
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,requestEntity.getBody());
return _predictions(requestEntity.getBody());
}

try
finally
{
SeldonMessage response = predictionService.predict(request);
String responseJson = ProtoBufUtils.toJson(response);
return new ResponseEntity<String>(responseJson,HttpStatus.OK);
if (tracingScope != null)
tracingScope.close();
}
catch (InterruptedException e) {
throw new APIException(ApiExceptionType.ENGINE_INTERRUPTED,e.getMessage());
} catch (ExecutionException e) {
if (e.getCause().getClass() == APIException.class){
throw (APIException) e.getCause();

}


@Timed
@CrossOrigin(origins = "*")
@RequestMapping(value = "/api/v0.1/predictions", method = RequestMethod.POST, consumes = "multipart/form-data", produces = "application/json; charset=utf-8")
public ResponseEntity<String> predictions_multiform(MultipartHttpServletRequest requestEntity)
{
logger.debug("Received predict request");
Scope tracingScope = null;
if (tracingProvider.isActive())
tracingScope = tracingProvider.getTracer().buildSpan("/api/v0.1/predictions").startActive(true);
try {
ObjectMapper mapper = new ObjectMapper();
Map<String,Object> mergedParamMap = new HashMap<String,Object>();
if(requestEntity.getParameterMap() != null){
for(Map.Entry<String,String[]> formEntry: requestEntity.getParameterMap().entrySet()){
if(formEntry.getKey().equalsIgnoreCase(SeldonMessage.DataOneofCase.STRDATA.name())){
mergedParamMap.put(formEntry.getKey(),formEntry.getValue()[0]);
}else{
mergedParamMap.put(formEntry.getKey(),mapper.readTree(formEntry.getValue()[0]));
}
}
}
else
{
throw new APIException(ApiExceptionType.ENGINE_EXECUTION_FAILURE,e.getMessage());
if(requestEntity.getFileMap() != null){
for(Map.Entry<String ,MultipartFile> fileEntry: requestEntity.getFileMap().entrySet()){
if(fileEntry.getKey().equalsIgnoreCase(SeldonMessage.DataOneofCase.STRDATA.name())){
mergedParamMap.put(fileEntry.getKey(),new String(fileEntry.getValue().getBytes()));
}else{
mergedParamMap.put(fileEntry.getKey(),fileEntry.getValue().getBytes());
}
}
}
} catch (InvalidProtocolBufferException e) {
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,"");
}
}
finally

return _predictions(mapper.writeValueAsString(mergedParamMap));
} catch (IOException e) {
logger.error("Bad request",e);
throw new APIException(ApiExceptionType.REQUEST_IO_EXCEPTION,e.getMessage());

} finally
{
if (tracingScope != null)
tracingScope.close();
}

}


private ResponseEntity<String> _predictions(String json)
{
SeldonMessage request;
try
{
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, json );
request = builder.build();
}
catch (InvalidProtocolBufferException e)
{
logger.error("Bad request",e);
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,json);
}

try
{
SeldonMessage response = predictionService.predict(request);
String responseJson = ProtoBufUtils.toJson(response);
return new ResponseEntity<String>(responseJson,HttpStatus.OK);
}
catch (InterruptedException e) {
throw new APIException(ApiExceptionType.ENGINE_INTERRUPTED,e.getMessage());
} catch (ExecutionException e) {
if (e.getCause().getClass() == APIException.class){
throw (APIException) e.getCause();
}
else
{
throw new APIException(ApiExceptionType.ENGINE_EXECUTION_FAILURE,e.getMessage());
}
} catch (InvalidProtocolBufferException e) {
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,"");
}
}

@Timed
@CrossOrigin(origins = "*")
@RequestMapping(value= "/api/v0.1/feedback", method = RequestMethod.POST, consumes = "application/json; charset=utf-8", produces = "application/json; charset=utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public enum ApiExceptionType {
ENGINE_INVALID_COMBINER_RESPONSE(204,"Invalid number of predictions from combiner",500),
ENGINE_INTERRUPTED(205,"API call interrupted",500),
ENGINE_EXECUTION_FAILURE(206,"Execution failure",500),
ENGINE_INVALID_ROUTING(207,"Invalid Routing",500);
ENGINE_INVALID_ROUTING(207,"Invalid Routing",500),
REQUEST_IO_EXCEPTION(208,"IO Exception",500);

int id;
String message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.util.Arrays;

import com.google.protobuf.ByteString;
import io.seldon.protos.PredictionProtos;
import org.springframework.stereotype.Component;

import io.seldon.protos.PredictionProtos.DefaultData;
Expand All @@ -37,15 +39,25 @@ public SimpleModelUnit() {}

@Override
public SeldonMessage transformInput(SeldonMessage input, PredictiveUnitState state){
SeldonMessage output = SeldonMessage.newBuilder()
SeldonMessage.Builder builder = SeldonMessage.newBuilder()
.setStatus(Status.newBuilder().setStatus(Status.StatusFlag.SUCCESS).build())
.setMeta(Meta.newBuilder()
.addMetrics(Metric.newBuilder().setKey("mymetric_counter").setType(MetricType.COUNTER).setValue(1))
.addMetrics(Metric.newBuilder().setKey("mymetric_gauge").setType(MetricType.GAUGE).setValue(100))
.addMetrics(Metric.newBuilder().setKey("mymetric_timer").setType(MetricType.TIMER).setValue(22.1F)))
.setData(DefaultData.newBuilder().addAllNames(Arrays.asList(classes))
.addMetrics(Metric.newBuilder().setKey("mymetric_timer").setType(MetricType.TIMER).setValue(22.1F)));

// echo in case of strData and binData
if(input.getDataOneofCase().equals(SeldonMessage.DataOneofCase.BINDATA)){
builder.setBinData(input.getBinData());
} else if (input.getDataOneofCase().equals(SeldonMessage.DataOneofCase.STRDATA)){
builder.setStrData(input.getStrData());
}else{
builder.setData(DefaultData.newBuilder().addAllNames(Arrays.asList(classes))
.setTensor(Tensor.newBuilder().addShape(1).addShape(values.length)
.addAllValues(Arrays.asList(values)))).build();
.addAllValues(Arrays.asList(values))));
}

SeldonMessage output = builder.build();
System.out.println("Model " + state.name + " finishing computations");
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.http.MediaType;
import org.springframework.jmx.support.MetricType;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.context.WebApplicationContext;

import io.seldon.engine.pb.ProtoBufUtils;
import io.seldon.protos.PredictionProtos.SeldonMessage;

import java.util.*;

@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
//@AutoConfigureMockMvc
Expand Down Expand Up @@ -135,4 +138,146 @@ public void testPredict_21dim_tensor() throws Exception
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
}

@Test
public void testPredict_multiform_11dim_ndarry() throws Exception
{
final String predictJson = "{" +
"\"request\": {" +
"\"ndarray\": [[1.0]]}" +
"}";
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("data", Arrays.asList(predictJson));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredict_multiform_21dim_ndarry() throws Exception
{
final String predictJson = "{" +
"\"request\": {" +
"\"ndarray\": [[1.0],[2.0]]}" +
"}";
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("data", Arrays.asList(predictJson));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
}

@Test
public void testPredict_multiform_21dim_tensor() throws Exception
{
final String predictJson = "{" +
"\"request\": {" +
"\"tensor\": {\"shape\":[2,1],\"values\":[1.0,2.0]}}" +
"}";
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("data", Arrays.asList(predictJson));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
}
@Test
public void testPredict_multiform_binData() throws Exception
{
final String metaJson = "{\"puid\":\"1234\"}" ;
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("meta", Arrays.asList(metaJson));
byte[] fileData = "test data".getBytes();
MvcResult res = mvc.perform(MockMvcRequestBuilders.fileUpload("/api/v0.1/predictions").file("binData",fileData)
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
Assert.assertEquals(new String(fileData), seldonMessage.getBinData().toStringUtf8());
Assert.assertEquals("1234", seldonMessage.getMeta().getPuid());
}
@Test
public void testPredict_multiform_strData_as_file() throws Exception
{
final String metaJson = "{\"puid\":\"1234\"}" ;
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("meta", Arrays.asList(metaJson));
byte[] fileData = "test data".getBytes();
MvcResult res = mvc.perform(MockMvcRequestBuilders.fileUpload("/api/v0.1/predictions").file("strData",fileData)
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
Assert.assertEquals(new String(fileData), seldonMessage.getStrData());
Assert.assertEquals("1234", seldonMessage.getMeta().getPuid());

}
@Test
public void testPredict_multiform_strData_as_text() throws Exception
{
final String metaJson = "{\"puid\":\"1234\"}" ;
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("meta", Arrays.asList(metaJson));
String strdata = "test data";
paramMap.put("strData",Arrays.asList(strdata));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
Assert.assertEquals(strdata, seldonMessage.getStrData());
Assert.assertEquals("1234", seldonMessage.getMeta().getPuid());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
MODEL_NAME=IrisClassifier
API_TYPE=REST
SERVICE_TYPE=MODEL
PERSISTENCE=0
Loading

0 comments on commit 72d17b7

Please sign in to comment.