-
Notifications
You must be signed in to change notification settings - Fork 858
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* XGBoost classifier for Iris dataset * XGBoost classifier for Iris dataset * XGBoost classifier for Iris dataset * Added requirements.txt file * spellcheck
- Loading branch information
Showing
9 changed files
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# XGBoost Classifier for Iris dataset | ||
|
||
This example shows how to serve an XGBoost classifier model using TorchServe. | ||
Here we train a model to classify iris dataset | ||
|
||
## Pre-requisites | ||
|
||
Train an XGBoost classifier model for Iris dataset | ||
|
||
``` | ||
pip install -r requirements.txt | ||
python xgboost_train.py | ||
``` | ||
|
||
results in | ||
|
||
``` | ||
Model accuracy is 1.0 | ||
Saving trained model to iris_model.json | ||
``` | ||
|
||
## Create model archive | ||
|
||
``` | ||
mkdir model_store | ||
torch-model-archiver --model-name xgb_iris --version 1.0 --serialized-file iris_model.json --handler xgboost_iris_handler.py --export-path model_store --extra-files index_to_name.json --config-file model-config.yaml -f | ||
``` | ||
|
||
## Start TorchServe | ||
|
||
``` | ||
torchserve --start --ncs --model-store model_store --models xgb_iris=xgb_iris.mar | ||
``` | ||
|
||
## Inference request | ||
|
||
We send a batch of 2 requests | ||
``` | ||
curl -X POST http://127.0.0.1:8080/predictions/xgb_iris -T sample_input_2.txt & curl -X POST http://127.0.0.1:8080/predictions/xgb_iris -T sample_input_1.txt | ||
``` | ||
|
||
results in | ||
|
||
``` | ||
versicolor setosa | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "setosa", "1": "versicolor", "2": "virginica"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
batchSize: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
scikit-learn | ||
xgboost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
6.1,2.8,4.7,1.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
5.7,3.8,1.7,0.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import logging | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
from xgboost import XGBClassifier | ||
|
||
from ts.torch_handler.base_handler import BaseHandler | ||
from ts.utils.util import load_label_mapping | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class XGBIrisHandler(BaseHandler): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def initialize(self, context): | ||
# Set device type | ||
self.device = torch.device("cpu") | ||
|
||
# Load the model | ||
properties = context.system_properties | ||
self.manifest = context.manifest | ||
model_dir = properties.get("model_dir") | ||
self.model = XGBClassifier() | ||
if "serializedFile" in self.manifest["model"]: | ||
serialized_file = self.manifest["model"]["serializedFile"] | ||
model_weights = os.path.join(model_dir, serialized_file) | ||
self.model.load_model(model_weights) | ||
|
||
mapping_file_path = os.path.join(model_dir, "index_to_name.json") | ||
self.mapping = load_label_mapping(mapping_file_path) | ||
|
||
logger.info( | ||
f"XGBoost Classifier for iris dataset with weights {model_weights} loaded successfully" | ||
) | ||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
inputs = [] | ||
for row in requests: | ||
input = row.get("data") or row.get("body") | ||
if isinstance(input, (bytes, bytearray)): | ||
input = [float(value) for value in input.decode("utf-8").split(",")] | ||
inputs.append(input) | ||
return np.array(inputs) | ||
|
||
def inference(self, data): | ||
return self.model.predict(data) | ||
|
||
def postprocess(self, result): | ||
output = [self.mapping[str(res)] for res in result.tolist()] | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Importing dataset from sklearn | ||
from sklearn import datasets, metrics | ||
|
||
iris = datasets.load_iris() # dataset loading | ||
X = iris.data # Features stored in X | ||
y = iris.target # Class variable | ||
|
||
# Splitting dataset into Training (80%) and testing data (20%) using train_test_split | ||
from sklearn.model_selection import train_test_split | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=0.2, random_state=42 | ||
) | ||
|
||
# Create an XGB classifier and instance of the same | ||
from xgboost import XGBClassifier | ||
|
||
clf = XGBClassifier() | ||
|
||
clf.fit(X_train, y_train) | ||
|
||
y_pred = clf.predict(X_test) | ||
# classification accuracy | ||
from sklearn import metrics | ||
|
||
print(f"Model accuracy is {metrics.accuracy_score(y_test, y_pred)}") | ||
saved_model = "iris_model.json" | ||
print(f"Saving trained model to {saved_model}") | ||
clf.save_model(saved_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1218,6 +1218,7 @@ venv | |
TorchInductor | ||
Pytests | ||
deviceType | ||
XGBoost | ||
Clamd | ||
Fickling | ||
TorchServer | ||
|