diff --git a/keras/keras_model.py b/keras/keras_model.py
index 5add25e02b950adb31f6ae7e8381f49d354fb39a..530282d40f25a3b0486fb7d685013349ae0cb296 100644
--- a/keras/keras_model.py
+++ b/keras/keras_model.py
@@ -15,48 +15,42 @@ def load_pytorch_weights(file_path):
     
     return weight_tensor
 
-@tf.function
-def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
-    out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride)
+# @tf.function
+def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, image_height, image_width, stride):
+    out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
+    out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
+    out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
+    out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
+    out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
 
     recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
 
     return recon
 
 # @tf.function
-def do_recon_3d(filters, activations, batch_size, stride):
-    activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
-    recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride])
+def do_recon_3d(filters, activations, batch_size, image_height, image_width, stride):
+#     activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
+    recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, image_height, image_width, 1), strides=[1, stride, stride], padding='VALID')
 
     return recon
 
-@tf.function
-def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride):
-    e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
-
-    g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME')
+# @tf.function
+def conv_error(filters, e, stride):
+    g = tf.nn.conv2d(e, filters, strides=stride, padding='VALID')
 
     return g
 
 # @tf.function
 def conv_error_3d(filters, e, stride):
-    e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]])
+#     e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]])
     g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID')
 
     return g
 
-@tf.function
+# @tf.function
 def normalize_weights(filters, out_channels):
     #print('filters shape', tf.shape(filters))
-    norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
+    norms = tf.norm(tf.reshape(tf.transpose(tf.stack(filters), perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
     
     adjusted = [f / norms for f in filters]
@@ -83,7 +77,7 @@ def normalize_weights_3d(filters, out_channels):
     return adjusted
 
 class SparseCode(keras.layers.Layer):
-    def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+    def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
         super(SparseCode, self).__init__()
 
         self.out_channels = out_channels
@@ -93,63 +87,78 @@ class SparseCode(keras.layers.Layer):
         self.activation_lr = activation_lr
         self.max_activation_iter = max_activation_iter
         self.batch_size = batch_size
+        self.image_height = image_height
+        self.image_width = image_width
+        self.kernel_size = kernel_size
         self.run_2d = run_2d
 
-    @tf.function
+#     @tf.function
     def do_update(self, images, filters, u, m, v, b1, b2, eps, i):
         activations = tf.nn.relu(u - self.lam)
 
         if self.run_2d:
-            recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.stride)
+            recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.image_height, self.image_width, self.stride)
         else:
-            recon = do_recon_3d(filters, activations, self.batch_size, self.stride)
+            recon = do_recon_3d(filters, activations, self.batch_size, self.image_height, self.image_width, self.stride)
 
         e = images - recon
         g = -1 * u
         
         if self.run_2d:
-            convd_error = conv_error(filters[0], filters[1], filters[2], filters[3], filters[4], e, self.stride)
+            e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
+            g += conv_error(filters[0], e1, self.stride)
+            g += conv_error(filters[1], e2, self.stride)
+            g += conv_error(filters[2], e3, self.stride)
+            g += conv_error(filters[3], e4, self.stride)
+            g += conv_error(filters[4], e5, self.stride)
         else:
             convd_error = conv_error_3d(filters, e, self.stride)
             
-        g = g + convd_error
+            g = g + convd_error
 
         g = g + activations
         
         m = b1 * m + (1-b1) * g
-        v = b2 * v + (1-b2) * g**2
-        mh = m / (1 - b1**(1+i))
-        vh = v / (1 - b2**(1+i))
+        
+        v = b2 * v + (1-b2) * tf.math.pow(g, 2)
+        
+        mh = m / (1 - tf.math.pow(b1, (1+i)))
+        
+        vh = v / (1 - tf.math.pow(b2, (1+i)))
         
         du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps)
+        
         u += du
         
 #         i += 1
 
-#         return images, u, m, v, b1, b2, eps, i
+#         return images, filters, u, m, v, b1, b2, eps, i
         return u, m, v
 
-    @tf.function
+#     @tf.function
     def call(self, images, filters):
+        filters = tf.squeeze(filters, axis=0)
         if self.run_2d:
