diff --git a/generate_yolo_regions.py b/generate_yolo_regions.py
index ceed5ca92273c08073b8d5c09b965053ec0d538e..a0351cdb1fd2125bd3cfe1947f2ff0e9a3544d9a 100644
--- a/generate_yolo_regions.py
+++ b/generate_yolo_regions.py
@@ -3,7 +3,7 @@ import os
 import time
 import numpy as np
 import torchvision
-from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels, calculate_angle
 from torchvision.datasets.video_utils import VideoClips
 import torchvision as tv
 import csv
@@ -16,26 +16,30 @@ import cv2
 import tensorflow.keras as keras
 from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse
 import glob
+from sparse_coding_torch.train_sparse_model import plot_video
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--input_video', required=True, type=str, help='Path to input video.')
     parser.add_argument('--output_dir', default='yolo_output', type=str, help='Location where yolo clips should be saved.')
     parser.add_argument('--num_frames', default=5, type=int)
-    parser.add_argument('--image_height', default=285, type=int)
-    parser.add_argument('--image_width', default=235, type=int)
+    parser.add_argument('--stride', default=5, type=int)
+    parser.add_argument('--image_height', default=150, type=int)
+    parser.add_argument('--image_width', default=400, type=int)
     
     args = parser.parse_args()
     
     path = args.input_video
         
-    region_labels = load_pnb_region_labels(os.path.join('/'.join(path.split('/')[:-3]), 'sme_region_labels.csv'))
+    region_labels = load_pnb_region_labels('sme_region_labels.csv')
                         
     if not os.path.exists(args.output_dir):
         os.makedirs(args.output_dir)
 
     image_height = args.image_height
     image_width = args.image_width
+    clip_depth = args.num_frames
+    frames_to_skip = args.stride
     
     # For some reason the size has to be even for the clips, so it will add one if the size is odd
     transforms = torchvision.transforms.Compose([
@@ -46,69 +50,97 @@ if __name__ == "__main__":
         
     vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
     is_right = classify_nerve_is_right(yolo_model, vc)
+    angle = calculate_angle(yolo_model, vc)
     person_idx = path.split('/')[-2]
     label = path.split('/')[-3]
     
     output_count = 0
-
+    
     if label == 'Positives' and person_idx in region_labels:
         negative_regions, positive_regions = region_labels[person_idx]
         for sub_region in negative_regions.split(','):
             sub_region = sub_region.split('-')
             start_loc = int(sub_region[0])
+#                             end_loc = int(sub_region[1]) - 50
             end_loc = int(sub_region[1]) + 1
-            for frame in range(start_loc, end_loc - args.num_frames, args.num_frames):
-                vc_sub = vc[:, frame:frame+args.num_frames, :, :]
-                if vc_sub.size(1) < args.num_frames:
+            for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth):
+                frames = []
+                for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
+                    frames.append(vc[:, k, :, :])
+                vc_sub = torch.stack(frames, dim=1)
+
+                if vc_sub.size(1) < clip_depth:
                     continue
 
-                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
                     clip = transforms(clip)
-                    tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                    ani = plot_video(clip)
+                    ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'))
                     output_count += 1
 
         if positive_regions:
             for sub_region in positive_regions.split(','):
                 sub_region = sub_region.split('-')
+#                                 start_loc = int(sub_region[0]) + 15
                 start_loc = int(sub_region[0])
-                if len(sub_region) == 1:
-                    vc_sub = vc[:, start_loc:start_loc+args.num_frames, :, :]
-                    if vc_sub.size(1) < args.num_frames:
+                if len(sub_region) == 1 and vc.size(1) >= start_loc + clip_depth * frames_to_skip:
+                    frames = []
+                    for k in range(start_loc, start_loc + clip_depth * frames_to_skip, frames_to_skip):
+                        frames.append(vc[:, k, :, :])
+                    vc_sub = torch.stack(frames, dim=1)
+
+                    if vc_sub.size(1) < clip_depth:
                         continue
 
-                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
+                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
                         clip = transforms(clip)
-                        tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                        ani = plot_video(clip)
+                        ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'))
                         output_count += 1
-                else:
+                elif vc.size(1) >= start_loc + clip_depth * frames_to_skip:
                     end_loc = sub_region[1]
                     if end_loc.strip().lower() == 'end':
                         end_loc = vc.size(1)
                     else:
                         end_loc = int(end_loc)
-                    for frame in range(start_loc, end_loc - args.num_frames, args.num_frames):
-                        vc_sub = vc[:, frame:frame+args.num_frames, :, :]
-#                                         cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
-                        if vc_sub.size(1) < args.num_frames:
+                    for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth):
+                        frames = []
+                        for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
+                            frames.append(vc[:, k, :, :])
+                        vc_sub = torch.stack(frames, dim=1)
+
+                        if vc_sub.size(1) < clip_depth:
                             continue
