Skip to content

Commit

Permalink
add device_info to _extract_throughput (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Apr 7, 2023
1 parent b060d3e commit f7d8588
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
9 changes: 8 additions & 1 deletion src/sparsezoo/deployment_package/utils/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import logging
from types import MappingProxyType
from typing import Optional

from sparsezoo import Model

Expand All @@ -32,14 +33,20 @@ def _size(model: Model) -> float:
return size


def _throughput(model: Model, num_cores: int = 24, batch_size: int = 64) -> float:
def _throughput(
model: Model,
num_cores: int = 24,
batch_size: int = 64,
device_info: Optional[str] = None,
) -> float:
# num_cores : 24, batch_size: 64 are standard defaults in sparsezoo
throughput_results = getattr(model, "validation_results", {}).get("throughput", [])

for throughput_result in throughput_results:
if (
throughput_result.batch_size == batch_size
and throughput_result.num_cores == num_cores
and (device_info is None or (throughput_result.device_info == device_info))
):
return throughput_result.recorded_value

Expand Down
7 changes: 4 additions & 3 deletions tests/sparsezoo/deployment_package/utils/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,17 @@ def model():


@pytest.mark.parametrize(
"num_cores, batch_size, expected",
"num_cores,batch_size,device_info,expected",
[
(24, 64, 1948.45),
(24, 64, "c6i.12xlarge", 1948.45),
],
)
def test_throughput_extractor(model, num_cores, batch_size, expected):
def test_throughput_extractor(model, num_cores, batch_size, device_info, expected):
actual_throughput = _throughput(
model=model,
num_cores=num_cores,
batch_size=batch_size,
device_info=device_info,
)
assert actual_throughput == expected

Expand Down

0 comments on commit f7d8588

Please sign in to comment.