diff --git a/clips_to_test_swift/6.mp4 b/clips_to_test_swift/6.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..ef9814bc63de5bddf3976c7a27119d5673012940
Binary files /dev/null and b/clips_to_test_swift/6.mp4 differ
diff --git a/keras/generate_tflite.py b/keras/generate_tflite.py
index 3dd871a95cc24a0d310393e502829c642600cfd1..7baffc0b6091bc3a9abb928cc1ab32a65196591b 100644
--- a/keras/generate_tflite.py
+++ b/keras/generate_tflite.py
@@ -8,49 +8,56 @@ import torch
 import torch.nn as nn
 from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
 from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
+from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d
 from keras_model import MobileModel
 
-inputs = keras.Input(shape=(100, 200, 5))
+inputs = keras.Input(shape=(5, 100, 200, 3))
 
-outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1e-2, max_activation_iter=40, run_2d=True)(inputs)
+outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=2, lam=0.05, activation_lr=1e-1, max_activation_iter=100, run_2d=True)(inputs)
+# outputs = tf.math.add(inputs, 1)
 
 model = keras.Model(inputs=inputs, outputs=outputs)
 
 
-pytorch_checkpoint = torch.load('../output/final_model_75_iter/model-best_fold_0.pt', map_location='cpu')['model_state_dict']
-conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
+pytorch_checkpoint = torch.load('../stride_2_100_iter.pt', map_location='cpu')['model_state_dict']
+conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].squeeze(2).swapaxes(0, 2).swapaxes(1, 3).swapaxes(2, 3).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
 model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
-
-ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
-model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
-ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
-model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
-ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
+# # ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
+# # model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
+# # ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
+# # model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
+ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
 model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
-ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
+ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
 model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
 
 # frozen_sparse = ConvSparseLayer(in_channels=1,
 #                                out_channels=64,
 #                                kernel_size=(5, 15, 15),
-#                                stride=1,
-#                                padding=(0, 7, 7),
+#                                stride=2,
+#                                padding=0,
 #                                convo_dim=3,
 #                                rectifier=True,
 #                                lam=0.05,
-#                                max_activation_iter=10,
-#                                activation_lr=1e-2)
-
+#                                max_activation_iter=100,
+#                                activation_lr=1e-1)
+#
 # sparse_param = torch.load('../sparse.pt', map_location='cpu')
 # frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
 #
-# # pytorch_filter = frozen_sparse.filters[30, :, 0, :, :].squeeze(0).unsqueeze(2).detach().numpy()
-# # keras_filter = model.get_layer('sparse_code').filter[0,:,:,:,30].numpy()
-# #
-# # cv2.imwrite('pytorch_filter.png', pytorch_filter / np.max(pytorch_filter) * 255.)
-# # cv2.imwrite('keras_filter.png', keras_filter / np.max(keras_filter) * 255.)
-# # raise Exception
+# predictive_model = SmallDataClassifierConv3d()
+# classifier_param = {k.replace('module.', ''): v for k,v in torch.load('../stride_2_100_iter.pt', map_location='cpu')['model_state_dict'].items()}
+# predictive_model.load_state_dict(classifier_param)
 #
+# predictive_model.eval()
+# #
+# # # pytorch_filter = frozen_sparse.filters[30, :, 0, :, :].squeeze(0).unsqueeze(2).detach().numpy()
+# # # keras_filter = model.get_layer('sparse_code').filter[0,:,:,:,30].numpy()
+# # #
+# # # cv2.imwrite('pytorch_filter.png', pytorch_filter / np.max(pytorch_filter) * 255.)
+# # # cv2.imwrite('keras_filter.png', keras_filter / np.max(keras_filter) * 255.)
+# # # raise Exception
+# #
 # img = tv.io.read_video('../clips/No_Sliding/Image_262499828648_clean1050.mp4')[0].permute(3, 0, 1, 2)
 # transform = tv.transforms.Compose(
 # [VideoGrayScaler(),
@@ -59,12 +66,13 @@ model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
 #  tv.transforms.CenterCrop((100, 200))
 # ])
 # img = transform(img)
-
+#
 # with torch.no_grad():
-#     activations = frozen_sparse(img.unsqueeze(0))
-
+#     activations, _ = predictive_model(frozen_sparse(img.unsqueeze(0)).squeeze(2))
+#     activations = torch.nn.Sigmoid()(activations)
+#
 # output = model(img.swapaxes(1, 3).swapaxes(1,2).numpy())
-
+#
 # print(activations.size())
 # print(output.shape)
 # print(torch.sum(activations))
