diff --git a/keras/generate_tflite.py b/keras/generate_tflite.py
new file mode 100644
index 0000000000000000000000000000000000000000..f68b95b30a5b37321e04f793a5884b1aebfcad26
--- /dev/null
+++ b/keras/generate_tflite.py
@@ -0,0 +1,84 @@
+from tensorflow import keras
+import numpy as np
+import torch
+import tensorflow as tf
+import cv2
+import torchvision as tv
+import torch
+import torch.nn as nn
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
+from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
+from keras_model import SparseCode, Classifier
+
+inputs = keras.Input(shape=(5, 100, 200, 1))
+
+x = SparseCode('../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1, max_activation_iter=1)(inputs)
+outputs = Classifier()(x)
+
+model = keras.Model(inputs=inputs, outputs=x)
+
+
+pytorch_checkpoint = torch.load('../classifier.pt', map_location='cpu')['model_state_dict']
+conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
+model.get_layer('classifier').conv.set_weights(conv_weights)
+
+ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
+model.get_layer('classifier').ff_1.set_weights(ff_1_weights)
+ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
+model.get_layer('classifier').ff_2.set_weights(ff_2_weights)
+ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
+model.get_layer('classifier').ff_3.set_weights(ff_3_weights)
+ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
+model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
+
+# frozen_sparse = ConvSparseLayer(in_channels=1,
+#                                out_channels=64,
+#                                kernel_size=(5, 15, 15),
+#                                stride=1,
+#                                padding=(0, 7, 7),
+#                                convo_dim=3,
+#                                rectifier=True,
+#                                lam=0.05,
+#                                max_activation_iter=1,
+#                                activation_lr=1)
+#
+# sparse_param = torch.load('../sparse.pt', map_location='cpu')
+# frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
+#
+# # pytorch_filter = frozen_sparse.filters[30, :, 0, :, :].squeeze(0).unsqueeze(2).detach().numpy()
+# # keras_filter = model.get_layer('sparse_code').filter[0,:,:,:,30].numpy()
+# #
+# # cv2.imwrite('pytorch_filter.png', pytorch_filter / np.max(pytorch_filter) * 255.)
+# # cv2.imwrite('keras_filter.png', keras_filter / np.max(keras_filter) * 255.)
+# # raise Exception
+#
+# img = tv.io.read_video('../clips/No_Sliding/Image_262499828648_clean1050.mp4')[0].permute(3, 0, 1, 2)
+# transform = tv.transforms.Compose(
+# [VideoGrayScaler(),
+#  MinMaxScaler(0, 255),
+#  tv.transforms.Normalize((0.2592,), (0.1251,)),
+#  tv.transforms.CenterCrop((100, 200))
+# ])
+# img = transform(img)
+#
+# with torch.no_grad():
+#     activations = frozen_sparse(img.unsqueeze(0))
+#
+# output = model(img.unsqueeze(4).numpy())
+
+input_name = model.input_names[0]
+index = model.input_names.index(input_name)
+model.inputs[index].set_shape([1, 100, 200, 5])
+
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+# converter.experimental_new_converter = True
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_types = [tf.float16]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+
+tflite_model = converter.convert()
+
+print('Converted')
+
+with open("./output/tf_lite_model.tflite", "wb") as f:
+    f.write(tflite_model)
diff --git a/keras/keras_model.py b/keras/keras_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2360c59651090ccee188aa1ab899cc8567be7b83
--- /dev/null
+++ b/keras/keras_model.py
@@ -0,0 +1,160 @@
+from tensorflow import keras
+import numpy as np
+import torch
+import tensorflow as tf
+import cv2
+import torchvision as tv
+import torch
+import torch.nn as nn
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
+from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
+
+class SparseCode(keras.layers.Layer):
+    def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
+        super(SparseCode, self).__init__()
+
+        self.out_channels = out_channels
+        self.in_channels = in_channels
+        self.stride = stride
+        self.lam = lam
+        self.activation_lr = activation_lr
+        self.max_activation_iter = max_activation_iter
+        self.batch_size = batch_size
+
+        # pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
+        # weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
+        # self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
+        # weight_list = torch.chunk(weight_tensor, 5, dim=2)
+        #
+        initializer = tf.keras.initializers.HeNormal()
+        self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
+        self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
+        self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
+        self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
+        self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
+
+    def normalize_weights(self):
+        norms = tf.norm(tf.reshape(tf.stack([self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5]), (self.out_channels, -1)), axis=1, keepdims=True)
+        norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), self.filters_1.shape)
+        self.filters_1 = self.filters_1 / norms
+        self.filters_2 = self.filters_2 / norms
+        self.filters_3 = self.filters_3 / norms
+        self.filters_4 = self.filters_4 / norms
+        self.filters_5 = self.filters_5 / norms
+
+    @tf.function
+    def do_recon(self, activations):
+        out_1 = tf.nn.conv2d_transpose(activations, self.filters_1, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
+        out_2 = tf.nn.conv2d_transpose(activations, self.filters_2, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
+        out_3 = tf.nn.conv2d_transpose(activations, self.filters_3, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
+        out_4 = tf.nn.conv2d_transpose(activations, self.filters_4, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
+        out_5 = tf.nn.conv2d_transpose(activations, self.filters_5, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
+
+        recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
+
+        return recon
+
+    @tf.function
+    def do_recon_3d(self, activations):
+        recon = tf.nn.conv3d_transpose(activations, self.filter, output_shape=(self.batch_size, 5, 100, 200, 1), strides=self.stride)
+
+        return recon
+
+    @tf.function
+    def conv_error(self, e):
+        e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
+
+        g = tf.nn.conv2d(e1, self.filters_1, strides=self.stride, padding='SAME')
+        g = g + tf.nn.conv2d(e2, self.filters_2, strides=self.stride, padding='SAME')
+        g = g + tf.nn.conv2d(e3, self.filters_3, strides=self.stride, padding='SAME')
+        g = g + tf.nn.conv2d(e4, self.filters_4, strides=self.stride, padding='SAME')
+        g = g + tf.nn.conv2d(e5, self.filters_5, strides=self.stride, padding='SAME')
+
+        return g
+
+    @tf.function
+    def conv_error_3d(self, e):
+        g = tf.nn.conv3d(e, self.filter, strides=[1, 1, 1, 1, 1], padding='SAME')
+
+        return g
+
+    @tf.function
+    def do_update(self, images, u, m, v, b1, b2, eps, i):
+        activations = tf.nn.relu(u - self.lam)
+
+        recon = self.do_recon(activations)
+        e = images - recon
+        g = -1 * u
+
+        g = g + self.conv_error(e)
+
+        g = g + activations
+
+        m = b1 * m + (1-b1) * g
+        v = b2 * v + (1-b2) * g**2
+        mh = m / (1 - b1**(1+i))
+        vh = v / (1 - b2**(1+i))
+        u = u + (self.activation_lr * mh / (tf.math.sqrt(vh) + eps))
+
+        return u, m, v
+
+    def loss(self, images, activations):
+        recon = self.do_recon(activations)
+        loss = 0.5 * (1/images.shape[0]) * tf.sum(tf.math.pow(images - recon, 2))
+        loss += self.lam * tf.reduce_mean(tf.sum(tf.math.abs(activations.reshape(activations.shape[0], -1)), axis=1))
+        return loss
+
+    @tf.function
+    def call(self, images):
+        u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
+        m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
+        v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
+
+        b1 = 0.9
+        b2 = 0.999
+        eps = 1e-8
+
+        # images, u, m, v, b1, b2, eps, i = self.do_update(images, u, m, v, b1, b2, eps, 0)
+
+        # i = tf.constant(0, dtype=tf.float16)
+        # c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, 20)
+        # images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i], shape_invariants=[images.get_shape(), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), None, None, None, i.get_shape()])
+        for i in range(self.max_activation_iter):
+            u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i)
+
+        u = tf.nn.relu(u - self.lam)
+
+        self.add_loss(self.loss(images, u))
+
+        return u
+
+class Classifier(keras.layers.Layer):
+    def __init__(self):
+        super(Classifier, self).__init__()
+
+        self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
+        self.conv = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='SAME')
+
+        self.flatten = keras.layers.Flatten()
+
+        self.dropout = keras.layers.Dropout(0.5)
+
+        self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+        self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_4 = keras.layers.Dense(1, activation='sigmoid')
+
+    @tf.function
+    def call(self, activations):
+        x = self.max_pool(activations)
+        x = self.conv(x)
+        x = self.flatten(x)
+        x = self.ff_1(x)
+        x = self.dropout(x)
+        x = self.ff_2(x)
+        x = self.dropout(x)
+        x = self.ff_3(x)
+        x = self.dropout(x)
+        x = self.ff_4(x)
+
+        return x
diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c9cfcae85236fdd0fd6fc73ba12f7d5370de13
--- /dev/null
+++ b/keras/train_sparse_model.py
@@ -0,0 +1,157 @@
+import time
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from matplotlib import cm
+from matplotlib.animation import FuncAnimation
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.load_data import load_yolo_clips
+
+def plot_video(video):
+
+    fig = plt.gcf()
+    ax = plt.gca()
+
+    DPI = fig.get_dpi()
+    fig.set_size_inches(video.shape[2]/float(DPI), video.shape[3]/float(DPI))
+
+    ax.set_title("Video")
+
+    T = video.shape[1]
+    im = ax.imshow(video[0, 0, :, :],
+                     cmap=cm.Greys_r)
+
+    def update(i):
+        t = i % T
+        im.set_data(video[0, t, :, :])
+
+    return FuncAnimation(plt.gcf(), update, interval=1000/20)
+
+def plot_original_vs_recon(original, reconstruction, idx=0):
+
+    # create two subplots
+    ax1 = plt.subplot(1, 2, 1)
+    ax2 = plt.subplot(1, 2, 2)
+    ax1.set_title("Original")
+    ax2.set_title("Reconstruction")
+
+    T = original.shape[2]
+    im1 = ax1.imshow(original[idx, 0, 0, :, :],
+                     cmap=cm.Greys_r)
+    im2 = ax2.imshow(reconstruction[idx, 0, 0, :, :],
+                     cmap=cm.Greys_r)
+
+    def update(i):
+        t = i % T
+        im1.set_data(original[idx, 0, t, :, :])
+        im2.set_data(reconstruction[idx, 0, t, :, :])
+
+    return FuncAnimation(plt.gcf(), update, interval=1000/30)
+
+
+def plot_filters(filters):
+    num_filters = filters.shape[0]
+    ncol = 3
+    # ncol = int(np.sqrt(num_filters))
+    # nrow = int(np.sqrt(num_filters))
+    T = filters.shape[2]
+
+    if num_filters // ncol == num_filters / ncol:
+        nrow = num_filters // ncol
+    else:
+        nrow = num_filters // ncol + 1
+
+    fig, axes = plt.subplots(ncols=ncol, nrows=nrow,
+                             constrained_layout=True,
+                             figsize=(ncol*2, nrow*2))
+
+    ims = {}
+    for i in range(num_filters):
+        r = i // ncol
+        c = i % ncol
+        ims[(r, c)] = axes[r, c].imshow(filters[i, 0, 0, :, :],
+                                        cmap=cm.Greys_r)
+
+    def update(i):
+        t = i % T
+        for i in range(num_filters):
+            r = i // ncol
+            c = i % ncol
+            ims[(r, c)].set_data(filters[i, 0, t, :, :])
+
+    return FuncAnimation(plt.gcf(), update, interval=1000/20)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=12, type=int)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--num_kernels', default=64, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=200, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lr', default=1e-3, type=float)
+    parser.add_argument('--epochs', default=100, type=int)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--seed', default=42, type=int)
+
+    args = parser.parse_args()
+
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+
+    train_loader, _ = load_yolo_clips(batch_size, mode='all_train')
+    print('Loaded', len(train_loader), 'train examples')
+
+    example_data = next(iter(train_loader))
+
+    inputs = keras.Input(shape=(5, 100, 200, 1))
+
+    output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs)
+
+    model = keras.Model(inputs=inputs, outputs=output)
+
+    learning_rate = args.lr
+    filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
+
+    loss_log = []
+    best_so_far = float('inf')
+
+    for epoch in tqdm(range(args.epochs)):
+        epoch_loss = 0
+        epoch_start = time.perf_counter()
+
+        for labels, local_batch, vid_f in train_loader:
+            with tf.GradientTape() as tape:
+
+                activations = model(local_batch.numpy())
+                loss = tf.sum(model.losses)
+
+                epoch_loss += loss * local_batch.size(0)
+
+                gradients = tape.gradient(loss, model.trainable_weights)
+
+            optimizer.apply_gradients(zip(gradients, model.trainable_weights))
+
+        epoch_end = time.perf_counter()
+        epoch_loss /= len(train_loader.sampler)
+
+        if epoch_loss < best_so_far:
+            print("found better model")
+            # Save model parameters
+            model.save(os.path.join(output_dir, "sparse_conv3d_model-best.pt"))
+            best_so_far = epoch_loss
+
+        loss_log.append(epoch_loss)
+        print('epoch={}, epoch_loss={:.2f}, time={:.2f}'.format(epoch, epoch_loss, epoch_end - epoch_start))
+
+    plt.plot(loss_log)
+
+    plt.savefig(os.path.join(output_dir, 'loss_graph.png'))