diff --git a/keras/keras_model.py b/keras/keras_model.py index 2360c59651090ccee188aa1ab899cc8567be7b83..2b11890b207db201fc955a520ec26d45ae3cb7e2 100644 --- a/keras/keras_model.py +++ b/keras/keras_model.py @@ -9,6 +9,51 @@ import torch.nn as nn from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler from sparse_coding_torch.conv_sparse_model import ConvSparseLayer +@tf.function +def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride): + out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride) + out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride) + out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride) + out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride) + out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride) + + recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3) + + return recon + +@tf.function +def do_recon_3d(filters, activations, batch_size, stride): + recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=stride) + + return recon + +@tf.function +def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride): + e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3) + + g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME') + g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME') + g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME') + g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME') + g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME') + + return g + +@tf.function +def conv_error_3d(filters, e, stride): + g = tf.nn.conv3d(e, filters, strides=[stride, stride, stride, stride, stride], padding='SAME') + + return g + +@tf.function +def normalize_weights(filters, out_channels): + norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1) + norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) + + adjusted = [f / norms for f in filters] + + return adjusted + class SparseCode(keras.layers.Layer): def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter): super(SparseCode, self).__init__() @@ -27,104 +72,57 @@ class SparseCode(keras.layers.Layer): # weight_list = torch.chunk(weight_tensor, 5, dim=2) # initializer = tf.keras.initializers.HeNormal() - self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True) - self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True) - self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True) - self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True) - self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True) - - def normalize_weights(self): - norms = tf.norm(tf.reshape(tf.stack([self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5]), (self.out_channels, -1)), axis=1, keepdims=True) - norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), self.filters_1.shape) - self.filters_1 = self.filters_1 / norms - self.filters_2 = self.filters_2 / norms - self.filters_3 = self.filters_3 / norms - self.filters_4 = self.filters_4 / norms - self.filters_5 = self.filters_5 / norms - - @tf.function - def do_recon(self, activations): - out_1 = tf.nn.conv2d_transpose(activations, self.filters_1, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) - out_2 = tf.nn.conv2d_transpose(activations, self.filters_2, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) - out_3 = tf.nn.conv2d_transpose(activations, self.filters_3, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) - out_4 = tf.nn.conv2d_transpose(activations, self.filters_4, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) - out_5 = tf.nn.conv2d_transpose(activations, self.filters_5, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) - - recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3) - - return recon - - @tf.function - def do_recon_3d(self, activations): - recon = tf.nn.conv3d_transpose(activations, self.filter, output_shape=(self.batch_size, 5, 100, 200, 1), strides=self.stride) - - return recon - - @tf.function - def conv_error(self, e): - e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3) - - g = tf.nn.conv2d(e1, self.filters_1, strides=self.stride, padding='SAME') - g = g + tf.nn.conv2d(e2, self.filters_2, strides=self.stride, padding='SAME') - g = g + tf.nn.conv2d(e3, self.filters_3, strides=self.stride, padding='SAME') - g = g + tf.nn.conv2d(e4, self.filters_4, strides=self.stride, padding='SAME') - g = g + tf.nn.conv2d(e5, self.filters_5, strides=self.stride, padding='SAME') - - return g - - @tf.function - def conv_error_3d(self, e): - g = tf.nn.conv3d(e, self.filter, strides=[1, 1, 1, 1, 1], padding='SAME') - - return g + self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + + weights = normalize_weights(self.get_weights(), out_channels) + self.set_weights(weights) @tf.function def do_update(self, images, u, m, v, b1, b2, eps, i): activations = tf.nn.relu(u - self.lam) - recon = self.do_recon(activations) + recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride) e = images - recon g = -1 * u - - g = g + self.conv_error(e) + + convd_error = conv_error(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, e, self.stride) + + g = g + convd_error g = g + activations - + m = b1 * m + (1-b1) * g v = b2 * v + (1-b2) * g**2 mh = m / (1 - b1**(1+i)) vh = v / (1 - b2**(1+i)) - u = u + (self.activation_lr * mh / (tf.math.sqrt(vh) + eps)) + + du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps) + u += du return u, m, v - def loss(self, images, activations): - recon = self.do_recon(activations) - loss = 0.5 * (1/images.shape[0]) * tf.sum(tf.math.pow(images - recon, 2)) - loss += self.lam * tf.reduce_mean(tf.sum(tf.math.abs(activations.reshape(activations.shape[0], -1)), axis=1)) - return loss - @tf.function def call(self, images): - u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) - m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) - v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) + u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + + tf.print('activations before:', tf.reduce_sum(u)) b1 = 0.9 b2 = 0.999 eps = 1e-8 - # images, u, m, v, b1, b2, eps, i = self.do_update(images, u, m, v, b1, b2, eps, 0) - - # i = tf.constant(0, dtype=tf.float16) - # c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, 20) - # images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i], shape_invariants=[images.get_shape(), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), None, None, None, i.get_shape()]) for i in range(self.max_activation_iter): u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i) u = tf.nn.relu(u - self.lam) - - self.add_loss(self.loss(images, u)) + + tf.print('activations after:', tf.reduce_sum(u)) return u diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py index 72c9cfcae85236fdd0fd6fc73ba12f7d5370de13..54e0209ef0449d1ac16b5458859de48b03ab4d24 100644 --- a/keras/train_sparse_model.py +++ b/keras/train_sparse_model.py @@ -8,6 +8,9 @@ from tqdm import tqdm import argparse import os from sparse_coding_torch.load_data import load_yolo_clips +import tensorflow.keras as keras +import tensorflow as tf +from keras_model import SparseCode, do_recon, normalize_weights def plot_video(video): @@ -83,16 +86,21 @@ def plot_filters(filters): return FuncAnimation(plt.gcf(), update, interval=1000/20) +def sparse_loss(filters_1, filters_2, filters_3, filters_4, filters_5, images, activations, batch_size, lam, stride): + recon = do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride) + loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2)) + loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1)) + return loss if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch_size', default=12, type=int) + parser.add_argument('--batch_size', default=8, type=int) parser.add_argument('--kernel_size', default=15, type=int) parser.add_argument('--num_kernels', default=64, type=int) - parser.add_argument('--stride', default=1, type=int) - parser.add_argument('--max_activation_iter', default=200, type=int) + parser.add_argument('--stride', default=2, type=int) + parser.add_argument('--max_activation_iter', default=150, type=int) parser.add_argument('--activation_lr', default=1e-2, type=float) - parser.add_argument('--lr', default=1e-3, type=float) + parser.add_argument('--lr', default=1e-5, type=float) parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--lam', default=0.05, type=float) parser.add_argument('--output_dir', default='./output', type=str) @@ -103,16 +111,18 @@ if __name__ == "__main__": output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: out_f.write(str(args)) - train_loader, _ = load_yolo_clips(batch_size, mode='all_train') + train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json') print('Loaded', len(train_loader), 'train examples') example_data = next(iter(train_loader)) - inputs = keras.Input(shape=(5, 100, 200, 1)) + inputs = keras.Input(shape=(100, 200, 5)) output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs) @@ -124,21 +134,40 @@ if __name__ == "__main__": loss_log = [] best_so_far = float('inf') - for epoch in tqdm(range(args.epochs)): + for epoch in range(args.epochs): epoch_loss = 0 epoch_start = time.perf_counter() - for labels, local_batch, vid_f in train_loader: + for labels, local_batch, vid_f in tqdm(train_loader): + tf.print(vid_f) + if local_batch.size(0) != args.batch_size: + continue + images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy() + with tf.GradientTape() as tape: - - activations = model(local_batch.numpy()) - loss = tf.sum(model.losses) - - epoch_loss += loss * local_batch.size(0) - - gradients = tape.gradient(loss, model.trainable_weights) - - optimizer.apply_gradients(zip(gradients, model.trainable_weights)) + activations = model(images) + + filters = model.get_weights() + + loss = sparse_loss(filters[0], filters[1], filters[2], filters[3], filters[4], images, activations, args.batch_size, args.lam, args.stride) +# recon = model.do_recon(activations) +# loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2)) +# loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1)) + + epoch_loss += loss * local_batch.size(0) + tf.print('loss:', loss) + + gradients = tape.gradient(loss, model.trainable_weights) + tf.print('gradients:', tf.reduce_sum(gradients)) + + filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights)) + + tf.print('weights:', tf.reduce_sum(model.trainable_weights)) + + weights = normalize_weights(model.get_weights(), args.num_kernels) + model.set_weights(weights) + + tf.print('normalized weights:', tf.reduce_sum(model.trainable_weights)) epoch_end = time.perf_counter() epoch_loss /= len(train_loader.sampler)