diff --git a/keras/keras_model.py b/keras/keras_model.py
index 6327fe99250ce8141393da80d3a7ddbaec3cbb05..919dc7db48b92a3aba8bc164a43f31c0337fd9b8 100644
--- a/keras/keras_model.py
+++ b/keras/keras_model.py
@@ -9,6 +9,12 @@ import torch.nn as nn
 from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
 from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
 
+def load_pytorch_weights(file_path):
+    pytorch_checkpoint = torch.load(file_path, map_location='cpu')
+    weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4).numpy()
+    
+    return weight_tensor
+
 @tf.function
 def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
     out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
@@ -21,8 +27,9 @@ def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations,
 
     return recon
 
-@tf.function
+# @tf.function
 def do_recon_3d(filters, activations, batch_size, stride):
+    activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
     recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride])
 
     return recon
@@ -39,9 +46,10 @@ def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride)
 
     return g
 
-@tf.function
+# @tf.function
 def conv_error_3d(filters, e, stride):
-    g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='SAME')
+    e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]])
+    g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID')
 
     return g
 
@@ -54,7 +62,7 @@ def normalize_weights(filters, out_channels):
     
     return adjusted
 
-@tf.function
+# @tf.function
 def normalize_weights_3d(filters, out_channels):
     norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1)
     norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
@@ -117,9 +125,9 @@ class SparseCode(keras.layers.Layer):
             m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
             v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
         else:
-            u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
-            m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
-            v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
+            u = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
+            m = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
+            v = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
         
 #         tf.print('activations before:', tf.reduce_sum(u))
 
@@ -211,3 +219,41 @@ class Classifier(keras.layers.Layer):
         x = self.ff_4(x)
 
         return x
+
+class MobileModel(keras.Model):
+    def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+        super().__init__()
+        self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
+        self.classifier = Classifier()
+        
+        self.out_channels = out_channels
+        self.in_channels = in_channels
+        self.stride = stride
+        self.lam = lam
+        self.activation_lr = activation_lr
+        self.max_activation_iter = max_activation_iter
+        self.batch_size = batch_size
+        self.run_2d = run_2d
+        
+        pytorch_weights = load_pytorch_weights(sparse_checkpoint)
+        
+        if run_2d:
+            weight_list = np.split(pytorch_weights, 5, axis=0)
+            self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
+            self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False)
+            self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False)
+            self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False)
+            self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False)
+        else:
+            self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False)
+
+    @tf.function
+    def call(self, images):
+        if self.run_2d:
+            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
+        else:
+            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
+            
+        pred = self.classifier(activations)
+            
+        return pred
\ No newline at end of file