Skip to content

Commit

Permalink
[Fix] Allow for processing Path in the sparsezoo analysis (#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Feb 2, 2024
1 parent 6e0d12b commit 1a9ee4b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/sparsezoo/analyze/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from typing import Any, Dict, List, Optional, Union

import numpy
import onnx
import yaml
from onnx import ModelProto, NodeProto
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt
Expand Down Expand Up @@ -68,6 +67,7 @@
is_parameterized_prunable_layer,
is_quantized_layer,
is_sparse_layer,
load_model,
)


Expand Down Expand Up @@ -914,7 +914,7 @@ def from_onnx(cls, onnx_file_path: Union[str, ModelProto]):
model_onnx = onnx_file_path
model_name = ""
else:
model_onnx = onnx.load(onnx_file_path)
model_onnx = load_model(onnx_file_path)
model_name = str(onnx_file_path)

model_graph = ONNXGraph(model_onnx)
Expand Down
4 changes: 2 additions & 2 deletions src/sparsezoo/utils/onnx/external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def validate_onnx(model: Union[str, ModelProto]):
raise ValueError(f"Invalid onnx model: {err}")


def load_model(model: Union[str, ModelProto]) -> ModelProto:
def load_model(model: Union[str, ModelProto, Path]) -> ModelProto:
"""
Load an ONNX model from an onnx model file path. If a ModelProto
is given, then it is returned.
Expand All @@ -185,7 +185,7 @@ def load_model(model: Union[str, ModelProto]) -> ModelProto:
if isinstance(model, ModelProto):
return model

if isinstance(model, str):
if isinstance(model, (Path, str)):
return onnx.load(clean_path(model))

raise ValueError(f"unknown type given for model: {type(model)}")
Expand Down

0 comments on commit 1a9ee4b

Please sign in to comment.