Skip to content

Commit

Permalink
ojAlgo v47
Browse files Browse the repository at this point in the history
  • Loading branch information
apete committed Jan 12, 2019
1 parent 0e9a579 commit afeec4e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 21 deletions.
2 changes: 1 addition & 1 deletion engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
<dependency>
<groupId>org.ojalgo</groupId>
<artifactId>ojalgo</artifactId>
<version>45.1.1</version>
<version>47.0.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import java.util.Iterator;
import java.util.List;

import org.ojalgo.matrix.BasicMatrix;
import static org.ojalgo.function.PrimitiveFunction.*;
import org.ojalgo.matrix.PrimitiveMatrix;
import org.ojalgo.matrix.PrimitiveMatrix.DenseReceiver;
import org.springframework.stereotype.Component;

import io.seldon.engine.exception.APIException;
Expand Down Expand Up @@ -47,8 +48,7 @@ public SeldonMessage aggregate(List<SeldonMessage> outputs, PredictiveUnitState
if (shape.length!=2){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received data that is not 2 dimensional"));
}
BasicMatrix.Factory<PrimitiveMatrix> matrixFactory = PrimitiveMatrix.FACTORY;
PrimitiveMatrix currentSum = matrixFactory.makeZero(shape[0], shape[1]);
DenseReceiver currentSum = PrimitiveMatrix.FACTORY.makeDense(shape[0], shape[1]);
SeldonMessage.Builder respBuilder = SeldonMessage.newBuilder();

for (Iterator<SeldonMessage> i = outputs.iterator(); i.hasNext();)
Expand All @@ -67,12 +67,11 @@ public SeldonMessage aggregate(List<SeldonMessage> outputs, PredictiveUnitState
if (inputShape[1] != shape[1]){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Expected batch length %d but found %d",shape[1],inputShape[1]));
}
PrimitiveMatrix inputArr = PredictorUtils.getOJMatrix(inputData);
currentSum = currentSum.add(inputArr);
PredictorUtils.add(inputData, currentSum);
}
currentSum = currentSum.divide((float)outputs.size());
currentSum.modifyAll(DIVIDE.by(outputs.size()));

DefaultData newData = PredictorUtils.updateData(outputs.get(0).getData(), currentSum);
DefaultData newData = PredictorUtils.updateData(outputs.get(0).getData(), currentSum.get());
respBuilder.setData(newData);
respBuilder.setMeta(outputs.get(0).getMeta());
respBuilder.setStatus(outputs.get(0).getStatus());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.util.Iterator;
import java.util.List;

import org.ojalgo.matrix.BasicMatrix;
import org.ojalgo.matrix.PrimitiveMatrix;

import com.google.protobuf.ListValue;
Expand Down Expand Up @@ -65,8 +64,38 @@ else if (data.getDataOneofCase() == DataOneofCase.NDARRAY)
return null;
}

public static void add(DefaultData data, PrimitiveMatrix.DenseReceiver receiver) {

if (data.getDataOneofCase() == DataOneofCase.TENSOR) {

List<Double> valuesList = data.getTensor().getValuesList();
List<Integer> shapeList = data.getTensor().getShapeList();

int rows = shapeList.get(0);
int columns = shapeList.get(1);

for (int i = 0; i < rows * columns; i++) {
receiver.add(i / columns, i % columns, valuesList.get(i));
}

} else if (data.getDataOneofCase() == DataOneofCase.NDARRAY) {

ListValue list = data.getNdarray();

int rows = list.getValuesCount();
int cols = list.getValues(0).getListValue().getValuesCount();

for (int i = 0; i < rows; ++i) {
ListValue rowListValue = list.getValues(i).getListValue();
for (int j = 0; j < cols; j++) {
receiver.add(i, j, rowListValue.getValues(j).getNumberValue());
}
}
}
}

public static PrimitiveMatrix getOJMatrix(DefaultData data){
BasicMatrix.Factory<PrimitiveMatrix> matrixFactory = PrimitiveMatrix.FACTORY;
PrimitiveMatrix.Factory matrixFactory = PrimitiveMatrix.FACTORY;
if (data.getDataOneofCase() == DataOneofCase.TENSOR)
{

Expand Down Expand Up @@ -101,9 +130,6 @@ else if (data.getDataOneofCase() == DataOneofCase.NDARRAY)
return null;
}




public static int[] getShape(DefaultData data){
if (data.getDataOneofCase() == DataOneofCase.TENSOR){
List<Integer> shapeList = data.getTensor().getShapeList();
Expand Down Expand Up @@ -135,26 +161,29 @@ public static DefaultData updateData(DefaultData oldData, PrimitiveMatrix newDat
// index++;
// }

int rows = (int) newData.countRows();
int cols = (int) newData.countColumns();

if (oldData.getDataOneofCase() == DataOneofCase.TENSOR){
Tensor.Builder tBuilder = Tensor.newBuilder();

tBuilder.addShape((int)newData.countRows());
tBuilder.addShape((int)newData.countColumns());
tBuilder.addShape(rows);
tBuilder.addShape(cols);

for (int i=0; i<newData.countRows(); ++i){
for (int j=0; j<newData.countColumns(); ++j){
tBuilder.addValues(newData.get(i,j));
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
tBuilder.addValues(newData.doubleValue(i, j));
}
}
dataBuilder.setTensor(tBuilder);
return dataBuilder.build();
}
else if (oldData.getDataOneofCase() == DataOneofCase.NDARRAY){
ListValue.Builder b1 = ListValue.newBuilder();
for (int i = 0; i < newData.countRows(); ++i) {
for (int i = 0; i < rows; ++i) {
ListValue.Builder b2 = ListValue.newBuilder();
for (int j = 0; j < newData.countColumns(); j++){
b2.addValues(Value.newBuilder().setNumberValue(newData.get(i,j)));
for (int j = 0; j < cols; j++) {
b2.addValues(Value.newBuilder().setNumberValue(newData.doubleValue(i, j)));
}
b1.addValues(Value.newBuilder().setListValue(b2.build()));
}
Expand Down

0 comments on commit afeec4e

Please sign in to comment.