@@ -72,7 +80,7 @@ model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
 
 input_name = model.input_names[0]
 index = model.input_names.index(input_name)
-model.inputs[index].set_shape([1, 100, 200, 5])
+model.inputs[index].set_shape([1, 5, 100, 200, 3])
 
 converter = tf.lite.TFLiteConverter.from_keras_model(model)
 # converter.experimental_new_converter = True
diff --git a/keras/keras_model.py b/keras/keras_model.py
index 5add25e02b950adb31f6ae7e8381f49d354fb39a..8f8e9b8fe1622c175eb9b3b6114efc7489daebcc 100644
--- a/keras/keras_model.py
+++ b/keras/keras_model.py
@@ -12,22 +12,22 @@ from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
 def load_pytorch_weights(file_path):
     pytorch_checkpoint = torch.load(file_path, map_location='cpu')
     weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4).numpy()
-    
+
     return weight_tensor
 
 @tf.function
 def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
-    out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride)
-    out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride)
+    out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
+    out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
+    out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
+    out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
+    out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
 
     recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
 
     return recon
 
-# @tf.function
+@tf.function
 def do_recon_3d(filters, activations, batch_size, stride):
     activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
     recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride])
@@ -35,18 +35,13 @@ def do_recon_3d(filters, activations, batch_size, stride):
     return recon
 
 @tf.function
-def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride):
-    e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
+def conv_error(filter, e, stride):
 
-    g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME')
-    g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME')
+    g = tf.nn.conv2d(e, filter, strides=stride, padding='VALID')
 
     return g
 
-# @tf.function
+@tf.function
 def conv_error_3d(filters, e, stride):
     e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]])
     g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID')
@@ -58,27 +53,27 @@ def normalize_weights(filters, out_channels):
     #print('filters shape', tf.shape(filters))
     norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
-    
+
     adjusted = [f / norms for f in filters]
-    
+
     #raise Exception('Beep')
-    
+
     return adjusted
 
-# @tf.function
+@tf.function
 def normalize_weights_3d(filters, out_channels):
     #for f in filters:
     #    print('filters 3d shape', f.shape)
     norms = tf.norm(tf.reshape(tf.transpose(filters[0], perm=[4, 0, 1, 2, 3]), (out_channels, -1)), axis=1)
     # tf.print("norms", norms.shape, norms)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
-    
+
     adjusted = [f / norms for f in filters]
 
     #for i in range(out_channels):
     #    tf.print("after normalization", tf.norm(adjusted[0][:,:,:,0,i]))
     #print()
-    
+
     #raise Exception('Beep')
     return adjusted
 
@@ -88,6 +83,7 @@ class SparseCode(keras.layers.Layer):
 
         self.out_channels = out_channels
         self.in_channels = in_channels
+        self.kernel_size = kernel_size
         self.stride = stride
         self.lam = lam
         self.activation_lr = activation_lr
@@ -106,24 +102,29 @@ class SparseCode(keras.layers.Layer):
 
         e = images - recon
         g = -1 * u
-        
+
         if self.run_2d:
-            convd_error = conv_error(filters[0], filters[1], filters[2], filters[3], filters[4], e, self.stride)
+            e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
+            g += conv_error(filters[0], e1, self.stride)
+            g += conv_error(filters[1], e2, self.stride)
+            g += conv_error(filters[2], e3, self.stride)
+            g += conv_error(filters[3], e4, self.stride)
+            g += conv_error(filters[4], e5, self.stride)
         else:
             convd_error = conv_error_3d(filters, e, self.stride)
-            
-        g = g + convd_error
+
+            g = g + convd_error
 
         g = g + activations
-        
+
         m = b1 * m + (1-b1) * g
         v = b2 * v + (1-b2) * g**2
         mh = m / (1 - b1**(1+i))
         vh = v / (1 - b2**(1+i))
-        
+
         du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps)
         u += du
-        
+
 #         i += 1
 
 #         return images, u, m, v, b1, b2, eps, i
@@ -132,21 +133,21 @@ class SparseCode(keras.layers.Layer):
     @tf.function
     def call(self, images, filters):
         if self.run_2d:
