Skip to content

Commit

Permalink
XGBoost classifier example (#3088)
Browse files Browse the repository at this point in the history
* XGBoost classifier for Iris dataset

* XGBoost classifier for Iris dataset

* XGBoost classifier for Iris dataset

* Added requirements.txt file

* spellcheck
  • Loading branch information
agunapal authored Apr 23, 2024
1 parent f2163e8 commit 946da22
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 0 deletions.
46 changes: 46 additions & 0 deletions examples/xgboost_classfication/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# XGBoost Classifier for Iris dataset

This example shows how to serve an XGBoost classifier model using TorchServe.
Here we train a model to classify iris dataset

## Pre-requisites

Train an XGBoost classifier model for Iris dataset

```
pip install -r requirements.txt
python xgboost_train.py
```

results in

```
Model accuracy is 1.0
Saving trained model to iris_model.json
```

## Create model archive

```
mkdir model_store
torch-model-archiver --model-name xgb_iris --version 1.0 --serialized-file iris_model.json --handler xgboost_iris_handler.py --export-path model_store --extra-files index_to_name.json --config-file model-config.yaml -f
```

## Start TorchServe

```
torchserve --start --ncs --model-store model_store --models xgb_iris=xgb_iris.mar
```

## Inference request

We send a batch of 2 requests
```
curl -X POST http://127.0.0.1:8080/predictions/xgb_iris -T sample_input_2.txt & curl -X POST http://127.0.0.1:8080/predictions/xgb_iris -T sample_input_1.txt
```

results in

```
versicolor setosa
```
1 change: 1 addition & 0 deletions examples/xgboost_classfication/index_to_name.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "setosa", "1": "versicolor", "2": "virginica"}
3 changes: 3 additions & 0 deletions examples/xgboost_classfication/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
minWorkers: 1
maxWorkers: 1
batchSize: 2
2 changes: 2 additions & 0 deletions examples/xgboost_classfication/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
scikit-learn
xgboost
1 change: 1 addition & 0 deletions examples/xgboost_classfication/sample_input_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.1,2.8,4.7,1.2
1 change: 1 addition & 0 deletions examples/xgboost_classfication/sample_input_2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5.7,3.8,1.7,0.3
54 changes: 54 additions & 0 deletions examples/xgboost_classfication/xgboost_iris_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import os

import numpy as np
import torch
from xgboost import XGBClassifier

from ts.torch_handler.base_handler import BaseHandler
from ts.utils.util import load_label_mapping

logger = logging.getLogger(__name__)


class XGBIrisHandler(BaseHandler):
def __init__(self):
super().__init__()

def initialize(self, context):
# Set device type
self.device = torch.device("cpu")

# Load the model
properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
self.model = XGBClassifier()
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
model_weights = os.path.join(model_dir, serialized_file)
self.model.load_model(model_weights)

mapping_file_path = os.path.join(model_dir, "index_to_name.json")
self.mapping = load_label_mapping(mapping_file_path)

logger.info(
f"XGBoost Classifier for iris dataset with weights {model_weights} loaded successfully"
)
self.initialized = True

def preprocess(self, requests):
inputs = []
for row in requests:
input = row.get("data") or row.get("body")
if isinstance(input, (bytes, bytearray)):
input = [float(value) for value in input.decode("utf-8").split(",")]
inputs.append(input)
return np.array(inputs)

def inference(self, data):
return self.model.predict(data)

def postprocess(self, result):
output = [self.mapping[str(res)] for res in result.tolist()]
return output
29 changes: 29 additions & 0 deletions examples/xgboost_classfication/xgboost_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Importing dataset from sklearn
from sklearn import datasets, metrics

iris = datasets.load_iris() # dataset loading
X = iris.data # Features stored in X
y = iris.target # Class variable

# Splitting dataset into Training (80%) and testing data (20%) using train_test_split
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)

# Create an XGB classifier and instance of the same
from xgboost import XGBClassifier

clf = XGBClassifier()

clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
# classification accuracy
from sklearn import metrics

print(f"Model accuracy is {metrics.accuracy_score(y_test, y_pred)}")
saved_model = "iris_model.json"
print(f"Saving trained model to {saved_model}")
clf.save_model(saved_model)
1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ venv
TorchInductor
Pytests
deviceType
XGBoost
Clamd
Fickling
TorchServer
Expand Down

0 comments on commit 946da22

Please sign in to comment.