-                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
                             clip = transforms(clip)
-                            tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                            ani = plot_video(clip)
+                            ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'))
                             output_count += 1
+                else:
+                    continue
     elif label == 'Positives':
-        vc_sub = vc[:, -args.num_frames:, :, :]
-        if not vc_sub.size(1) < args.num_frames:
-            for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
-                clip = transforms(clip)
-                tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
-                output_count += 1
-    elif label == 'Negatives':
-        for j in range(0, vc.size(1) - args.num_frames, args.num_frames):
-            vc_sub = vc[:, j:j+args.num_frames, :, :]
-            if not vc_sub.size(1) < args.num_frames:
+        frames = []
+        for k in range(j, -1 * clip_depth * frames_to_skip, frames_to_skip):
+            frames.append(vc[:, k, :, :])
+        if frames:
+            vc_sub = torch.stack(frames, dim=1)
+            if vc_sub.size(1) >= clip_depth:
                 for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
                     clip = transforms(clip)
-                    tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20)
+                    ani = plot_video(clip)
+                    ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'))
+                    output_count += 1
+    elif label == 'Negatives':
+        for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth):
+            frames = []
+            for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
+                frames.append(vc[:, k, :, :])
+            vc_sub = torch.stack(frames, dim=1)
+            if vc_sub.size(1) >= clip_depth:
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
+                    clip = transforms(clip)
+                    ani = plot_video(clip)
+                    ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'))
                     output_count += 1
     else:
         raise Exception('Invalid label')
\ No newline at end of file
diff --git a/run_ptx.py b/run_ptx.py
index 5ee5656f7e1db3638e5352dbf1e8e171c506b8ac..1b927f46a452eabe5fda728b7638d15389bfc4f9 100644
--- a/run_ptx.py
+++ b/run_ptx.py
@@ -24,8 +24,8 @@ if __name__ == "__main__":
     parser.add_argument('--max_activation_iter', default=100, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
     parser.add_argument('--lam', default=0.05, type=float)
-    parser.add_argument('--sparse_checkpoint', default='converted_checkpoints/sparse.pt', type=str)
-    parser.add_argument('--checkpoint', default='converted_checkpoints/classifier.pt', type=str)
+    parser.add_argument('--sparse_checkpoint', default='ptx_tensorflow/sparse.pt', type=str)
+    parser.add_argument('--checkpoint', default='ptx_tensorflow/best_classifier.pt', type=str)
     parser.add_argument('--run_2d', action='store_true')
     parser.add_argument('--dataset', default='ptx', type=str)
     
@@ -33,12 +33,10 @@ if __name__ == "__main__":
     #print(args.accumulate(args.integers))
     batch_size = 1
     
-    if args.dataset == 'pnb':
-        image_height = 250
-        image_width = 600
-    elif args.dataset == 'ptx':
+    if args.dataset == 'ptx':
         image_height = 100
         image_width = 200
+        clip_depth = 5
     else:
         raise Exception('Invalid dataset')
 
@@ -49,15 +47,9 @@ if __name__ == "__main__":
 
     filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
 
-    output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='SAME')(inputs, filter_inputs)
+    output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='SAME')(inputs, filter_inputs)
 
     sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
-    
-    recon_inputs = keras.Input(shape=(1, image_height // args.stride, image_width // args.stride, args.num_kernels))
-    
-    recon_outputs = ReconSparse(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
-    
-    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
 
     if args.sparse_checkpoint:
         recon_model = keras.models.load_model(args.sparse_checkpoint)
diff --git a/sme_region_labels.csv b/sme_region_labels.csv
index c533ec9c715bf1da2cd30d320e034c5275fce9b9..5f584839b200a189c21252e9b83051480b7e261f 100644
--- a/sme_region_labels.csv
+++ b/sme_region_labels.csv
@@ -1,9 +1,11 @@
 idx,negative_regions,positive_regions
 11,"0-121,158-516","121-157,517-end"
+46,0-208,209-end
 54,0-397,397
 67,0-120,120-end
 93,"0-60,155-256","61-154,257-end"
 94,0-78,79-end
+125,0-372,373-end
 134,0-393,394-end
 153,0-200,201-end
 189,"0-122,123-184",185
@@ -12,6 +14,6 @@ idx,negative_regions,positive_regions
 211,0-86,87-end
 217,"0-86,87-137",138-end
 222,0-112,113-end
-230,0-7,
-238,"0-140,141-152",153-end)
-240,0-205,206-end
\ No newline at end of file
+230,0-75,76-end
+238,"0-140,141-152",153-end
+240,0-205,206-end
diff --git a/sparse_coding_torch/keras_model.py b/sparse_coding_torch/keras_model.py
index c73b1423eb1abb9e4b20d1cd4ffe21e88133acdd..86f015a7950113252f087f9ce4596b0a90c9cd0b 100644
--- a/sparse_coding_torch/keras_model.py
+++ b/sparse_coding_torch/keras_model.py
@@ -299,6 +299,45 @@ class PNBClassifier(keras.layers.Layer):
         x = self.ff_4(x)
 
         return x
+    
+class PNBTemporalClassifier(keras.layers.Layer):
+    def __init__(self):
+        super(PNBTemporalClassifier, self).__init__()
+        self.conv_1 = keras.layers.Conv2D(12, kernel_size=(150, 24), strides=(1, 8), activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv1D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
+        
+        self.ff_1 = keras.layers.Dense(100, activation='relu', use_bias=True)
+        
+        self.gru = keras.layers.GRU(25)
+
+        self.flatten = keras.layers.Flatten()
+
+        self.ff_2 = keras.layers.Dense(10, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(1)
+
+#     @tf.function
+    def call(self, clip):
+        width = clip.shape[3]
+        height = clip.shape[2]
+        depth = clip.shape[1]
+        
+        x = tf.expand_dims(clip, axis=4)
+        x = tf.reshape(clip, (-1, height, width, 1))
+
+        x = self.conv_1(x)
+        x = tf.squeeze(x, axis=1)
+        x = self.conv_2(x)
+
+        x = self.flatten(x)
+        x = self.ff_1(x)
+
+        x = tf.reshape(x, (-1, 5, 100))
+        x = self.gru(x)
+        
+        x = self.ff_2(x)
+        x = self.ff_3(x)
+
+        return x
 
 class MobileModelPTX(keras.Model):
     def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/load_data.py
index 96ef3fa885a4a84dfba9f3966ca5a67b75673608..6dea292a90a127da77ec9c6ce16d8c234a2cb5da 100644
--- a/sparse_coding_torch/load_data.py
+++ b/sparse_coding_torch/load_data.py
@@ -40,12 +40,12 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non
         return gss.split(np.arange(len(targets)), targets, groups), dataset
     elif mode == 'all_train':
         train_idx = np.arange(len(targets))
-        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
-                                               sampler=train_sampler)
-        test_loader = None
+#         train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+#         train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+#                                                sampler=train_sampler)
+#         test_loader = None
         
-        return train_loader, test_loader, dataset
+        return [(train_idx, None)], dataset
     elif mode == 'k_fold':
         gss = StratifiedGroupKFold(n_splits=n_splits)
 
diff --git a/sparse_coding_torch/train_classifier.py b/sparse_coding_torch/train_classifier.py
index 812769c676852d472491168fca65d770e1bbb76d..ca71ca2f59f6aec2f5b0d4eebaf60c4b049be13c 100644
--- a/sparse_coding_torch/train_classifier.py
+++ b/sparse_coding_torch/train_classifier.py
@@ -5,7 +5,7 @@ from tqdm import tqdm
 import argparse
 import os
 from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler, get_sample_weights
-from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d
+from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d, PNBTemporalClassifier
 import time
 import numpy as np
 from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
@@ -19,6 +19,7 @@ from yolov4.get_bounding_boxes import YoloModel
 import torchvision
 from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
 import glob
+import cv2
 
 configproto = tf.compat.v1.ConfigProto()
 configproto.gpu_options.polling_inactive_delay_msecs = 5000
@@ -129,8 +130,8 @@ if __name__ == "__main__":
     
     i_fold = 0
     for train_idx, test_idx in splits:
-#         train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-        train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+#         train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
         train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=train_sampler)
         
@@ -139,19 +140,20 @@ if __name__ == "__main__":
             test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                    sampler=test_sampler)
             
-            with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out:
-                test_set = set([x for tup in test_loader for x in tup[2]])
-                test_videos_out.writelines(test_set)
+#             with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out:
+#                 test_set = set([x for tup in test_loader for x in tup[2]])
+#                 test_videos_out.writelines(test_set)
         else:
             test_loader = None
         
         if args.checkpoint:
             classifier_model = keras.models.load_model(args.checkpoint)
         else:
-            classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+#             classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernel))
+            classifier_inputs = keras.Input(shape=(clip_depth, image_height, image_width))
 
             if args.dataset == 'pnb':
-                classifier_outputs = PNBClassifier()(classifier_inputs)
+                classifier_outputs = PNBTemporalClassifier()(classifier_inputs)
             elif args.dataset == 'ptx':
                 classifier_outputs = PTXClassifier()(classifier_inputs)
             else:
@@ -176,6 +178,13 @@ if __name__ == "__main__":
 
                 for labels, local_batch, vid_f in tqdm(train_loader):
                     images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+                    cv2.imwrite('example_video_2/test_{}_0.png'.format(labels[3]), images[3, 0, :, :] * 255)
+                    cv2.imwrite('example_video_2/test_{}_1.png'.format(labels[3]), images[3, 1, :, :] * 255)
+                    cv2.imwrite('example_video_2/test_{}_2.png'.format(labels[3]), images[3, 2, :, :] * 255)
+                    cv2.imwrite('example_video_2/test_{}_3.png'.format(labels[3]), images[3, 3, :, :] * 255)
+                    cv2.imwrite('example_video_2/test_{}_4.png'.format(labels[3]), images[3, 4, :, :] * 255)
+                    print(vid_f[3])
+                    raise Exception
 
                     torch_labels = np.zeros(len(labels))
                     torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
@@ -183,17 +192,17 @@ if __name__ == "__main__":
 
                     if args.train_sparse:
                         with tf.GradientTape() as tape:
-                            activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)])
+#                             activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)])
                             pred = classifier_model(activations)
                             loss = criterion(torch_labels, pred)
 
                             print(loss)
                     else:
-                        activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+#                         activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
     #                     raise Exception
 
                         with tf.GradientTape() as tape:
-                            pred = classifier_model(activations)
+                            pred = classifier_model(images)
                             loss = criterion(torch_labels, pred)
 
                     epoch_loss += loss * local_batch.size(0)
@@ -235,9 +244,9 @@ if __name__ == "__main__":
                     torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
                     torch_labels = np.expand_dims(torch_labels, axis=1)
 
-                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+#                     activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
 
-                    pred = classifier_model(activations)
+                    pred = classifier_model(images)
                     loss = criterion(torch_labels, pred)
 
                     test_loss += loss
diff --git a/sparse_coding_torch/train_sparse_model.py b/sparse_coding_torch/train_sparse_model.py
index 8fc91ca5d995f1f68483d5c4bbf0497c742dafcd..a15e73d4fb8cfce69abe2a29e89d1f8e66667333 100644
--- a/sparse_coding_torch/train_sparse_model.py
+++ b/sparse_coding_torch/train_sparse_model.py
@@ -146,7 +146,8 @@ if __name__ == "__main__":
     elif args.dataset == 'pnb':
         train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=False, balance_classes=False, mode='all_train', frames_to_skip=args.frames_to_skip)
     elif args.dataset == 'ptx':
-        train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json')
+        splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+        train_idx, test_idx = splits[0]
     elif args.dataset == 'needle':
         train_loader, test_loader, dataset = load_needle_clips(args.batch_size, input_size=(image_height, image_width, clip_depth))
     else:
diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py
index fbe9e03b1fb93ef9ce670a8fecd8a8861101bd2f..6621f9a572094a43aa62fef0b74dfd8d0cc423f5 100644
--- a/sparse_coding_torch/video_loader.py
+++ b/sparse_coding_torch/video_loader.py
@@ -76,15 +76,14 @@ def load_pnb_region_labels(file_path):
             
         return all_regions
     
-def get_yolo_regions(yolo_model, clip, is_right, crop_width, crop_height):
+def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height):
     orig_height = clip.size(2)
     orig_width = clip.size(3)
     bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy())
     bounding_boxes = bounding_boxes.squeeze(0)
     classes = classes.squeeze(0)
     
-    rotate_box = False
-    angle = 20
+    rotate_box = True
     
     all_clips = []
     for bb, class_pred in zip(bounding_boxes, classes):
@@ -120,9 +119,10 @@ def get_yolo_regions(yolo_model, clip, is_right, crop_width, crop_height):
         
 #         if orig_width - center_x >= center_x:
 #         if not is_right:
-# #         cv2.imwrite('test.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-#             cv2.imwrite('test_3.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-#             raise Exception
+#         print(angle)
+#         cv2.imwrite('test_2.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#         cv2.imwrite('test_3.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#         raise Exception
         
 #         print(trimmed_clip.size())
 
@@ -180,6 +180,41 @@ def classify_nerve_is_right(yolo_model, video):
     final_pred = round(sum(all_preds) / len(all_preds))
 
     return final_pred == 1
+
+def calculate_angle(yolo_model, video):
+    orig_height = video.size(2)
+    orig_width = video.size(3)
+
+    all_preds = []
+    if video.size(1) < 10:
+        return 45
+
+    for frame in range(0, video.size(1), round(video.size(1) / 10)):
+        frame = video[:, frame, :, :]
+        bounding_boxes, classes = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
+        bounding_boxes = bounding_boxes.squeeze(0)
+        classes = classes.squeeze(0)
+        
+        vessel_x = 0
+        vessel_y = 0
+        needle_x = 0
+        needle_y = 0
+    
+        for bb, class_pred in zip(bounding_boxes, classes):
+            if class_pred == 0 and vessel_x == 0:
+                vessel_x = (bb[3] + bb[1]) / 2 * orig_width
+                vessel_y = (bb[2] + bb[0]) / 2 * orig_height
+            elif class_pred == 2 and needle_x == 0:
+                needle_x = bb[1] * orig_width
+                needle_y = bb[0] * orig_height
+                
+            if needle_x != 0 and vessel_x != 0:
+                break
+                
+        if vessel_x == 0 or needle_x == 0:
+            return 45
+        else:
+            return np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x))))
                 
     
 class PNBLoader(Dataset):
@@ -217,6 +252,7 @@ class PNBLoader(Dataset):
             for label, path, _ in tqdm(self.videos):
                 vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
                 is_right = classify_nerve_is_right(yolo_model, vc)
+                angle = calculate_angle(yolo_model, vc)
 
                 if classify_mode:
 #                     person_idx = path.split('/')[-2]
@@ -232,7 +268,7 @@ class PNBLoader(Dataset):
                             start_loc = int(sub_region[0])
 #                             end_loc = int(sub_region[1]) - 50
                             end_loc = int(sub_region[1]) + 1
-                            for j in range(start_loc, end_loc - clip_depth * frames_to_skip, 5):
+                            for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth):
                                 frames = []
                                 for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
                                     frames.append(vc[:, k, :, :])
@@ -241,7 +277,7 @@ class PNBLoader(Dataset):
                                 if vc_sub.size(1) < clip_depth:
                                     continue
 
-                                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
+                                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
                                     if self.transform:
                                         clip = self.transform(clip)
 
@@ -261,7 +297,7 @@ class PNBLoader(Dataset):
                                     if vc_sub.size(1) < clip_depth:
                                         continue
                                         
-                                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
+                                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
                                         if self.transform:
                                             clip = self.transform(clip)
 
@@ -272,7 +308,7 @@ class PNBLoader(Dataset):
                                         end_loc = vc.size(1)
                                     else:
                                         end_loc = int(end_loc)
-                                    for j in range(start_loc, end_loc - clip_depth * frames_to_skip, 3):
+                                    for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth):
                                         frames = []
                                         for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
                                             frames.append(vc[:, k, :, :])
@@ -280,7 +316,7 @@ class PNBLoader(Dataset):
 
                                         if vc_sub.size(1) < clip_depth:
                                             continue
-                                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
+                                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
                                             if self.transform:
                                                 clip = self.transform(clip)
 
@@ -302,14 +338,14 @@ class PNBLoader(Dataset):
 
                             self.clips.append((self.videos[vid_idx][0], clip, self.videos[vid_idx][2]))
                     elif label == 'Negatives':
-                        for j in range(0, vc.size(1) - clip_depth * frames_to_skip, 10):
+                        for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth):
                             frames = []
                             for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
                                 frames.append(vc[:, k, :, :])
                             vc_sub = torch.stack(frames, dim=1)
                             if vc_sub.size(1) < clip_depth:
                                 continue
-                            for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
+                            for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
                                 if self.transform:
                                     clip = self.transform(clip)
 
@@ -317,7 +353,7 @@ class PNBLoader(Dataset):
                     else:
                         raise Exception('Invalid label')
                 else:
-                    for j in range(0, vc.size(1) - clip_depth * frames_to_skip, 10):
+                    for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth):
                         frames = []
                         for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
                             frames.append(vc[:, k, :, :])