diff --git a/generate_yolo_regions.py b/generate_yolo_regions.py
index 864c493833e2006a5c0670c056c01e5d07d7e39b..8dc6f8676d142a7127b485492899d3d27d9375d9 100644
--- a/generate_yolo_regions.py
+++ b/generate_yolo_regions.py
@@ -22,6 +22,8 @@ if __name__ == "__main__":
     parser.add_argument('--input_video', required=True, type=str, help='Path to input video.')
     parser.add_argument('--output_dir', default='yolo_output', type=str, help='Location where yolo clips should be saved.')
     parser.add_argument('--num_frames', default=5, type=int)
+    parser.add_argument('--image_height', default=285, type=int)
+    parser.add_argument('--image_width', default=235, type=int)
     
     args = parser.parse_args()
     
@@ -32,8 +34,8 @@ if __name__ == "__main__":
     if not os.path.exists(args.output_dir):
         os.makedirs(args.output_dir)
 
-    image_height = 285
-    image_width = 235
+    image_height = args.image_height
+    image_width = args.image_width
     
     # For some reason the size has to be even for the clips, so it will add one if the size is odd
     transforms = torchvision.transforms.Compose([
@@ -60,7 +62,7 @@ if __name__ == "__main__":
                 if vc_sub.size(1) < 5:
                     continue
 
-                for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
                     clip = transforms(clip)
                     tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
                     output_count += 1
@@ -74,7 +76,7 @@ if __name__ == "__main__":
                     if vc_sub.size(1) < 5:
                         continue
 
-                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
                         clip = transforms(clip)
                         tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
                         output_count += 1
@@ -89,14 +91,14 @@ if __name__ == "__main__":
 #                                         cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
                         if vc_sub.size(1) < 5:
                             continue
-                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
                             clip = transforms(clip)
                             tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
                             output_count += 1
     elif label == 'Positives':
         vc_sub = vc[:, -5:, :, :]
         if not vc_sub.size(1) < 5:
-            for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+            for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
                 clip = transforms(clip)
                 tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
                 output_count += 1
@@ -104,7 +106,7 @@ if __name__ == "__main__":
         for j in range(0, vc.size(1) - 5, 5):
             vc_sub = vc[:, j:j+5, :, :]
             if not vc_sub.size(1) < 5:
-                for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
                     clip = transforms(clip)
                     tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
                     output_count += 1
diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/load_data.py
index 9caa3549201dd0f8f8e9abfed1fc9a48b97d96f9..967d60f1fa05bc6ccda6792a331124b3d0adb14a 100644
--- a/sparse_coding_torch/load_data.py
+++ b/sparse_coding_torch/load_data.py
@@ -128,7 +128,7 @@ def load_pnb_videos(batch_size, input_size, mode=None, classify_mode=False, bala
      torchvision.transforms.RandomAffine(degrees=0, translate=(0.01, 0))
 #      torchvision.transforms.CenterCrop((100, 200))
     ])
-    dataset = PNBLoader(video_path, classify_mode, balance_classes=balance_classes, num_frames=5, frame_rate=20, transform=transforms, augmentation=augment_transforms)
+    dataset = PNBLoader(video_path, input_size[1], input_size[0], classify_mode, balance_classes=balance_classes, num_frames=5, frame_rate=20, transform=transforms, augmentation=augment_transforms)
     
     targets = dataset.get_labels()
     
diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py
index a855e5e472742ac1f750b89c99fb56428f8e1ccd..9908d6e7ddbce0a413424a5b04bb4042d5d53b6a 100644
--- a/sparse_coding_torch/video_loader.py
+++ b/sparse_coding_torch/video_loader.py
@@ -76,7 +76,7 @@ def load_pnb_region_labels(file_path):
             
         return all_regions
     
-def get_yolo_regions(yolo_model, clip, is_right):
+def get_yolo_regions(yolo_model, clip, is_right, crop_width, crop_height):
     orig_height = clip.size(2)
     orig_width = clip.size(3)
     bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy())
@@ -95,14 +95,14 @@ def get_yolo_regions(yolo_model, clip, is_right):
         lower_x = round((bb[1] * orig_width))
         upper_x = round((bb[3] * orig_width))
         
-        lower_y = upper_y - 285
+        lower_y = upper_y - crop_height
         
         if is_right:
-            lower_x = center_x - 470
+            lower_x = center_x - crop_width
             upper_x = center_x
         else:
             lower_x = center_x
-            upper_x = center_x + 470
+            upper_x = center_x + crop_width
 
         trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
         
@@ -167,13 +167,13 @@ def classify_nerve_is_right(yolo_model, video):
     
 class PNBLoader(Dataset):
     
-    def __init__(self, video_path, classify_mode=False, balance_classes=False, num_frames=5, frame_rate=20, frames_between_clips=None, transform=None, augmentation=None):
+    def __init__(self, video_path, clip_width, clip_height, classify_mode=False, balance_classes=False, num_frames=5, frame_rate=20, frames_between_clips=None, transform=None, augmentation=None):
         self.transform = transform
         self.augmentation = augmentation
         self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
         
-        clip_cache_file = 'clip_cache_pnb_double.pt'
-        clip_cache_final_file = 'clip_cache_pnb_final_double.pt'
+        clip_cache_file = 'clip_cache_pnb_{}_{}.pt'.format(clip_width, clip_height)
+        clip_cache_final_file = 'clip_cache_pnb_{}_{}_final.pt'.format(clip_width, clip_height)
         
         region_labels = load_pnb_region_labels(join(video_path, 'sme_region_labels.csv'))
 
@@ -221,7 +221,7 @@ class PNBLoader(Dataset):
                                     
 #                                 cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
 
-                                for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
                                     if self.transform:
                                         clip = self.transform(clip)
                                         
@@ -243,7 +243,7 @@ class PNBLoader(Dataset):
                                     if vc_sub.size(1) < 5:
                                         continue
                                         
-                                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
                                         if self.transform:
                                             clip = self.transform(clip)
 
@@ -259,7 +259,7 @@ class PNBLoader(Dataset):
 #                                         cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
                                         if vc_sub.size(1) < 5:
                                             continue
-                                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
                                             if self.transform:
                                                 clip = self.transform(clip)
                                                 
@@ -271,7 +271,7 @@ class PNBLoader(Dataset):
                         vc_sub = vc[:, -5:, :, :]
                         if vc_sub.size(1) < 5:
                             continue
-                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
                             if self.transform:
                                 clip = self.transform(clip)
 
@@ -281,7 +281,7 @@ class PNBLoader(Dataset):
                             vc_sub = vc[:, j:j+5, :, :]
                             if vc_sub.size(1) < 5:
                                 continue
-                            for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                            for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
                                 if self.transform:
                                     clip = self.transform(clip)
 
@@ -293,7 +293,7 @@ class PNBLoader(Dataset):
                         vc_sub = vc[:, j:j+5, :, :]
                         if vc_sub.size(1) < 5:
                             continue
-                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right):
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
                             if self.transform:
                                 clip = self.transform(clip)