diff --git a/generate_yolo_regions.py b/generate_yolo_regions.py
new file mode 100644
index 0000000000000000000000000000000000000000..864c493833e2006a5c0670c056c01e5d07d7e39b
--- /dev/null
+++ b/generate_yolo_regions.py
@@ -0,0 +1,112 @@
+import torch
+import os
+import time
+import numpy as np
+import torchvision
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels
+from torchvision.datasets.video_utils import VideoClips
+import torchvision as tv
+import csv
+from datetime import datetime
+from yolov4.get_bounding_boxes import YoloModel
+import argparse
+import tensorflow as tf
+import scipy.stats
+import cv2
+import tensorflow.keras as keras
+from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse
+import glob
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_video', required=True, type=str, help='Path to input video.')
+    parser.add_argument('--output_dir', default='yolo_output', type=str, help='Location where yolo clips should be saved.')
+    parser.add_argument('--num_frames', default=5, type=int)
+    
+    args = parser.parse_args()
+    
+    path = args.input_video
+        
+    region_labels = load_pnb_region_labels(os.path.join('/'.join(path.split('/')[:-3]), 'sme_region_labels.csv'))
+                        
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+
+    image_height = 285
+    image_width = 235
+    
+    # For some reason the size has to be even for the clips, so it will add one if the size is odd
+    transforms = torchvision.transforms.Compose([
+     torchvision.transforms.Resize(((image_height//2)*2, (image_width//2)*2))
+    ])
+        
+    yolo_model = YoloModel()
+        
+    vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
+    is_right = classify_nerve_is_right(yolo_model, vc)
+    person_idx = path.split('/')[-2]
+    label = path.split('/')[-3]
+    
+    output_count = 0
+
+    if label == 'Positives' and person_idx in region_labels:
+        negative_regions, positive_regions = region_labels[person_idx]
+        for sub_region in negative_regions.split(','):
+            sub_region = sub_region.split('-')
+            start_loc = int(sub_region[0])
+            end_loc = int(sub_region[1]) + 1
+            for frame in range(start_loc, end_loc - 5, 5):
+                vc_sub = vc[:, frame:frame+5, :, :]
+                if vc_sub.size(1) < 5:
+                    continue
+
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                    clip = transforms(clip)
+                    tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                    output_count += 1
+
+        if positive_regions:
+            for sub_region in positive_regions.split(','):
+                sub_region = sub_region.split('-')
+                start_loc = int(sub_region[0])
+                if len(sub_region) == 1:
+                    vc_sub = vc[:, start_loc:start_loc+5, :, :]
+                    if vc_sub.size(1) < 5:
+                        continue
+
+                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                        clip = transforms(clip)
+                        tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                        output_count += 1
+                else:
+                    end_loc = sub_region[1]
+                    if end_loc.strip().lower() == 'end':
+                        end_loc = vc.size(1)
+                    else:
+                        end_loc = int(end_loc)
+                    for frame in range(start_loc, end_loc - 5, 5):
+                        vc_sub = vc[:, frame:frame+5, :, :]
+#                                         cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
+                        if vc_sub.size(1) < 5:
+                            continue
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                            clip = transforms(clip)
+                            tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                            output_count += 1
+    elif label == 'Positives':
+        vc_sub = vc[:, -5:, :, :]
+        if not vc_sub.size(1) < 5:
+            for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                clip = transforms(clip)
+                tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                output_count += 1
+    elif label == 'Negatives':
+        for j in range(0, vc.size(1) - 5, 5):
+            vc_sub = vc[:, j:j+5, :, :]
+            if not vc_sub.size(1) < 5:
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                    clip = transforms(clip)
+                    tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                    output_count += 1
+    else:
+        raise Exception('Invalid label')
\ No newline at end of file
diff --git a/run_pnb.py b/run_pnb.py
new file mode 100644
index 0000000000000000000000000000000000000000..52a023c56604b9849a37eca5afe99f27ff306e5d
--- /dev/null
+++ b/run_pnb.py
@@ -0,0 +1,119 @@
+import torch
+import os
+import time
+import numpy as np
+import torchvision
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right
+from torchvision.datasets.video_utils import VideoClips
+import torchvision as tv
+import csv
+from datetime import datetime
+from yolov4.get_bounding_boxes import YoloModel
+import argparse
+import tensorflow as tf
+import scipy.stats
+import cv2
+import tensorflow.keras as keras
+from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse
+import glob
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_dir', default='/shared_data/bamc_pnb_data/full_training_data', type=str)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=48, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=150, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--sparse_checkpoint', default='sparse_coding_torch/output/sparse_pnb_48/sparse_conv3d_model-best.pt/', type=str)
+    parser.add_argument('--checkpoint', default='sparse_coding_torch/classifier_outputs/48_filters_6/best_classifier.pt/', type=str)
+    parser.add_argument('--run_2d', action='store_true')
+    
+    args = parser.parse_args()
+    #print(args.accumulate(args.integers))
+    batch_size = 1
+
+    image_height = 285
+    image_width = 235
+
+    if args.run_2d:
+        inputs = keras.Input(shape=(image_height, image_width, 5))
+    else:
+        inputs = keras.Input(shape=(5, image_height, image_width, 1))
+
+    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+    output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='VALID')(inputs, filter_inputs)
+
+    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+
+    recon_model = keras.models.load_model(args.sparse_checkpoint)
+        
+    classifier_model = keras.models.load_model(args.checkpoint)
+        
+    yolo_model = YoloModel()
+
+    transform = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     MinMaxScaler(0, 255),
+     torchvision.transforms.Resize((285, 235))
+    ])
+
+    all_predictions = []
+
+    all_files = glob.glob(pathname=os.path.join(args.input_dir, '**', '*.mp4'), recursive=True)
+
+    for f in all_files:
+        print('Processing', f)
+        
+        vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2)
+        is_right = classify_nerve_is_right(yolo_model, vc)
+        
+        vc_sub = vc[:, -5:, :, :]
+        if vc_sub.size(1) < 5:
+            print(f + ' does not contain enough frames for processing')
+            continue
+            
+        ### START time after loading video ###
+        start_time = time.time()
+        clip = None
+        i = 1
+        while not clip and i < 5:
+            if vc_sub.size(1) < 5:
+                break
+            clip = get_yolo_regions(yolo_model, vc_sub, is_right)
+            vc_sub = vc[:, -5*(i+1):-5*i, :, :]
+            i += 1
+            
+        if clip:
+            clip = clip[0]
+            clip = transform(clip)
+            clip = tf.expand_dims(clip, axis=4) 
+
+            activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+
+            pred = tf.math.sigmoid(classifier_model(activations))
+
+            final_pred = tf.math.round(pred)
+
+            if final_pred == 1:
+                str_pred = 'Positive'
+            else:
+                str_pred = 'Negative'
+        else:
+            print('here')
+            str_pred = "Positive"
+
+        end_time = time.time()
+
+        print(str_pred)
+
+        all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
+
+    with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
+        writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
+
+        writer.writeheader()
+        writer.writerows(all_predictions)
diff --git a/run_ptx.py b/run_ptx.py
index 5b5ee354fcb6138f8c96870b7c891f1699e64bd7..5ee5656f7e1db3638e5352dbf1e8e171c506b8ac 100644
--- a/run_ptx.py
+++ b/run_ptx.py
@@ -49,11 +49,11 @@ if __name__ == "__main__":
 
     filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
 
