Skip to content

Commit

Permalink
Improved support for the 'OneHotEncoder.infrequent_categories_' attri…
Browse files Browse the repository at this point in the history
…bute. Fixes jpmml/sklearn2pmml#412
  • Loading branch information
vruusmann committed Feb 23, 2024
1 parent fa8eab4 commit ec23fbc
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ public List<List<Integer>> transformInfrequentIndices(List<HasArray> arrays){

@Override
public List<Integer> apply(HasArray hasArray){

if(hasArray == null){
return null;
}

return ValueUtil.asIntegers((List)hasArray.getArrayContent());
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

import com.google.common.collect.Lists;
import org.dmg.pmml.DataField;
Expand Down Expand Up @@ -67,11 +65,16 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
for(int i = 0; i < features.size(); i++){
Feature feature = features.get(i);
List<Object> featureCategories = new ArrayList<>(categories.get(i));
Set<Integer> featureInfrequentIndices = infrequentEnabled ? new LinkedHashSet<>(infrequentIndices.get(i)) : Collections.emptySet();
List<Integer> featureInfrequentIndices = (infrequentEnabled ? infrequentIndices.get(i) : null);

boolean featureInfrequentEnabled = infrequentEnabled;
if(featureInfrequentIndices == null || featureInfrequentIndices.isEmpty()){
featureInfrequentEnabled = false;
}

Object infrequentCategory = null;

if(infrequentEnabled){
if(featureInfrequentEnabled){
infrequentCategory = getInfrequentCategory(feature);
}

Expand All @@ -90,7 +93,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
break;
case "infrequent_if_exist":
{
if(infrequentEnabled){
if(featureInfrequentEnabled){
invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_VALUE, infrequentCategory);
} else

Expand Down Expand Up @@ -162,7 +165,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
throw new IllegalArgumentException();
} // End if

if(infrequentEnabled){
if(featureInfrequentEnabled){

if(infrequentCategory == null || featureCategories.contains(infrequentCategory)){
throw new IllegalArgumentException();
Expand Down Expand Up @@ -198,7 +201,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
}
}

if(infrequentEnabled){
if(featureInfrequentEnabled){
result.add(new BinaryFeature(encoder, feature, infrequentCategory));
}
}
Expand Down Expand Up @@ -268,6 +271,11 @@ private Object getInfrequentCategory(Feature feature){

static
private <E> List<E> selectValues(List<E> values, Collection<Integer> indices){

if(indices == null || indices.isEmpty()){
return Collections.emptyList();
}

List<E> result = new ArrayList<>();

for(Integer index : indices){
Expand Down
2 changes: 1 addition & 1 deletion pmml-sklearn/src/test/resources/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def build_multi_auto(auto_df, regressor, name, with_kneighbors = False):
(["displacement", "horsepower", "weight"], ContinuousDomain(with_statistics = True)),
(["cylinders"], [CategoricalDomain(with_statistics = True, invalid_value_treatment = "as_is"), OneHotEncoder(handle_unknown = "infrequent_if_exist")]),
(["model_year"], [CategoricalDomain(with_statistics = True, invalid_value_treatment = "as_is"), OneHotEncoder(max_categories = 10, handle_unknown = "infrequent_if_exist")]),
(["origin"], [CategoricalDomain(with_statistics = True, invalid_value_treatment = "as_is"), OneHotEncoder()])
(["origin"], [CategoricalDomain(with_statistics = True, invalid_value_treatment = "as_is"), OneHotEncoder(handle_unknown = "infrequent_if_exist", max_categories = 5)])
])
pipeline = PMMLPipeline([
("mapper", mapper),
Expand Down
Binary file not shown.
Binary file modified pmml-sklearn/src/test/resources/pkl/MultiKNNAuto.pkl
Binary file not shown.
Binary file not shown.
Binary file modified pmml-sklearn/src/test/resources/pkl/MultiLinearSVRAuto.pkl
Binary file not shown.
Binary file modified pmml-sklearn/src/test/resources/pkl/MultiMLPAuto.pkl
Binary file not shown.

0 comments on commit ec23fbc

Please sign in to comment.