From a7fa231921a53f3dff180376041bdc3cdadcde22 Mon Sep 17 00:00:00 2001
From: Chris MacLellan <cm3786@drexel.edu>
Date: Fri, 11 Feb 2022 19:54:26 +0000
Subject: [PATCH] updated 3d weight normalization

---
 keras/keras_model.py | 15 +++++++++++++--
 1 file changed, 13 insertions(+), 2 deletions(-)

diff --git a/keras/keras_model.py b/keras/keras_model.py
index 6327fe9..9d5ddc2 100644
--- a/keras/keras_model.py
+++ b/keras/keras_model.py
@@ -47,20 +47,31 @@ def conv_error_3d(filters, e, stride):
 
 @tf.function
 def normalize_weights(filters, out_channels):
+    #print('filters shape', tf.shape(filters))
     norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
     
     adjusted = [f / norms for f in filters]
     
+    #raise Exception('Beep')
+    
     return adjusted
 
-@tf.function
+# @tf.function
 def normalize_weights_3d(filters, out_channels):
-    norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1)
+    #for f in filters:
+    #    print('filters 3d shape', f.shape)
+    norms = tf.norm(tf.reshape(tf.transpose(filters[0], perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1)
+    # tf.print("norms", norms.shape, norms)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
     
     adjusted = [f / norms for f in filters]
+
+    #for i in range(out_channels):
+    #    tf.print("after normalization", tf.norm(adjusted[0][:,:,:,0,i]))
+    #print()
     
+    #raise Exception('Beep')
     return adjusted
 
 class SparseCode(keras.layers.Layer):
-- 
GitLab