Skip to content

Commit

Permalink
Fixing Quant % Calcuation (#462)
Browse files Browse the repository at this point in the history
* initial fix

* style
  • Loading branch information
Satrat committed Feb 22, 2024
1 parent cd2b23a commit 36fd754
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 35 deletions.
13 changes: 8 additions & 5 deletions src/sparsezoo/analyze_v2/memory_access_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_quantization(self) -> List["QuantizationAnalysisSchema"]:
:returns: List of quantization analysis pydantic models for each grouping
if the node has weights
"""
data = get_memeory_access_bits(self.model_graph, self.node, self.node_shape)
data = get_memory_access_bits(self.model_graph, self.node, self.node_shape)
if data is not None:
quantization_analysis_model = []
for grouping, counts_dict in data.items():
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_memory_access_counts(
}


def get_memeory_access_bits(
def get_memory_access_bits(
model_graph: ONNXGraph,
node: NodeProto,
node_shape: Dict,
Expand All @@ -164,12 +164,15 @@ def get_memeory_access_bits(
)
node_weight = get_node_weight(model_graph, node)
precision = get_numpy_quantization_level(node_weight)
bits = memory_access_counts["single"]["counts"] * precision
bits_quant = bits * is_quantized_layer(model_graph, node)
counts = memory_access_counts["single"]["counts"]
bits = counts * precision
is_quantized = is_quantized_layer(model_graph, node)

return {
"tensor": {
"bits": bits,
"bits_quant": bits_quant,
"bits_quant": bits * is_quantized,
"counts": counts,
"counts_quant": counts * is_quantized,
}
}
14 changes: 7 additions & 7 deletions src/sparsezoo/analyze_v2/model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def calculate_sparsity_percentage(self, category: Dict):
counts = category["counts"]
return (counts_sparse / counts) * 100 if counts != 0 else 0

def calculate_quantized_percentage(self, tensor: Dict):
bits_quant = tensor["bits_quant"]
bits = tensor["bits"]
return (bits_quant / bits) * 100 if bits != 0 else 0
def calculate_quantized_percentage(self, tensor: Dict, counts_prefix: str):
counts_quant = tensor[f"{counts_prefix}_quant"]
counts = tensor[counts_prefix]
return (counts_quant / counts) * 100 if counts != 0 else 0

def __repr__(self):
data = self.to_dict()
Expand All @@ -93,7 +93,7 @@ def __repr__(self):
)
param_size = summaries["params"]["quantization"]["tensor"]["bits"]
param_quantized = self.calculate_quantized_percentage(
summaries["params"]["quantization"]["tensor"]
summaries["params"]["quantization"]["tensor"], "counts"
)

ops_total = summaries["ops"]["sparsity"]["single"]["counts"]
Expand All @@ -102,7 +102,7 @@ def __repr__(self):
)
ops_size = summaries["ops"]["quantization"]["tensor"]["bits"]
ops_quantized = self.calculate_quantized_percentage(
summaries["ops"]["quantization"]["tensor"]
summaries["ops"]["quantization"]["tensor"], "counts"
)

mem_access_total = summaries["mem_access"]["sparsity"]["single"]["counts"]
Expand All @@ -111,7 +111,7 @@ def __repr__(self):
)
mem_access_size = summaries["mem_access"]["quantization"]["tensor"]["bits"]
mem_access_quantized = self.calculate_quantized_percentage(
summaries["mem_access"]["quantization"]["tensor"]
summaries["mem_access"]["quantization"]["tensor"], "counts"
)

return (
Expand Down
21 changes: 12 additions & 9 deletions src/sparsezoo/analyze_v2/operation_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,21 @@ def get_operation_bits(
precision = get_numpy_quantization_level(node_weight)
is_quantized_op = "32" not in str(precision)

bits = (ops["single"]["counts"]) * precision

bits_block4 = (ops["block4"]["counts"]) * precision

bits_quant = is_quantized_op * bits
single_counts = ops["single"]["counts"]
single_bits = single_counts * precision
block4_counts = ops["block4"]["counts"]
block4_bits = block4_counts * precision
return {
"tensor": {
"bits": bits,
"bits_quant": bits_quant,
"counts": single_counts,
"counts_quant": is_quantized_op * single_counts,
"bits": single_bits,
"bits_quant": is_quantized_op * single_bits,
},
"block4": {
"bits": bits_block4,
"bits_quant": bits_quant,
"counts": block4_counts,
"counts_quant": is_quantized_op * block4_counts,
"bits": block4_bits,
"bits_quant": is_quantized_op * block4_bits,
},
}
16 changes: 9 additions & 7 deletions src/sparsezoo/analyze_v2/parameter_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
get_node_num_four_block_zeros_and_size,
get_node_param_counts,
get_node_weight,
get_node_weight_bits,
get_node_weight_precision,
get_numpy_distribution_statistics,
get_numpy_entropy,
get_numpy_modes,
Expand Down Expand Up @@ -153,14 +153,16 @@ def get_parameter_bits(
If the layer is quantized, assume all its elements in the ndarray
are quantized
"""
node_weight = get_node_weight(model_graph, node)
if node_weight is not None and node_weight.size > 0:
bits = get_node_weight_bits(model_graph, node)

num_weights, _, _ = get_node_param_counts(node, model_graph)
if num_weights > 0:
precision = get_node_weight_precision(model_graph, node)
is_quantized = is_quantized_layer(model_graph, node)
return {
"tensor": {
"bits": bits,
"bits_quant": bits * is_quantized_layer(model_graph, node),
"counts": num_weights,
"counts_quant": num_weights * is_quantized,
"bits": num_weights * precision,
"bits_quant": num_weights * precision * is_quantized,
},
}

Expand Down
18 changes: 15 additions & 3 deletions src/sparsezoo/analyze_v2/schemas/quantization_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@


class QuantizationSummaryAnalysisSchema(BaseModel):
counts: float = Field(..., description="Total number of weights")
counts_quant: int = Field(
...,
description=(
"Total number of quantized weights."
"Here we assume if the layer is quantized, the entire array is quantized"
),
)
bits: float = Field(..., description="Total bits required to store the weights")
bits_quant: int = Field(
...,
Expand All @@ -39,9 +47,9 @@ def validate_types(cls, value):
@validator("percent", pre=True, always=True)
def calculate_percent_if_none(cls, value, values):
if value is None:
bits = values.get("bits", 0)
bits_quant = values.get("bits_quant", 0)
return bits_quant / bits if bits > 0 else 0.0
counts = values.get("counts", 0)
counts_quant = values.get("counts_quant", 0)
return counts_quant / counts if counts > 0 else 0.0
return value

def __add__(self, model: BaseModel):
Expand All @@ -51,7 +59,9 @@ def __add__(self, model: BaseModel):

if validator_model is not None:
return validator_model(
counts=self.counts + model.counts,
bits=self.bits + model.bits,
counts_quant=self.counts_quant + model.counts_quant,
bits_quant=self.bits_quant + model.bits_quant,
)

Expand All @@ -67,6 +77,8 @@ def __add__(self, model: BaseModel):
if validator_model is not None and self.grouping == model.grouping:
return validator_model(
grouping=self.grouping,
counts=self.counts + model.counts,
bits=self.bits + model.bits,
counts_quant=self.counts_quant + model.counts_quant,
bits_quant=self.bits_quant + model.bits_quant,
)
8 changes: 4 additions & 4 deletions src/sparsezoo/utils/onnx/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"get_numpy_distribution_statistics",
"get_numpy_quantization_level",
"get_numpy_bits",
"get_node_weight_bits",
"get_node_weight_precision",
"get_node_param_counts",
"get_node_kernel_shape",
]
Expand Down Expand Up @@ -485,13 +485,13 @@ def get_node_param_counts(
return params, bias, sparse_params


def get_node_weight_bits(
def get_node_weight_precision(
model_graph: ONNXGraph,
node: NodeProto,
) -> int:
"""Get the bits needed to store the node weights"""
"""Get the precision of the node in number of bits"""
node_weight = get_node_weight(model_graph, node)
return get_numpy_bits(node_weight)
return get_numpy_quantization_level(node_weight)


def get_numpy_bits(arr: numpy.ndarray) -> int:
Expand Down

0 comments on commit 36fd754

Please sign in to comment.