diff --git a/keras/generate_tflite.py b/keras/generate_tflite.py
index f68b95b30a5b37321e04f793a5884b1aebfcad26..3dd871a95cc24a0d310393e502829c642600cfd1 100644
--- a/keras/generate_tflite.py
+++ b/keras/generate_tflite.py
@@ -8,28 +8,27 @@ import torch
 import torch.nn as nn
 from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
 from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
-from keras_model import SparseCode, Classifier
+from keras_model import MobileModel
 
-inputs = keras.Input(shape=(5, 100, 200, 1))
+inputs = keras.Input(shape=(100, 200, 5))
 
-x = SparseCode('../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1, max_activation_iter=1)(inputs)
-outputs = Classifier()(x)
+outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1e-2, max_activation_iter=40, run_2d=True)(inputs)
 
-model = keras.Model(inputs=inputs, outputs=x)
+model = keras.Model(inputs=inputs, outputs=outputs)
 
 
-pytorch_checkpoint = torch.load('../classifier.pt', map_location='cpu')['model_state_dict']
+pytorch_checkpoint = torch.load('../output/final_model_75_iter/model-best_fold_0.pt', map_location='cpu')['model_state_dict']
 conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
-model.get_layer('classifier').conv.set_weights(conv_weights)
+model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
 
 ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
-model.get_layer('classifier').ff_1.set_weights(ff_1_weights)
+model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
 ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
-model.get_layer('classifier').ff_2.set_weights(ff_2_weights)
+model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
 ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
-model.get_layer('classifier').ff_3.set_weights(ff_3_weights)
+model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
 ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
-model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
+model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
 
 # frozen_sparse = ConvSparseLayer(in_channels=1,
 #                                out_channels=64,
@@ -39,9 +38,9 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
 #                                convo_dim=3,
 #                                rectifier=True,
 #                                lam=0.05,
-#                                max_activation_iter=1,
-#                                activation_lr=1)
-#
+#                                max_activation_iter=10,
+#                                activation_lr=1e-2)
+
 # sparse_param = torch.load('../sparse.pt', map_location='cpu')
 # frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
 #
@@ -60,11 +59,16 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
 #  tv.transforms.CenterCrop((100, 200))
 # ])
 # img = transform(img)
-#
+
 # with torch.no_grad():
 #     activations = frozen_sparse(img.unsqueeze(0))
-#
-# output = model(img.unsqueeze(4).numpy())
+
+# output = model(img.swapaxes(1, 3).swapaxes(1,2).numpy())
+
+# print(activations.size())
+# print(output.shape)
+# print(torch.sum(activations))
+# print(tf.math.reduce_sum(output))
 
 input_name = model.input_names[0]
 index = model.input_names.index(input_name)
@@ -80,5 +84,5 @@ tflite_model = converter.convert()
 
 print('Converted')
 
-with open("./output/tf_lite_model.tflite", "wb") as f:
+with open("./mobile_output/tf_lite_model.tflite", "wb") as f:
     f.write(tflite_model)
diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py
index f54caede35daaf45bdd1c266a7a4c86f7f0e452e..d543e3af8d470e2db1b090f0718a829e40fcd18c 100644
--- a/keras/train_sparse_model.py
+++ b/keras/train_sparse_model.py
@@ -10,7 +10,7 @@ import os
 from sparse_coding_torch.load_data import load_yolo_clips
 import tensorflow.keras as keras
 import tensorflow as tf
-from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv
+from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv, load_pytorch_weights
 
 def plot_video(video):
 
@@ -99,7 +99,7 @@ if __name__ == "__main__":
     parser.add_argument('--stride', default=2, type=int)
     parser.add_argument('--max_activation_iter', default=50, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
-    parser.add_argument('--lr', default=1e-2, type=float)
+    parser.add_argument('--lr', default=5e-2, type=float)
     parser.add_argument('--epochs', default=100, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
     parser.add_argument('--output_dir', default='./output', type=str)