-    output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+    output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='SAME')(inputs, filter_inputs)
 
     sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
     
-    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+    recon_inputs = keras.Input(shape=(1, image_height // args.stride, image_width // args.stride, args.num_kernels))
     
     recon_outputs = ReconSparse(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
     
@@ -65,7 +65,7 @@ if __name__ == "__main__":
     if args.checkpoint:
         classifier_model = keras.models.load_model(args.checkpoint)
     else:
-        classifier_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+        classifier_inputs = keras.Input(shape=(1, image_height // args.stride, image_width // args.stride, args.num_kernels))
 
         if args.dataset == 'pnb':
             classifier_outputs = PNBClassifier()(classifier_inputs)
@@ -111,7 +111,8 @@ if __name__ == "__main__":
             clip, _, _, _ = vc.get_clip(i)
             clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
             
-            bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)).squeeze(0)
+            bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1))
+            bounding_boxes = bounding_boxes.squeeze(0)
             if bounding_boxes.size == 0:
                 continue
             #widths = []
diff --git a/run_tflite_pnb.py b/run_tflite_pnb.py
index 2b0f4fd0d57654a46a2d4e7266c067bbd00a4d03..ce690e34454a332b3bf191d63e18eb4f883a13a7 100644
--- a/run_tflite_pnb.py
+++ b/run_tflite_pnb.py
@@ -3,7 +3,7 @@ import os
 import time
 import numpy as np
 import torchvision
-from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right
 from torchvision.datasets.video_utils import VideoClips
 import csv
 from datetime import datetime
@@ -12,12 +12,14 @@ import argparse
 import tensorflow as tf
 import scipy.stats
 import cv2
+import glob
+import torchvision as tv
 
 if __name__ == "__main__":
 
     parser = argparse.ArgumentParser(description='Python program for processing PNB data')
-    parser.add_argument('--classifier', type=str, default='keras/mobile_output/tf_lite_model.tflite')
-    parser.add_argument('--input_dir', type=str, default='input_videos')
+    parser.add_argument('--classifier', type=str, default='sparse_coding_torch/mobile_output/pnb.tflite')
+    parser.add_argument('--input_dir', default='/shared_data/bamc_pnb_data/full_training_data', type=str)
     args = parser.parse_args()
 
     interpreter = tf.lite.Interpreter(args.classifier)
@@ -29,31 +31,41 @@ if __name__ == "__main__":
     yolo_model = YoloModel()
 
     transform = torchvision.transforms.Compose(
-    [VideoGrayScaler(),
+    [#VideoGrayScaler(),
      MinMaxScaler(0, 255),
-     torchvision.transforms.Resize((360, 304))
+#      torchvision.transforms.Resize((285, 235))
     ])
 
     all_predictions = []
 
-    all_files = list(os.listdir(args.input_dir))
+    all_files = glob.glob(pathname=os.path.join(args.input_dir, '**', '*.mp4'), recursive=True)
 
     for f in all_files:
         print('Processing', f)
         
-        vc = tv.io.read_video(os.path.join(args.input_dir, f))[0].permute(3, 0, 1, 2)
+        vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2)
+        is_right = classify_nerve_is_right(yolo_model, vc)
         
         vc_sub = vc[:, -5:, :, :]
         if vc_sub.size(1) < 5:
-            raise Exception(f + ' does not contain enough frames for processing')
+            print(f + ' does not contain enough frames for processing')
+            continue
             
         ### START time after loading video ###
         start_time = time.time()
         
-        clip = get_yolo_regions(yolo_model, vc_sub)
+        clip = None
+        i = 1
+        while not clip and i < 5:
+            if vc_sub.size(1) < 5:
+                break
+            clip = get_yolo_regions(yolo_model, vc_sub, is_right)
+            vc_sub = vc[:, -5*(i+1):-5*i, :, :]
+            i += 1
+
         if clip:
             clip = clip[0]
-            clip = transform(clip)
+            clip = transform(clip).to(torch.float32)
 
             interpreter.set_tensor(input_details[0]['index'], clip)
 
@@ -66,10 +78,10 @@ if __name__ == "__main__":
 
             final_pred = pred.round()
 
-                if final_pred == 1:
-                    str_pred = 'Positive'
-                else:
-                    str_pred = 'Negative'
+            if final_pred == 1:
+                str_pred = 'Positive'
+            else:
+                str_pred = 'Negative'
         else:
             str_pred = "Positive"
 
diff --git a/sme_region_labels.csv b/sme_region_labels.csv
new file mode 100644
index 0000000000000000000000000000000000000000..c533ec9c715bf1da2cd30d320e034c5275fce9b9
--- /dev/null
+++ b/sme_region_labels.csv
@@ -0,0 +1,17 @@
+idx,negative_regions,positive_regions
+11,"0-121,158-516","121-157,517-end"
+54,0-397,397
+67,0-120,120-end
+93,"0-60,155-256","61-154,257-end"
+94,0-78,79-end
+134,0-393,394-end
+153,0-200,201-end
+189,"0-122,123-184",185
+193,0-111,112-end
+205,0-646,
+211,0-86,87-end
+217,"0-86,87-137",138-end
+222,0-112,113-end
+230,0-7,
+238,"0-140,141-152",153-end)
+240,0-205,206-end
\ No newline at end of file
diff --git a/sparse_coding_torch/convert_pytorch_to_keras.py b/sparse_coding_torch/convert_pytorch_to_keras.py
index 9abf4ac07fdaab26ce25aeefc483eaf09639faf3..4f3e1abbd830c01acca62f924e118c4864ab7478 100644
--- a/sparse_coding_torch/convert_pytorch_to_keras.py
+++ b/sparse_coding_torch/convert_pytorch_to_keras.py
@@ -28,7 +28,7 @@ if __name__ == "__main__":
         os.makedirs(args.output_dir)
     
     if args.classifier_checkpoint:
