Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) #3883

Merged
merged 21 commits into from
Feb 3, 2021

Conversation

jameslamb
Copy link
Collaborator

@jameslamb jameslamb commented Jan 31, 2021

This PR proposes a different pattern for how users can provide a Dask client to model objects in the Dask module.

It removes the keyword argument client from fit() and predict(). As of this PR, users can optionally pass a client into the model's constructor. If they do that, it is saved on the model object and used by fit() and predict(). Otherwise, distributed.default_client() is used at runtime.

Given that you've set up a cluster and client...

import lightgbm as lgb
from distributed import LocalCluster, Client

cluster = LocalCluster()
client = Client(cluster)

...any of the following would work as of this PR.

# keyword arg in constructor
clf = lgb.DaskLGBMClassifier(client=client)
clf.fit(X, y)

# updating after construction with `.set_params(client=client)`
clf = lgb.DaskLGBMClassifier()
clf.set_params(client = client)
clf.fit(X, y)

# never telling LightGBM to use a specific client
clf = lgb.DaskLGBMClassifier()
clf.fit(X, y)

Why prefer this pattern?

Based on feedback from @jsignell and @martindurant in #3808 (comment) and #3808 (comment), who felt comfortable with the use of default_client() when none is given, and who made these excellent points:

I think the dasky-est approach would be to ... assume that they have initiated a client already

I do think that allowing an optional argument to the model constructor is a reasonable and useful addition; you might just want to have separate things fitting on different clusters simultaneously

This pattern of an optional keyword-only argument in the constructor makes the signatures of .predict() and .fit() identical to those in the scikit-learn model objects.

I was also about to work around not being able to pickle a model object that had a Dask client as an attribute, which was my main concern with that approach.