-            u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
-            m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
-            v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+            u = tf.zeros(shape=(self.batch_size, 100 // self.stride - self.kernel_size // self.stride, 200 // self.stride - self.kernel_size // self.stride, self.out_channels))
+            m = tf.zeros(shape=(self.batch_size, 100 // self.stride - self.kernel_size // self.stride, 200 // self.stride - self.kernel_size // self.stride, self.out_channels))
+            v = tf.zeros(shape=(self.batch_size, 100 // self.stride - self.kernel_size // self.stride, 200 // self.stride - self.kernel_size // self.stride, self.out_channels))
         else:
             u = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
             m = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
             v = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
-        
+
 #         tf.print('activations before:', tf.reduce_sum(u))
 
         b1 = tf.constant(0.9, dtype='float32')
         b2 = tf.constant(0.999, dtype='float32')
         eps = tf.constant(1e-8, dtype='float32')
-        
-        
+
+
 #         i = tf.constant(0, dtype='float32')
 #         c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter)
 #         images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i])
@@ -154,16 +155,16 @@ class SparseCode(keras.layers.Layer):
             u, m, v = self.do_update(images, filters, u, m, v, b1, b2, eps, i)
 
         u = tf.nn.relu(u - self.lam)
-        
+
 #         tf.print('activations after:', tf.reduce_sum(u))
 
         return u
-    
+
 class SparseCodeConv(keras.Model):
     def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
         super().__init__()
         self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
-        
+
         self.out_channels = out_channels
         self.in_channels = in_channels
         self.stride = stride
@@ -172,7 +173,7 @@ class SparseCodeConv(keras.Model):
         self.max_activation_iter = max_activation_iter
         self.batch_size = batch_size
         self.run_2d = run_2d
-        
+
         initializer = tf.keras.initializers.HeNormal()
         if run_2d:
             self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
@@ -182,7 +183,7 @@ class SparseCodeConv(keras.Model):
             self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
         else:
             self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
-        
+
         if run_2d:
             weights = normalize_weights(self.get_weights(), out_channels)
         else:
@@ -197,7 +198,7 @@ class SparseCodeConv(keras.Model):
         else:
             activations = self.sparse_code(images, tf.stop_gradient(self.filters))
             recon = do_recon_3d(self.filters, activations, self.batch_size, self.stride)
-            
+
         return recon, activations
 
 class Classifier(keras.layers.Layer):
@@ -205,14 +206,14 @@ class Classifier(keras.layers.Layer):
         super(Classifier, self).__init__()
 
         self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
-        self.conv = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='SAME')
+        self.conv = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
 
         self.flatten = keras.layers.Flatten()
 
         self.dropout = keras.layers.Dropout(0.5)
 
-        self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
-        self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
+        # self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+        # self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
         self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
         self.ff_4 = keras.layers.Dense(1, activation='sigmoid')
 
@@ -221,10 +222,10 @@ class Classifier(keras.layers.Layer):
         x = self.max_pool(activations)
         x = self.conv(x)
         x = self.flatten(x)
-        x = self.ff_1(x)
-        x = self.dropout(x)
-        x = self.ff_2(x)
-        x = self.dropout(x)
+        # # x = self.ff_1(x)
+        # # x = self.dropout(x)
+        # # x = self.ff_2(x)
+        # # x = self.dropout(x)
         x = self.ff_3(x)
         x = self.dropout(x)
         x = self.ff_4(x)
@@ -236,7 +237,7 @@ class MobileModel(keras.Model):
         super().__init__()
         self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
         self.classifier = Classifier()
-        
+
         self.out_channels = out_channels
         self.in_channels = in_channels
         self.stride = stride
@@ -245,9 +246,9 @@ class MobileModel(keras.Model):
         self.max_activation_iter = max_activation_iter
         self.batch_size = batch_size
         self.run_2d = run_2d
-        
+
         pytorch_weights = load_pytorch_weights(sparse_checkpoint)
-        
+
         if run_2d:
             weight_list = np.split(pytorch_weights, 5, axis=0)
             self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
@@ -260,11 +261,16 @@ class MobileModel(keras.Model):
 
     @tf.function
     def call(self, images):
+        images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1)
+        images = tf.transpose(images, perm=[0, 2, 3, 1])
+        images = images / 255
+        images = (images - 0.2592) / 0.1251
+
         if self.run_2d:
             activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
         else:
             activations = self.sparse_code(images, tf.stop_gradient(self.filters))
-            
+
         pred = self.classifier(activations)
-            
-        return pred
\ No newline at end of file
+
+        return pred
diff --git a/stride_2_100_iter.pt b/stride_2_100_iter.pt
new file mode 100644
index 0000000000000000000000000000000000000000..9c398c787dd27738188c4f56b8a2c938592f60dd
Binary files /dev/null and b/stride_2_100_iter.pt differ