-        classifier_inputs = keras.Input(shape=(1, (args.input_image_height - args.kernel_size) // args.stride + 1, (args.input_image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+        classifier_inputs = keras.Input(shape=(1, args.input_image_height // args.stride, args.input_image_width // args.stride, args.num_kernels))
 
         if args.dataset == 'pnb':
             classifier_outputs = PNBClassifier()(classifier_inputs)
@@ -52,10 +52,10 @@ if __name__ == "__main__":
         classifier_model.save(os.path.join(args.output_dir, "classifier.pt"))
         
     if args.sparse_checkpoint:
-        input_shape = [1, (args.input_image_height - args.kernel_size) // args.stride + 1, (args.input_image_width - args.kernel_size) // args.stride + 1, args.num_kernels]
+        input_shape = [1, args.input_image_height // args.stride , args.input_image_width // args.stride, args.num_kernels]
         recon_inputs = keras.Input(shape=input_shape)
     
-        recon_outputs = ReconSparse(batch_size=1, image_height=args.input_image_height, image_width=args.input_image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+        recon_outputs = ReconSparse(batch_size=1, image_height=args.input_image_height, image_width=args.input_image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='SAME')(recon_inputs)
 
         recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
         
diff --git a/sparse_coding_torch/generate_tflite.py b/sparse_coding_torch/generate_tflite.py
index 53a69a09016f52b46fd1618b65252c25ddc0ab2e..9903f002392c96c0d172a3a1a65edfd46cb53f1f 100644
--- a/sparse_coding_torch/generate_tflite.py
+++ b/sparse_coding_torch/generate_tflite.py
@@ -7,35 +7,53 @@ import torchvision as tv
 import torch
 import torch.nn as nn
 from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
-from sparse_coding_torch.keras_model import MobileModel
-
-inputs = keras.Input(shape=(5, 100, 200, 3))
-
-outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=2, lam=0.05, activation_lr=1e-1, max_activation_iter=100, run_2d=True)(inputs)
-
-model = keras.Model(inputs=inputs, outputs=outputs)
-
-
-pytorch_checkpoint = torch.load('../stride_2_100_iter.pt', map_location='cpu')['model_state_dict']
-conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].squeeze(2).swapaxes(0, 2).swapaxes(1, 3).swapaxes(2, 3).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
-model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
-ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
-model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
-ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
-model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
-
-input_name = model.input_names[0]
-index = model.input_names.index(input_name)
-model.inputs[index].set_shape([1, 5, 100, 200, 3])
-
-converter = tf.lite.TFLiteConverter.from_keras_model(model)
-converter.optimizations = [tf.lite.Optimize.DEFAULT]
-converter.target_spec.supported_types = [tf.float16]
-converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
-
-tflite_model = converter.convert()
-
-print('Converted')
-
-with open("./mobile_output/tf_lite_model.tflite", "wb") as f:
-    f.write(tflite_model)
+from sparse_coding_torch.keras_model import MobileModelPNB
+import argparse
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_dir', default='/shared_data/bamc_pnb_data/full_training_data', type=str)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=32, type=int)
+    parser.add_argument('--stride', default=2, type=int)
+    parser.add_argument('--max_activation_iter', default=100, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--sparse_checkpoint', default='sparse_coding_torch/output/sparse_pnb_32/sparse_conv3d_model-best.pt/', type=str)
+    parser.add_argument('--checkpoint', default='sparse_coding_torch/classifier_outputs/32_filters/best_classifier.pt/', type=str)
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--batch_size', default=1, type=int)
+    
+    args = parser.parse_args()
+    #print(args.accumulate(args.integers))
+    batch_size = args.batch_size
+
+    image_height = 285
+    image_width = 235
+    
+    recon_model = keras.models.load_model(args.sparse_checkpoint)
+        
+    classifier_model = keras.models.load_model(args.checkpoint)
+
+    inputs = keras.Input(shape=(5, image_height, image_width))
+
+    outputs = MobileModelPNB(sparse_weights=recon_model.weights[0], classifier_model=classifier_model, batch_size=batch_size, image_height=image_height, image_width=image_width, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=True)(inputs)
+
+    model = keras.Model(inputs=inputs, outputs=outputs)
+
+    input_name = model.input_names[0]
+    index = model.input_names.index(input_name)
+    model.inputs[index].set_shape([1, 5, image_height, image_width])
+
+    converter = tf.lite.TFLiteConverter.from_keras_model(model)
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    converter.target_spec.supported_types = [tf.float16]
+    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+
+    tflite_model = converter.convert()
+
+    print('Converted')
+
+    with open("./sparse_coding_torch/mobile_output/pnb.tflite", "wb") as f:
+        f.write(tflite_model)
diff --git a/sparse_coding_torch/keras_model.py b/sparse_coding_torch/keras_model.py
index 3f00ae302197225f660a9c24b851927964c5e3d4..4850759c1db56aad56bfcd1a3121cb3749acca5b 100644
--- a/sparse_coding_torch/keras_model.py
+++ b/sparse_coding_torch/keras_model.py
@@ -15,36 +15,42 @@ def load_pytorch_weights(file_path):
     return weight_tensor
 
 # @tf.function
-def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, image_height, image_width, stride):
+def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, image_height, image_width, stride, padding='VALID'):
     batch_size = tf.shape(activations)[0]
-    out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
-    out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
-    out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
-    out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
-    out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
+    out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding=padding)
+    out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding=padding)
+    out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding=padding)
+    out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding=padding)
+    out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding=padding)
 
     recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
 
     return recon
 
 # @tf.function
