Skip to content

Commit

Permalink
Cleaned up code
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Apr 20, 2024
1 parent f90041c commit b9c2734
Show file tree
Hide file tree
Showing 16 changed files with 42 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import com.google.common.base.Strings;
import com.google.common.collect.LinkedHashMultimap;
Expand Down Expand Up @@ -64,13 +65,13 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

Object missingCategory = null;

if((BaseNEncoder.HANDLEMISSING_VALUE).equals(handleMissing)){
if(Objects.equals(BaseNEncoder.HANDLEMISSING_VALUE, handleMissing)){
missingCategory = BaseEncoder.CATEGORY_NAN;
}

Integer defaultValue = null;

if((BaseNEncoder.HANDLEUNKNOWN_VALUE).equals(handleUnknown)){
if(Objects.equals(BaseNEncoder.HANDLEUNKNOWN_VALUE, handleUnknown)){
defaultValue = 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import com.google.common.base.Functions;
Expand Down Expand Up @@ -63,13 +64,13 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

Object missingCategory = null;

if((CountEncoder.HANDLEMISSING_COUNT).equals(handleMissing) || (CountEncoder.HANDLEMISSING_VALUE).equals(handleMissing)){
if(Objects.equals(CountEncoder.HANDLEMISSING_COUNT, handleMissing) || Objects.equals(CountEncoder.HANDLEMISSING_VALUE, handleMissing)){
missingCategory = BaseEncoder.CATEGORY_NAN;
}

Integer defaultValue = null;

if((CountEncoder.HANDLEUNKNOWN_VALUE).equals(handleUnknown)){
if(Objects.equals(CountEncoder.HANDLEUNKNOWN_VALUE, handleUnknown)){
defaultValue = getDefaultValue();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;

import com.google.common.base.Functions;
Expand Down Expand Up @@ -63,13 +64,13 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

Object missingCategory = null;

if((MeanEncoder.HANDLEMISSING_VALUE).equals(handleMissing)){
if(Objects.equals(MeanEncoder.HANDLEMISSING_VALUE, handleMissing)){
missingCategory = BaseEncoder.CATEGORY_NAN;
}

Number defaultValue = null;

if((MeanEncoder.HANDLEUNKNOWN_VALUE).equals(handleUnknown)){
if(Objects.equals(MeanEncoder.HANDLEUNKNOWN_VALUE, handleUnknown)){
defaultValue = getMean();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import com.google.common.base.Function;
import com.google.common.base.Functions;
Expand Down Expand Up @@ -52,13 +53,13 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

Integer mapMissingTo = null;

if((OrdinalEncoder.HANDLEMISSING_VALUE).equals(handleMissing)){
if(Objects.equals(OrdinalEncoder.HANDLEMISSING_VALUE, handleMissing)){
mapMissingTo = -1;
}

Integer defaultValue = null;

if((OrdinalEncoder.HANDLEUNKNOWN_VALUE).equals(handleUnknown)){
if(Objects.equals(OrdinalEncoder.HANDLEUNKNOWN_VALUE, handleUnknown)){
defaultValue = -2;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import numpy.core.ScalarUtil;
import org.dmg.pmml.Field;
Expand Down Expand Up @@ -55,13 +56,13 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

Object missingCategory = null;

if((OrdinalMapEncoder.HANDLEMISSING_VALUE).equals(handleMissing)){
if(Objects.equals(OrdinalMapEncoder.HANDLEMISSING_VALUE, handleMissing)){
missingCategory = BaseEncoder.CATEGORY_NAN;
}

Object unknownCategory = null;

if((OrdinalMapEncoder.HANDLEUNKNOWN_VALUE).equals(handleUnknown)){
if(Objects.equals(OrdinalMapEncoder.HANDLEUNKNOWN_VALUE, handleUnknown)){
unknownCategory = OrdinalEncoder.CATEGORY_UNKNOWN;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import org.dmg.pmml.PMML;
import org.jpmml.converter.Feature;
Expand Down Expand Up @@ -106,11 +107,11 @@ protected Transformer getTransformer(Object[] fittedTransformer){
@Override
public Transformer apply(Object object){

if((SkLearnSteps.DROP).equals(object)){
if(Objects.equals(SkLearnSteps.DROP, object)){
return Drop.INSTANCE;
} else

if((SkLearnSteps.PASSTHROUGH).equals(object)){
if(Objects.equals(SkLearnSteps.PASSTHROUGH, object)){
return PassThrough.INSTANCE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import com.google.common.primitives.Booleans;
import org.jpmml.converter.ValueUtil;
Expand All @@ -43,7 +44,7 @@ public List<Boolean> getSupportMask(){
Object k = getK();
List<Number> scores = getScores();

if(("all").equals(k)){
if(Objects.equals("all", k)){
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
Expand Down Expand Up @@ -59,7 +60,7 @@ public Model encodeModel(Schema schema){
String sklearnVersion = getSkLearnVersion();
String multiClass = getMultiClass();

if((LogisticRegression.MULTICLASS_AUTO).equals(multiClass)){
if(Objects.equals(LogisticRegression.MULTICLASS_AUTO, multiClass)){
int[] shape = getCoefShape();
String solver = getSolver();

Expand Down Expand Up @@ -166,7 +167,7 @@ public String getMultiClass(){
String multiClass = getEnum("multi_class", this::getString, Arrays.asList(LogisticRegression.MULTICLASS_AUTO, LogisticRegression.MULTICLASS_MULTINOMIAL, LogisticRegression.MULTICLASS_OVR, LogisticRegression.MULTICLASS_WARN));

// SkLearn 0.20
if((LogisticRegression.MULTICLASS_WARN).equals(multiClass)){
if(Objects.equals(LogisticRegression.MULTICLASS_WARN, multiClass)){
multiClass = LogisticRegression.MULTICLASS_OVR;
}

Expand All @@ -182,7 +183,7 @@ private String getAutoMultiClass(String solver, int[] shape){
int numberOfClasses = shape[0];
int numberOfFeatures = shape[1];

if((LogisticRegression.SOLVER_LIBLINEAR).equals(solver)){
if(Objects.equals(LogisticRegression.SOLVER_LIBLINEAR, solver)){
return LogisticRegression.MULTICLASS_OVR;
} // End if

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package sklearn.pipeline;

import java.util.List;
import java.util.Objects;

import com.google.common.collect.Lists;
import net.razorvine.pickle.objects.ClassDict;
Expand Down Expand Up @@ -80,7 +81,7 @@ public boolean hasFinalEstimator(){

Object step = TupleUtil.extractElement(finalStep, 1);

if((step == null) || (SkLearnSteps.PASSTHROUGH).equals(step)){
if((step == null) || Objects.equals(SkLearnSteps.PASSTHROUGH, step)){
return false;
} // End if

Expand Down Expand Up @@ -130,7 +131,7 @@ public List<? extends Transformer> getTransformers(){
@Override
public Transformer apply(Object object){

if((object == null) || (SkLearnSteps.PASSTHROUGH).equals(object)){
if((object == null) || Objects.equals(SkLearnSteps.PASSTHROUGH, object)){
return PassThrough.INSTANCE;
}

Expand Down Expand Up @@ -163,7 +164,7 @@ public <E extends Estimator> E getFinalEstimator(Class<? extends E> clazz){

Object step = TupleUtil.extractElement(finalStep, 1);

if((step == null) || (SkLearnSteps.PASSTHROUGH).equals(step)){
if((step == null) || Objects.equals(SkLearnSteps.PASSTHROUGH, step)){
throw new SkLearnException("The pipeline ends with a transformer-like object");
}

Expand Down Expand Up @@ -195,7 +196,7 @@ public Step getHead(){
@Override
public Step apply(Object object){

if((object == null) || (SkLearnSteps.PASSTHROUGH).equals(object)){
if((object == null) || Objects.equals(SkLearnSteps.PASSTHROUGH, object)){
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package sklearn.preprocessing;

import java.util.List;
import java.util.Objects;

import sklearn.impute.SimpleImputer;

Expand All @@ -32,7 +33,7 @@ public Imputer(String module, String name){
public Object getMissingValues(){
Object missingValues = super.getMissingValues();

if(("NaN").equals(missingValues)){
if(Objects.equals("NaN", missingValues)){
missingValues = null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import com.google.common.collect.ContiguousSet;
import com.google.common.collect.DiscreteDomain;
Expand Down Expand Up @@ -107,7 +108,7 @@ public List<Number> getValues(){

Object numberOfValues = getOptionalScalar("n_values");

if(("auto").equals(numberOfValues)){
if(Objects.equals("auto", numberOfValues)){
return getActiveFeatures();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

import org.dmg.pmml.DataType;
import org.dmg.pmml.InvalidValueTreatmentMethod;
Expand Down Expand Up @@ -48,7 +49,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

if(handleUnknown != null){

if((OrdinalEncoder.HANDLEUNKNOWN_USE_ENCODED_VALUE).equals(handleUnknown)){
if(Objects.equals(OrdinalEncoder.HANDLEUNKNOWN_USE_ENCODED_VALUE, handleUnknown)){
unknownValue = getUnknownValue();

if(ValueUtil.isNaN(unknownValue)){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ private DocumentNamespaceContext(Document document){
public String getNamespaceURI(String prefix){
Document document = getDocument();

if((XMLConstants.DEFAULT_NS_PREFIX).equals(prefix)){
if(Objects.equals(XMLConstants.DEFAULT_NS_PREFIX, prefix)){
return document.lookupNamespaceURI(null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.Collections;
import java.util.List;
import java.util.Objects;

import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
Expand Down Expand Up @@ -50,7 +51,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
for(int i = 0; i < features.size(); i++){
Feature feature = features.get(i);

if((i > 0) && !("").equals(separator)){
if((i > 0) && !Objects.equals("", separator)){
expressions.add(ExpressionUtil.createConstant(DataType.STRING, separator));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.Collections;
import java.util.List;
import java.util.Objects;

import org.dmg.pmml.DataType;
import org.jpmml.converter.Feature;
Expand Down Expand Up @@ -57,7 +58,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
feature = CategoricalDtypeUtil.refineFeature(feature, categoricalDtype, encoder);
} // End if

if(name != null && !(feature.getName()).equals(name)){
if(name != null && !Objects.equals(feature.getName(), name)){
encoder.renameFeature(feature, name);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.Collections;
import java.util.List;
import java.util.Objects;

import org.dmg.pmml.MissingValueTreatmentMethod;
import org.jpmml.converter.Feature;
Expand All @@ -41,7 +42,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

ClassDictUtil.checkSize(1, features);

if(("NaN").equals(missingValues)){
if(Objects.equals("NaN", missingValues)){
missingValues = null;
}

Expand Down

0 comments on commit b9c2734

Please sign in to comment.