diff --git a/keras/keras_model.py b/keras/keras_model.py
index 2360c59651090ccee188aa1ab899cc8567be7b83..2b11890b207db201fc955a520ec26d45ae3cb7e2 100644
--- a/keras/keras_model.py
+++ b/keras/keras_model.py
@@ -9,6 +9,51 @@ import torch.nn as nn
 from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
 from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
 
+@tf.function
+def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
+    out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
+    out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride)
+    out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride)
+    out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride)
+    out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride)
+
+    recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
+
+    return recon
+
+@tf.function
+def do_recon_3d(filters, activations, batch_size, stride):
+    recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=stride)
+
+    return recon
+
+@tf.function
+def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride):
+    e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
+
+    g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME')
+    g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME')
+    g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME')
+    g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME')
+    g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME')
+
+    return g
+
+@tf.function
+def conv_error_3d(filters, e, stride):
+    g = tf.nn.conv3d(e, filters, strides=[stride, stride, stride, stride, stride], padding='SAME')
+
+    return g
+
+@tf.function
+def normalize_weights(filters, out_channels):
+    norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
+    norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
+    
+    adjusted = [f / norms for f in filters]
+    
+    return adjusted
+
 class SparseCode(keras.layers.Layer):
     def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
         super(SparseCode, self).__init__()
@@ -27,104 +72,57 @@ class SparseCode(keras.layers.Layer):
         # weight_list = torch.chunk(weight_tensor, 5, dim=2)
         #
         initializer = tf.keras.initializers.HeNormal()