-def do_recon_3d(filters, activations, image_height, image_width, stride):
+def do_recon_3d(filters, activations, image_height, image_width, stride, padding='VALID'):
 #     activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
     batch_size = tf.shape(activations)[0]
-    recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, image_height, image_width, 1), strides=[1, stride, stride], padding='VALID')
+    if padding == 'SAME':
+        activations = tf.pad(activations, paddings=[[0,0],[2,2],[0,0],[0,0],[0,0]])
+    recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, image_height, image_width, 1), strides=[1, stride, stride], padding=padding)
 
     return recon
 
 # @tf.function
-def conv_error(filters, e, stride):
-    g = tf.nn.conv2d(e, filters, strides=stride, padding='VALID')
+def conv_error(filters, e, stride, padding='VALID'):
+    g = tf.nn.conv2d(e, filters, strides=stride, padding=padding)
 
     return g
 
 @tf.function
-def conv_error_3d(filters, e, stride):
+def conv_error_3d(filters, e, stride, padding='VALID'):
 #     e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]])
-    g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID')
+    if padding == 'SAME':
+        e = tf.pad(e, paddings=[[0,0], [0,0], [7,7], [7,7], [0,0]])
+        g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID')
+    else:
+        g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding=padding)
 
     return g
 
@@ -78,7 +84,7 @@ def normalize_weights_3d(filters, out_channels):
     return adjusted
 
 class SparseCode(keras.layers.Layer):
-    def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+    def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d, padding='VALID'):
         super(SparseCode, self).__init__()
 
         self.out_channels = out_channels
@@ -93,28 +99,29 @@ class SparseCode(keras.layers.Layer):
         self.image_width = image_width
         self.kernel_size = kernel_size
         self.run_2d = run_2d
+        self.padding = padding
 
 #     @tf.function
     def do_update(self, images, filters, u, m, v, b1, b2, eps, i):
         activations = tf.nn.relu(u - self.lam)
 
         if self.run_2d:
-            recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.image_height, self.image_width, self.stride)
+            recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.image_height, self.image_width, self.stride, self.padding)
         else:
-            recon = do_recon_3d(filters, activations, self.image_height, self.image_width, self.stride)
+            recon = do_recon_3d(filters, activations, self.image_height, self.image_width, self.stride, self.padding)
 
         e = images - recon
         g = -1 * u
 
         if self.run_2d:
             e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
-            g += conv_error(filters[0], e1, self.stride)
-            g += conv_error(filters[1], e2, self.stride)
-            g += conv_error(filters[2], e3, self.stride)
-            g += conv_error(filters[3], e4, self.stride)
-            g += conv_error(filters[4], e5, self.stride)
+            g += conv_error(filters[0], e1, self.stride, self.padding)
+            g += conv_error(filters[1], e2, self.stride, self.padding)
+            g += conv_error(filters[2], e3, self.stride, self.padding)
+            g += conv_error(filters[3], e4, self.stride, self.padding)
+            g += conv_error(filters[4], e5, self.stride, self.padding)
         else:
-            convd_error = conv_error_3d(filters, e, self.stride)
+            convd_error = conv_error_3d(filters, e, self.stride, self.padding)
 
             g = g + convd_error
 
@@ -135,11 +142,18 @@ class SparseCode(keras.layers.Layer):
 
 #     @tf.function
     def call(self, images, filters):
-        filters = tf.squeeze(filters, axis=0)
-        if self.run_2d:
-            output_shape = (len(images), (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
+        if not self.run_2d:
+            filters = tf.squeeze(filters, axis=0)
+        if self.padding == 'SAME':
+            if self.run_2d:
+                output_shape = (len(images), self.image_height // self.stride, self.image_width // self.stride, self.out_channels)
+            else:
+                output_shape = (len(images), 1, self.image_height // self.stride, self.image_width // self.stride, self.out_channels)
         else:
-            output_shape = (len(images), 1, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
+            if self.run_2d:
+                output_shape = (len(images), (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
+            else:
+                output_shape = (len(images), 1, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
 
         u = tf.stop_gradient(tf.zeros(shape=output_shape))
         m = tf.stop_gradient(tf.zeros(shape=output_shape))
@@ -162,7 +176,7 @@ class SparseCode(keras.layers.Layer):
         return u
     
 class ReconSparse(keras.Model):
-    def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+    def __init__(self, batch_size, image_height, image_width, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d, padding='VALID'):
         super().__init__()
         
         self.out_channels = out_channels
@@ -175,6 +189,7 @@ class ReconSparse(keras.Model):
         self.image_height = image_height
         self.image_width = image_width
         self.run_2d = run_2d
+        self.padding = padding
 
         initializer = tf.keras.initializers.HeNormal()
         if run_2d:
@@ -197,9 +212,9 @@ class ReconSparse(keras.Model):
 #     @tf.function
     def call(self, activations):
         if self.run_2d:
-            recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.image_height, self.image_width, self.stride)
+            recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.image_height, self.image_width, self.stride, self.padding)
         else:
-            recon = do_recon_3d(self.filters, activations, self.image_height, self.image_width, self.stride)
+            recon = do_recon_3d(self.filters, activations, self.image_height, self.image_width, self.stride, self.padding)
             
         return recon
 
@@ -241,9 +256,11 @@ class PNBClassifier(keras.layers.Layer):
     def __init__(self):
         super(PNBClassifier, self).__init__()
 
-        self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
-        self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
-#         self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
+        self.max_pool = keras.layers.MaxPooling2D(pool_size=8, strides=8)
+#         self.conv_1 = keras.layers.Conv2D(32, kernel_size=8, strides=4, activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv2D(48, kernel_size=4, strides=2, activation='relu', padding='valid')
+#         self.conv_3 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
+#         self.conv_4 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
 
         self.flatten = keras.layers.Flatten()
 
@@ -258,8 +275,10 @@ class PNBClassifier(keras.layers.Layer):
     def call(self, activations):
         activations = tf.squeeze(activations, axis=1)
         x = self.max_pool(activations)
-        x = self.conv_1(x)
-#         x = self.conv_2(x)
+#         x = self.conv_1(x)
+        x = self.conv_2(x)
+#         x = self.conv_3(x)
+#         x = self.conv_4(x)
         x = self.flatten(x)
 #         x = self.ff_1(x)
 #         x = self.dropout(x)
@@ -271,7 +290,7 @@ class PNBClassifier(keras.layers.Layer):
 
         return x
 
-class MobileModel(keras.Model):
+class MobileModelPTX(keras.Model):
     def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
         super().__init__()
         self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
@@ -313,3 +332,43 @@ class MobileModel(keras.Model):
         pred = self.classifier(activations)
 
         return pred
+    
+class MobileModelPNB(keras.Model):
+    def __init__(self, sparse_weights, classifier_model, batch_size, image_height, image_width, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+        super().__init__()
+        self.sparse_code = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=out_channels, kernel_size=kernel_size, stride=stride, lam=lam, activation_lr=activation_lr, max_activation_iter=max_activation_iter, run_2d=run_2d, padding='VALID')
+        self.classifier = classifier_model
+
+        self.out_channels = out_channels
+        self.stride = stride
+        self.lam = lam
+        self.activation_lr = activation_lr
+        self.max_activation_iter = max_activation_iter
+        self.batch_size = batch_size
+        self.run_2d = run_2d
+        
+        if run_2d:
+            weight_list = np.split(sparse_weights, 5, axis=0)
+            self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
+            self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False)
+            self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False)
+            self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False)
+            self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False)
+        else:
+            self.filters = tf.Variable(initial_value=sparse_weights, dtype='float32', trainable=False)
+
+    @tf.function
+    def call(self, images):
+#         images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1)
+        images = tf.transpose(images, perm=[0, 2, 3, 1])
+        images = images / 255
+
+        if self.run_2d:
+            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
+            activations = tf.expand_dims(activations, axis=1)
+        else:
+            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
+
+        pred = tf.math.sigmoid(self.classifier(activations))
+
+        return pred
diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/load_data.py
index 4dd4228aa5445830df110cd51c3db56b27671895..9caa3549201dd0f8f8e9abfed1fc9a48b97d96f9 100644
--- a/sparse_coding_torch/load_data.py
+++ b/sparse_coding_torch/load_data.py
@@ -7,7 +7,7 @@ from sparse_coding_torch.video_loader import YoloClipLoader, get_ptx_participant
 from sparse_coding_torch.video_loader import VideoGrayScaler
 from typing import Sequence, Iterator
 import csv
-from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold
+from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit
 
 def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None):   
     video_path = "/shared_data/YOLO_Updated_PL_Model_Results/"
@@ -69,6 +69,25 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non
         
         return train_loader, test_loader, dataset
     
+def get_sample_weights(train_idx, dataset):
+    dataset = list(dataset)
+
+    num_positive = len([clip[0] for clip in dataset if clip[0] == 'Positives'])
+    negative_weight = num_positive / len(dataset)
+    positive_weight = 1.0 - negative_weight
+    
+    weights = []
+    for idx in train_idx:
+        label = dataset[idx][0]
+        if label == 'Positives':
+            weights.append(positive_weight)
+        elif label == 'Negatives':
+            weights.append(negative_weight)
+        else:
+            raise Exception('Sampler encountered invalid label')
+    
+    return weights
+
 class SubsetWeightedRandomSampler(torch.utils.data.Sampler[int]):
     weights: torch.Tensor
     num_samples: int
@@ -92,18 +111,19 @@ class SubsetWeightedRandomSampler(torch.utils.data.Sampler[int]):
     def __len__(self) -> int:
         return len(self.indicies)
     
-def load_pnb_videos(batch_size, mode=None, classify_mode=False, balance_classes=False, device=None, n_splits=None, sparse_model=None):   
+def load_pnb_videos(batch_size, input_size, mode=None, classify_mode=False, balance_classes=False, device=None, n_splits=None, sparse_model=None):   
     video_path = "/shared_data/bamc_pnb_data/full_training_data"
     
     transforms = torchvision.transforms.Compose(
     [VideoGrayScaler(),
      MinMaxScaler(0, 255),
-     torchvision.transforms.Resize((250, 600))
+     torchvision.transforms.Resize(input_size)
     ])
     augment_transforms = torchvision.transforms.Compose(
-    [torchvision.transforms.RandomRotation(30),
+    [torchvision.transforms.RandomRotation(20),
      torchvision.transforms.RandomHorizontalFlip(),
-     torchvision.transforms.ColorJitter(brightness=0.1),
+#      torchvision.transforms.RandomVerticalFlip(),
+     torchvision.transforms.ColorJitter(brightness=0.1),     
 #      torchvision.transforms.RandomAdjustSharpness(0, p=0.15),
      torchvision.transforms.RandomAffine(degrees=0, translate=(0.01, 0))
 #      torchvision.transforms.CenterCrop((100, 200))
@@ -133,13 +153,17 @@ def load_pnb_videos(batch_size, mode=None, classify_mode=False, balance_classes=
         
         return gss.split(np.arange(len(targets)), targets, groups), dataset
     else:
+#         gss = ShuffleSplit(n_splits=n_splits, test_size=0.2)
         gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
 
         groups = get_pnb_participants(dataset.get_filenames())
         
         train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0]
+#         train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets))[0]
+
         
         train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+#         train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
         train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=train_sampler)
         
diff --git a/sparse_coding_torch/train_classifier.py b/sparse_coding_torch/train_classifier.py
index c78d6fd11d1b160f3a04d3e77432ed71af88be5b..8f0276cac663126541cda57ca90ce2d5db0c775d 100644
--- a/sparse_coding_torch/train_classifier.py
+++ b/sparse_coding_torch/train_classifier.py
@@ -20,24 +20,6 @@ configproto.gpu_options.allow_growth = True
 sess = tf.compat.v1.Session(config=configproto) 
 tf.compat.v1.keras.backend.set_session(sess)
 
-def get_sample_weights(train_idx, dataset):
-    dataset = list(dataset)
-
-    num_positive = len([clip[0] for clip in dataset if clip[0] == 'Positives'])
-    negative_weight = num_positive / len(dataset)
-    positive_weight = 1.0 - negative_weight
-    
-    weights = []
-    for idx in train_idx:
-        label = dataset[idx][0]
-        if label == 'Positives':
-            weights.append(positive_weight)
-        elif label == 'Negatives':
-            weights.append(negative_weight)
-        else:
-            raise Exception('Sampler encountered invalid label')
-    
-    return weights
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
@@ -48,7 +30,7 @@ if __name__ == "__main__":
     parser.add_argument('--stride', default=1, type=int)
     parser.add_argument('--max_activation_iter', default=150, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
-    parser.add_argument('--lr', default=5e-2, type=float)
+    parser.add_argument('--lr', default=5e-4, type=float)
     parser.add_argument('--epochs', default=40, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
     parser.add_argument('--output_dir', default='./output', type=str)
@@ -67,8 +49,8 @@ if __name__ == "__main__":
     args = parser.parse_args()
     
     if args.dataset == 'pnb':
-        image_height = 250
-        image_width = 600
+        image_height = 285
+        image_width = 235
     elif args.dataset == 'ptx':
         image_height = 100
         image_width = 200
@@ -110,7 +92,9 @@ if __name__ == "__main__":
         
     positive_class = None
     if args.dataset == 'pnb':
-        train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None)
+        train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width), classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None)
+        print(len([labels for labels, _, _ in test_loader for label in labels if label == 'Positives']))
+        print(len(test_loader))
         positive_class = 'Positives'
     elif args.dataset == 'ptx':
         train_loader, test_loader, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
diff --git a/sparse_coding_torch/train_sparse_model.py b/sparse_coding_torch/train_sparse_model.py
index 76226db631272bc98e9b13aa1c04663c40522e5b..72439f899751f59867a961c76b3b7e835e10be2a 100644
--- a/sparse_coding_torch/train_sparse_model.py
+++ b/sparse_coding_torch/train_sparse_model.py
@@ -95,20 +95,20 @@ def sparse_loss(recon, activations, batch_size, lam, stride):
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--batch_size', default=6, type=int)
+    parser.add_argument('--batch_size', default=32, type=int)
     parser.add_argument('--kernel_size', default=15, type=int)
-    parser.add_argument('--num_kernels', default=64, type=int)
-    parser.add_argument('--stride', default=2, type=int)
-    parser.add_argument('--max_activation_iter', default=100, type=int)
+    parser.add_argument('--num_kernels', default=16, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=200, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
-    parser.add_argument('--lr', default=5e-2, type=float)
-    parser.add_argument('--epochs', default=20, type=int)
+    parser.add_argument('--lr', default=0.003, type=float)
+    parser.add_argument('--epochs', default=40, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
     parser.add_argument('--output_dir', default='./output', type=str)
     parser.add_argument('--seed', default=42, type=int)
     parser.add_argument('--run_2d', action='store_true')
     parser.add_argument('--save_filters', action='store_true')
-    parser.add_argument('--optimizer', default='adam', type=str)
+    parser.add_argument('--optimizer', default='sgd', type=str)
     parser.add_argument('--dataset', default='pnb', type=str)
     
 
@@ -121,8 +121,8 @@ if __name__ == "__main__":
 #     policy = keras.mixed_precision.Policy('mixed_float16')
 #     keras.mixed_precision.set_global_policy(policy)
 
-    image_height = 360
-    image_width = 304
+    image_height = 285
+    image_width = 235
 
     output_dir = args.output_dir
     if not os.path.exists(output_dir):
@@ -134,7 +134,7 @@ if __name__ == "__main__":
         out_f.write(str(args))
 
     if args.dataset == 'pnb':
-        train_loader, _ = load_pnb_videos(args.batch_size, classify_mode=False, mode='all_train', device=device, n_splits=1, sparse_model=None)
+        train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width), classify_mode=False, balance_classes=False, mode='all_train')
     elif args.dataset == 'ptx':
         train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json')
     else:
diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py
index c0a6bab61a4a5d21a35be424d9d7cb75b8f52011..e7060d1c76d90271aa0ce7252b4a964e1ea99104 100644
--- a/sparse_coding_torch/video_loader.py
+++ b/sparse_coding_torch/video_loader.py
@@ -76,39 +76,93 @@ def load_pnb_region_labels(file_path):
             
         return all_regions
     
-def get_yolo_regions(yolo_model, clip):
+def get_yolo_regions(yolo_model, clip, is_right):
     orig_height = clip.size(2)
     orig_width = clip.size(3)
-    bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy()).squeeze(0)
+    bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy())
+    bounding_boxes = bounding_boxes.squeeze(0)
+    classes = classes.squeeze(0)
     
     all_clips = []