This is very close to what xgboost chose (https://github.com/dmlc/xgboost/blob/d8ec7aad5a9a3eb580c55680aee8ad1a975cba20/python-package/xgboost/dask.py#L1386-L1394), except that today I believe they require users to set .client after constructing a model object.

cuml also allows for passing client into the model constructor (https://github.com/rapidsai/cuml/blob/a22681c89630926c48f98ac3cb54d39bd2b91026/python/cuml/dask/common/base.py#L41-L45).

Changes in this PR

  • adds optional keyword-only argument client to constructor for Dask model objects
  • adds option for users to set a model object's Dask client by using .client after construction
  • removes keyword argument client from .fit() and .predict() for Dask model objects
  • adds a __getstate__() method on lightgbm.sklearn.LGBMModel, to handle removing a stored Dask client from a model before serializing it. (see the pickle docs for information on how this works)
  • adds unit tests that Dask model objects can be successfully serialized and read back in with pickle, joblib, and cloudpickle
  • copies the constructor from lightgbm.sklearn equivalents into the Dask module, so that customizing documentation works correctly (see comment on the relevant part of the diff). Adds a unit test that the signatures for Dask model objects and their sklearn equivalent are identical.

Notes for Reviewers

  • using the docs/jlamb branch so we can see readthedocs builds
  • I chose not to label this breaking because there has never been a lightgbm release with the dask module. This change does make lightgbm's API different from the one in dask-lightgbm, but that shouldn't be considered a breaking change from LightGBM's perspective
  • I believe we should strive to make the Dask model objects pickleable (can be documented more fully in [docs] [dask] Document how to save a Dask model #3838), so that users can store and load them just like they do other model objects. This PR makes that explicit with tests.

I know this is a small PR but with a lot of explanation and considerations. Thank you for your time and energy reviewing it!



class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""

def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to have the desired effect! When I built the docs locally, I saw that __init__() docs for DaskLGBMClassifier, DaskLGBMRegressor, and DaskLGBMRanker have the client doc added, and the docs for their scikit-learn equivalents do not.

DaskLGBMClassifier.init()

image

LGBMClassifier.init()

image

Why I think copying is the best alternative

I tried several other ways to update the docstrings, but none of them quite worked. I'll explain in terms of *Classifier, but this applies for *Regressor and *Ranker as well.

  1. Editing DaskLGBMClassifier.__init__.doc
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
    ...

_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
DaskLGBMClassifier.__init__.__doc__ = (
    _before_kwargs
    + 'client : dask.distributed.Client or None, optional (default=None)\n'
    + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
    + ' ' * 8 + _kwargs + _after_kwargs
)

Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for client. Those 3 copies also show up on the docs for lightgbm.sklearn.LGBMClassifier.

Screen Shot 2021-01-30 at 11 23 24 PM

  1. Setting up a pass-through __init__()
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    _base_doc = __init__.__doc__
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
    __init__.__doc__ = (
        _before_kwargs
        + 'client : dask.distributed.Client or None, optional (default=None)\n'
        + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
        + ' ' * 8 + _kwargs + _after_kwargs
    )

This results in this error at runtime.

doing this to avoid:
RuntimeError: scikit-learn estimators should always specify their parameters in
the signature of their __init__ (no varargs).
<class 'lightgbm.dask.DaskLGBMClassifier'> with constructor (self, *args, **kwargs)
doesn't  follow this convention.
  1. Just copying __init__

Both of these variations:

class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
     __init__ = LGBMClassifier.__init__

    _base_doc = __init__.__doc__
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
    __init__.__doc__ = (
        _before_kwargs
        + 'client : dask.distributed.Client or None, optional (default=None)\n'
        + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
        + ' ' * 8 + _kwargs + _after_kwargs
    )

Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for client. Those 3 copies also show up on the docs for lightgbm.sklearn.LGBMClassifier. (same as in the image above)

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb I'm afraid this is an antipattern in the scikit-learn world! Our classes will be incompatible with scikit-learn dask ecosystem.
This is why you got

RuntimeError: scikit-learn estimators should always specify their parameters in
the signature of their __init__ (no varargs).

I believe the best option will be passing client in __init__ as normal argument with None as default value. I mean the following:

class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
    def __init__(self, ..., client=None, **kwargs):
        self.client = client
        super().__init__(...)

Also, do not suggest users setting client directly via attribute. set_params(client=new_client) should be used.

Comment on lines 331 to 336
def __getstate__(self):
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("_client", None)
out = copy.deepcopy(self.__dict__)
self._client = client
return out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this can be done in Dask estimators, not in parent class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do it there, this code would have to be copied 3 times. Because _DaskLGBMModel comes second in MRO, it wouldn't be safe to just put this on that mixin. Are you ok with that duplication?

Copy link
Collaborator

@StrikerRUS StrikerRUS Jan 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably __getstate__() is overridden somewhere above in pure scikit-learn branch of inheritance. But we can name this method as _getstate() in _DaskLGBMModel and call it, so no duplication will be required.

def __getstate__(self):
   return self._get_state()

Just like we currently do for all other methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, good suggestion! did that in 344376b.

I called it _lgb_get_state because I'm nervous about using general-sounding method names in _DaskLGBMModel when there are so many classes upstream of LGBMClassifier. I'm nervous about some version of scikit-learn introducing a _get_state that would override this. We might not catch that in tests because we don't test against many different versions of scikit-learn.

I think I'll propose a PR to do that for other methods on the _DaskLGBMModel mixin as well.

@jameslamb
Copy link
Collaborator Author

jameslamb commented Jan 31, 2021

I believe the best option will be passing client in init as normal argument with None as default value. I mean the following:

class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
def init(self, ..., client=None, **kwargs):
self.client = client
super().init(...)
Also, do not suggest users setting client directly via attribute. set_params(client=new_client) should be used.

I'm really struggling to adopt this pattern. sklearn.BaseEstimator (the parent class that get_params() comes from), assumes that every keyword argument from the constructor counts as a parameter and returns it from .get_params().

https://github.com/scikit-learn/scikit-learn/blob/b3ea3ed6a/sklearn/base.py#L152

So treating client like this is causing all of the tests to fail, with complaints like

TypeError: cannot pickle '_asyncio.Task' object

Probably because get_params() is called in

params = self.get_params(True)
and then passed through to _train(). The client is in it and not pickleable.

I pushed 344376b so you can see what I tried.


I really think that we should not expose client as a keyword argument in the signature, and should revert that part of 344376b and use my original proposal to .pop() it off of kwargs. I also believe we should not use .set_params() for client.

I understand that you said my original proposal is a scikit-learn antipattern, but I don't think it's appropriate to treat client like "just another parameter". scikit-learn's preference for having explicit kwargs and having each argument be an instance attribute are described in "developing scikit-learn estimators"

every keyword argument accepted by __init__ should correspond to an attribute on the instance. Scikit-learn relies on this to find the relevant attributes to set on an estimator when doing model selection.

There should be no logic, not even input validation, and the parameters should not be changed...The reason for postponing the validation is that the same validation would have to be performed in set_params, which is used in algorithms like GridSearchCV.

Given this, I really don't think the pattern I proposed violates the intention from scikit-learn.

  • you would never search over values of client in hyperparameter tuning
  • client isn't required to be the same to reproduce a model result
  • client shouldn't ever be saved in a model file, even if it could be pickled, because it refers to temporary infrastructure that isn't guaranteed to be there when you load that model in the future

@StrikerRUS
Copy link
Collaborator

Given this, I really don't think the pattern I proposed violates the intention from scikit-learn.

Scikit-learn is really complex system with many undocumented "features", rules and assumptions about inherited classes. I guess you've already known this. So it is very likely that violating some of their rules can break a lot of integrations. Of course, inability to use hyperparameter search for client argument is not a serious problem, but there is no guaranty that it is not breaking a whole hyperparameter tunability.

Also, I'm strongly against interfering into sklearn.py module with changes required for child Dask module. Making a parent aware of its children is another serious programming antipattern.

I believe that we can achieve the goal to make _DaskLGBMModel class serialisable without much harm to sklearn compatibility and hacky workarounds. I'll try some ideas in a next days.

@jameslamb
Copy link
Collaborator Author

I'll try some ideas in a next days

Don't worry about it, I know you've mentioned you can't test Dask features in your local dev environment. I'll continue down this path you've suggested and we can look at the diff at the end.

@StrikerRUS
Copy link
Collaborator

Don't worry about it, I know you've mentioned you can't test Dask features in your local dev environment.

But I could use our Linux CI job 😉 . Not as comfortable as own local setup, but better than nothing.

Only one thing is making me puzzled.

# fitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object

Given that dask.distributed.Client is unpickable, how fitted_model.client === model_loaded_from_disk.client can be achieved without overwriting __setstate__ and finding a way to save client metainfo that will help to re-create it after loading from disk?
For example, Booster object is saved into a string and recreated from it during unpickling process:

def __getstate__(self):
this = self.__dict__.copy()
handle = this['handle']
this.pop('train_set', None)
this.pop('valid_sets', None)
if handle is not None:
this["handle"] = self.model_to_string(num_iteration=-1)
return this
def __setstate__(self, state):
model_str = state.get('handle', None)
if model_str is not None:
handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterLoadModelFromString(
c_str(model_str),
ctypes.byref(out_num_iterations),
ctypes.byref(handle)))
state['handle'] = handle
self.__dict__.update(state)

@jameslamb
Copy link
Collaborator Author

jameslamb commented Jan 31, 2021

Please don't spend time on it, I want to be respectful of your time and I think I can go faster testing locally.

For your questiin about pickling...it's acceptable and expected that the model loaded from a file doesn't have .client set. A client is a live connection to a specfic Dask scheduler and you have no guarantee that exact same schefuler will still exist or be accessible from wherever you load the model. This has no impact on the correctness of the model. Load the same model.pkl and call .predict() on the same data on two different Dask clusters, you will get identical results.

@StrikerRUS
Copy link
Collaborator

Yeah, you'll definitely be faster implementing this, but I'd like to help.

Then I'm listing here some initial ideas for your consideration I wanted to try first:

  • move if client is None: client = default_client() into _train() function;
  • simply use self.client = client in __init__s; set_params should be used outside class implementation by users, not by us in the internals;
  • pop self.client in __getstate__ and replace it with something that will help to "clone" current client;
  • in __setstate__ use saved metainfo about client to create a new one and save in self.client.

@StrikerRUS
Copy link
Collaborator

For the last two points in my list of initial ideas I meant to find something close to the following:
https://github.com/kichappa/QC/blob/ad9f756d8fafac3916688fd6dfb3786ae7c2aa14/distributed/variable.py#L203-L213

@StrikerRUS
Copy link
Collaborator

@jameslamb Seems pickling tests are still failing. Please check it.

I disagree with this. I think using an @property on the estimator is a better pattern, because then user code can access model.client and know for sure that that will always return the client to be used in training or prediction (either one they set or distributed.default_client()).

Agree, makes sense! But please rename this property to client_ according to scikit-learn rules. All normal attributes (without trailing and leading underscores) should be set in __init__ and their values should never be changed. For any computed during fit stage attributes should attributes with trailing underscores be used.

https://scikit-learn.org/stable/developers/develop.html#parameters-and-init

Please check the following example with objective argument.

objective : string, callable or None, optional (default=None)
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
Default: 'regression' for LGBMRegressor, 'binary' or 'multiclass' for LGBMClassifier, 'lambdarank' for LGBMRanker.

self.objective = objective

self._objective = objective

if self._objective is None:
if isinstance(self, LGBMRegressor):
self._objective = "regression"
elif isinstance(self, LGBMClassifier):
self._objective = "binary"
elif isinstance(self, LGBMRanker):
self._objective = "lambdarank"
else:
raise ValueError("Unknown LGBMModel type.")

@property
def objective_(self):
""":obj:`string` or :obj:`callable`: The concrete objective used while fitting this model."""
if self._n_features is None:
raise LGBMNotFittedError('No objective found. Need to call fit beforehand.')
return self._objective

@jameslamb
Copy link
Collaborator Author

Ok I think I've made the requested changes in af269b5. Add client_ so now we have client, _client, and _client. And fixed the pickling tests.

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb
Awesome work!
Thanks a lot for digging deep in sklearn internals and finding way not to make a parent aware of its children. I left some comments for your consideration.

This property can be passed in the constructor or updated
with ``model.set_params(client=client)``.
"""
if self._client is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sklearn requires that unfitted models raise NotFittedError and checks it via accessing "post-fitted" attributes: attributes with trailing underscores.
BTW, I think it will be great to setup sklearn integration tests at our CI for Dask classes. It will not allow us to be sure that our classes fully compatible with sklearn but at least will check basic compatibility. WDYT?

def test_sklearn_integration(estimator, check):

Suggested change
if self._client is None:
if self._n_features is None:
raise LGBMNotFittedError('No client found. Need to call fit beforehand.')
if self._client is None:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, ok. Why should the check be on ._n_features? That seems kind of indirect. Shouldn't it be on .fitted_ directly for a check of whether or not the model has been fit?

I'm also confused how you would like _fit() to work. I'm currently passing client=self.client_ into self._fit(), and relying on that to resolve whether there is a client stored on the object or get_default_client() should be used. This suggested change would break that behavior, because of course the model is probably not fitted yet at the time you call .fit(). What do you suggest I do for that?

BTW, I think it will be great to setup sklearn integration tests at our CI for Dask classes. It will not allow us to be sure that our classes fully compatible with sklearn but at least will check basic compatibility. WDYT

Sure, I think that's fine. I think it's an ok fit for this PR and will add it here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should the check be on ._n_features?

For historical reasons. fitted_ was introduced quite recently in our sklearn wrapper. But now you can use it, indeed 👍 .

I'm also confused how you would like _fit() to work.

Oh I see now! Can you pass self.client into _fit() and there assign result of a check to self._client? While client_ will always return self._client or raise error? Just like all other properties in sklearn wrapper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's an ok fit for this PR and will add it here.

Maybe in a follow-up PR? I'm afraid that it can fail now and stop merging client argument migration.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to be really sure that there is exactly one place where resolve which client to use (#3883 (comment)).

So I don't want to put this code into the body of _fit().

if self.client is None:
    client = default_client()
else:
    client = self.client

I think it would work to pull that out into a small function like this:

def _choose_client(client: Optional[Client]):
    if self.client is None:
        return default_client()
    else:
        return self.client

Then use that in both the client_ property (after the fitted check) and in _fit(). That would give us confidence that accessing .client_ returns the same client that will be used in training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added #3894 to capture the task of adding a scikit-learn compatibility test. I made this a "good first issue" because I think it could be done without deep knowledge of LightGBM, but I also think it's ok for you or me to pick up in the near future (we don't need to reserve it for new contributors).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think I've addressed these suggestions in 555a57a

  • added an LGBMNotFittedError if accessing .client_ for a not-yet-fitted model
  • added an internal function _get_dask_cluster() to hold the logic of using default_client() if client is None
  • changed the unit tests to use two different clusters and clients

Note that the change to have two clusters active in the same test will result in a lot of this warning in the logs:

test_dask.py::test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly[False-ranking-cloudpickle]
/home/jlamb/miniconda3/lib/python3.7/site-packages/distributed/node.py:155: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 40743 instead
http_address["port"], self.http_server.port

This is not something we have to worry about, and I'd rather leave it alone and let Dask do the right thing (picking a random port for the scheduler when the default one is unavailable) than add more complexity to the tests by adding our own logic to set the scheduler port.

__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, it was not clear enough that any client will not be saved. "This" may refer to distributed.default_client(), I guess, and confuse users that custom client will be saved...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I can change it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in 555a57a to this

The Dask client used by this class will not be saved if the model object is pickled.

tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should raise UnfittedError according to sklearn policy.


# should be able to set client after construction
dask_model = model_factory(**params)
dask_model.set_params(client=client)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set here some custom client to ensure that path branch else for condition if self._client is None: works?

}

if set_client:
params.update({"client": client})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here about custom client to check difference between default_client() and provided by user.

tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
Comment on lines 792 to 794
_compare_spec(lgb.DaskLGBMClassifier, lgb.LGBMClassifier)
_compare_spec(lgb.DaskLGBMRegressor, lgb.LGBMRegressor)
_compare_spec(lgb.DaskLGBMRanker, lgb.LGBMRanker)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @pytest.mark.parametrize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, updated in 555a57a

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb Great job! And I think this API is quite intuitive.
Just two nits.

Comment on lines 456 to 457
# self._client is set in the constructor of classes that use this mixin
_client: Optional[Client] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like _client was replaced with _get_dask_client() function in the latest commit and is not needed anymore.

Copy link
Collaborator Author

@jameslamb jameslamb Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good point! Removed in 876bfe5

Comment on lines 835 to 837
def _compare_spec(dask_cls, sklearn_cls):
dask_spec = inspect.getfullargspec(dask_cls)
sklearn_spec = inspect.getfullargspec(sklearn_cls)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is possible to simplify here by removing inner function and setting directly

dask_spec = inspect.getfullargspec(classes[0])
sklearn_spec = inspect.getfullargspec(classes[1])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right! didn't think about how using parametrize meant this function was useless

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this function in 876bfe5

@jameslamb
Copy link
Collaborator Author

Great job! And I think this API is quite intuitive.

Thanks for talking me through it! Learned a lot more about scikit-learn doing this, you're a good teacher.

I've addressed the last two comments in 876bfe5. Will merge after I see if CI + readthedocs passes, and after checking the docs site.

@jameslamb
Copy link
Collaborator Author

readthedocs looks ok to me!

image

https://lightgbm.readthedocs.io/en/docs-jlamb/

no change to LGBMClassifier docs

image

dask docs have client in them (I checked classifier, regressor, and ranker)

image

client_ shows up in attributes table

image

with the expected docs

image

client is gone from the fit() docs

image

@StrikerRUS
Copy link
Collaborator

@jameslamb Thank you! Sorry my help was in a form of only words not of actual code 🙁 .

@github-actions
Copy link

This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 24, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants