From a7fa231921a53f3dff180376041bdc3cdadcde22 Mon Sep 17 00:00:00 2001 From: Chris MacLellan <cm3786@drexel.edu> Date: Fri, 11 Feb 2022 19:54:26 +0000 Subject: [PATCH] updated 3d weight normalization --- keras/keras_model.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/keras/keras_model.py b/keras/keras_model.py index 6327fe9..9d5ddc2 100644 --- a/keras/keras_model.py +++ b/keras/keras_model.py @@ -47,20 +47,31 @@ def conv_error_3d(filters, e, stride): @tf.function def normalize_weights(filters, out_channels): + #print('filters shape', tf.shape(filters)) norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1) norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) adjusted = [f / norms for f in filters] + #raise Exception('Beep') + return adjusted -@tf.function +# @tf.function def normalize_weights_3d(filters, out_channels): - norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1) + #for f in filters: + # print('filters 3d shape', f.shape) + norms = tf.norm(tf.reshape(tf.transpose(filters[0], perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1) + # tf.print("norms", norms.shape, norms) norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) adjusted = [f / norms for f in filters] + + #for i in range(out_channels): + # tf.print("after normalization", tf.norm(adjusted[0][:,:,:,0,i])) + #print() + #raise Exception('Beep') return adjusted class SparseCode(keras.layers.Layer): -- GitLab