diff --git a/src/sparsezoo/deployment_package/utils/extractors.py b/src/sparsezoo/deployment_package/utils/extractors.py index c9701efa..a1cddec9 100644 --- a/src/sparsezoo/deployment_package/utils/extractors.py +++ b/src/sparsezoo/deployment_package/utils/extractors.py @@ -18,6 +18,7 @@ import logging from types import MappingProxyType +from typing import Optional from sparsezoo import Model @@ -32,7 +33,12 @@ def _size(model: Model) -> float: return size -def _throughput(model: Model, num_cores: int = 24, batch_size: int = 64) -> float: +def _throughput( + model: Model, + num_cores: int = 24, + batch_size: int = 64, + device_info: Optional[str] = None, +) -> float: # num_cores : 24, batch_size: 64 are standard defaults in sparsezoo throughput_results = getattr(model, "validation_results", {}).get("throughput", []) @@ -40,6 +46,7 @@ def _throughput(model: Model, num_cores: int = 24, batch_size: int = 64) -> floa if ( throughput_result.batch_size == batch_size and throughput_result.num_cores == num_cores + and (device_info is None or (throughput_result.device_info == device_info)) ): return throughput_result.recorded_value diff --git a/tests/sparsezoo/deployment_package/utils/test_extractors.py b/tests/sparsezoo/deployment_package/utils/test_extractors.py index 62f24d2f..23ce699f 100644 --- a/tests/sparsezoo/deployment_package/utils/test_extractors.py +++ b/tests/sparsezoo/deployment_package/utils/test_extractors.py @@ -64,16 +64,17 @@ def model(): @pytest.mark.parametrize( - "num_cores, batch_size, expected", + "num_cores,batch_size,device_info,expected", [ - (24, 64, 1948.45), + (24, 64, "c6i.12xlarge", 1948.45), ], ) -def test_throughput_extractor(model, num_cores, batch_size, expected): +def test_throughput_extractor(model, num_cores, batch_size, device_info, expected): actual_throughput = _throughput( model=model, num_cores=num_cores, batch_size=batch_size, + device_info=device_info, ) assert actual_throughput == expected