-        self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
-        self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
-        self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
-        self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
-        self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
-
-    def normalize_weights(self):
-        norms = tf.norm(tf.reshape(tf.stack([self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5]), (self.out_channels, -1)), axis=1, keepdims=True)
-        norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), self.filters_1.shape)
-        self.filters_1 = self.filters_1 / norms
-        self.filters_2 = self.filters_2 / norms
-        self.filters_3 = self.filters_3 / norms
-        self.filters_4 = self.filters_4 / norms
-        self.filters_5 = self.filters_5 / norms
-
-    @tf.function
-    def do_recon(self, activations):
-        out_1 = tf.nn.conv2d_transpose(activations, self.filters_1, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
-        out_2 = tf.nn.conv2d_transpose(activations, self.filters_2, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
-        out_3 = tf.nn.conv2d_transpose(activations, self.filters_3, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
-        out_4 = tf.nn.conv2d_transpose(activations, self.filters_4, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
-        out_5 = tf.nn.conv2d_transpose(activations, self.filters_5, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
-
-        recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
-
-        return recon
-
-    @tf.function
-    def do_recon_3d(self, activations):
-        recon = tf.nn.conv3d_transpose(activations, self.filter, output_shape=(self.batch_size, 5, 100, 200, 1), strides=self.stride)
-
-        return recon
-
-    @tf.function
-    def conv_error(self, e):
-        e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
-
-        g = tf.nn.conv2d(e1, self.filters_1, strides=self.stride, padding='SAME')
-        g = g + tf.nn.conv2d(e2, self.filters_2, strides=self.stride, padding='SAME')
-        g = g + tf.nn.conv2d(e3, self.filters_3, strides=self.stride, padding='SAME')
-        g = g + tf.nn.conv2d(e4, self.filters_4, strides=self.stride, padding='SAME')
-        g = g + tf.nn.conv2d(e5, self.filters_5, strides=self.stride, padding='SAME')
-
-        return g
-
-    @tf.function
-    def conv_error_3d(self, e):
-        g = tf.nn.conv3d(e, self.filter, strides=[1, 1, 1, 1, 1], padding='SAME')
-
-        return g
+        self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+        self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+        self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+        self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+        self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
+        
+        weights = normalize_weights(self.get_weights(), out_channels)
+        self.set_weights(weights)
 
     @tf.function
     def do_update(self, images, u, m, v, b1, b2, eps, i):
         activations = tf.nn.relu(u - self.lam)
 
-        recon = self.do_recon(activations)
+        recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
         e = images - recon
         g = -1 * u
-
-        g = g + self.conv_error(e)
+                
+        convd_error = conv_error(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, e, self.stride)
+        
+        g = g + convd_error
 
         g = g + activations
-
+        
         m = b1 * m + (1-b1) * g
         v = b2 * v + (1-b2) * g**2
         mh = m / (1 - b1**(1+i))
         vh = v / (1 - b2**(1+i))
-        u = u + (self.activation_lr * mh / (tf.math.sqrt(vh) + eps))
+        
+        du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps)
+        u += du
 
         return u, m, v
 
-    def loss(self, images, activations):
-        recon = self.do_recon(activations)
-        loss = 0.5 * (1/images.shape[0]) * tf.sum(tf.math.pow(images - recon, 2))
-        loss += self.lam * tf.reduce_mean(tf.sum(tf.math.abs(activations.reshape(activations.shape[0], -1)), axis=1))
-        return loss
-
     @tf.function
     def call(self, images):
-        u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
-        m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
-        v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
+        u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+        m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+        v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
+        
+        tf.print('activations before:', tf.reduce_sum(u))
 
         b1 = 0.9
         b2 = 0.999
         eps = 1e-8
 
-        # images, u, m, v, b1, b2, eps, i = self.do_update(images, u, m, v, b1, b2, eps, 0)
-
-        # i = tf.constant(0, dtype=tf.float16)
-        # c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, 20)
-        # images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i], shape_invariants=[images.get_shape(), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), None, None, None, i.get_shape()])
         for i in range(self.max_activation_iter):
             u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i)
 
         u = tf.nn.relu(u - self.lam)
-
-        self.add_loss(self.loss(images, u))
+        
+        tf.print('activations after:', tf.reduce_sum(u))
 
         return u
 
diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py
index 72c9cfcae85236fdd0fd6fc73ba12f7d5370de13..54e0209ef0449d1ac16b5458859de48b03ab4d24 100644
--- a/keras/train_sparse_model.py
+++ b/keras/train_sparse_model.py
@@ -8,6 +8,9 @@ from tqdm import tqdm
 import argparse
 import os
 from sparse_coding_torch.load_data import load_yolo_clips
+import tensorflow.keras as keras
+import tensorflow as tf
+from keras_model import SparseCode, do_recon, normalize_weights
 
 def plot_video(video):
 
@@ -83,16 +86,21 @@ def plot_filters(filters):
 
     return FuncAnimation(plt.gcf(), update, interval=1000/20)
 
+def sparse_loss(filters_1, filters_2, filters_3, filters_4, filters_5, images, activations, batch_size, lam, stride):
+    recon = do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride)
+    loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
+    loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
+    return loss
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--batch_size', default=12, type=int)
+    parser.add_argument('--batch_size', default=8, type=int)
     parser.add_argument('--kernel_size', default=15, type=int)
     parser.add_argument('--num_kernels', default=64, type=int)
-    parser.add_argument('--stride', default=1, type=int)
-    parser.add_argument('--max_activation_iter', default=200, type=int)
+    parser.add_argument('--stride', default=2, type=int)
+    parser.add_argument('--max_activation_iter', default=150, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
-    parser.add_argument('--lr', default=1e-3, type=float)
+    parser.add_argument('--lr', default=1e-5, type=float)
     parser.add_argument('--epochs', default=100, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
     parser.add_argument('--output_dir', default='./output', type=str)
@@ -103,16 +111,18 @@ if __name__ == "__main__":
     output_dir = args.output_dir
     if not os.path.exists(output_dir):
         os.makedirs(output_dir)
+        
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
     with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
         out_f.write(str(args))
 
-    train_loader, _ = load_yolo_clips(batch_size, mode='all_train')
+    train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json')
     print('Loaded', len(train_loader), 'train examples')
 
     example_data = next(iter(train_loader))
 
-    inputs = keras.Input(shape=(5, 100, 200, 1))
+    inputs = keras.Input(shape=(100, 200, 5))
 
     output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs)
 
@@ -124,21 +134,40 @@ if __name__ == "__main__":
     loss_log = []
     best_so_far = float('inf')
 
-    for epoch in tqdm(range(args.epochs)):
+    for epoch in range(args.epochs):
         epoch_loss = 0
         epoch_start = time.perf_counter()
 
-        for labels, local_batch, vid_f in train_loader:
+        for labels, local_batch, vid_f in tqdm(train_loader):
+            tf.print(vid_f)
+            if local_batch.size(0) != args.batch_size:
+                continue
+            images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
+            
             with tf.GradientTape() as tape:
-
-                activations = model(local_batch.numpy())
-                loss = tf.sum(model.losses)
-
-                epoch_loss += loss * local_batch.size(0)
-
-                gradients = tape.gradient(loss, model.trainable_weights)
-
-            optimizer.apply_gradients(zip(gradients, model.trainable_weights))
+                activations = model(images)
+                
+                filters = model.get_weights()
+
+                loss = sparse_loss(filters[0], filters[1], filters[2], filters[3], filters[4], images, activations, args.batch_size, args.lam, args.stride)
+#             recon = model.do_recon(activations)
+#             loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
+#             loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
+
+            epoch_loss += loss * local_batch.size(0)
+            tf.print('loss:', loss)
+
+            gradients = tape.gradient(loss, model.trainable_weights)
+            tf.print('gradients:', tf.reduce_sum(gradients))
+
+            filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights))
+            
+            tf.print('weights:', tf.reduce_sum(model.trainable_weights))
+            
+            weights = normalize_weights(model.get_weights(), args.num_kernels)
+            model.set_weights(weights)
+            
+            tf.print('normalized weights:', tf.reduce_sum(model.trainable_weights))
 
         epoch_end = time.perf_counter()
         epoch_loss /= len(train_loader.sampler)