-            u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
-            m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
-            v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+            output_shape = (self.batch_size, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
         else:
-            u = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
-            m = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
-            v = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
+            output_shape = (self.batch_size, 1, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
+
+        u = tf.zeros(shape=output_shape)
+        m = tf.zeros(shape=output_shape)
+        v = tf.zeros(shape=output_shape)
         
 #         tf.print('activations before:', tf.reduce_sum(u))
 
         b1 = tf.constant(0.9, dtype='float32')
-        b2 = tf.constant(0.999, dtype='float32')
+        b2 = tf.constant(0.99, dtype='float32')
         eps = tf.constant(1e-8, dtype='float32')
         
+#         print(u)
+        
         
 #         i = tf.constant(0, dtype='float32')
-#         c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter)
-#         images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i])
+#         c = lambda images, filters, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter)
+#         images, filters, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, filters, u, m, v, b1, b2, eps, i])
         for i in range(self.max_activation_iter):
             u, m, v = self.do_update(images, filters, u, m, v, b1, b2, eps, i)
 
@@ -159,10 +168,9 @@ class SparseCode(keras.layers.Layer):
 
         return u
     
-class SparseCodeConv(keras.Model):
-    def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+class ReconSparse(keras.Model):
+    def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
         super().__init__()
-        self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
         
         self.out_channels = out_channels
         self.in_channels = in_channels
@@ -171,6 +179,8 @@ class SparseCodeConv(keras.Model):
         self.activation_lr = activation_lr
         self.max_activation_iter = max_activation_iter
         self.batch_size = batch_size
+        self.image_height = image_height
+        self.image_width = image_width
         self.run_2d = run_2d
         
         initializer = tf.keras.initializers.HeNormal()
@@ -181,7 +191,7 @@ class SparseCodeConv(keras.Model):
             self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
             self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
         else:
-            self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+            self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels), dtype='float32'), trainable=True)
         
         if run_2d:
             weights = normalize_weights(self.get_weights(), out_channels)
@@ -189,16 +199,14 @@ class SparseCodeConv(keras.Model):
             weights = normalize_weights_3d(self.get_weights(), out_channels)
         self.set_weights(weights)
 
-    @tf.function
-    def call(self, images):
+#     @tf.function
+    def call(self, activations):
         if self.run_2d:
-            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
-            recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
+            recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.image_height, self.image_width, self.stride)
         else:
-            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
-            recon = do_recon_3d(self.filters, activations, self.batch_size, self.stride)
+            recon = do_recon_3d(self.filters, activations, self.batch_size, self.image_height, self.image_width, self.stride)
             
-        return recon, activations
+        return recon
 
 class Classifier(keras.layers.Layer):
     def __init__(self):
@@ -218,6 +226,7 @@ class Classifier(keras.layers.Layer):
 
     @tf.function
     def call(self, activations):
+        activations = tf.squeeze(activations, axis=1)
         x = self.max_pool(activations)
         x = self.conv(x)
         x = self.flatten(x)
diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py
index d543e3af8d470e2db1b090f0718a829e40fcd18c..7a074267b9bbb394006887f0a747e4565265873a 100644
--- a/keras/train_sparse_model.py
+++ b/keras/train_sparse_model.py
@@ -7,10 +7,11 @@ from matplotlib.animation import FuncAnimation
 from tqdm import tqdm
 import argparse
 import os
-from sparse_coding_torch.load_data import load_yolo_clips
+from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos
 import tensorflow.keras as keras
 import tensorflow as tf
-from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv, load_pytorch_weights
+from keras_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse
+import random
 
 def plot_video(video):
 
@@ -55,6 +56,7 @@ def plot_original_vs_recon(original, reconstruction, idx=0):
 
 
 def plot_filters(filters):
+    filters = filters.astype('float32')
     num_filters = filters.shape[4]
     ncol = 3
     # ncol = int(np.sqrt(num_filters))
@@ -97,17 +99,30 @@ if __name__ == "__main__":
     parser.add_argument('--kernel_size', default=15, type=int)
     parser.add_argument('--num_kernels', default=64, type=int)
     parser.add_argument('--stride', default=2, type=int)
-    parser.add_argument('--max_activation_iter', default=50, type=int)
+    parser.add_argument('--max_activation_iter', default=100, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
     parser.add_argument('--lr', default=5e-2, type=float)
-    parser.add_argument('--epochs', default=100, type=int)
+    parser.add_argument('--epochs', default=20, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
     parser.add_argument('--output_dir', default='./output', type=str)
     parser.add_argument('--seed', default=42, type=int)
     parser.add_argument('--run_2d', action='store_true')
     parser.add_argument('--save_filters', action='store_true')
+    parser.add_argument('--optimizer', default='adam', type=str)
+    parser.add_argument('--dataset', default='pnb', type=str)
+    
 
     args = parser.parse_args()
+    
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    
+#     policy = keras.mixed_precision.Policy('mixed_float16')
+#     keras.mixed_precision.set_global_policy(policy)
+
+    image_height = 360
+    image_width = 304
 
     output_dir = args.output_dir
     if not os.path.exists(output_dir):
@@ -118,36 +133,55 @@ if __name__ == "__main__":
     with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
         out_f.write(str(args))
 
-    train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json')
+    if args.dataset == 'pnb':
+        train_loader, _ = load_pnb_videos(args.batch_size, mode='all_train', device=device, n_splits=1, sparse_model=None)
+    elif args.dataset == 'ptx':
+        train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json')
+    else:
+        raise Exception('Invalid dataset')
     print('Loaded', len(train_loader), 'train examples')
 
     example_data = next(iter(train_loader))
 
     if args.run_2d:
-        inputs = keras.Input(shape=(100, 200, 5))
+        inputs = keras.Input(shape=(image_height, image_width, 5))
     else:
-        inputs = keras.Input(shape=(5, 100, 200, 1))
+        inputs = keras.Input(shape=(5, image_height, image_width, 1))
+        
+    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
 
-    output = SparseCodeConv(batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs)
+    output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
 
-    model = keras.Model(inputs=inputs, outputs=output)
+    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+    
+    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+    
+    recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+    
+    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
     
     if args.save_filters:
         if args.run_2d:
-            filters = plot_filters(tf.stack(model.get_weights(), axis=0))
+            filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0))
         else:
-            filters = plot_filters(model.get_weights()[0])
+            filters = plot_filters(recon_model.get_weights()[0])
         filters.save(os.path.join(args.output_dir, 'filters_start.mp4'))
 
     learning_rate = args.lr
-    filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
+    if args.optimizer == 'sgd':
+        filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
+    else:
+        filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
 
     loss_log = []
     best_so_far = float('inf')
 
     for epoch in range(args.epochs):
         epoch_loss = 0
+        running_loss = 0.0
         epoch_start = time.perf_counter()
+        
+        num_iters = 0
 
         for labels, local_batch, vid_f in tqdm(train_loader):
             if local_batch.size(0) != args.batch_size:
@@ -156,37 +190,52 @@ if __name__ == "__main__":
                 images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
             else:
                 images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+                
+            activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
             
             with tf.GradientTape() as tape:
-                recon, activations = model(images)
+                recon = recon_model(activations)
                 loss = sparse_loss(recon, activations, args.batch_size, args.lam, args.stride)
 
             epoch_loss += loss * local_batch.size(0)
+            running_loss += loss * local_batch.size(0)
 
-            gradients = tape.gradient(loss, model.trainable_weights)
+            gradients = tape.gradient(loss, recon_model.trainable_weights)
 
-            filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights))
+            filter_optimizer.apply_gradients(zip(gradients, recon_model.trainable_weights))
             
             if args.run_2d:
-                weights = normalize_weights(model.get_weights(), args.num_kernels)
+                weights = normalize_weights(recon_model.get_weights(), args.num_kernels)
             else:
-                weights = normalize_weights_3d(model.get_weights(), args.num_kernels)
-            model.set_weights(weights)
+                weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels)
+            recon_model.set_weights(weights)
+            
+#             if args.save_filters and num_iters % 25 == 0:
+#                 if args.run_2d:
+#                     filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0))
+#                 else:
+#                     filters = plot_filters(recon_model.get_weights()[0])
+#                 filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) + '_' + str(num_iters) + '.mp4'))
+#                 loss_log.append(running_loss)
+#                 print(running_loss)
+#                 running_loss = 0.0
+                
+            num_iters += 1
 
         epoch_end = time.perf_counter()
         epoch_loss /= len(train_loader.sampler)
         
-        if args.save_filters and epoch % 5 == 0:
+        if args.save_filters and epoch % 2 == 0:
             if args.run_2d:
-                filters = plot_filters(tf.stack(model.get_weights(), axis=0))
+                filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0))
             else:
-                filters = plot_filters(model.get_weights()[0])
+                filters = plot_filters(recon_model.get_weights()[0])
             filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4'))
 
         if epoch_loss < best_so_far:
             print("found better model")
             # Save model parameters
-            model.save(os.path.join(output_dir, "sparse_conv3d_model-best.pt"))
+            recon_model.save(os.path.join(output_dir, "sparse_conv3d_model-best.pt"))
             best_so_far = epoch_loss
 
         loss_log.append(epoch_loss)
diff --git a/scripts/train_classifier.py b/scripts/train_classifier.py
index 38d43adb0a39a102e05c266616ef9f734da931b3..79b76808371b7883f5c1b6033bef54c40552a549 100644
--- a/scripts/train_classifier.py
+++ b/scripts/train_classifier.py
@@ -20,9 +20,9 @@ if __name__ == "__main__":
     parser.add_argument('--kernel_width', default=15, type=int)
     parser.add_argument('--kernel_depth', default=5, type=int)
     parser.add_argument('--num_kernels', default=64, type=int)
