Skip to content

Commit

Permalink
Allow custom model path
Browse files Browse the repository at this point in the history
  • Loading branch information
kvantricht committed Oct 18, 2023
1 parent 445d7ca commit ef3133f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions presto/presto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from copy import deepcopy
from typing import Optional, Sized, Tuple, Union, cast
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -785,7 +786,7 @@ def construct_finetuning_model(
return model

@classmethod
def load_pretrained(cls):
def load_pretrained(cls, modelpath: Union[str, Path] = default_model_path):
model = cls.construct()
model.load_state_dict(torch.load(default_model_path, map_location=device))
model.load_state_dict(torch.load(modelpath, map_location=device))
return model

0 comments on commit ef3133f

Please sign in to comment.