diff --git a/presto/presto.py b/presto/presto.py index 1edbc7c..6821455 100644 --- a/presto/presto.py +++ b/presto/presto.py @@ -1,6 +1,7 @@ import math from copy import deepcopy from typing import Optional, Sized, Tuple, Union, cast +from pathlib import Path import numpy as np import torch @@ -785,7 +786,7 @@ def construct_finetuning_model( return model @classmethod - def load_pretrained(cls): + def load_pretrained(cls, modelpath: Union[str, Path] = default_model_path): model = cls.construct() - model.load_state_dict(torch.load(default_model_path, map_location=device)) + model.load_state_dict(torch.load(modelpath, map_location=device)) return model