From 092b5e42615e8d16c36acee63c13ad297ed29d16 Mon Sep 17 00:00:00 2001 From: hannandarryl <hannandarryl@gmail.com> Date: Thu, 10 Feb 2022 17:45:40 +0000 Subject: [PATCH] added updates to keras model to run in 2d and 3d --- keras/keras_model.py | 121 ++++++++++++++++++++++++++---------- keras/train_sparse_model.py | 71 ++++++++++++--------- 2 files changed, 129 insertions(+), 63 deletions(-) diff --git a/keras/keras_model.py b/keras/keras_model.py index 2b11890..6327fe9 100644 --- a/keras/keras_model.py +++ b/keras/keras_model.py @@ -23,7 +23,7 @@ def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, @tf.function def do_recon_3d(filters, activations, batch_size, stride): - recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=stride) + recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride]) return recon @@ -41,7 +41,7 @@ def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride) @tf.function def conv_error_3d(filters, e, stride): - g = tf.nn.conv3d(e, filters, strides=[stride, stride, stride, stride, stride], padding='SAME') + g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='SAME') return g @@ -54,8 +54,17 @@ def normalize_weights(filters, out_channels): return adjusted +@tf.function +def normalize_weights_3d(filters, out_channels): + norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1) + norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) + + adjusted = [f / norms for f in filters] + + return adjusted + class SparseCode(keras.layers.Layer): - def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter): + def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): super(SparseCode, self).__init__() self.out_channels = out_channels @@ -65,32 +74,25 @@ class SparseCode(keras.layers.Layer): self.activation_lr = activation_lr self.max_activation_iter = max_activation_iter self.batch_size = batch_size - - # pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu') - # weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4) - # self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32') - # weight_list = torch.chunk(weight_tensor, 5, dim=2) - # - initializer = tf.keras.initializers.HeNormal() - self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) - self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) - self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) - self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) - self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) - - weights = normalize_weights(self.get_weights(), out_channels) - self.set_weights(weights) + self.run_2d = run_2d @tf.function - def do_update(self, images, u, m, v, b1, b2, eps, i): + def do_update(self, images, filters, u, m, v, b1, b2, eps, i): activations = tf.nn.relu(u - self.lam) - recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride) + if self.run_2d: + recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.stride) + else: + recon = do_recon_3d(filters, activations, self.batch_size, self.stride) + e = images - recon g = -1 * u - - convd_error = conv_error(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, e, self.stride) + if self.run_2d: + convd_error = conv_error(filters[0], filters[1], filters[2], filters[3], filters[4], e, self.stride) + else: + convd_error = conv_error_3d(filters, e, self.stride) + g = g + convd_error g = g + activations @@ -102,29 +104,82 @@ class SparseCode(keras.layers.Layer): du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps) u += du + +# i += 1 +# return images, u, m, v, b1, b2, eps, i return u, m, v @tf.function - def call(self, images): - u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) - m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) - v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + def call(self, images, filters): + if self.run_2d: + u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + else: + u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) + m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) + v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) - tf.print('activations before:', tf.reduce_sum(u)) - - b1 = 0.9 - b2 = 0.999 - eps = 1e-8 +# tf.print('activations before:', tf.reduce_sum(u)) + b1 = tf.constant(0.9, dtype='float32') + b2 = tf.constant(0.999, dtype='float32') + eps = tf.constant(1e-8, dtype='float32') + + +# i = tf.constant(0, dtype='float32') +# c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter) +# images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i]) for i in range(self.max_activation_iter): - u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i) + u, m, v = self.do_update(images, filters, u, m, v, b1, b2, eps, i) u = tf.nn.relu(u - self.lam) - tf.print('activations after:', tf.reduce_sum(u)) +# tf.print('activations after:', tf.reduce_sum(u)) return u + +class SparseCodeConv(keras.Model): + def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): + super().__init__() + self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d) + + self.out_channels = out_channels + self.in_channels = in_channels + self.stride = stride + self.lam = lam + self.activation_lr = activation_lr + self.max_activation_iter = max_activation_iter + self.batch_size = batch_size + self.run_2d = run_2d + + initializer = tf.keras.initializers.HeNormal() + if run_2d: + self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + else: + self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + + if run_2d: + weights = normalize_weights(self.get_weights(), out_channels) + else: + weights = normalize_weights_3d(self.get_weights(), out_channels) + self.set_weights(weights) + + @tf.function + def call(self, images): + if self.run_2d: + activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)]) + recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride) + else: + activations = self.sparse_code(images, tf.stop_gradient(self.filters)) + recon = do_recon_3d(self.filters, activations, self.batch_size, self.stride) + + return recon, activations class Classifier(keras.layers.Layer): def __init__(self): diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py index 54e0209..f54caed 100644 --- a/keras/train_sparse_model.py +++ b/keras/train_sparse_model.py @@ -10,7 +10,7 @@ import os from sparse_coding_torch.load_data import load_yolo_clips import tensorflow.keras as keras import tensorflow as tf -from keras_model import SparseCode, do_recon, normalize_weights +from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv def plot_video(video): @@ -55,11 +55,11 @@ def plot_original_vs_recon(original, reconstruction, idx=0): def plot_filters(filters): - num_filters = filters.shape[0] + num_filters = filters.shape[4] ncol = 3 # ncol = int(np.sqrt(num_filters)) # nrow = int(np.sqrt(num_filters)) - T = filters.shape[2] + T = filters.shape[0] if num_filters // ncol == num_filters / ncol: nrow = num_filters // ncol @@ -74,7 +74,7 @@ def plot_filters(filters): for i in range(num_filters): r = i // ncol c = i % ncol - ims[(r, c)] = axes[r, c].imshow(filters[i, 0, 0, :, :], + ims[(r, c)] = axes[r, c].imshow(filters[0, :, :, 0, i], cmap=cm.Greys_r) def update(i): @@ -82,29 +82,30 @@ def plot_filters(filters): for i in range(num_filters): r = i // ncol c = i % ncol - ims[(r, c)].set_data(filters[i, 0, t, :, :]) + ims[(r, c)].set_data(filters[t, :, :, 0, i]) return FuncAnimation(plt.gcf(), update, interval=1000/20) -def sparse_loss(filters_1, filters_2, filters_3, filters_4, filters_5, images, activations, batch_size, lam, stride): - recon = do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride) +def sparse_loss(recon, activations, batch_size, lam, stride): loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2)) loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1)) return loss if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch_size', default=8, type=int) + parser.add_argument('--batch_size', default=6, type=int) parser.add_argument('--kernel_size', default=15, type=int) parser.add_argument('--num_kernels', default=64, type=int) parser.add_argument('--stride', default=2, type=int) - parser.add_argument('--max_activation_iter', default=150, type=int) + parser.add_argument('--max_activation_iter', default=50, type=int) parser.add_argument('--activation_lr', default=1e-2, type=float) - parser.add_argument('--lr', default=1e-5, type=float) + parser.add_argument('--lr', default=1e-2, type=float) parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--lam', default=0.05, type=float) parser.add_argument('--output_dir', default='./output', type=str) parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--run_2d', action='store_true') + parser.add_argument('--save_filters', action='store_true') args = parser.parse_args() @@ -122,14 +123,24 @@ if __name__ == "__main__": example_data = next(iter(train_loader)) - inputs = keras.Input(shape=(100, 200, 5)) + if args.run_2d: + inputs = keras.Input(shape=(100, 200, 5)) + else: + inputs = keras.Input(shape=(5, 100, 200, 1)) - output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs) + output = SparseCodeConv(batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs) model = keras.Model(inputs=inputs, outputs=output) + + if args.save_filters: + if args.run_2d: + filters = plot_filters(tf.stack(model.get_weights(), axis=0)) + else: + filters = plot_filters(model.get_weights()[0]) + filters.save(os.path.join(args.output_dir, 'filters_start.mp4')) learning_rate = args.lr - filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) + filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss_log = [] best_so_far = float('inf') @@ -139,38 +150,38 @@ if __name__ == "__main__": epoch_start = time.perf_counter() for labels, local_batch, vid_f in tqdm(train_loader): - tf.print(vid_f) if local_batch.size(0) != args.batch_size: continue - images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy() + if args.run_2d: + images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy() + else: + images = local_batch.permute(0, 2, 3, 4, 1).numpy() with tf.GradientTape() as tape: - activations = model(images) - - filters = model.get_weights() - - loss = sparse_loss(filters[0], filters[1], filters[2], filters[3], filters[4], images, activations, args.batch_size, args.lam, args.stride) -# recon = model.do_recon(activations) -# loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2)) -# loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1)) + recon, activations = model(images) + loss = sparse_loss(recon, activations, args.batch_size, args.lam, args.stride) epoch_loss += loss * local_batch.size(0) - tf.print('loss:', loss) gradients = tape.gradient(loss, model.trainable_weights) - tf.print('gradients:', tf.reduce_sum(gradients)) filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights)) - tf.print('weights:', tf.reduce_sum(model.trainable_weights)) - - weights = normalize_weights(model.get_weights(), args.num_kernels) + if args.run_2d: + weights = normalize_weights(model.get_weights(), args.num_kernels) + else: + weights = normalize_weights_3d(model.get_weights(), args.num_kernels) model.set_weights(weights) - - tf.print('normalized weights:', tf.reduce_sum(model.trainable_weights)) epoch_end = time.perf_counter() epoch_loss /= len(train_loader.sampler) + + if args.save_filters and epoch % 5 == 0: + if args.run_2d: + filters = plot_filters(tf.stack(model.get_weights(), axis=0)) + else: + filters = plot_filters(model.get_weights()[0]) + filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4')) if epoch_loss < best_so_far: print("found better model") -- GitLab