-    parser.add_argument('--stride', default=1, type=int)
-    parser.add_argument('--max_activation_iter', default=200, type=int)
-    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--stride', default=2, type=int)
+    parser.add_argument('--max_activation_iter', default=100, type=int)
+    parser.add_argument('--activation_lr', default=1e-1, type=float)
     parser.add_argument('--lr', default=5e-5, type=float)
     parser.add_argument('--epochs', default=40, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
@@ -39,7 +39,7 @@ if __name__ == "__main__":
     parser.add_argument('--n_splits', default=5, type=int)
     parser.add_argument('--whole_video', action='store_true')
     parser.add_argument('--save_train_test_splits', action='store_true')
-    parser.add_argument('--positive_videos', default='positive_videos.json', type=str)
+    parser.add_argument('--positive_videos', default=None, type=str)
     
     args = parser.parse_args()
     
@@ -66,7 +66,7 @@ if __name__ == "__main__":
                                    out_channels=args.num_kernels,
                                    kernel_size=(args.kernel_depth, args.kernel_height, args.kernel_width),
                                    stride=args.stride,
-                                   padding=(0, 7, 7),
+                                   padding=0,
                                    convo_dim=3,
                                    rectifier=True,
                                    lam=args.lam,
@@ -82,7 +82,9 @@ if __name__ == "__main__":
 
     frozen_sparse.to(device)
     
-    splits, dataset = load_yolo_clips(batch_size, num_clips=args.num_clips, num_positives=args.num_positives, mode=args.splits, device=device, n_splits=args.n_splits, sparse_model=frozen_sparse, whole_video=args.whole_video, positive_videos=args.positive_videos)
+#     splits, dataset = load_yolo_clips(batch_size, num_clips=args.num_clips, num_positives=args.num_positives, mode=args.splits, device=device, n_splits=args.n_splits, sparse_model=frozen_sparse, whole_video=args.whole_video, positive_videos=args.positive_videos)
+    
+    train_loader, test_loader = load_yolo_clips(batch_size, num_clips=args.num_clips, num_positives=args.num_positives, mode='all_train', device=device, n_splits=args.n_splits, sparse_model=frozen_sparse, whole_video=args.whole_video, positive_videos=args.positive_videos)
     
     overall_true = []
     overall_pred = []
@@ -90,122 +92,122 @@ if __name__ == "__main__":
     fp_ids = []
     
     i_fold = 0
-    for train_idx, test_idx in [list(splits)[0]]:
-        
-        if args.save_train_test_splits:
-            with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.pkl'), 'wb+') as train_out:
-                pickle.dump(train_idx, train_out)
-                      
-            with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.pkl'), 'wb+') as test_out:
-                pickle.dump(test_idx, test_out) 
+#     for train_idx, test_idx in splits:
         
-        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+#     if args.save_train_test_splits:
+#         with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.pkl'), 'wb+') as train_out:
+#             pickle.dump(train_idx, train_out)
 
-        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
-                                                   # shuffle=True,
-                                                   sampler=train_sampler)
+#         with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.pkl'), 'wb+') as test_out:
+#             pickle.dump(test_idx, test_out) 
 
-        test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
-                                                        # shuffle=True,
-                                                        sampler=test_sampler)
-        
-        if args.save_train_test_splits:
-            with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.txt'), 'w+') as train_out:
-                train_videos = set()
-                for tmp in train_loader:
-                    train_videos.update(tmp[2])
-                train_out.write(str(train_videos))
-                      
-            with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.txt'), 'w+') as test_out:
-                test_videos = set()
-                for tmp in test_loader:
-                    test_videos.update(tmp[2])
-                test_out.write(str(test_videos))
-
-        best_so_far = float('inf')
-
-        if args.num_clips > 1 or args.whole_video:
-            predictive_model = torch.nn.DataParallel(SmallDataClassifierVideo(args.num_clips))
-        else:
-            predictive_model = torch.nn.DataParallel(SmallDataClassifierConv3d())
-        predictive_model.to(device)
-        
-        criterion = torch.nn.BCEWithLogitsLoss()
-        
-        if args.checkpoint:
-            checkpoint = torch.load(args.checkpoint)
-            predictive_model.load_state_dict(checkpoint['model_state_dict'])
-        
-        if args.train:
-            prediction_optimizer = torch.optim.Adam(predictive_model.parameters(),
-                                                    lr=args.lr)
+#     train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+#     test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
 
-            for epoch in range(args.epochs):
-                predictive_model.train()
-                epoch_loss = 0
-                t1 = time.perf_counter()
+#     train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+#                                                # shuffle=True,
+#                                                sampler=train_sampler)
 
-                for labels, local_batch, vid_f in tqdm(train_loader):
-                    local_batch = local_batch.to(device)
+#     test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+#                                                     # shuffle=True,
+#                                                     sampler=test_sampler)
 
-                    torch_labels = torch.zeros(len(labels))
-                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1
-                    torch_labels = torch_labels.unsqueeze(1).to(device)
+#     if args.save_train_test_splits:
+#         with open(os.path.join(output_dir, 'train_idx_' + str(i_fold) + '.txt'), 'w+') as train_out:
+#             train_videos = set()
+#             for tmp in train_loader:
+#                 train_videos.update(tmp[2])
+#             train_out.write(str(train_videos))
 
-                    pred, activations = predictive_model(local_batch)
+#         with open(os.path.join(output_dir, 'test_idx_' + str(i_fold) + '.txt'), 'w+') as test_out:
+#             test_videos = set()
+#             for tmp in test_loader:
+#                 test_videos.update(tmp[2])
+#             test_out.write(str(test_videos))
 
-                    loss = criterion(pred, torch_labels)
-                    if args.train_sparse:
-                        loss += args.recon_scale * frozen_sparse.loss(local_batch, activations)
-                    epoch_loss += loss.item() * local_batch.size(0)
+    best_so_far = float('inf')
 
-                    prediction_optimizer.zero_grad()
-                    loss.backward()
-                    prediction_optimizer.step()
+    if args.num_clips > 1 or args.whole_video:
+        predictive_model = torch.nn.DataParallel(SmallDataClassifierVideo(args.num_clips))
+    else:
+        predictive_model = torch.nn.DataParallel(SmallDataClassifierConv3d())
+    predictive_model.to(device)
 
-                t2 = time.perf_counter()
+    criterion = torch.nn.BCEWithLogitsLoss()
+
+    if args.checkpoint:
+        checkpoint = torch.load(args.checkpoint)
+        predictive_model.load_state_dict(checkpoint['model_state_dict'])
 
-                predictive_model.eval()
-                with torch.no_grad():
-                    y_true = None
-                    y_pred = None
-                    for labels, local_batch, vid_f in test_loader:
+    if args.train:
+        prediction_optimizer = torch.optim.Adam(predictive_model.parameters(),
+                                                lr=args.lr)
 
-                        local_batch = local_batch.to(device)
+        for epoch in range(args.epochs):
+            predictive_model.train()
+            epoch_loss = 0
+            t1 = time.perf_counter()
+
+            for labels, local_batch, vid_f in tqdm(train_loader):
+                local_batch = local_batch.to(device).squeeze(2)
 
-                        torch_labels = torch.zeros(len(labels))
-                        torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1
-                        torch_labels = torch_labels.unsqueeze(1).to(device)
+                torch_labels = torch.zeros(len(labels))
+                torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1
+                torch_labels = torch_labels.unsqueeze(1).to(device)
 
+                pred, activations = predictive_model(local_batch)
 
-                        pred, _ = predictive_model(local_batch)
+                loss = criterion(pred, torch_labels)
+                if args.train_sparse:
+                    loss += args.recon_scale * frozen_sparse.loss(local_batch, activations)
+                epoch_loss += loss.item() * local_batch.size(0)
 
-                        if y_true is None:
-                            y_true = torch_labels.detach().cpu().flatten().to(torch.long)
-                            y_pred = torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long)
-                        else:
-                            y_true = torch.cat((y_true, torch_labels.detach().cpu().flatten().to(torch.long)))
-                            y_pred = torch.cat((y_pred, torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long)))
+                prediction_optimizer.zero_grad()
+                loss.backward()
+                prediction_optimizer.step()
 
-                    t2 = time.perf_counter()
+            t2 = time.perf_counter()
 
-                    f1 = f1_score(y_true, y_pred, average='macro')
-                    accuracy = accuracy_score(y_true, y_pred)
+            predictive_model.eval()
+            with torch.no_grad():
+                y_true = None
+                y_pred = None
+                for labels, local_batch, vid_f in train_loader:
 
-                    print('fold={}, epoch={}, time={:.2f}, loss={:.2f}, f1={:.2f}, acc={:.2f}'.format(i_fold, epoch, t2-t1, epoch_loss, f1, accuracy))
+                    local_batch = local_batch.to(device).squeeze(2)
 
