Hi,
Were trying to use the the TF visualization to analyze the weight ranges, I noticed that the tool was reporting 171 channels in a layer that is suppose to have 32 channels only.
Then I checked the following file: TrainingExtensions/tensorflow/src/python/aimet_tensorflow/plotting_utils.py
. I noticed that the way is being reshaped does not reflect the output channels. In the comments it says " where each column is an output channel", however, that is not being the case in practice.
def get_weights(conv_module, sess):
"""
Returns the weights of a conv_module in a 2d matrix, where each column is an output channel.
:param sess: tf.compat.v1.Session
:param conv_module: convNd module
:return: 2d numpy array
"""
numpy_weight = WeightTensorUtils.get_tensor_as_numpy_data(sess, conv_module)
numpy_weight = np.reshape(numpy_weight, (numpy_weight.shape[3], numpy_weight.shape[2], numpy_weight.shape[0],
numpy_weight.shape[1]))
axis_0_length = numpy_weight.shape[0]
axis_1_length = np.prod(numpy_weight.shape[1:])
reshaped_weights = numpy_weight.reshape(int(axis_0_length), int(axis_1_length))
return reshaped_weights
Here axis_0_length
ends up with the output channels, because numpy_weight
is reshaped to put the channels in the first dimension. However, what is plotted as output channels is axis_1_length
. TF is a channel last data layout, contrary to Pytorch. Then, I modified the file as follows to properly being able to plot the channels:
def get_weights(conv_module, sess):
"""
Returns the weights of a conv_module in a 2d matrix, where each column is an output channel.
:param sess: tf.compat.v1.Session
:param conv_module: convNd module
:return: 2d numpy array
"""
numpy_weight = WeightTensorUtils.get_tensor_as_numpy_data(sess, conv_module)
axis_0_length = np.prod(numpy_weight.shape[:3])
axis_1_length = numpy_weight.shape[3]
reshaped_weights = numpy_weight.reshape(int(axis_0_length), int(axis_1_length))
return reshaped_weights
With this change axis_1_length
has the proper number of output channels of 32. Can someone from the AIMET team check this and confirm this observation.
Thanks!