Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast learning_rate to float lambda for pickle safety when doing model.load #1901

Merged
merged 5 commits into from
Apr 22, 2024

Conversation

markscsmith
Copy link
Contributor

@markscsmith markscsmith commented Apr 19, 2024

Description

closes #1900

Motivation and Context

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@@ -92,7 +92,7 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
return value_schedule
return lambda _: float(value_schedule(_))
Copy link
Member

@araffin araffin Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a better solution is to do a call to value_schedule(1.0) and check that the return type is a float (and output a useful error message if not).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, what you do is fine but I would explicitly name the parameter progress_remaining and add a comment of why

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm... I see the value in both. Let me noodle a bit and I'll see if I can sort it out during my lunch. Thanks araffin!

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks =)

@araffin araffin merged commit 9a74938 into DLR-RM:master Apr 22, 2024
4 checks passed
@markscsmith
Copy link
Contributor Author

Awesome! Thanks again araffin! The docs on SB3 you and the crew wrote and your guidance made this a breeze :)

friedeggs pushed a commit to friedeggs/stable-baselines3 that referenced this pull request Jul 22, 2024
….load (DLR-RM#1901)

* create failing test for unpickle error

* Fix learning_rate argument causing failure in weights_only=True if passed a function with non-float types

* Updated with feedback from araffin on PR#1901

* Update test and version

* Update changelog and SBX doc

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: if learning_rate function uses special types, they can cause torch.load to fail when weights_only=True
2 participants