-    for bb in bounding_boxes:
-        center_x = (bb[3] + bb[1]) / 2 * orig_width
-        center_y = (bb[2] + bb[0]) / 2 * orig_height
-
-        width_left = 400
-        width_right = 400
-        height_top = 200
-        height_bottom = 50
-
-        lower_y = round(center_y - height_top)
-        upper_y = round(center_y + height_bottom)
-        lower_x = round(center_x - width_left)
-        upper_x = round(center_x + width_right)
+    for bb, class_pred in zip(bounding_boxes, classes):
+        if class_pred != 0:
+            continue
+        center_x = round((bb[3] + bb[1]) / 2 * orig_width)
+        center_y = round((bb[2] + bb[0]) / 2 * orig_height)
+        
+        lower_y = round((bb[0] * orig_height))
+        upper_y = round((bb[2] * orig_height))
+        lower_x = round((bb[1] * orig_width))
+        upper_x = round((bb[3] * orig_width))
+        
+        lower_y = upper_y - 285
+        
+        if is_right:
+            lower_x = center_x - 235
+            upper_x = center_x
+        else:
+            lower_x = center_x
+            upper_x = center_x + 235
 
         trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
         
-#         print(trimmed_clip.size())
-        
+#         if orig_width - center_x >= center_x:
+#         if not is_right:
 #         cv2.imwrite('test.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
 #         cv2.imwrite('test_yolo.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
 #         raise Exception
         
