Skip to content
Snippets Groups Projects
Commit e3e0dafd authored by hannandarryl's avatar hannandarryl
Browse files

Merge remote-tracking branch 'origin/main' into main

parents 6c365931 a7fa2319
No related branches found
No related tags found
No related merge requests found
......@@ -55,20 +55,31 @@ def conv_error_3d(filters, e, stride):
@tf.function
def normalize_weights(filters, out_channels):
#print('filters shape', tf.shape(filters))
norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
adjusted = [f / norms for f in filters]
#raise Exception('Beep')
return adjusted
# @tf.function
def normalize_weights_3d(filters, out_channels):
norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1)
#for f in filters:
# print('filters 3d shape', f.shape)
norms = tf.norm(tf.reshape(tf.transpose(filters[0], perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1)
# tf.print("norms", norms.shape, norms)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
adjusted = [f / norms for f in filters]
#for i in range(out_channels):
# tf.print("after normalization", tf.norm(adjusted[0][:,:,:,0,i]))
#print()
#raise Exception('Beep')
return adjusted
class SparseCode(keras.layers.Layer):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment