Skip to content

Commit

Permalink
Update sklearn example's server and client (#1347)
Browse files Browse the repository at this point in the history
Co-authored-by: George <george@pop-os.localdomain>
  • Loading branch information
gxenos and George authored Aug 1, 2022
1 parent 3d6b2f5 commit 0eafca6
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/sklearn-logreg-mnist/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# Define Flower client
class MnistClient(fl.client.NumPyClient):
def get_parameters(self): # type: ignore
def get_parameters(self, config): # type: ignore
return utils.get_model_parameters(model)

def fit(self, parameters, config): # type: ignore
Expand All @@ -46,4 +46,4 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"accuracy": accuracy}

# Start Flower client
fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient())
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MnistClient())
2 changes: 1 addition & 1 deletion examples/sklearn-logreg-mnist/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ authors = [

[tool.poetry.dependencies]
python = "^3.8"
flwr = "^0.19.0"
flwr = "^1.0.0"
# flwr = { path = "../../", develop = true } # Development
scikit-learn = "^1.1.1"
openml = "^0.12.2"
4 changes: 2 additions & 2 deletions examples/sklearn-logreg-mnist/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_evaluate_fn(model: LogisticRegression):
_, (X_test, y_test) = utils.load_mnist()

# The `evaluate` function will be called after every round
def evaluate(parameters: fl.common.Weights):
def evaluate(server_round, parameters: fl.common.NDArrays, config):
# Update model with the latest parameters
utils.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test))
Expand All @@ -39,5 +39,5 @@ def evaluate(parameters: fl.common.Weights):
fl.server.start_server(
server_address="0.0.0.0:8080",
strategy=strategy,
config={"num_rounds": 5},
config=fl.server.ServerConfig(num_rounds=5),
)

0 comments on commit 0eafca6

Please sign in to comment.