Skip to content

Commit

Permalink
Use torch.finfo.eps as minimum scale (#3159)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Jul 10, 2024
1 parent c1f38cb commit 51780c8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ def compute_encodings_from_stats(self, stats: _MinMaxRange, num_steps: int, is_s
if stats.min is None or stats.max is None:
raise StatisticsNotFoundError('No statistics present to compute encodings.')

tiny_num = torch.finfo(stats.min.dtype).tiny
eps = torch.finfo(stats.min.dtype).eps
# enforces that 0 is within the min/max
min_with_zero = torch.clamp(stats.min, max=0)
max_with_zero = torch.clamp(stats.max, min=0)

# adjusts any min/max pairing that are too close
tensor_diff = (max_with_zero - min_with_zero) / num_steps
adjustment_step = tiny_num * (tensor_diff < tiny_num)
adjustment_step = eps * (tensor_diff < eps)

updated_max = max_with_zero + math.floor(num_steps / 2) * adjustment_step
updated_min = min_with_zero - math.ceil(num_steps / 2) * adjustment_step
Expand Down Expand Up @@ -350,10 +350,10 @@ def adjust_min_max(curr_min, curr_max, num_steps, is_symmetric):
curr_max.clamp_(min=0, max=torch.finfo(curr_max.dtype).max)

# ensure that min/max aren't too close
tiny_num = torch.finfo(curr_min.dtype).tiny
eps = torch.finfo(curr_min.dtype).eps
tensor_threshold = (curr_max - curr_min) / num_steps
curr_min[tensor_threshold < tiny_num] -= tiny_num * math.ceil(num_steps / 2)
curr_max[tensor_threshold < tiny_num] += tiny_num * math.floor(num_steps / 2)
curr_min[tensor_threshold < eps] -= eps * math.ceil(num_steps / 2)
curr_max[tensor_threshold < eps] += eps * math.floor(num_steps / 2)

if is_symmetric:
num_pos_steps = math.floor(num_steps / 2)
Expand Down Expand Up @@ -510,7 +510,7 @@ def _pick_test_candidates(self, stats, num_steps, symmetric):
max_vals = torch.stack([stat.max for stat in stats])
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
max_vals = torch.max(max_vals, min_vals + torch.finfo(min_vals.dtype).tiny * num_steps)
max_vals = torch.max(max_vals, min_vals + torch.finfo(min_vals.dtype).eps * num_steps)
if symmetric:
return self._pick_test_candidates_symmetric(min_vals, max_vals, num_steps)
return self._pick_test_candidates_asymmetric(min_vals, max_vals, num_steps)
Expand Down Expand Up @@ -552,7 +552,7 @@ def _pick_test_candidates_symmetric(self, min_vals, max_vals, num_steps):
test_deltas = max_delta[:, None] * search_space[None, :] / (num_deltas - 1)
# test_deltas.shape = (num_histograms, num_deltas, 1)
# test_offsets.shape = (1, 1, 1)
min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).tiny]).to(**tensor_kwargs)
min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).eps]).to(**tensor_kwargs)
test_deltas = torch.max(test_deltas, min_delta)
return test_deltas[:, :, None], test_offsets[:, None, None]

Expand All @@ -570,7 +570,7 @@ def _clamp_delta_offset_values(min_vals, max_vals, num_steps, test_deltas, test_
# Recompute delta/offset with clamped min/max
# Returned delta/offset shapes = (num_histograms, num_deltas, num_offsets)
test_deltas = (test_max - test_min) / num_steps
min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).tiny]).to(device=test_deltas.device,
min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).eps]).to(device=test_deltas.device,
dtype=test_deltas.dtype)
test_deltas = torch.max(test_deltas, min_delta)
test_offsets = torch.round(test_min / test_deltas)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def test_compute_encodings_with_only_zero_tensor(self):

num_steps = pow(2, 8) - 1
asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(num_steps=num_steps, is_symmetric = False)
updated_min = torch.finfo(asymmetric_min.dtype).tiny * (2 ** (8 - 1))
updated_max = torch.finfo(asymmetric_min.dtype).tiny * ((2 **(8 - 1)) - 1)
updated_min = torch.finfo(asymmetric_min.dtype).eps * (2 ** (8 - 1))
updated_max = torch.finfo(asymmetric_min.dtype).eps * ((2 **(8 - 1)) - 1)
assert torch.all(torch.eq(asymmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -updated_min)))
assert torch.all(torch.eq(asymmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), updated_max)))

Expand All @@ -187,15 +187,15 @@ def test_compute_encodings_with_only_zero_tensor(self):
@pytest.mark.parametrize('symmetric', [True, False])
def test_overflow(self, symmetric):
encoding_analyzer = MinMaxEncodingAnalyzer((1,))
float_input_min = (torch.arange(10) * torch.finfo(torch.float).tiny)
float_input_min = (torch.arange(10) * torch.finfo(torch.float).eps)
encoding_analyzer.update_stats(float_input_min)
num_steps = pow(2, 8) - 2
min, max = encoding_analyzer.compute_encodings(num_steps=num_steps, is_symmetric=symmetric)
scale = (max - min) / 255

# Scale should be at least as large as torch.min
assert scale != 0
assert torch.allclose(scale, torch.tensor(torch.finfo(scale.dtype).tiny), atol=1e-10)
assert torch.allclose(scale, torch.tensor(torch.finfo(scale.dtype).eps), rtol=0.01)

float_input_max = (torch.arange(10) * torch.finfo(torch.float).max)
encoding_analyzer.update_stats(float_input_max)
Expand Down
16 changes: 8 additions & 8 deletions TrainingExtensions/torch/test/python/v2/test_seq_mse_.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_seq_mse(self):
assert list(cand_min.size())[0] == linear.out_features

@pytest.mark.parametrize("enable_pcq", [True, False])
@pytest.mark.parametrize("param_bw", [2, 31])
@pytest.mark.parametrize("param_bw", [4, 16])
@pytest.mark.parametrize("loss_fn", ['mse', 'l1', 'sqnr'])
@pytest.mark.parametrize("qparam_requires_grad", [True, False])
def test_optimize_module_linear(self, enable_pcq, param_bw, loss_fn, qparam_requires_grad):
Expand All @@ -180,15 +180,15 @@ def test_optimize_module_linear(self, enable_pcq, param_bw, loss_fn, qparam_requ

# If we use higher param_bw (for example 16, 31), then it should always choose larger candidates so
# before and after param encodings should be almost same.
if param_bw == 31:
assert torch.allclose(before.min, after.min)
assert torch.allclose(before.max, after.max)
if param_bw >= 16:
assert torch.allclose(before.min, after.min, rtol=1e-4)
assert torch.allclose(before.max, after.max, rtol=1e-4)
else:
assert not torch.allclose(before.min, after.min)
assert not torch.allclose(before.max, after.max)

@pytest.mark.parametrize("enable_pcq", [True, False])
@pytest.mark.parametrize("param_bw", [2, 31])
@pytest.mark.parametrize("param_bw", [4, 16])
@pytest.mark.parametrize("loss_fn", ['mse', 'l1', 'sqnr'])
def test_optimize_module_conv(self, enable_pcq, param_bw, loss_fn):
""" test optimize module for linear """
Expand All @@ -213,9 +213,9 @@ def test_optimize_module_conv(self, enable_pcq, param_bw, loss_fn):

# If we use higher param_bw (for example 16, 31), then it should always choose larger candidates so
# before and after param encodings should be almost same.
if param_bw == 31:
assert torch.allclose(before.min, after.min)
assert torch.allclose(before.max, after.max)
if param_bw >= 16:
assert torch.allclose(before.min, after.min, rtol=1e-4)
assert torch.allclose(before.max, after.max, rtol=1e-4)
else:
assert not torch.allclose(before.min, after.min)
assert not torch.allclose(before.max, after.max)
Expand Down

0 comments on commit 51780c8

Please sign in to comment.