diff --git a/keras/keras_model.py b/keras/keras_model.py
index 919dc7db48b92a3aba8bc164a43f31c0337fd9b8..5add25e02b950adb31f6ae7e8381f49d354fb39a 100644
--- a/keras/keras_model.py
+++ b/keras/keras_model.py
@@ -55,20 +55,31 @@ def conv_error_3d(filters, e, stride):
 
 @tf.function
 def normalize_weights(filters, out_channels):
+    #print('filters shape', tf.shape(filters))
     norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
     
     adjusted = [f / norms for f in filters]
     
+    #raise Exception('Beep')
+    
     return adjusted
 
 # @tf.function
 def normalize_weights_3d(filters, out_channels):
-    norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1)
+    #for f in filters:
+    #    print('filters 3d shape', f.shape)
+    norms = tf.norm(tf.reshape(tf.transpose(filters[0], perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1)
+    # tf.print("norms", norms.shape, norms)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
     
     adjusted = [f / norms for f in filters]
+
+    #for i in range(out_channels):
+    #    tf.print("after normalization", tf.norm(adjusted[0][:,:,:,0,i]))
+    #print()
     
+    #raise Exception('Beep')
     return adjusted
 
 class SparseCode(keras.layers.Layer):