[TF Visualization] Wrong reshaping of weights for range analysis of TF Models

Hi,

Were trying to use the the TF visualization to analyze the weight ranges, I noticed that the tool was reporting 171 channels in a layer that is suppose to have 32 channels only.

Then I checked the following file: TrainingExtensions/tensorflow/src/python/aimet_tensorflow/plotting_utils.py. I noticed that the way is being reshaped does not reflect the output channels. In the comments it says " where each column is an output channel", however, that is not being the case in practice.

def get_weights(conv_module, sess):
    """
    Returns the weights of a conv_module in a 2d matrix, where each column is an output channel.

    :param sess: tf.compat.v1.Session
    :param conv_module: convNd module
    :return: 2d numpy array
    """
    numpy_weight = WeightTensorUtils.get_tensor_as_numpy_data(sess, conv_module)
    numpy_weight = np.reshape(numpy_weight, (numpy_weight.shape[3], numpy_weight.shape[2], numpy_weight.shape[0],
                                             numpy_weight.shape[1]))
    axis_0_length = numpy_weight.shape[0]
    axis_1_length = np.prod(numpy_weight.shape[1:])
    reshaped_weights = numpy_weight.reshape(int(axis_0_length), int(axis_1_length))
    return reshaped_weights

Here axis_0_length ends up with the output channels, because numpy_weight is reshaped to put the channels in the first dimension. However, what is plotted as output channels is axis_1_length. TF is a channel last data layout, contrary to Pytorch. Then, I modified the file as follows to properly being able to plot the channels:

def get_weights(conv_module, sess):
    """
    Returns the weights of a conv_module in a 2d matrix, where each column is an output channel.

    :param sess: tf.compat.v1.Session
    :param conv_module: convNd module
    :return: 2d numpy array
    """
    numpy_weight = WeightTensorUtils.get_tensor_as_numpy_data(sess, conv_module)

    axis_0_length = np.prod(numpy_weight.shape[:3])
    axis_1_length = numpy_weight.shape[3]
    reshaped_weights = numpy_weight.reshape(int(axis_0_length), int(axis_1_length))
    return reshaped_weights

With this change axis_1_length has the proper number of output channels of 32. Can someone from the AIMET team check this and confirm this observation.

Thanks!

Hi Miguel,

Thank you for reporting this.
Yes, the TF format is [kh , kw, Nic, Noc].
There is perhaps a correction needed to reshape here as below -
(weight represented in Pytorch format) [Nic, Noc, kh, kw]
numpy_weight = np.reshape(numpy_weight, (numpy_weight.shape[2], numpy_weight.shape[3], numpy_weight.shape[0], numpy_weight.shape[1]))

Hi Miguel,

Have also created a issue to fix this :
https://github.com/quic/aimet/issues/386
Please let us know if you have further questions.

Thanks,
Sangeetha