-                if epoch_loss <= best_so_far:
-                    print("found better model")
-                    # Save model parameters
-                    torch.save({
-                        'model_state_dict': predictive_model.state_dict(),
-                        'optimizer_state_dict': prediction_optimizer.state_dict(),
-                    }, os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
-                    best_so_far = epoch_loss
+                    torch_labels = torch.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1
+                    torch_labels = torch_labels.unsqueeze(1).to(device)
 
-            checkpoint = torch.load(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
-            predictive_model.load_state_dict(checkpoint['model_state_dict'])
+
+                    pred, _ = predictive_model(local_batch)
+
+                    if y_true is None:
+                        y_true = torch_labels.detach().cpu().flatten().to(torch.long)
+                        y_pred = torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long)
+                    else:
+                        y_true = torch.cat((y_true, torch_labels.detach().cpu().flatten().to(torch.long)))
+                        y_pred = torch.cat((y_pred, torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long)))
+
+                t2 = time.perf_counter()
+
+                f1 = f1_score(y_true, y_pred, average='macro')
+                accuracy = accuracy_score(y_true, y_pred)
+
+                print('fold={}, epoch={}, time={:.2f}, loss={:.2f}, f1={:.2f}, acc={:.2f}'.format(i_fold, epoch, t2-t1, epoch_loss, f1, accuracy))
+#             print(epoch_loss)
+            if epoch_loss <= best_so_far:
+                print("found better model")
+                # Save model parameters
+                torch.save({
+                    'model_state_dict': predictive_model.state_dict(),
+                    'optimizer_state_dict': prediction_optimizer.state_dict(),
+                }, os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
+                best_so_far = epoch_loss
+
+#         checkpoint = torch.load(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
+#         predictive_model.load_state_dict(checkpoint['model_state_dict'])
 
         predictive_model.eval()
         with torch.no_grad():
@@ -213,13 +215,13 @@ if __name__ == "__main__":
 
             y_true = None
             y_pred = None
-            
+
             pred_dict = {}
             gt_dict = {}
 
             t1 = time.perf_counter()
-            for labels, local_batch, vid_f in test_loader:
-                local_batch = local_batch.to(device)
+            for labels, local_batch, vid_f in train_loader:
+                local_batch = local_batch.to(device).squeeze(2)
 
                 torch_labels = torch.zeros(len(labels))
                 torch_labels[[i for i in range(len(labels)) if labels[i] == 'No_Sliding']] = 1
@@ -229,13 +231,13 @@ if __name__ == "__main__":
 
                 loss = criterion(pred, torch_labels)
                 epoch_loss += loss.item() * local_batch.size(0)
-                
+
                 for i, v_f in enumerate(vid_f):
                     if v_f not in pred_dict:
                         pred_dict[v_f] = torch.nn.Sigmoid()(pred[i]).round().detach().cpu().flatten().to(torch.long)
                     else:
                         pred_dict[v_f] = torch.cat((pred_dict[v_f], torch.nn.Sigmoid()(pred[i]).detach().round().cpu().flatten().to(torch.long)))
-                        
+
                     if v_f not in gt_dict:
                         gt_dict[v_f] = torch_labels[i].detach().cpu().flatten().to(torch.long)
                     else:
@@ -249,11 +251,17 @@ if __name__ == "__main__":
                     y_pred = torch.cat((y_pred, torch.nn.Sigmoid()(pred).detach().round().cpu().flatten().to(torch.long)))
 
             t2 = time.perf_counter()
-            
+
             vid_acc = []
             for k in pred_dict.keys():
                 gt_mode = torch.mode(gt_dict[k])[0].item()
-                pred_mode = torch.mode(pred_dict[k])[0].item()
+                perm = torch.randperm(pred_dict[k].size(0))
+                cutoff = int(pred_dict[k].size(0)/4)
+                if cutoff < 3:
+                    cutoff = 3
+                idx = perm[:cutoff]
+                samples = pred_dict[k][idx]
+                pred_mode = torch.mode(samples)[0].item()
                 overall_true.append(gt_mode)
                 overall_pred.append(pred_mode)
                 if pred_mode == gt_mode:
@@ -264,9 +272,9 @@ if __name__ == "__main__":
                         fn_ids.append(k)
                     else:
                         fp_ids.append(k)
-                    
+
             vid_acc = np.array(vid_acc)
-            
+
             print('----------------------------------------------------------------------------')
             for k in pred_dict.keys():
                 print(k)
@@ -275,31 +283,31 @@ if __name__ == "__main__":
                 print('Ground Truth:')
                 print(gt_dict[k])
                 print('Overall Prediction:')
-#                 pred_mode = 1
-#                 contiguous_zeros = 0
-#                 best_num = 0
-#                 for val in pred_dict[k]:
-#                     if val.item() == 0:
-#                         contiguous_zeros += 1
-#                     else:
-#                         if contiguous_zeros > best_num:
-#                             best_num = contiguous_zeros
-#                             contiguous_zeros = 0
-#                 if best_num >= 4 or contiguous_zeros >= 4:
-#                     pred_mode = 0
+    #                 pred_mode = 1
+    #                 contiguous_zeros = 0
+    #                 best_num = 0
+    #                 for val in pred_dict[k]:
+    #                     if val.item() == 0:
+    #                         contiguous_zeros += 1
+    #                     else:
+    #                         if contiguous_zeros > best_num:
+    #                             best_num = contiguous_zeros
+    #                             contiguous_zeros = 0
+    #                 if best_num >= 4 or contiguous_zeros >= 4:
+    #                     pred_mode = 0
                 print(torch.mode(pred_dict[k])[0].item())
                 print('----------------------------------------------------------------------------')
 
             print('fold={}, loss={:.2f}, time={:.2f}'.format(i_fold, loss, t2-t1))
-            
+
             f1 = f1_score(y_true, y_pred, average='macro')
             accuracy = accuracy_score(y_true, y_pred)
             all_errors.append(np.sum(vid_acc) / len(vid_acc))
 
             print("Test f1={:.2f}, clip_acc={:.2f}, vid_acc={:.2f} fold={}".format(f1, accuracy, np.sum(vid_acc) / len(vid_acc), i_fold))
-            
+
             print(confusion_matrix(y_true, y_pred))
-            
+
         i_fold = i_fold + 1
                               
     fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/load_data.py
index 2b3c91d92f0fe878b693dd29ea5436826ac3912d..dc808f636c6862a1be1c7260393715db003b39e7 100644
--- a/sparse_coding_torch/load_data.py
+++ b/sparse_coding_torch/load_data.py
@@ -4,10 +4,10 @@ import torch
 from sklearn.model_selection import train_test_split
 from sparse_coding_torch.video_loader import MinMaxScaler
 from sparse_coding_torch.video_loader import VideoLoader
-from sparse_coding_torch.video_loader import VideoClipLoader, YoloClipLoader, get_video_participants, YoloVideoLoader, MobileLoader
+from sparse_coding_torch.video_loader import VideoClipLoader, YoloClipLoader, get_video_participants, YoloVideoLoader, MobileLoader, PNBLoader
 from sparse_coding_torch.video_loader import VideoGrayScaler
 import csv
-from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold
+from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold
 
 def load_balls_data(batch_size):
     
@@ -107,8 +107,8 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non
     ])
     augment_transforms = torchvision.transforms.Compose(
     [torchvision.transforms.RandomRotation(45),
-     torchvision.transforms.RandomHorizontalFlip(),
-     torchvision.transforms.CenterCrop((100, 200))
+     torchvision.transforms.RandomHorizontalFlip()
+#      torchvision.transforms.CenterCrop((100, 200))
     ])
     if whole_video:
         dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