+#         print(trimmed_clip.size())
+        
         if trimmed_clip.shape[2] == 0 or trimmed_clip.shape[3] == 0:
             continue
         all_clips.append(trimmed_clip)
 
     return all_clips
+
+def classify_nerve_is_right(yolo_model, video):
+    orig_height = video.size(2)
+    orig_width = video.size(3)
+
+    all_preds = []
+    if video.size(1) < 10:
+        return 1
+
+    for frame in range(0, video.size(1), round(video.size(1) / 10)):
+        frame = video[:, frame, :, :]
+        bounding_boxes, classes = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
+        bounding_boxes = bounding_boxes.squeeze(0)
+        classes = classes.squeeze(0)
+    
+        for bb, class_pred in zip(bounding_boxes, classes):
+            if class_pred != 2:
+                continue
+            center_x = (bb[3] + bb[1]) / 2 * orig_width
+            center_y = (bb[2] + bb[0]) / 2 * orig_height
+
+            if orig_width - center_x < center_x:
+                all_preds.append(0)
+            else:
+                all_preds.append(1)
+        
+        if not all_preds:
+            for bb, class_pred in zip(bounding_boxes, classes):
+                if class_pred != 1:
+                    continue
+                center_x = (bb[3] + bb[1]) / 2 * orig_width
+                center_y = (bb[2] + bb[0]) / 2 * orig_height
+
+                if orig_width - center_x < center_x:
+                    all_preds.append(1)
+                else:
+                    all_preds.append(0)
+                    
+        if not all_preds:
+            all_preds.append(1)
+                
+    final_pred = round(sum(all_preds) / len(all_preds))
+
+    return final_pred == 1
                 
     
 class PNBLoader(Dataset):
