Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Model's state_dict not loading properly when loading a ckpt #484

Closed
inafergra opened this issue Jul 2, 2024 · 5 comments · Fixed by #516
Closed

[Bug] Model's state_dict not loading properly when loading a ckpt #484

inafergra opened this issue Jul 2, 2024 · 5 comments · Fixed by #516
Labels
bug Something isn't working

Comments

@inafergra
Copy link
Collaborator

Right now, when loading model checkpoints (QuantumModel or QNN) the state_dict of the model is not loaded properly. The issue is in the load_model() function, which treats the QuantumModel.from_dict() classmethod as an in-place operation

iteration, model_dict = torch.load(folder / model_ckpt_name, *args, **kwargs)
if isinstance(model, (QuantumModel, QNN)):
model._from_dict(model_dict, as_torch=True)

when in reality .from_dict() is a class method and does not act as an in-place operation, as it returns a new class instance with the parameters taken from the input model state_dict:

qadence/qadence/model.py

Lines 276 to 308 in 70e12d4

@classmethod
def _from_dict(cls, d: dict, as_torch: bool = False) -> QuantumModel:
from qadence.serialization import deserialize
qm: QuantumModel
try:
qm_dict = d[cls.__name__]
qm = cls(
circuit=QuantumCircuit._from_dict(qm_dict["circuit"]),
observable=(
None
if not isinstance(qm_dict["observable"], list)
else [deserialize(q_obs) for q_obs in qm_dict["observable"]] # type: ignore[misc]
),
backend=qm_dict["backend"],
diff_mode=qm_dict["diff_mode"],
measurement=Measurements._from_dict(qm_dict["measurement"]),
noise=Noise._from_dict(qm_dict["noise"]),
configuration=config_factory(qm_dict["backend"], qm_dict["backend_configuration"]),
)
if as_torch:
conv_pd = torch.nn.ParameterDict()
param_dict = d["param_dict"]
for n, param in param_dict.items():
conv_pd[n] = torch.nn.Parameter(param)
qm._params = conv_pd
logger.debug(f"Initialized {cls.__name__} from {d}.")
except Exception as e:
logger.warning(f"Unable to deserialize object {d} to {cls.__name__} due to {e}.")
return qm

The most straightforward solution would be to do model = model.from_dict() when loading the state_dict. However, this doesn't work since then, when using the train_grad() loop, the parameters passed to the optimizer before the training loop do not correspond to the newly created instance of the model, and therefore the parameters won't be updated in optimizer.step().

One solution would be to make .from_dict() an instance method instead of a class method so that it acts as an in-place operation, by mutating the instance attributes instead of creating a new instance. This might cause some additional issues (@dominikandreasseitz mentioned the uuid assignments might cause problems?)

@inafergra inafergra added the bug Something isn't working label Jul 2, 2024
@inafergra
Copy link
Collaborator Author

@inafergra inafergra changed the title [Bug] Model's state_dict not loading properly [Bug] Model's state_dict not loading properly when loading a ckpt Jul 2, 2024
@Roland-djee
Copy link
Collaborator

Tagging @dominikandreasseitz @jpmoutinho @Roland-djee

Thanks for this @inafergra. I guess @dominikandreasseitz or @chMoussa are the best placed to have a look ?

@dominikandreasseitz
Copy link
Collaborator

Tagging @dominikandreasseitz @jpmoutinho @Roland-djee

Thanks for this @inafergra. I guess @dominikandreasseitz or @chMoussa are the best placed to have a look ?

thanks @inafergra. @Roland-djee , since @inafergra already knows how to solve it he could coordinate with @chMoussa on fixing this. i can ofc also do it. Let me know

@Roland-djee
Copy link
Collaborator

Tagging @dominikandreasseitz @jpmoutinho @Roland-djee

Thanks for this @inafergra. I guess @dominikandreasseitz or @chMoussa are the best placed to have a look ?

thanks @inafergra. @Roland-djee , since @inafergra already knows how to solve it he could coordinate with @chMoussa on fixing this. i can ofc also do it. Let me know

If that is fine with them sure.

@Roland-djee
Copy link
Collaborator

@inafergra @chMoussa Would you be able to coordinate on this ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants