From 974eb177ab14c1f86895e9d9c00f8a81e823520b Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Tue, 5 Apr 2022 09:00:13 -0700 Subject: [PATCH] Enable TF quantsim per channel range learning Signed-off-by: Kevin Hsieh --- .../src/python/aimet_tensorflow/quantsim.py | 6 +- .../quantsim_straight_through_grad.py | 81 +++++--- .../python/test_per_channel_quantization.py | 175 +++++++++++++++++- .../tensorflow/test/python/test_quantsim.py | 52 +++++- 4 files changed, 289 insertions(+), 25 deletions(-) diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py index 6db8f10c68..48ced206d6 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py @@ -1174,7 +1174,8 @@ def create_quantize_op(): if self._quant_scheme in [QuantScheme.training_range_learning_with_tf_init, QuantScheme.training_range_learning_with_tf_enhanced_init]: with self.session.graph.gradient_override_map( - {"QcQuantize": "QcQuantizeRangeLearningCustomGradient"}): + {"QcQuantize": "QcQuantizeRangeLearningCustomGradient", + "QcQuantizePerChannel": "QcQuantizePerChannelRangeLearningCustomGradient"}): q_op_out = create_quantize_op() else: q_op_out = create_quantize_op() @@ -1184,7 +1185,8 @@ def create_quantize_op(): if self._quant_scheme in [QuantScheme.training_range_learning_with_tf_init, QuantScheme.training_range_learning_with_tf_enhanced_init]: with self.session.graph.gradient_override_map( - {"QcQuantize": "QcQuantizeRangeLearningCustomGradient"}): + {"QcQuantize": "QcQuantizeRangeLearningCustomGradient", + "QcQuantizePerChannel": "QcQuantizePerChannelRangeLearningCustomGradient"}): q_op_out = create_quantize_op() else: q_op_out = create_quantize_op() diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim_straight_through_grad.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim_straight_through_grad.py index cd4bb5c478..2660458b85 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim_straight_through_grad.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim_straight_through_grad.py @@ -164,38 +164,43 @@ def _compute_dloss_by_dmax(x, grad, scaling, offset, bitwidth, use_symmetric_enc # to single value before returning gradient # this uses chain rule, multiply by loss and sum it to get scalar. dq_by_dmax = tf.where(tf.less_equal(n, r_x_by_s_plus_round_o), inner_cond, false_expr) - dloss_by_dmax = tf.reduce_sum(dq_by_dmax * grad) + # If per channel is active, scaling tensor will be rank 1 (an array instead of a singular value). + # In case of per channel, we reduce by all but the last dimension. Otherwise, we reduce all dimensions. + dloss_by_dmax = tf.cond(tf.equal(tf.rank(scaling), 0), lambda: tf.reduce_sum(dq_by_dmax * grad), + lambda: tf.reduce_sum(dq_by_dmax * grad, axis=tf.range(0, tf.rank(x) - 1))) return dloss_by_dmax -@tf_ops.RegisterGradient("QcQuantizeRangeLearningCustomGradient") -def quantsim_custom_grad_learned_grid(op, grad): +# pylint: disable=too-many-locals +def _compute_dloss_by_dmin_dmax_and_dx(inputs: tf.Tensor, bitwidth: tf.Tensor, op_mode: tf.Tensor, + encoding_min: tf.Tensor, encoding_max: tf.Tensor, is_symmetric: tf.Tensor, + grad: tf.Tensor): """ - Performs custom gradient calculations for trained Quantize op - :param op: Tf operation for which gradients are to be computed - :param grad: Gradient flowing through + Return tensors for dloss_by_dmin, dloss_by_dmax, and dloss_by_dx. + :param inputs: Inputs to op + :param bitwidth: Bitwidth used to quantize + :param op_mode: Op mode (if passthrough, gradient is returned as is) + :param encoding_min: Encoding min value(s), will be more than one if per channel is active + :param encoding_max: Encoding max value(s), will be more than one if per channel is active + :param is_symmetric: True if symmetric encodings are used, False otherwise + :param grad: Gradient from child layer + :return: Tensors for dloss_by_dmin, dloss_by_dmax, and dloss_by_dx """ - # pylint: disable=R0914 - - # read bitwidth, use_symmetric_encoding_flag, - # encoding_min and encoding_max from the op inputs - x = tf.cast(op.inputs[0], tf.float32) - bitwidth = tf.cast(op.inputs[int(QuantizeOpIndices.bit_width)], tf.float32) - op_mode = tf.cast(op.inputs[int(QuantizeOpIndices.op_mode)], tf.int8) - - encoding_min = tf.cast(op.inputs[int(QuantizeOpIndices.encoding_min)], tf.float32) - encoding_max_read = tf.cast(op.inputs[int(QuantizeOpIndices.encoding_max)], tf.float32) - + x = tf.cast(inputs, tf.float32) + bitwidth = tf.cast(bitwidth, tf.float32) + op_mode = tf.cast(op_mode, tf.int8) + encoding_min = tf.cast(encoding_min, tf.float32) + encoding_max = tf.cast(encoding_max, tf.float32) # handle min == max to avoid divide by zero epsilon = tf.constant(1e-5, dtype=tf.float32) - encoding_max = tf.math.maximum(encoding_max_read, tf.add(encoding_min, epsilon)) + encoding_max = tf.math.maximum(encoding_max, tf.add(encoding_min, epsilon)) # compute n, p, scaling and offset params # choose n based on symmetric or asymmetric flag # symmetric : -two_pow_bw + 1 # asymmetric : 0 - n, p = _get_n_and_p(bitwidth, op.inputs[int(QuantizeOpIndices.use_symmetric_encoding)]) + n, p = _get_n_and_p(bitwidth, is_symmetric) steps = tf.cast(tf.pow(tf.cast(tf.constant(2), tf.float32), bitwidth) - 1, tf.float32) scaling = tf.cast(((encoding_max - encoding_min) / steps), tf.float32) rounded_offset = tf.round(-encoding_min / scaling) # pylint: disable=invalid-unary-operand-type @@ -211,12 +216,46 @@ def quantsim_custom_grad_learned_grid(op, grad): inner_cond, # execute if true tf.zeros_like(r_x_by_s_plus_round_o))) * grad - dloss_by_dmax = tf.cast(_compute_dloss_by_dmax(x, grad, scaling, rounded_offset, bitwidth, - op.inputs[int(QuantizeOpIndices.use_symmetric_encoding)]), + dloss_by_dmax = tf.cast(_compute_dloss_by_dmax(x, grad, scaling, rounded_offset, bitwidth, is_symmetric), tf.float64) dloss_by_dmin = tf.cast(_compute_dloss_by_dmin_using_dmax(dloss_by_dmax), tf.float64) # Pass through gradient for skipped ops dloss_by_dx = tf.cond(tf.equal(op_mode, 3), lambda: grad, lambda: dloss_by_dx) + return dloss_by_dmin, dloss_by_dmax, dloss_by_dx + +@tf_ops.RegisterGradient("QcQuantizeRangeLearningCustomGradient") +def quantsim_custom_grad_learned_grid(op, grad): + """ + Performs custom gradient calculations for trained Quantize op + :param op: Tf operation for which gradients are to be computed + :param grad: Gradient flowing through + """ + dloss_by_dmin, dloss_by_dmax, dloss_by_dx = \ + _compute_dloss_by_dmin_dmax_and_dx(op.inputs[0], + op.inputs[int(QuantizeOpIndices.bit_width)], + op.inputs[int(QuantizeOpIndices.op_mode)], + op.inputs[int(QuantizeOpIndices.encoding_min)], + op.inputs[int(QuantizeOpIndices.encoding_max)], + op.inputs[int(QuantizeOpIndices.use_symmetric_encoding)], + grad) + return dloss_by_dx, None, None, dloss_by_dmin, dloss_by_dmax, None, None, None + + +@tf_ops.RegisterGradient("QcQuantizePerChannelRangeLearningCustomGradient") +def quantsim_per_channel_custom_grad_learned_grid(op, grad): + """ + Performs custom gradient calculations for trained QcQuantizePerChannel op + :param op: Tf operation for which gradients are to be computed + :param grad: Gradient flowing through + """ + dloss_by_dmin, dloss_by_dmax, dloss_by_dx = \ + _compute_dloss_by_dmin_dmax_and_dx(op.inputs[0], + op.inputs[int(QuantizeOpIndices.bit_width)], + op.inputs[int(QuantizeOpIndices.op_mode)], + op.inputs[int(QuantizeOpIndices.encoding_min)], + op.inputs[int(QuantizeOpIndices.encoding_max)], + op.inputs[int(QuantizeOpIndices.use_symmetric_encoding)], + grad) return dloss_by_dx, None, None, dloss_by_dmin, dloss_by_dmax, None, None, None diff --git a/TrainingExtensions/tensorflow/test/python/test_per_channel_quantization.py b/TrainingExtensions/tensorflow/test/python/test_per_channel_quantization.py index 5a032bc777..5dc0a2dbbd 100644 --- a/TrainingExtensions/tensorflow/test/python/test_per_channel_quantization.py +++ b/TrainingExtensions/tensorflow/test/python/test_per_channel_quantization.py @@ -47,7 +47,9 @@ from aimet_tensorflow.common.graph_eval import initialize_uninitialized_vars from aimet_tensorflow.quantsim import QuantizationSimModel from aimet_tensorflow.examples.test_models import depthwise_conv2d_model +from aimet_tensorflow.utils.constants import QuantizeOpIndices from aimet_tensorflow.utils.op.conv import WeightTensorUtils +from aimet_common.defs import QuantScheme from aimet_common.quantsim import calculate_delta_offset tf.compat.v1.disable_eager_execution() @@ -482,7 +484,6 @@ def dummy_forward_pass(sess, args): assert np.allclose(encoding_numpy, encodings, rtol=0.01) - @pytest.mark.cuda def test_to_compare_time_per_channel_and_per_tensor_quantization(self): save_config_file_for_per_channel_quantization() @@ -552,6 +553,153 @@ def dummy_forward_pass(sess, args): encoding = quantizer_info.get_encoding() assert isinstance(encoding, list) + # Mark below test as cuda until per channel on cpu is supported. + @pytest.mark.cuda + def test_per_channel_range_learning(self): + """ + Test to validate per channel range learning + """ + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(0) + np.random.seed(0) + with tf.device('/cpu:0'): + inputs = tf.keras.Input(shape=(32, 32, 4,)) + conv_op = tf.keras.layers.Conv2D(2, (3, 3), + kernel_initializer=tf.random_uniform_initializer(-1, 2), + bias_initializer='random_uniform', + padding='SAME')(inputs) + relu_op = tf.nn.relu(conv_op) + reshape = tf.keras.layers.Flatten()(relu_op) + _ = tf.keras.layers.Dense(10, bias_initializer='random_uniform')(reshape) + + sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph()) + initialize_uninitialized_vars(sess) + + save_config_file_bias_quantized_for_per_channel_quantization() + + # create quantsim model without config file + sim = QuantizationSimModel(sess, ['input_1'], ['dense/BiasAdd'], use_cuda=True, + quant_scheme=QuantScheme.training_range_learning_with_tf_init, + config_file='./quantsim_config.json') + + def dummy_forward_pass(sess, _): + model_output = sess.graph.get_tensor_by_name('dense/BiasAdd_quantized:0') + model_input = sess.graph.get_tensor_by_name('input_1:0') + shape = model_input.shape + dummy_input = np.random.randn(1, shape[1], shape[2], shape[3]) + sess.run(model_output, feed_dict={model_input: dummy_input}) + + conv2d_weight_quant_op = sim.session.graph.get_operation_by_name('conv2d/Conv2D/ReadVariableOp_quantized') + conv2d_output_quant_op = sim.session.graph.get_operation_by_name('conv2d/BiasAdd_quantized') + dense_bias_quant_op = sim.session.graph.get_operation_by_name('dense/BiasAdd/ReadVariableOp_quantized') + + # enable input + sim.compute_encodings(dummy_forward_pass, None) + + inp_tensor = sim.session.graph.get_tensor_by_name('input_1:0') + w_shape = inp_tensor.shape + batches = 32 + inp_data = np.random.rand(batches, w_shape[1], w_shape[2], w_shape[3]) + logits = sim.session.graph.get_tensor_by_name('dense/BiasAdd_quantized:0') + + labels = np.random.randint(10, size=batches) + one_hot_labels = np.eye(10)[labels] + + with sim.session.graph.as_default(): + var_list = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) + labels_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 10], name='labels') + loss = tf.compat.v1.losses.softmax_cross_entropy(onehot_labels=labels_placeholder, logits=logits) + + update_ops = [] + global_step = tf.compat.v1.train.create_global_step() + initialize_uninitialized_vars(sim.session) + + optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1e-3) + gradients = optimizer.compute_gradients(loss, var_list) + + grad_updates = optimizer.apply_gradients(gradients, global_step=global_step) + update_ops.append(grad_updates) + update_op = tf.group(*update_ops) + + conv_inp_tensor = conv2d_weight_quant_op.inputs[0] + grads = tf.gradients(loss, [conv_inp_tensor, + conv2d_weight_quant_op.inputs[QuantizeOpIndices.encoding_min], + conv2d_weight_quant_op.inputs[QuantizeOpIndices.encoding_max], + dense_bias_quant_op.inputs[QuantizeOpIndices.encoding_min], + dense_bias_quant_op.inputs[QuantizeOpIndices.encoding_max]]) + _, conv_dqbydmin, conv_dqbydmax, dense_dqbydmin, dense_dqbydmax = grads + conv2d_weight_min_gradient = sim.session.run(conv_dqbydmin, + feed_dict={inp_tensor: inp_data, + labels_placeholder: one_hot_labels}) + conv2d_weight_max_gradient = sim.session.run(conv_dqbydmax, + feed_dict={inp_tensor: inp_data, + labels_placeholder: one_hot_labels}) + dense_bias_min_gradient = sim.session.run(dense_dqbydmin, + feed_dict={inp_tensor: inp_data, + labels_placeholder: one_hot_labels}) + dense_bias_max_gradient = sim.session.run(dense_dqbydmax, + feed_dict={inp_tensor: inp_data, + labels_placeholder: one_hot_labels}) + + assert len(conv2d_weight_min_gradient) == 2 + assert len(conv2d_weight_max_gradient) == 2 + assert len(dense_bias_min_gradient) == 10 + assert len(dense_bias_max_gradient) == 10 + + weights_before_train = sim.session.run(conv2d_weight_quant_op.inputs[0]) + encoding_min_before_train = sim.session.run(conv2d_weight_quant_op.inputs[QuantizeOpIndices.encoding_min]) + encoding_max_before_train = sim.session.run(conv2d_weight_quant_op.inputs[QuantizeOpIndices.encoding_max]) + conv2d_output_encoding_min_before_train = sim.session.run(conv2d_output_quant_op.inputs[ + QuantizeOpIndices.encoding_min]) + conv2d_output_encoding_max_before_train = sim.session.run(conv2d_output_quant_op.inputs[ + QuantizeOpIndices.encoding_max]) + dense_bias_encoding_min_before_train = \ + sim.session.run(dense_bias_quant_op.inputs[QuantizeOpIndices.encoding_min]) + dense_bias_encoding_max_before_train = \ + sim.session.run(dense_bias_quant_op.inputs[QuantizeOpIndices.encoding_max]) + with tf.control_dependencies([update_op]): + train_op = tf.identity(loss, name='train_op') + + for quant_op_name in sim._param_quantizers.keys(): + print(quant_op_name + '_min_before_train = ' + str(sim.session.run( + sim.session.graph.get_operation_by_name(quant_op_name).inputs[QuantizeOpIndices.encoding_min]))) + print(quant_op_name + '_max_before_train = ' + str(sim.session.run( + sim.session.graph.get_operation_by_name(quant_op_name).inputs[QuantizeOpIndices.encoding_max]))) + + # start training + _ = sim.session.run(train_op, feed_dict={inp_tensor: inp_data, labels_placeholder: one_hot_labels}) + + for quant_op_name in sim._param_quantizers.keys(): + print(quant_op_name + '_min = ' + str(sim.session.run(sim.session.graph.get_operation_by_name + (quant_op_name).inputs[ + QuantizeOpIndices.encoding_min]))) + print(quant_op_name + '_max = ' + str(sim.session.run(sim.session.graph.get_operation_by_name + (quant_op_name).inputs[ + QuantizeOpIndices.encoding_max]))) + + weights_after_train = sim.session.run(conv2d_weight_quant_op.inputs[0]) + conv2d_output_encoding_min_after_train = sim.session.run(conv2d_output_quant_op.inputs[ + QuantizeOpIndices.encoding_min]) + conv2d_output_encoding_max_after_train = sim.session.run(conv2d_output_quant_op.inputs[ + QuantizeOpIndices.encoding_max]) + encoding_min_after_train = sim.session.run(conv2d_weight_quant_op.inputs[QuantizeOpIndices.encoding_min]) + encoding_max_after_train = sim.session.run(conv2d_weight_quant_op.inputs[QuantizeOpIndices.encoding_max]) + dense_bias_encoding_min_after_train = \ + sim.session.run(dense_bias_quant_op.inputs[QuantizeOpIndices.encoding_min]) + dense_bias_encoding_max_after_train = \ + sim.session.run(dense_bias_quant_op.inputs[QuantizeOpIndices.encoding_max]) + + assert not np.allclose(weights_before_train, weights_after_train, atol=1e-6) + assert not np.array_equal(encoding_min_before_train, encoding_min_after_train) + assert not np.array_equal(encoding_max_before_train, encoding_max_after_train) + assert not np.array_equal(conv2d_output_encoding_min_before_train, conv2d_output_encoding_min_after_train) + assert not np.array_equal(conv2d_output_encoding_max_before_train, conv2d_output_encoding_max_after_train) + assert not np.array_equal(dense_bias_encoding_min_before_train, dense_bias_encoding_min_after_train) + assert not np.array_equal(dense_bias_encoding_max_before_train, dense_bias_encoding_max_after_train) + + sess.close() + sim.session.close() + def save_config_file_for_per_channel_quantization(): quantsim_config = { @@ -578,6 +726,31 @@ def save_config_file_for_per_channel_quantization(): with open('./quantsim_config.json', 'w') as f: json.dump(quantsim_config, f) + +def save_config_file_bias_quantized_for_per_channel_quantization(): + quantsim_config = { + "defaults": { + "ops": { + "is_output_quantized": "True", + "is_symmetric": "False" + }, + "params": { + "is_quantized": "True", + "is_symmetric": "False" + }, + "per_channel_quantization": "True", + }, + "params": {}, + "op_type": {}, + "supergroups": [], + "model_input": {}, + "model_output": {} + } + + with open('./quantsim_config.json', 'w') as f: + json.dump(quantsim_config, f) + + def compute_tf_encodings(sess, op, axis): data = WeightTensorUtils.get_tensor_as_numpy_data(sess, op) diff --git a/TrainingExtensions/tensorflow/test/python/test_quantsim.py b/TrainingExtensions/tensorflow/test/python/test_quantsim.py index 8c6fa008b4..c1aee1c450 100644 --- a/TrainingExtensions/tensorflow/test/python/test_quantsim.py +++ b/TrainingExtensions/tensorflow/test/python/test_quantsim.py @@ -47,7 +47,7 @@ import libpymo import aimet_tensorflow.utils.quantsim from aimet_tensorflow.quantsim import QuantizationSimModel, check_accumulator_overflow -from aimet_tensorflow.quantsim_straight_through_grad import _get_n_and_p +from aimet_tensorflow.quantsim_straight_through_grad import _get_n_and_p, _compute_dloss_by_dmax from aimet_tensorflow.utils.graph_saver import load_model_from_meta from aimet_tensorflow.common.graph_eval import initialize_uninitialized_vars from aimet_tensorflow.defs import ParameterInfo @@ -1260,6 +1260,56 @@ def test_n_p_computation(self): sess.close() + def test_compute_dloss_by_dmax_shape(self): + """ Test compute_dloss_by_dmax returns tensor with correct shape """ + tf.compat.v1.set_random_seed(0) + + # Per tensor case + tf.compat.v1.reset_default_graph() + graph = tf.Graph() + sess = tf.compat.v1.Session(graph=graph) + with graph.as_default(): + inputs = tf.random.uniform(shape=[3, 3, 4, 2], dtype=tf.float32) + grad = tf.random.uniform(shape=[3, 3, 4, 2], dtype=tf.float32) + scaling = tf.random.uniform(shape=[], dtype=tf.float32) + offset = tf.random.uniform(shape=[], dtype=tf.float32) + bitwidth = tf.constant(8.0, dtype=tf.float32) + is_symmetric = tf.constant(False) + + dloss_by_dmax = _compute_dloss_by_dmax(inputs, grad, scaling, offset, bitwidth, is_symmetric) + assert sess.run(dloss_by_dmax).shape == () + + # Per channel case with weights + tf.compat.v1.reset_default_graph() + graph = tf.Graph() + sess = tf.compat.v1.Session(graph=graph) + with graph.as_default(): + inputs = tf.random.uniform(shape=[3, 3, 4, 2], dtype=tf.float32) + grad = tf.random.uniform(shape=[3, 3, 4, 2], dtype=tf.float32) + scaling = tf.random.uniform(shape=[2,], dtype=tf.float32) + offset = tf.random.uniform(shape=[2,], dtype=tf.float32) + bitwidth = tf.constant(8.0, dtype=tf.float32) + is_symmetric = tf.constant(False) + + dloss_by_dmax = _compute_dloss_by_dmax(inputs, grad, scaling, offset, bitwidth, is_symmetric) + assert sess.run(dloss_by_dmax).shape == (2,) + + # Per channel case with bias + tf.compat.v1.reset_default_graph() + graph = tf.Graph() + sess = tf.compat.v1.Session(graph=graph) + with graph.as_default(): + inputs = tf.random.uniform(shape=[10,], dtype=tf.float32) + grad = tf.random.uniform(shape=[10,], dtype=tf.float32) + scaling = tf.random.uniform(shape=[10,], dtype=tf.float32) + offset = tf.random.uniform(shape=[10,], dtype=tf.float32) + bitwidth = tf.constant(8.0, dtype=tf.float32) + is_symmetric = tf.constant(False) + + dloss_by_dmax = _compute_dloss_by_dmax(inputs, grad, scaling, offset, bitwidth, is_symmetric) + assert sess.run(dloss_by_dmax).shape == (10,) + + def test_qc_custom_gradient_training_loop_range_learning(self, iterations=1): """ test to get average time spent in range learning grad