@@ -146,6 +200,8 @@ class PNBLoader(Dataset):
             vid_idx = 0
             for label, path, _ in tqdm(self.videos):
                 vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
+                is_right = classify_nerve_is_right(yolo_model, vc)
+
                 if classify_mode:
                     person_idx = path.split('/')[-2]
 
@@ -165,13 +221,16 @@ class PNBLoader(Dataset):
                                     
 #                                 cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
 
-                                for clip in get_yolo_regions(yolo_model, vc_sub):
-                                    if self.transform:
-                                        clip = self.transform(clip)
+                                for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+#                                     if self.transform:
+#                                         clip = self.transform(clip)
                                         
 #                                     print(clip[0, 0, :, :].size())
 #                                     cv2.imwrite('test_yolo.png', clip[0, 0, :, :].unsqueeze(2).numpy())
-#                                     raise Exception
+#                                     print(clip.shape)
+#                                     tv.io.write_video('test_yolo.mp4', clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                                    print(path)
+                                    raise Exception
 
                                     self.clips.append(('Negatives', clip, self.videos[vid_idx][2]))
 
@@ -184,7 +243,7 @@ class PNBLoader(Dataset):
                                     if vc_sub.size(1) < 5:
                                         continue
                                         
-                                    for clip in get_yolo_regions(yolo_model, vc_sub):
+                                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
                                         if self.transform:
                                             clip = self.transform(clip)
 
@@ -200,7 +259,7 @@ class PNBLoader(Dataset):
 #                                         cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
                                         if vc_sub.size(1) < 5:
                                             continue
-                                        for clip in get_yolo_regions(yolo_model, vc_sub):
+                                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
                                             if self.transform:
                                                 clip = self.transform(clip)
                                                 
@@ -212,7 +271,7 @@ class PNBLoader(Dataset):
                         vc_sub = vc[:, -5:, :, :]
                         if vc_sub.size(1) < 5:
                             continue
-                        for clip in get_yolo_regions(yolo_model, vc_sub):
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
                             if self.transform:
                                 clip = self.transform(clip)
 
@@ -222,7 +281,7 @@ class PNBLoader(Dataset):
                             vc_sub = vc[:, j:j+5, :, :]
                             if vc_sub.size(1) < 5:
                                 continue
-                            for clip in get_yolo_regions(yolo_model, vc_sub):
+                            for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
                                 if self.transform:
                                     clip = self.transform(clip)
 
@@ -234,10 +293,11 @@ class PNBLoader(Dataset):
                         vc_sub = vc[:, j:j+5, :, :]
                         if vc_sub.size(1) < 5:
                             continue
-                        if self.transform:
-                            vc_sub = self.transform(vc_sub)
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                            if self.transform:
+                                clip = self.transform(clip)
 
-                        self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2]))
+                            self.clips.append((self.videos[vid_idx][0], clip, self.videos[vid_idx][2]))
 
                 self.final_clips[self.videos[vid_idx][2]] = self.clips[-1]
                 vid_idx += 1