From 092b5e42615e8d16c36acee63c13ad297ed29d16 Mon Sep 17 00:00:00 2001
From: hannandarryl <hannandarryl@gmail.com>
Date: Thu, 10 Feb 2022 17:45:40 +0000
Subject: [PATCH] added updates to keras model to run in 2d and 3d

---
 keras/keras_model.py        | 121 ++++++++++++++++++++++++++----------
 keras/train_sparse_model.py |  71 ++++++++++++---------
 2 files changed, 129 insertions(+), 63 deletions(-)

diff --git a/keras/keras_model.py b/keras/keras_model.py
index 2b11890..6327fe9 100644
--- a/keras/keras_model.py
+++ b/keras/keras_model.py
@@ -23,7 +23,7 @@ def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations,
 
 @tf.function
 def do_recon_3d(filters, activations, batch_size, stride):
-    recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=stride)
+    recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride])
 
     return recon
 
@@ -41,7 +41,7 @@ def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride)
 
 @tf.function
 def conv_error_3d(filters, e, stride):
-    g = tf.nn.conv3d(e, filters, strides=[stride, stride, stride, stride, stride], padding='SAME')
+    g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='SAME')
 
     return g
 
@@ -54,8 +54,17 @@ def normalize_weights(filters, out_channels):
     
     return adjusted
 
+@tf.function
+def normalize_weights_3d(filters, out_channels):
+    norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1)
+    norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
+    
+    adjusted = [f / norms for f in filters]
+    
+    return adjusted
+
 class SparseCode(keras.layers.Layer):
-    def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
+    def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
         super(SparseCode, self).__init__()
 
         self.out_channels = out_channels
@@ -65,32 +74,25 @@ class SparseCode(keras.layers.Layer):
         self.activation_lr = activation_lr
         self.max_activation_iter = max_activation_iter
         self.batch_size = batch_size
-
-        # pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
-        # weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
-        # self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
-        # weight_list = torch.chunk(weight_tensor, 5, dim=2)
-        #
-        initializer = tf.keras.initializers.HeNormal()
-        self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
-        self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
-        self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
-        self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
-        self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
-        
-        weights = normalize_weights(self.get_weights(), out_channels)
-        self.set_weights(weights)
+        self.run_2d = run_2d
 
     @tf.function
-    def do_update(self, images, u, m, v, b1, b2, eps, i):
+    def do_update(self, images, filters, u, m, v, b1, b2, eps, i):
         activations = tf.nn.relu(u - self.lam)
 
-        recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
+        if self.run_2d:
+            recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.stride)
+        else:
+            recon = do_recon_3d(filters, activations, self.batch_size, self.stride)
+
         e = images - recon
         g = -1 * u
-                
-        convd_error = conv_error(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, e, self.stride)
         
+        if self.run_2d:
+            convd_error = conv_error(filters[0], filters[1], filters[2], filters[3], filters[4], e, self.stride)
+        else:
+            convd_error = conv_error_3d(filters, e, self.stride)
+            
         g = g + convd_error
 
         g = g + activations
@@ -102,29 +104,82 @@ class SparseCode(keras.layers.Layer):
         
         du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps)
         u += du
+        
+#         i += 1
 
+#         return images, u, m, v, b1, b2, eps, i
         return u, m, v
 
     @tf.function
-    def call(self, images):
-        u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
-        m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
-        v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+    def call(self, images, filters):
+        if self.run_2d:
+            u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+            m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+            v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+        else:
+            u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
+            m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
+            v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
         
-        tf.print('activations before:', tf.reduce_sum(u))
-
-        b1 = 0.9
-        b2 = 0.999
-        eps = 1e-8
+#         tf.print('activations before:', tf.reduce_sum(u))
 
+        b1 = tf.constant(0.9, dtype='float32')
+        b2 = tf.constant(0.999, dtype='float32')
+        eps = tf.constant(1e-8, dtype='float32')
+        
+        
+#         i = tf.constant(0, dtype='float32')
+#         c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter)
+#         images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i])
         for i in range(self.max_activation_iter):
-            u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i)
+            u, m, v = self.do_update(images, filters, u, m, v, b1, b2, eps, i)
 
         u = tf.nn.relu(u - self.lam)
         
-        tf.print('activations after:', tf.reduce_sum(u))
+#         tf.print('activations after:', tf.reduce_sum(u))
 
         return u
+    
+class SparseCodeConv(keras.Model):
+    def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+        super().__init__()
+        self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
+        
+        self.out_channels = out_channels
+        self.in_channels = in_channels
+        self.stride = stride
+        self.lam = lam
+        self.activation_lr = activation_lr
+        self.max_activation_iter = max_activation_iter
+        self.batch_size = batch_size
+        self.run_2d = run_2d
+        
+        initializer = tf.keras.initializers.HeNormal()
+        if run_2d:
+            self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+            self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+            self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+            self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+            self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+        else:
+            self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+        
+        if run_2d:
+            weights = normalize_weights(self.get_weights(), out_channels)
+        else:
+            weights = normalize_weights_3d(self.get_weights(), out_channels)
+        self.set_weights(weights)
+
+    @tf.function
+    def call(self, images):
+        if self.run_2d:
+            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
+            recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
+        else:
+            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
+            recon = do_recon_3d(self.filters, activations, self.batch_size, self.stride)
+            
+        return recon, activations
 
 class Classifier(keras.layers.Layer):
     def __init__(self):
diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py
index 54e0209..f54caed 100644
--- a/keras/train_sparse_model.py
+++ b/keras/train_sparse_model.py
@@ -10,7 +10,7 @@ import os
 from sparse_coding_torch.load_data import load_yolo_clips
 import tensorflow.keras as keras
 import tensorflow as tf
-from keras_model import SparseCode, do_recon, normalize_weights
+from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv
 
 def plot_video(video):
 
@@ -55,11 +55,11 @@ def plot_original_vs_recon(original, reconstruction, idx=0):
 
 
 def plot_filters(filters):
-    num_filters = filters.shape[0]
+    num_filters = filters.shape[4]
     ncol = 3
     # ncol = int(np.sqrt(num_filters))
     # nrow = int(np.sqrt(num_filters))
-    T = filters.shape[2]
+    T = filters.shape[0]
 
     if num_filters // ncol == num_filters / ncol:
         nrow = num_filters // ncol
@@ -74,7 +74,7 @@ def plot_filters(filters):
     for i in range(num_filters):
         r = i // ncol
         c = i % ncol
-        ims[(r, c)] = axes[r, c].imshow(filters[i, 0, 0, :, :],
+        ims[(r, c)] = axes[r, c].imshow(filters[0, :, :, 0, i],
                                         cmap=cm.Greys_r)
 
     def update(i):
@@ -82,29 +82,30 @@ def plot_filters(filters):
         for i in range(num_filters):
             r = i // ncol
             c = i % ncol
-            ims[(r, c)].set_data(filters[i, 0, t, :, :])
+            ims[(r, c)].set_data(filters[t, :, :, 0, i])
 
     return FuncAnimation(plt.gcf(), update, interval=1000/20)
 
-def sparse_loss(filters_1, filters_2, filters_3, filters_4, filters_5, images, activations, batch_size, lam, stride):
-    recon = do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride)
+def sparse_loss(recon, activations, batch_size, lam, stride):
     loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
     loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
     return loss
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--batch_size', default=8, type=int)
+    parser.add_argument('--batch_size', default=6, type=int)
     parser.add_argument('--kernel_size', default=15, type=int)
     parser.add_argument('--num_kernels', default=64, type=int)
     parser.add_argument('--stride', default=2, type=int)
-    parser.add_argument('--max_activation_iter', default=150, type=int)
+    parser.add_argument('--max_activation_iter', default=50, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
-    parser.add_argument('--lr', default=1e-5, type=float)
+    parser.add_argument('--lr', default=1e-2, type=float)
     parser.add_argument('--epochs', default=100, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
     parser.add_argument('--output_dir', default='./output', type=str)
     parser.add_argument('--seed', default=42, type=int)
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--save_filters', action='store_true')
 
     args = parser.parse_args()
 
@@ -122,14 +123,24 @@ if __name__ == "__main__":
 
     example_data = next(iter(train_loader))
 
-    inputs = keras.Input(shape=(100, 200, 5))
+    if args.run_2d:
+        inputs = keras.Input(shape=(100, 200, 5))
+    else:
+        inputs = keras.Input(shape=(5, 100, 200, 1))
 
-    output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs)
+    output = SparseCodeConv(batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs)
 
     model = keras.Model(inputs=inputs, outputs=output)
+    
+    if args.save_filters:
+        if args.run_2d:
+            filters = plot_filters(tf.stack(model.get_weights(), axis=0))
+        else:
+            filters = plot_filters(model.get_weights()[0])
+        filters.save(os.path.join(args.output_dir, 'filters_start.mp4'))
 
     learning_rate = args.lr
-    filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
+    filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
 
     loss_log = []
     best_so_far = float('inf')
@@ -139,38 +150,38 @@ if __name__ == "__main__":
         epoch_start = time.perf_counter()
 
         for labels, local_batch, vid_f in tqdm(train_loader):
-            tf.print(vid_f)
             if local_batch.size(0) != args.batch_size:
                 continue
-            images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
+            if args.run_2d:
+                images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
+            else:
+                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
             
             with tf.GradientTape() as tape:
-                activations = model(images)
-                
-                filters = model.get_weights()
-
-                loss = sparse_loss(filters[0], filters[1], filters[2], filters[3], filters[4], images, activations, args.batch_size, args.lam, args.stride)
-#             recon = model.do_recon(activations)
-#             loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
-#             loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
+                recon, activations = model(images)
+                loss = sparse_loss(recon, activations, args.batch_size, args.lam, args.stride)
 
             epoch_loss += loss * local_batch.size(0)
-            tf.print('loss:', loss)
 
             gradients = tape.gradient(loss, model.trainable_weights)
-            tf.print('gradients:', tf.reduce_sum(gradients))
 
             filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights))
             
-            tf.print('weights:', tf.reduce_sum(model.trainable_weights))
-            
-            weights = normalize_weights(model.get_weights(), args.num_kernels)
+            if args.run_2d:
+                weights = normalize_weights(model.get_weights(), args.num_kernels)
+            else:
+                weights = normalize_weights_3d(model.get_weights(), args.num_kernels)
             model.set_weights(weights)
-            
-            tf.print('normalized weights:', tf.reduce_sum(model.trainable_weights))
 
         epoch_end = time.perf_counter()
         epoch_loss /= len(train_loader.sampler)
+        
+        if args.save_filters and epoch % 5 == 0:
+            if args.run_2d:
+                filters = plot_filters(tf.stack(model.get_weights(), axis=0))
+            else:
+                filters = plot_filters(model.get_weights()[0])
+            filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4'))
 
         if epoch_loss < best_so_far:
             print("found better model")
-- 
GitLab