diff --git a/keras/keras_model.py b/keras/keras_model.py index 919dc7db48b92a3aba8bc164a43f31c0337fd9b8..5add25e02b950adb31f6ae7e8381f49d354fb39a 100644 --- a/keras/keras_model.py +++ b/keras/keras_model.py @@ -55,20 +55,31 @@ def conv_error_3d(filters, e, stride): @tf.function def normalize_weights(filters, out_channels): + #print('filters shape', tf.shape(filters)) norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1) norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) adjusted = [f / norms for f in filters] + #raise Exception('Beep') + return adjusted # @tf.function def normalize_weights_3d(filters, out_channels): - norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1) + #for f in filters: + # print('filters 3d shape', f.shape) + norms = tf.norm(tf.reshape(tf.transpose(filters[0], perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1) + # tf.print("norms", norms.shape, norms) norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) adjusted = [f / norms for f in filters] + + #for i in range(out_channels): + # tf.print("after normalization", tf.norm(adjusted[0][:,:,:,0,i])) + #print() + #raise Exception('Beep') return adjusted class SparseCode(keras.layers.Layer):