diff --git a/keras/keras_model.py b/keras/keras_model.py index 5add25e02b950adb31f6ae7e8381f49d354fb39a..530282d40f25a3b0486fb7d685013349ae0cb296 100644 --- a/keras/keras_model.py +++ b/keras/keras_model.py @@ -15,48 +15,42 @@ def load_pytorch_weights(file_path): return weight_tensor -@tf.function -def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride): - out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride) - out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride) - out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride) - out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride) - out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride) +# @tf.function +def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, image_height, image_width, stride): + out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID') + out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID') + out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID') + out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID') + out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID') recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3) return recon # @tf.function -def do_recon_3d(filters, activations, batch_size, stride): - activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]]) - recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride]) +def do_recon_3d(filters, activations, batch_size, image_height, image_width, stride): +# activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]]) + recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, image_height, image_width, 1), strides=[1, stride, stride], padding='VALID') return recon -@tf.function -def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride): - e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3) - - g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME') - g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME') - g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME') - g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME') - g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME') +# @tf.function +def conv_error(filters, e, stride): + g = tf.nn.conv2d(e, filters, strides=stride, padding='VALID') return g # @tf.function def conv_error_3d(filters, e, stride): - e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]]) +# e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]]) g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID') return g -@tf.function +# @tf.function def normalize_weights(filters, out_channels): #print('filters shape', tf.shape(filters)) - norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1) + norms = tf.norm(tf.reshape(tf.transpose(tf.stack(filters), perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1) norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) adjusted = [f / norms for f in filters] @@ -83,7 +77,7 @@ def normalize_weights_3d(filters, out_channels): return adjusted class SparseCode(keras.layers.Layer): - def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): + def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): super(SparseCode, self).__init__() self.out_channels = out_channels @@ -93,63 +87,78 @@ class SparseCode(keras.layers.Layer): self.activation_lr = activation_lr self.max_activation_iter = max_activation_iter self.batch_size = batch_size + self.image_height = image_height + self.image_width = image_width + self.kernel_size = kernel_size self.run_2d = run_2d - @tf.function +# @tf.function def do_update(self, images, filters, u, m, v, b1, b2, eps, i): activations = tf.nn.relu(u - self.lam) if self.run_2d: - recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.stride) + recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.image_height, self.image_width, self.stride) else: - recon = do_recon_3d(filters, activations, self.batch_size, self.stride) + recon = do_recon_3d(filters, activations, self.batch_size, self.image_height, self.image_width, self.stride) e = images - recon g = -1 * u if self.run_2d: - convd_error = conv_error(filters[0], filters[1], filters[2], filters[3], filters[4], e, self.stride) + e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3) + g += conv_error(filters[0], e1, self.stride) + g += conv_error(filters[1], e2, self.stride) + g += conv_error(filters[2], e3, self.stride) + g += conv_error(filters[3], e4, self.stride) + g += conv_error(filters[4], e5, self.stride) else: convd_error = conv_error_3d(filters, e, self.stride) - g = g + convd_error + g = g + convd_error g = g + activations m = b1 * m + (1-b1) * g - v = b2 * v + (1-b2) * g**2 - mh = m / (1 - b1**(1+i)) - vh = v / (1 - b2**(1+i)) + + v = b2 * v + (1-b2) * tf.math.pow(g, 2) + + mh = m / (1 - tf.math.pow(b1, (1+i))) + + vh = v / (1 - tf.math.pow(b2, (1+i))) du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps) + u += du # i += 1 -# return images, u, m, v, b1, b2, eps, i +# return images, filters, u, m, v, b1, b2, eps, i return u, m, v - @tf.function +# @tf.function def call(self, images, filters): + filters = tf.squeeze(filters, axis=0) if self.run_2d: - u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) - m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) - v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) + output_shape = (self.batch_size, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels) else: - u = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels)) - m = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels)) - v = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels)) + output_shape = (self.batch_size, 1, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels) + + u = tf.zeros(shape=output_shape) + m = tf.zeros(shape=output_shape) + v = tf.zeros(shape=output_shape) # tf.print('activations before:', tf.reduce_sum(u)) b1 = tf.constant(0.9, dtype='float32') - b2 = tf.constant(0.999, dtype='float32') + b2 = tf.constant(0.99, dtype='float32') eps = tf.constant(1e-8, dtype='float32') +# print(u) + # i = tf.constant(0, dtype='float32') -# c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter) -# images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i]) +# c = lambda images, filters, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter) +# images, filters, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, filters, u, m, v, b1, b2, eps, i]) for i in range(self.max_activation_iter): u, m, v = self.do_update(images, filters, u, m, v, b1, b2, eps, i) @@ -159,10 +168,9 @@ class SparseCode(keras.layers.Layer): return u -class SparseCodeConv(keras.Model): - def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): +class ReconSparse(keras.Model): + def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): super().__init__() - self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d) self.out_channels = out_channels self.in_channels = in_channels @@ -171,6 +179,8 @@ class SparseCodeConv(keras.Model): self.activation_lr = activation_lr self.max_activation_iter = max_activation_iter self.batch_size = batch_size + self.image_height = image_height + self.image_width = image_width self.run_2d = run_2d initializer = tf.keras.initializers.HeNormal() @@ -181,7 +191,7 @@ class SparseCodeConv(keras.Model): self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) else: - self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True) + self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels), dtype='float32'), trainable=True) if run_2d: weights = normalize_weights(self.get_weights(), out_channels) @@ -189,16 +199,14 @@ class SparseCodeConv(keras.Model): weights = normalize_weights_3d(self.get_weights(), out_channels) self.set_weights(weights) - @tf.function - def call(self, images): +# @tf.function + def call(self, activations): if self.run_2d: - activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)]) - recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride) + recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.image_height, self.image_width, self.stride) else: - activations = self.sparse_code(images, tf.stop_gradient(self.filters)) - recon = do_recon_3d(self.filters, activations, self.batch_size, self.stride) + recon = do_recon_3d(self.filters, activations, self.batch_size, self.image_height, self.image_width, self.stride) - return recon, activations + return recon class Classifier(keras.layers.Layer): def __init__(self): @@ -218,6 +226,7 @@ class Classifier(keras.layers.Layer): @tf.function def call(self, activations): + activations = tf.squeeze(activations, axis=1) x = self.max_pool(activations) x = self.conv(x) x = self.flatten(x) diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py index d543e3af8d470e2db1b090f0718a829e40fcd18c..7a074267b9bbb394006887f0a747e4565265873a 100644 --- a/keras/train_sparse_model.py +++ b/keras/train_sparse_model.py @@ -7,10 +7,11 @@ from matplotlib.animation import FuncAnimation from tqdm import tqdm import argparse import os -from sparse_coding_torch.load_data import load_yolo_clips +from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos import tensorflow.keras as keras import tensorflow as tf -from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv, load_pytorch_weights +from keras_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse +import random def plot_video(video): @@ -55,6 +56,7 @@ def plot_original_vs_recon(original, reconstruction, idx=0): def plot_filters(filters): + filters = filters.astype('float32') num_filters = filters.shape[4] ncol = 3 # ncol = int(np.sqrt(num_filters)) @@ -97,17 +99,30 @@ if __name__ == "__main__": parser.add_argument('--kernel_size', default=15, type=int) parser.add_argument('--num_kernels', default=64, type=int) parser.add_argument('--stride', default=2, type=int) - parser.add_argument('--max_activation_iter', default=50, type=int) + parser.add_argument('--max_activation_iter', default=100, type=int) parser.add_argument('--activation_lr', default=1e-2, type=float) parser.add_argument('--lr', default=5e-2, type=float) - parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--epochs', default=20, type=int) parser.add_argument('--lam', default=0.05, type=float) parser.add_argument('--output_dir', default='./output', type=str) parser.add_argument('--seed', default=42, type=int) parser.add_argument('--run_2d', action='store_true') parser.add_argument('--save_filters', action='store_true') + parser.add_argument('--optimizer', default='adam', type=str) + parser.add_argument('--dataset', default='pnb', type=str) + args = parser.parse_args() + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + +# policy = keras.mixed_precision.Policy('mixed_float16') +# keras.mixed_precision.set_global_policy(policy) + + image_height = 360 + image_width = 304 output_dir = args.output_dir if not os.path.exists(output_dir): @@ -118,36 +133,55 @@ if __name__ == "__main__": with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: out_f.write(str(args)) - train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json') + if args.dataset == 'pnb': + train_loader, _ = load_pnb_videos(args.batch_size, mode='all_train', device=device, n_splits=1, sparse_model=None) + elif args.dataset == 'ptx': + train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json') + else: + raise Exception('Invalid dataset') print('Loaded', len(train_loader), 'train examples') example_data = next(iter(train_loader)) if args.run_2d: - inputs = keras.Input(shape=(100, 200, 5)) + inputs = keras.Input(shape=(image_height, image_width, 5)) else: - inputs = keras.Input(shape=(5, 100, 200, 1)) + inputs = keras.Input(shape=(5, image_height, image_width, 1)) + + filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32') - output = SparseCodeConv(batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs) + output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs) - model = keras.Model(inputs=inputs, outputs=output) + sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output) + + recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels)) + + recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs) + + recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs) if args.save_filters: if args.run_2d: - filters = plot_filters(tf.stack(model.get_weights(), axis=0)) + filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0)) else: - filters = plot_filters(model.get_weights()[0]) + filters = plot_filters(recon_model.get_weights()[0]) filters.save(os.path.join(args.output_dir, 'filters_start.mp4')) learning_rate = args.lr - filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + if args.optimizer == 'sgd': + filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + else: + filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) loss_log = [] best_so_far = float('inf') for epoch in range(args.epochs): epoch_loss = 0 + running_loss = 0.0 epoch_start = time.perf_counter() + + num_iters = 0 for labels, local_batch, vid_f in tqdm(train_loader): if local_batch.size(0) != args.batch_size: @@ -156,37 +190,52 @@ if __name__ == "__main__": images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy() else: images = local_batch.permute(0, 2, 3, 4, 1).numpy() + + activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) with tf.GradientTape() as tape: - recon, activations = model(images) + recon = recon_model(activations) loss = sparse_loss(recon, activations, args.batch_size, args.lam, args.stride) epoch_loss += loss * local_batch.size(0) + running_loss += loss * local_batch.size(0) - gradients = tape.gradient(loss, model.trainable_weights) + gradients = tape.gradient(loss, recon_model.trainable_weights) - filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights)) + filter_optimizer.apply_gradients(zip(gradients, recon_model.trainable_weights)) if args.run_2d: - weights = normalize_weights(model.get_weights(), args.num_kernels) + weights = normalize_weights(recon_model.get_weights(), args.num_kernels) else: - weights = normalize_weights_3d(model.get_weights(), args.num_kernels) - model.set_weights(weights) + weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels) + recon_model.set_weights(weights) + +# if args.save_filters and num_iters % 25 == 0: +# if args.run_2d: +# filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0)) +# else: +# filters = plot_filters(recon_model.get_weights()[0]) +# filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) + '_' + str(num_iters) + '.mp4')) +# loss_log.append(running_loss) +# print(running_loss) +# running_loss = 0.0 + + num_iters += 1 epoch_end = time.perf_counter() epoch_loss /= len(train_loader.sampler) - if args.save_filters and epoch % 5 == 0: + if args.save_filters and epoch % 2 == 0: if args.run_2d: - filters = plot_filters(tf.stack(model.get_weights(), axis=0)) + filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0)) else: - filters = plot_filters(model.get_weights()[0]) + filters = plot_filters(recon_model.get_weights()[0]) filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4')) if epoch_loss < best_so_far: print("found better model") # Save model parameters - model.save(os.path.join(output_dir, "sparse_conv3d_model-best.pt")) + recon_model.save(os.path.join(output_dir, "sparse_conv3d_model-best.pt")) best_so_far = epoch_loss loss_log.append(epoch_loss) diff --git a/scripts/train_classifier.py b/scripts/train_classifier.py index 38d43adb0a39a102e05c266616ef9f734da931b3..79b76808371b7883f5c1b6033bef54c40552a549 100644 --- a/scripts/train_classifier.py +++ b/scripts/train_classifier.py @@ -20,9 +20,9 @@ if __name__ == "__main__": parser.add_argument('--kernel_width', default=15, type=int) parser.add_argument('--kernel_depth', default=5, type=int) parser.add_argument('--num_kernels', default=64, type=int) - parser.add_argument('--stride', default=1, type=int) - parser.add_argument('--max_activation_iter', default=200, type=int) - parser.add_argument('--activation_lr', default=1e-2, type=float) + parser.add_argument('--stride', default=2, type=int) + parser.add_argument('--max_activation_iter', default=100, type=int) + parser.add_argument('--activation_lr', default=1e-1, type=float) parser.add_argument('--lr', default=5e-5, type=float) parser.add_argument('--epochs', default=40, type=int) parser.add_argument('--lam', default=0.05, type=float) @@ -39,7 +39,7 @@ if __name__ == "__main__": parser.add_argument('--n_splits', default=5, type=int) parser.add_argument('--whole_video', action='store_true') parser.add_argument('--save_train_test_splits', action='store_true') - parser.add_argument('--positive_videos', default='positive_videos.json', type=str) + parser.add_argument('--positive_videos', default=None, type=str) args = parser.parse_args() @@ -66,7 +66,7 @@ if __name__ == "__main__": out_channels=args.num_kernels, kernel_size=(args.kernel_depth, args.kernel_height, args.kernel_width), stride=args.stride, - padding=(0, 7, 7), + padding=0, convo_dim=3, rectifier=True, lam=args.lam, @@ -82,7 +82,9 @@ if __name__ == "__main__": frozen_sparse.to(device) - splits, dataset = load_yolo_clips(batch_size, num_clips=args.num_clips, num_positives=args.num_positives, mode=args.splits, device=device, n_splits=args.n_splits, sparse_model=frozen_sparse, whole_video=args.whole_video, positive_videos=args.positive_videos) +# splits, dataset = load_yolo_clips(batch_size, num_clips=args.num_clips, num_positives=args.num_positives, mode=args.splits, device=device, n_splits=args.n_splits, sparse_model=frozen_sparse, whole_video=args.whole_video, positive_videos=args.positive_videos) + + train_loader, test_loader = load_yolo_clips(batch_size, num_clips=args.num_clips, num_positives=args.num_positives, mode='all_train', device=device, n_splits=args.n_splits, sparse_model=frozen_sparse, whole_video=args.whole_video, positive_videos=args.positive_videos) overall_true = [] overall_pred = [] @@ -90,122 +92,122 @@ if __name__ == "__main__": fp_ids = [] i_fold = 0 - for train_idx, test_idx in [list(splits)[0]]: - - if args.save_train_test_splits: - with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.pkl'), 'wb+') as train_out: - pickle.dump(train_idx, train_out) - - with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.pkl'), 'wb+') as test_out: - pickle.dump(test_idx, test_out) +# for train_idx, test_idx in splits: - train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) - test_sampler = torch.utils.data.SubsetRandomSampler(test_idx) +# if args.save_train_test_splits: +# with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.pkl'), 'wb+') as train_out: +# pickle.dump(train_idx, train_out) - train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - # shuffle=True, - sampler=train_sampler) +# with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.pkl'), 'wb+') as test_out: +# pickle.dump(test_idx, test_out) - test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - # shuffle=True, - sampler=test_sampler) - - if args.save_train_test_splits: - with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.txt'), 'w+') as train_out: - train_videos = set() - for tmp in train_loader: - train_videos.update(tmp[2]) - train_out.write(str(train_videos)) - - with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.txt'), 'w+') as test_out: - test_videos = set() - for tmp in test_loader: - test_videos.update(tmp[2]) - test_out.write(str(test_videos)) - - best_so_far = float('inf') - - if args.num_clips > 1 or args.whole_video: - predictive_model = torch.nn.DataParallel(SmallDataClassifierVideo(args.num_clips)) - else: - predictive_model = torch.nn.DataParallel(SmallDataClassifierConv3d()) - predictive_model.to(device) - - criterion = torch.nn.BCEWithLogitsLoss() - - if args.checkpoint: - checkpoint = torch.load(args.checkpoint) - predictive_model.load_state_dict(checkpoint['model_state_dict']) - - if args.train: - prediction_optimizer = torch.optim.Adam(predictive_model.parameters(), - lr=args.lr) +# train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) +# test_sampler = torch.utils.data.SubsetRandomSampler(test_idx) - for epoch in range(args.epochs): - predictive_model.train() - epoch_loss = 0 - t1 = time.perf_counter() +# train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, +# # shuffle=True, +# sampler=train_sampler) - for labels, local_batch, vid_f in tqdm(train_loader): - local_batch = local_batch.to(device) +# test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, +# # shuffle=True, +# sampler=test_sampler) - torch_labels = torch.zeros(len(labels)) - torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1 - torch_labels = torch_labels.unsqueeze(1).to(device) +# if args.save_train_test_splits: +# with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.txt'), 'w+') as train_out: +# train_videos = set() +# for tmp in train_loader: +# train_videos.update(tmp[2]) +# train_out.write(str(train_videos)) - pred, activations = predictive_model(local_batch) +# with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.txt'), 'w+') as test_out: +# test_videos = set() +# for tmp in test_loader: +# test_videos.update(tmp[2]) +# test_out.write(str(test_videos)) - loss = criterion(pred, torch_labels) - if args.train_sparse: - loss += args.recon_scale * frozen_sparse.loss(local_batch, activations) - epoch_loss += loss.item() * local_batch.size(0) + best_so_far = float('inf') - prediction_optimizer.zero_grad() - loss.backward() - prediction_optimizer.step() + if args.num_clips > 1 or args.whole_video: + predictive_model = torch.nn.DataParallel(SmallDataClassifierVideo(args.num_clips)) + else: + predictive_model = torch.nn.DataParallel(SmallDataClassifierConv3d()) + predictive_model.to(device) - t2 = time.perf_counter() + criterion = torch.nn.BCEWithLogitsLoss() + + if args.checkpoint: + checkpoint = torch.load(args.checkpoint) + predictive_model.load_state_dict(checkpoint['model_state_dict']) - predictive_model.eval() - with torch.no_grad(): - y_true = None - y_pred = None - for labels, local_batch, vid_f in test_loader: + if args.train: + prediction_optimizer = torch.optim.Adam(predictive_model.parameters(), + lr=args.lr) - local_batch = local_batch.to(device) + for epoch in range(args.epochs): + predictive_model.train() + epoch_loss = 0 + t1 = time.perf_counter() + + for labels, local_batch, vid_f in tqdm(train_loader): + local_batch = local_batch.to(device).squeeze(2) - torch_labels = torch.zeros(len(labels)) - torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1 - torch_labels = torch_labels.unsqueeze(1).to(device) + torch_labels = torch.zeros(len(labels)) + torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1 + torch_labels = torch_labels.unsqueeze(1).to(device) + pred, activations = predictive_model(local_batch) - pred, _ = predictive_model(local_batch) + loss = criterion(pred, torch_labels) + if args.train_sparse: + loss += args.recon_scale * frozen_sparse.loss(local_batch, activations) + epoch_loss += loss.item() * local_batch.size(0) - if y_true is None: - y_true = torch_labels.detach().cpu().flatten().to(torch.long) - y_pred = torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long) - else: - y_true = torch.cat((y_true, torch_labels.detach().cpu().flatten().to(torch.long))) - y_pred = torch.cat((y_pred, torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long))) + prediction_optimizer.zero_grad() + loss.backward() + prediction_optimizer.step() - t2 = time.perf_counter() + t2 = time.perf_counter() - f1 = f1_score(y_true, y_pred, average='macro') - accuracy = accuracy_score(y_true, y_pred) + predictive_model.eval() + with torch.no_grad(): + y_true = None + y_pred = None + for labels, local_batch, vid_f in train_loader: - print('fold={}, epoch={}, time={:.2f}, loss={:.2f}, f1={:.2f}, acc={:.2f}'.format(i_fold, epoch, t2-t1, epoch_loss, f1, accuracy)) + local_batch = local_batch.to(device).squeeze(2) - if epoch_loss <= best_so_far: - print("found better model") - # Save model parameters - torch.save({ - 'model_state_dict': predictive_model.state_dict(), - 'optimizer_state_dict': prediction_optimizer.state_dict(), - }, os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt")) - best_so_far = epoch_loss + torch_labels = torch.zeros(len(labels)) + torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1 + torch_labels = torch_labels.unsqueeze(1).to(device) - checkpoint = torch.load(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt")) - predictive_model.load_state_dict(checkpoint['model_state_dict']) + + pred, _ = predictive_model(local_batch) + + if y_true is None: + y_true = torch_labels.detach().cpu().flatten().to(torch.long) + y_pred = torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long) + else: + y_true = torch.cat((y_true, torch_labels.detach().cpu().flatten().to(torch.long))) + y_pred = torch.cat((y_pred, torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long))) + + t2 = time.perf_counter() + + f1 = f1_score(y_true, y_pred, average='macro') + accuracy = accuracy_score(y_true, y_pred) + + print('fold={}, epoch={}, time={:.2f}, loss={:.2f}, f1={:.2f}, acc={:.2f}'.format(i_fold, epoch, t2-t1, epoch_loss, f1, accuracy)) +# print(epoch_loss) + if epoch_loss <= best_so_far: + print("found better model") + # Save model parameters + torch.save({ + 'model_state_dict': predictive_model.state_dict(), + 'optimizer_state_dict': prediction_optimizer.state_dict(), + }, os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt")) + best_so_far = epoch_loss + +# checkpoint = torch.load(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt")) +# predictive_model.load_state_dict(checkpoint['model_state_dict']) predictive_model.eval() with torch.no_grad(): @@ -213,13 +215,13 @@ if __name__ == "__main__": y_true = None y_pred = None - + pred_dict = {} gt_dict = {} t1 = time.perf_counter() - for labels, local_batch, vid_f in test_loader: - local_batch = local_batch.to(device) + for labels, local_batch, vid_f in train_loader: + local_batch = local_batch.to(device).squeeze(2) torch_labels = torch.zeros(len(labels)) torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1 @@ -229,13 +231,13 @@ if __name__ == "__main__": loss = criterion(pred, torch_labels) epoch_loss += loss.item() * local_batch.size(0) - + for i, v_f in enumerate(vid_f): if v_f not in pred_dict: pred_dict[v_f] = torch.nn.Sigmoid()(pred[i]).round().detach().cpu().flatten().to(torch.long) else: pred_dict[v_f] = torch.cat((pred_dict[v_f], torch.nn.Sigmoid()(pred[i]).detach().round().cpu().flatten().to(torch.long))) - + if v_f not in gt_dict: gt_dict[v_f] = torch_labels[i].detach().cpu().flatten().to(torch.long) else: @@ -249,11 +251,17 @@ if __name__ == "__main__": y_pred = torch.cat((y_pred, torch.nn.Sigmoid()(pred).detach().round().cpu().flatten().to(torch.long))) t2 = time.perf_counter() - + vid_acc = [] for k in pred_dict.keys(): gt_mode = torch.mode(gt_dict[k])[0].item() - pred_mode = torch.mode(pred_dict[k])[0].item() + perm = torch.randperm(pred_dict[k].size(0)) + cutoff = int(pred_dict[k].size(0)/4) + if cutoff < 3: + cutoff = 3 + idx = perm[:cutoff] + samples = pred_dict[k][idx] + pred_mode = torch.mode(samples)[0].item() overall_true.append(gt_mode) overall_pred.append(pred_mode) if pred_mode == gt_mode: @@ -264,9 +272,9 @@ if __name__ == "__main__": fn_ids.append(k) else: fp_ids.append(k) - + vid_acc = np.array(vid_acc) - + print('----------------------------------------------------------------------------') for k in pred_dict.keys(): print(k) @@ -275,31 +283,31 @@ if __name__ == "__main__": print('Ground Truth:') print(gt_dict[k]) print('Overall Prediction:') -# pred_mode = 1 -# contiguous_zeros = 0 -# best_num = 0 -# for val in pred_dict[k]: -# if val.item() == 0: -# contiguous_zeros += 1 -# else: -# if contiguous_zeros > best_num: -# best_num = contiguous_zeros -# contiguous_zeros = 0 -# if best_num >= 4 or contiguous_zeros >= 4: -# pred_mode = 0 + # pred_mode = 1 + # contiguous_zeros = 0 + # best_num = 0 + # for val in pred_dict[k]: + # if val.item() == 0: + # contiguous_zeros += 1 + # else: + # if contiguous_zeros > best_num: + # best_num = contiguous_zeros + # contiguous_zeros = 0 + # if best_num >= 4 or contiguous_zeros >= 4: + # pred_mode = 0 print(torch.mode(pred_dict[k])[0].item()) print('----------------------------------------------------------------------------') print('fold={}, loss={:.2f}, time={:.2f}'.format(i_fold, loss, t2-t1)) - + f1 = f1_score(y_true, y_pred, average='macro') accuracy = accuracy_score(y_true, y_pred) all_errors.append(np.sum(vid_acc) / len(vid_acc)) print("Test f1={:.2f}, clip_acc={:.2f}, vid_acc={:.2f} fold={}".format(f1, accuracy, np.sum(vid_acc) / len(vid_acc), i_fold)) - + print(confusion_matrix(y_true, y_pred)) - + i_fold = i_fold + 1 fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt') diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/load_data.py index 2b3c91d92f0fe878b693dd29ea5436826ac3912d..dc808f636c6862a1be1c7260393715db003b39e7 100644 --- a/sparse_coding_torch/load_data.py +++ b/sparse_coding_torch/load_data.py @@ -4,10 +4,10 @@ import torch from sklearn.model_selection import train_test_split from sparse_coding_torch.video_loader import MinMaxScaler from sparse_coding_torch.video_loader import VideoLoader -from sparse_coding_torch.video_loader import VideoClipLoader, YoloClipLoader, get_video_participants, YoloVideoLoader, MobileLoader +from sparse_coding_torch.video_loader import VideoClipLoader, YoloClipLoader, get_video_participants, YoloVideoLoader, MobileLoader, PNBLoader from sparse_coding_torch.video_loader import VideoGrayScaler import csv -from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold +from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold def load_balls_data(batch_size): @@ -107,8 +107,8 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non ]) augment_transforms = torchvision.transforms.Compose( [torchvision.transforms.RandomRotation(45), - torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.CenterCrop((100, 200)) + torchvision.transforms.RandomHorizontalFlip() +# torchvision.transforms.CenterCrop((100, 200)) ]) if whole_video: dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device) @@ -183,5 +183,50 @@ def load_mobile_clips(batch_size, mode, num_clips=1, num_positives=100, n_splits groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] return gss.split(np.arange(len(targets)), targets, groups), dataset + else: + return None + +def load_pnb_videos(batch_size, mode, device=None, n_splits=None, sparse_model=None): + video_path = "/shared_data/bamc_pnb_data/full_training_data" + + transforms = torchvision.transforms.Compose( + [VideoGrayScaler(), + MinMaxScaler(0, 255), + torchvision.transforms.Resize((360, 304)) + ]) + augment_transforms = torchvision.transforms.Compose( + [torchvision.transforms.RandomAffine(45), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ColorJitter(brightness=0.5), + torchvision.transforms.RandomAdjustSharpness(0, p=0.15), + torchvision.transforms.RandomAffine(degrees=0, translate=(0.05, 0)) +# torchvision.transforms.CenterCrop((100, 200)) + ]) + dataset = PNBLoader(video_path, num_frames=5, frame_rate=20, transform=transforms) + + targets = dataset.get_labels() + + if mode == 'leave_one_out': + gss = LeaveOneGroupOut() + + groups = [v for v in dataset.get_filenames()] +# groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] + + return gss.split(np.arange(len(targets)), targets, groups), dataset + elif mode == 'all_train': + train_idx = np.arange(len(targets)) + train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) + train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + sampler=train_sampler) + test_loader = None + + return train_loader, test_loader + elif mode == 'k_fold': + gss = StratifiedKFold(n_splits=n_splits, shuffle=True) + +# groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] + groups = [v for v in dataset.get_filenames()] + + return gss.split(np.arange(len(targets)), targets), dataset else: return None \ No newline at end of file diff --git a/sparse_coding_torch/small_data_classifier.py b/sparse_coding_torch/small_data_classifier.py index 330596a614c7ac10bb6e8d44507010af2ad6960f..9f3ddc4d40effe665978080322e4834c5dd2a006 100644 --- a/sparse_coding_torch/small_data_classifier.py +++ b/sparse_coding_torch/small_data_classifier.py @@ -36,9 +36,9 @@ class SmallDataClassifierConv3d(nn.Module): def __init__(self): super().__init__() - self.max_pool_1 = nn.MaxPool3d(kernel_size=(1, 4, 4)) + self.max_pool_1 = nn.MaxPool2d(kernel_size=(4, 4)) - self.compress_activations_conv_1 = nn.Conv3d(in_channels=64, out_channels=24, kernel_size=(1, 8, 8), stride=(1, 4, 4), padding=(0, 4, 4)) + self.compress_activations_conv_1 = nn.Conv2d(in_channels=64, out_channels=24, kernel_size=(8, 8), stride=(4, 4)) # self.compress_activations_conv_2 = nn.Conv3d(in_channels=32, out_channels=16, kernel_size=(1, 8, 8), stride=(1, 4, 4), padding=(1, 4, 4)) # self.gru = nn.GRU(37, 100) @@ -46,16 +46,16 @@ class SmallDataClassifierConv3d(nn.Module): self.dropout = torch.nn.Dropout(p=0.5) # First fully connected layer - self.fc1 = nn.Linear(2184, 1000) - self.fc2 = nn.Linear(1000, 100) - self.fc3 = nn.Linear(100, 20) +# self.fc1 = nn.Linear(672, 1000) +# self.fc2 = nn.Linear(240, 100) + self.fc3 = nn.Linear(96, 20) self.fc4 = nn.Linear(20, 1) # x represents our data def forward(self, activations): - batch_size, channel_size, time_size, height_size, width_size = activations.size() + batch_size, channel_size, height_size, width_size = activations.size() - activations = activations.view(-1, channel_size, time_size, height_size, width_size) + activations = activations.view(-1, channel_size, height_size, width_size) x = self.max_pool_1(activations) @@ -72,12 +72,14 @@ class SmallDataClassifierConv3d(nn.Module): # x = x.to('cuda:' + str(save_device)) + x = x.swapaxes(1, 3) + x = torch.flatten(x, 1) - x = F.relu(self.fc1(x)) - x = self.dropout(x) - x = F.relu(self.fc2(x)) - x = self.dropout(x) +# x = F.relu(self.fc1(x)) +# x = self.dropout(x) +# x = F.relu(self.fc2(x)) +# x = self.dropout(x) x = F.relu(self.fc3(x)) x = self.dropout(x) x = self.fc4(x) diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py index a318a1c4da9d87b685cf419f3d161ad3a0284208..4e7188cd1d5aacdbae02be36acff80081f16d515 100644 --- a/sparse_coding_torch/video_loader.py +++ b/sparse_coding_torch/video_loader.py @@ -5,6 +5,7 @@ from os.path import isdir from os.path import abspath from os.path import exists import json +import glob from PIL import Image from torchvision.transforms import ToTensor @@ -136,6 +137,8 @@ class VideoClipLoader(Dataset): if not frames_between_clips: frames_between_clips = num_frames + + vc = VideoClips([path for _, path, _ in self.videos], clip_length_in_frames=num_frames, frame_rate=frame_rate, @@ -170,6 +173,59 @@ class VideoClipLoader(Dataset): def __len__(self): return len(self.clips) +class PNBLoader(Dataset): + + def __init__(self, video_path, num_frames=5, frame_rate=20, frames_between_clips=None, transform=None): + self.transform = transform + self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))] + + self.videos = [] + for label in self.labels: + self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))]) + + #for v in self.videos: + # video, _, info = read_video(v[1]) + # print(video.shape) + # print(info) + + if not frames_between_clips: + frames_between_clips = num_frames + + self.clips = [] + + self.video_idx = [] + + vid_idx = 0 + for _, path, _ in self.videos: + vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2) +# for j in range(vc.size(1), vc.size(1) - 10, -5): + for j in range(0, vc.size(1) - 5, 5): +# if j-5 < 0: +# continue +# vc_sub = vc_1 = vc[:, j-5:j, :, :] + vc_sub = vc[:, j:j+5, :, :] + if self.transform: + vc_sub = self.transform(vc_sub) + + self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2])) + self.video_idx.append(vid_idx) + vid_idx += 1 + + def get_filenames(self): + return [self.clips[i][2] for i in range(len(self.clips))] + + def get_video_labels(self): + return [self.videos[i][0] for i in range(len(self.videos))] + + def get_labels(self): + return [self.clips[i][0] for i in range(len(self.clips))] + + def __getitem__(self, index): + return self.clips[index] + + def __len__(self): + return len(self.clips) + class VideoFrameLoader(Dataset): def __init__(self, video_path, transform=None): @@ -261,8 +317,8 @@ class YoloClipLoader(Dataset): # width = region['relative_coordinates']['width'] * 1920 # height = region['relative_coordinates']['height'] * 1080 - width=400 - height=400 + width=200 + height=100 lower_y = round(center_y - height / 2) upper_y = round(center_y + height / 2)