Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Aug 15, 2023
1 parent 1eb3202 commit 289f545
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/deepsparse/benchmark/data_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def generate_random_image_data(
f"Using default image shape {image_shape}"
)

input_data = [numpy.random.randint(0, high=255, size=image_shape).astype(numpy.uint8) for _ in range(batch_size)]
input_data = [
numpy.random.randint(0, high=255, size=image_shape).astype(numpy.uint8)
for _ in range(batch_size)
]
return input_data


Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/benchmark/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DEFAULT_STRING_LENGTH = 50
DEFAULT_IMAGE_SHAPE = (240, 240, 3)


class ThreadPinningMode:
CORE: str = "core"
NUMA: str = "numa"
Expand Down
4 changes: 3 additions & 1 deletion tests/test_pipeline_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def test_generate_random_question_data():
avg_word_len = 10
config_args = {"gen_sequence_length": 50}
config = PipelineBenchmarkConfig(**config_args)
question, context = generate_random_question_data(config, 1, avg_word_len=avg_word_len)
question, context = generate_random_question_data(
config, 1, avg_word_len=avg_word_len
)
assert len(question) == config.gen_sequence_length
assert len(context) == config.gen_sequence_length
num_q_spaces = question.count(" ")
Expand Down

0 comments on commit 289f545

Please sign in to comment.