@@ -183,5 +183,50 @@ def load_mobile_clips(batch_size, mode, num_clips=1, num_positives=100, n_splits
         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
 
         return gss.split(np.arange(len(targets)), targets, groups), dataset
+    else:
+        return None
+    
+def load_pnb_videos(batch_size, mode, device=None, n_splits=None, sparse_model=None):   
+    video_path = "/shared_data/bamc_pnb_data/full_training_data"
+    
+    transforms = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     MinMaxScaler(0, 255),
+     torchvision.transforms.Resize((360, 304))
+    ])
+    augment_transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.RandomAffine(45),
+     torchvision.transforms.RandomHorizontalFlip(),
+     torchvision.transforms.ColorJitter(brightness=0.5),
+     torchvision.transforms.RandomAdjustSharpness(0, p=0.15),
+     torchvision.transforms.RandomAffine(degrees=0, translate=(0.05, 0))
+#      torchvision.transforms.CenterCrop((100, 200))
+    ])
+    dataset = PNBLoader(video_path, num_frames=5, frame_rate=20, transform=transforms)
+    
+    targets = dataset.get_labels()
+    
+    if mode == 'leave_one_out':
+        gss = LeaveOneGroupOut()
+
+        groups = [v for v in dataset.get_filenames()]
+#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    elif mode == 'all_train':
+        train_idx = np.arange(len(targets))
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        test_loader = None
+        
+        return train_loader, test_loader
+    elif mode == 'k_fold':
+        gss = StratifiedKFold(n_splits=n_splits, shuffle=True)
+
+#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        groups = [v for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets), dataset
     else:
         return None
\ No newline at end of file
diff --git a/sparse_coding_torch/small_data_classifier.py b/sparse_coding_torch/small_data_classifier.py
index 330596a614c7ac10bb6e8d44507010af2ad6960f..9f3ddc4d40effe665978080322e4834c5dd2a006 100644
--- a/sparse_coding_torch/small_data_classifier.py
+++ b/sparse_coding_torch/small_data_classifier.py
@@ -36,9 +36,9 @@ class SmallDataClassifierConv3d(nn.Module):
     def __init__(self):
         super().__init__()
         
-        self.max_pool_1 = nn.MaxPool3d(kernel_size=(1, 4, 4))
+        self.max_pool_1 = nn.MaxPool2d(kernel_size=(4, 4))
 
-        self.compress_activations_conv_1 = nn.Conv3d(in_channels=64, out_channels=24, kernel_size=(1, 8, 8), stride=(1, 4, 4), padding=(0, 4, 4))
+        self.compress_activations_conv_1 = nn.Conv2d(in_channels=64, out_channels=24, kernel_size=(8, 8), stride=(4, 4))
 #         self.compress_activations_conv_2 = nn.Conv3d(in_channels=32, out_channels=16, kernel_size=(1, 8, 8), stride=(1, 4, 4), padding=(1, 4, 4))
 
 #         self.gru = nn.GRU(37, 100)
@@ -46,16 +46,16 @@ class SmallDataClassifierConv3d(nn.Module):
         self.dropout = torch.nn.Dropout(p=0.5)
         
         # First fully connected layer
-        self.fc1 = nn.Linear(2184, 1000)
-        self.fc2 = nn.Linear(1000, 100)
-        self.fc3 = nn.Linear(100, 20)
+#         self.fc1 = nn.Linear(672, 1000)
+#         self.fc2 = nn.Linear(240, 100)
+        self.fc3 = nn.Linear(96, 20)
         self.fc4 = nn.Linear(20, 1)
 
     # x represents our data
     def forward(self, activations):
-        batch_size, channel_size, time_size, height_size, width_size = activations.size()
+        batch_size, channel_size, height_size, width_size = activations.size()
         
-        activations = activations.view(-1, channel_size, time_size, height_size, width_size)
+        activations = activations.view(-1, channel_size, height_size, width_size)
 
         x = self.max_pool_1(activations)
         
@@ -72,12 +72,14 @@ class SmallDataClassifierConv3d(nn.Module):
         
 #         x = x.to('cuda:' + str(save_device))
 
+        x = x.swapaxes(1, 3)
+
         x = torch.flatten(x, 1)
 
-        x = F.relu(self.fc1(x))
-        x = self.dropout(x)
-        x = F.relu(self.fc2(x))
-        x = self.dropout(x)
+#         x = F.relu(self.fc1(x))
+#         x = self.dropout(x)
+#         x = F.relu(self.fc2(x))
+#         x = self.dropout(x)
         x = F.relu(self.fc3(x))
         x = self.dropout(x)
         x = self.fc4(x)
diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py
index a318a1c4da9d87b685cf419f3d161ad3a0284208..4e7188cd1d5aacdbae02be36acff80081f16d515 100644
--- a/sparse_coding_torch/video_loader.py
+++ b/sparse_coding_torch/video_loader.py
@@ -5,6 +5,7 @@ from os.path import isdir
 from os.path import abspath
 from os.path import exists
 import json
+import glob
 
 from PIL import Image
 from torchvision.transforms import ToTensor
@@ -136,6 +137,8 @@ class VideoClipLoader(Dataset):
         if not frames_between_clips:
             frames_between_clips = num_frames
             
+        
+            
         vc = VideoClips([path for _, path, _ in self.videos],
                         clip_length_in_frames=num_frames,
                         frame_rate=frame_rate,
@@ -170,6 +173,59 @@ class VideoClipLoader(Dataset):
     def __len__(self):
         return len(self.clips)
     
+class PNBLoader(Dataset):
+    
+    def __init__(self, video_path, num_frames=5, frame_rate=20, frames_between_clips=None, transform=None):
+        self.transform = transform
+        self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
+        
+        self.videos = []
+        for label in self.labels:
+            self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))])
+            
+        #for v in self.videos:
+        #    video, _, info = read_video(v[1])
+        #    print(video.shape)
+        #    print(info)
+            
+        if not frames_between_clips:
+            frames_between_clips = num_frames
+            
+        self.clips = []
+                   
+        self.video_idx = []
+        
+        vid_idx = 0
+        for _, path, _ in self.videos:
+            vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
+#             for j in range(vc.size(1), vc.size(1) - 10, -5):
+            for j in range(0, vc.size(1) - 5, 5):
+#                 if j-5 < 0:
+#                     continue
+#                 vc_sub = vc_1 = vc[:, j-5:j, :, :]
+                vc_sub = vc[:, j:j+5, :, :]
+                if self.transform:
+                    vc_sub = self.transform(vc_sub)
+                    
+                self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2]))
+                self.video_idx.append(vid_idx)
+            vid_idx += 1
+        
+    def get_filenames(self):
+        return [self.clips[i][2] for i in range(len(self.clips))]
+        
+    def get_video_labels(self):
+        return [self.videos[i][0] for i in range(len(self.videos))]
+        
+    def get_labels(self):
+        return [self.clips[i][0] for i in range(len(self.clips))]
+    
+    def __getitem__(self, index):
+        return self.clips[index]
+        
+    def __len__(self):
+        return len(self.clips)
+    
 class VideoFrameLoader(Dataset):
     
     def __init__(self, video_path, transform=None):
@@ -261,8 +317,8 @@ class YoloClipLoader(Dataset):
 
                                 # width = region['relative_coordinates']['width'] * 1920
                                 # height = region['relative_coordinates']['height'] * 1080
-                                width=400
-                                height=400
+                                width=200
+                                height=100
 
                                 lower_y = round(center_y - height / 2)
                                 upper_y = round(center_y + height / 2)