From a343c7f08b4339b39d2ef85d83f25fe765cec522 Mon Sep 17 00:00:00 2001
From: hannandarryl <hannandarryl@gmail.com>
Date: Thu, 26 May 2022 01:01:09 +0000
Subject: [PATCH] first push with refactoring

---
 generate_yolo_regions.py                      |  25 +-
 sparse_coding_torch/onsd/classifier_model.py  |  42 ++
 sparse_coding_torch/onsd/load_data.py         |  50 ++
 sparse_coding_torch/onsd/train_classifier.py  | 299 ++++++++++++
 .../onsd/train_sparse_model.py                | 169 +++++++
 sparse_coding_torch/onsd/video_loader.py      | 122 +++++
 sparse_coding_torch/pnb/classifier_model.py   | 134 ++++++
 .../{ => pnb}/generate_tflite.py              |   4 +-
 sparse_coding_torch/{ => pnb}/load_data.py    | 108 +----
 sparse_coding_torch/pnb/train_classifier.py   | 442 ++++++++++++++++++
 .../{ => pnb}/train_classifier_needle.py      |   8 +-
 sparse_coding_torch/pnb/train_sparse_model.py | 172 +++++++
 sparse_coding_torch/{ => pnb}/video_loader.py | 421 +++--------------
 sparse_coding_torch/ptx/classifier_model.py   | 120 +++++
 .../{ => ptx}/convert_pytorch_to_keras.py     |  14 +-
 sparse_coding_torch/ptx/generate_tflite.py    |  63 +++
 sparse_coding_torch/ptx/load_data.py          |  69 +++
 .../{ => ptx}/train_classifier.py             | 342 ++++++--------
 .../{ => ptx}/train_sparse_model.py           |  20 +-
 sparse_coding_torch/ptx/video_loader.py       | 263 +++++++++++
 .../{keras_model.py => sparse_model.py}       | 205 +-------
 sparse_coding_torch/utils.py                  | 230 +++++----
 yolov4/get_bounding_boxes.py                  |  24 +-
 23 files changed, 2355 insertions(+), 991 deletions(-)
 create mode 100644 sparse_coding_torch/onsd/classifier_model.py
 create mode 100644 sparse_coding_torch/onsd/load_data.py
 create mode 100644 sparse_coding_torch/onsd/train_classifier.py
 create mode 100644 sparse_coding_torch/onsd/train_sparse_model.py
 create mode 100644 sparse_coding_torch/onsd/video_loader.py
 create mode 100644 sparse_coding_torch/pnb/classifier_model.py
 rename sparse_coding_torch/{ => pnb}/generate_tflite.py (95%)
 rename sparse_coding_torch/{ => pnb}/load_data.py (54%)
 create mode 100644 sparse_coding_torch/pnb/train_classifier.py
 rename sparse_coding_torch/{ => pnb}/train_classifier_needle.py (98%)
 create mode 100644 sparse_coding_torch/pnb/train_sparse_model.py
 rename sparse_coding_torch/{ => pnb}/video_loader.py (52%)
 create mode 100644 sparse_coding_torch/ptx/classifier_model.py
 rename sparse_coding_torch/{ => ptx}/convert_pytorch_to_keras.py (85%)
 create mode 100644 sparse_coding_torch/ptx/generate_tflite.py
 create mode 100644 sparse_coding_torch/ptx/load_data.py
 rename sparse_coding_torch/{ => ptx}/train_classifier.py (55%)
 rename sparse_coding_torch/{ => ptx}/train_sparse_model.py (88%)
 create mode 100644 sparse_coding_torch/ptx/video_loader.py
 rename sparse_coding_torch/{keras_model.py => sparse_model.py} (54%)

diff --git a/generate_yolo_regions.py b/generate_yolo_regions.py
index a0351cd..c39eb2d 100644
--- a/generate_yolo_regions.py
+++ b/generate_yolo_regions.py
@@ -3,7 +3,7 @@ import os
 import time
 import numpy as np
 import torchvision
-from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels, calculate_angle
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels, calculate_angle, calculate_angle_video, get_needle_bb
 from torchvision.datasets.video_utils import VideoClips
 import torchvision as tv
 import csv
@@ -24,7 +24,7 @@ if __name__ == "__main__":
     parser.add_argument('--output_dir', default='yolo_output', type=str, help='Location where yolo clips should be saved.')
     parser.add_argument('--num_frames', default=5, type=int)
     parser.add_argument('--stride', default=5, type=int)
-    parser.add_argument('--image_height', default=150, type=int)
+    parser.add_argument('--image_height', default=300, type=int)
     parser.add_argument('--image_width', default=400, type=int)
     
     args = parser.parse_args()
@@ -50,7 +50,8 @@ if __name__ == "__main__":
         
     vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
     is_right = classify_nerve_is_right(yolo_model, vc)
-    angle = calculate_angle(yolo_model, vc)
+#     video_angle = calculate_angle_video(yolo_model, vc)
+    needle_bb = get_needle_bb(yolo_model, vc)
     person_idx = path.split('/')[-2]
     label = path.split('/')[-3]
     
@@ -72,10 +73,11 @@ if __name__ == "__main__":
                 if vc_sub.size(1) < clip_depth:
                     continue
 
-                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height):
                     clip = transforms(clip)
                     ani = plot_video(clip)
                     ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'))
+                    print(output_count)
                     output_count += 1
 
         if positive_regions:
@@ -92,11 +94,13 @@ if __name__ == "__main__":
                     if vc_sub.size(1) < clip_depth:
                         continue
 
-                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
+                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height):
                         clip = transforms(clip)
                         ani = plot_video(clip)
                         ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'))
+                        print(output_count)
                         output_count += 1
+
                 elif vc.size(1) >= start_loc + clip_depth * frames_to_skip:
                     end_loc = sub_region[1]
                     if end_loc.strip().lower() == 'end':
@@ -111,10 +115,12 @@ if __name__ == "__main__":
 
                         if vc_sub.size(1) < clip_depth:
                             continue
-                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
+
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height):
                             clip = transforms(clip)
                             ani = plot_video(clip)
                             ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'))
+                            print(output_count)
                             output_count += 1
                 else:
                     continue
@@ -125,10 +131,11 @@ if __name__ == "__main__":
         if frames:
             vc_sub = torch.stack(frames, dim=1)
             if vc_sub.size(1) >= clip_depth:
-                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height):
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height):
                     clip = transforms(clip)
                     ani = plot_video(clip)
                     ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'))
+                    print(output_count)
                     output_count += 1
     elif label == 'Negatives':
         for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth):
@@ -136,11 +143,13 @@ if __name__ == "__main__":
             for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip):
                 frames.append(vc[:, k, :, :])
             vc_sub = torch.stack(frames, dim=1)
+
             if vc_sub.size(1) >= clip_depth:
-                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height):
+                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height):
                     clip = transforms(clip)
                     ani = plot_video(clip)
                     ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'))
+                    print(output_count)
                     output_count += 1
     else:
         raise Exception('Invalid label')
\ No newline at end of file
diff --git a/sparse_coding_torch/onsd/classifier_model.py b/sparse_coding_torch/onsd/classifier_model.py
new file mode 100644
index 0000000..1162dc1
--- /dev/null
+++ b/sparse_coding_torch/onsd/classifier_model.py
@@ -0,0 +1,42 @@
+from tensorflow import keras
+import numpy as np
+import torch
+import tensorflow as tf
+import cv2
+import torchvision as tv
+import torch
+import torch.nn as nn
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
+    
+class ONSDClassifier(keras.layers.Layer):
+    def __init__(self):
+        super(ONSDClassifier, self).__init__()
+
+        self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
+
+        self.flatten = keras.layers.Flatten()
+
+        self.dropout = keras.layers.Dropout(0.5)
+
+#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+#         self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True)
+#         self.ff_2 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_4 = keras.layers.Dense(1)
+
+#     @tf.function
+    def call(self, activations):
+        x = self.conv_1(activations)
+        x = self.conv_2(x)
+        x = self.flatten(x)
+#         x = self.ff_1(x)
+#         x = self.dropout(x)
+#         x = self.ff_2(x)
+#         x = self.dropout(x)
+        x = self.ff_3(x)
+        x = self.dropout(x)
+        x = self.ff_4(x)
+
+        return x
+
diff --git a/sparse_coding_torch/onsd/load_data.py b/sparse_coding_torch/onsd/load_data.py
new file mode 100644
index 0000000..de53a52
--- /dev/null
+++ b/sparse_coding_torch/onsd/load_data.py
@@ -0,0 +1,50 @@
+import numpy as np
+import torchvision
+import torch
+from sklearn.model_selection import train_test_split
+from sparse_coding_torch.utils import MinMaxScaler
+from sparse_coding_torch.onsd.video_loader import get_participants, ONSDLoader
+from sparse_coding_torch.utils import VideoGrayScaler
+from typing import Sequence, Iterator
+import csv
+from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit
+    
+def load_onsd_videos(batch_size, input_size, yolo_model=None, mode=None, n_splits=None):   
+    video_path = "/shared_data/bamc_onsd_data/preliminary_onsd_data"
+    
+    transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.Grayscale(1),
+     MinMaxScaler(0, 255),
+     torchvision.transforms.Resize(input_size[:2])
+    ])
+    augment_transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.RandomRotation(15)
+    ])
+    dataset = ONSDLoader(video_path, input_size[1], input_size[0], transform=transforms, augmentation=augment_transforms, yolo_model=yolo_model)
+    
+    targets = dataset.get_labels()
+    
+    if mode == 'leave_one_out':
+        gss = LeaveOneGroupOut()
+
+        groups = get_participants(dataset.get_filenames())
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    elif mode == 'all_train':
+        train_idx = np.arange(len(targets))
+        test_idx = None
+        
+        return [(train_idx, test_idx)], dataset
+    elif mode == 'k_fold':
+        gss = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)
+
+        groups = get_participants(dataset.get_filenames())
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    else:
+#         gss = ShuffleSplit(n_splits=n_splits, test_size=0.2)
+        gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
+
+        groups = get_participants(dataset.get_filenames())
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
\ No newline at end of file
diff --git a/sparse_coding_torch/onsd/train_classifier.py b/sparse_coding_torch/onsd/train_classifier.py
new file mode 100644
index 0000000..a879bf0
--- /dev/null
+++ b/sparse_coding_torch/onsd/train_classifier.py
@@ -0,0 +1,299 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.onsd.load_data import load_onsd_videos
+from sparse_coding_torch.utils import SubsetWeightedRandomSampler, get_sample_weights
+from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d
+from sparse_coding_torch.onsd.classifier_model import ONSDClassifier
+import time
+import numpy as np
+from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
+import random
+import pickle
+import tensorflow.keras as keras
+import tensorflow as tf
+from sparse_coding_torch.onsd.train_sparse_model import sparse_loss
+from yolov4.get_bounding_boxes import YoloModel
+import torchvision
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
+import glob
+import cv2
+
+configproto = tf.compat.v1.ConfigProto()
+configproto.gpu_options.polling_inactive_delay_msecs = 5000
+configproto.gpu_options.allow_growth = True
+sess = tf.compat.v1.Session(config=configproto) 
+tf.compat.v1.keras.backend.set_session(sess)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=12, type=int)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=64, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=150, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lr', default=5e-4, type=float)
+    parser.add_argument('--epochs', default=40, type=int)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--sparse_checkpoint', default=None, type=str)
+    parser.add_argument('--checkpoint', default=None, type=str)
+    parser.add_argument('--splits', default=None, type=str, help='k_fold or leave_one_out or all_train')
+    parser.add_argument('--seed', default=26, type=int)
+    parser.add_argument('--train', action='store_true')
+    parser.add_argument('--num_positives', default=100, type=int)
+    parser.add_argument('--n_splits', default=5, type=int)
+    parser.add_argument('--save_train_test_splits', action='store_true')
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--balance_classes', action='store_true')
+    parser.add_argument('--dataset', default='pnb', type=str)
+    parser.add_argument('--train_sparse', action='store_true')
+    parser.add_argument('--mixing_ratio', type=float, default=1.0)
+    parser.add_argument('--sparse_lr', type=float, default=0.003)
+    parser.add_argument('--crop_height', type=int, default=285)
+    parser.add_argument('--crop_width', type=int, default=350)
+    parser.add_argument('--scale_factor', type=int, default=1)
+    parser.add_argument('--clip_depth', type=int, default=5)
+    parser.add_argument('--frames_to_skip', type=int, default=1)
+    
+    args = parser.parse_args()
+    
+    crop_height = args.crop_height
+    crop_width = args.crop_width
+
+    image_height = int(crop_height / args.scale_factor)
+    image_width = int(crop_width / args.scale_factor)
+    clip_depth = args.clip_depth
+
+    batch_size = args.batch_size
+    
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+        
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+    
+    yolo_model = YoloModel(args.dataset)
+
+    all_errors = []
+    
+    if args.run_2d:
+        inputs = keras.Input(shape=(image_height, image_width, clip_depth))
+    else:
+        inputs = keras.Input(shape=(clip_depth, image_height, image_width, 1))
+
+    sparse_model = None
+    recon_model = None
+        
+    
+    splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width), yolo_model=yolo_model, mode=args.splits, n_splits=args.n_splits)
+    positive_class = 'Positives'
+
+    overall_true = []
+    overall_pred = []
+    fn_ids = []
+    fp_ids = []
+    
+    i_fold = 0
+    for train_idx, test_idx in splits:
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+#         train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        
+        if test_idx is not None:
+            test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+            test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                                   sampler=test_sampler)
+            
+#             with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out:
+#                 test_set = set([x for tup in test_loader for x in tup[2]])
+#                 test_videos_out.writelines(test_set)
+        else:
+            test_loader = None
+        
+        if args.checkpoint:
+            classifier_model = keras.models.load_model(args.checkpoint)
+        else:
+            classifier_inputs = keras.Input(shape=(image_height, image_width, 1))
+            classifier_outputs = ONSDClassifier()(classifier_inputs)
+
+            classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
+
+        prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)
+        filter_optimizer = tf.keras.optimizers.SGD(learning_rate=args.sparse_lr)
+
+        best_so_far = float('-inf')
+
+        criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM)
+
+        if args.train:
+            for epoch in range(args.epochs):
+                epoch_loss = 0
+                t1 = time.perf_counter()
+
+                y_true_train = None
+                y_pred_train = None
+
+                for labels, local_batch, vid_f in tqdm(train_loader):
+                    images = local_batch.permute(0, 2, 3, 1).numpy()
+
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                    if args.train_sparse:
+                        with tf.GradientTape() as tape:
+#                             activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)])
+                            pred = classifier_model(activations)
+                            loss = criterion(torch_labels, pred)
+
+                            print(loss)
+                    else:
+                        with tf.GradientTape() as tape:
+                            pred = classifier_model(images)
+                            loss = criterion(torch_labels, pred)
+
+                    epoch_loss += loss * local_batch.size(0)
+
+                    if args.train_sparse:
+                        sparse_gradients, classifier_gradients = tape.gradient(loss, [recon_model.trainable_weights, classifier_model.trainable_weights])
+
+                        prediction_optimizer.apply_gradients(zip(classifier_gradients, classifier_model.trainable_weights))
+
+                        filter_optimizer.apply_gradients(zip(sparse_gradients, recon_model.trainable_weights))
+
+                        if args.run_2d:
+                            weights = normalize_weights(recon_model.get_weights(), args.num_kernels)
+                        else:
+                            weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels)
+                        recon_model.set_weights(weights)
+                    else:
+                        gradients = tape.gradient(loss, classifier_model.trainable_weights)
+
+                        prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))
+
+
+                    if y_true_train is None:
+                        y_true_train = torch_labels
+                        y_pred_train = tf.math.round(tf.math.sigmoid(pred))
+                    else:
+                        y_true_train = tf.concat((y_true_train, torch_labels), axis=0)
+                        y_pred_train = tf.concat((y_pred_train, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+                t2 = time.perf_counter()
+
+                y_true = None
+                y_pred = None
+                test_loss = 0.0
+                
+                eval_loader = test_loader
+                if args.splits == 'all_train':
+                    eval_loader = train_loader
+                for labels, local_batch, vid_f in tqdm(eval_loader):
+                    images = local_batch.permute(0, 2, 3, 1).numpy()
+
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                    pred = classifier_model(images)
+                    loss = criterion(torch_labels, pred)
+
+                    test_loss += loss
+
+                    if y_true is None:
+                        y_true = torch_labels
+                        y_pred = tf.math.round(tf.math.sigmoid(pred))
+                    else:
+                        y_true = tf.concat((y_true, torch_labels), axis=0)
+                        y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+                t2 = time.perf_counter()
+
+                y_true = tf.cast(y_true, tf.int32)
+                y_pred = tf.cast(y_pred, tf.int32)
+
+                y_true_train = tf.cast(y_true_train, tf.int32)
+                y_pred_train = tf.cast(y_pred_train, tf.int32)
+
+                f1 = f1_score(y_true, y_pred, average='macro')
+                accuracy = accuracy_score(y_true, y_pred)
+
+                train_accuracy = accuracy_score(y_true_train, y_pred_train)
+
+                print('epoch={}, i_fold={}, time={:.2f}, train_loss={:.2f}, test_loss={:.2f}, train_acc={:.2f}, test_f1={:.2f}, test_acc={:.2f}'.format(epoch, i_fold, t2-t1, epoch_loss, test_loss, train_accuracy, f1, accuracy))
+    #             print(epoch_loss)
+                if f1 >= best_so_far:
+                    print("found better model")
+                    # Save model parameters
+                    classifier_model.save(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
+#                     recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold)))
+                    pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer_{}.pt'.format(i_fold)), 'wb+'))
+                    best_so_far = f1
+
+            classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
+#             recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold)))
+
+        epoch_loss = 0
+
+        y_true = None
+        y_pred = None
+
+        pred_dict = {}
+        gt_dict = {}
+
+        t1 = time.perf_counter()
+    #         test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader]
+        raise Exception('Not yet implemented')
+            
+        t2 = time.perf_counter()
+
+        print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1))
+
+        y_true = tf.cast(y_true, tf.int32)
+        y_pred = tf.cast(y_pred, tf.int32)
+
+        f1 = f1_score(y_true, y_pred, average='macro')
+        accuracy = accuracy_score(y_true, y_pred)
+
+        fn_ids.extend(fn)
+        fp_ids.extend(fp)
+
+        overall_true.extend(y_true)
+        overall_pred.extend(y_pred)
+
+        print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy))
+
+        print(confusion_matrix(y_true, y_pred))
+            
+        i_fold += 1
+
+    fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
+    with open(fp_fn_file, 'w+') as in_f:
+        in_f.write('FP:\n')
+        in_f.write(str(fp_ids) + '\n\n')
+        in_f.write('FN:\n')
+        in_f.write(str(fn_ids) + '\n\n')
+        
+    overall_true = np.array(overall_true)
+    overall_pred = np.array(overall_pred)
+            
+    final_f1 = f1_score(overall_true, overall_pred, average='macro')
+    final_acc = accuracy_score(overall_true, overall_pred)
+    final_conf = confusion_matrix(overall_true, overall_pred)
+            
+    print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1))
+    print(final_conf)
+
diff --git a/sparse_coding_torch/onsd/train_sparse_model.py b/sparse_coding_torch/onsd/train_sparse_model.py
new file mode 100644
index 0000000..0352d53
--- /dev/null
+++ b/sparse_coding_torch/onsd/train_sparse_model.py
@@ -0,0 +1,169 @@
+import time
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from matplotlib import cm
+from matplotlib.animation import FuncAnimation
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.onsd.load_data import load_onsd_videos
+import tensorflow.keras as keras
+import tensorflow as tf
+from sparse_coding_torch.sparse_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse
+import random
+
+def sparse_loss(images, recon, activations, batch_size, lam, stride):
+    loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
+    loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
+    return loss
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=32, type=int)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=32, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=300, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lr', default=0.003, type=float)
+    parser.add_argument('--epochs', default=150, type=int)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--seed', default=42, type=int)
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--save_filters', action='store_true')
+    parser.add_argument('--optimizer', default='sgd', type=str)
+    parser.add_argument('--dataset', default='onsd', type=str)
+    parser.add_argument('--crop_height', type=int, default=400)
+    parser.add_argument('--crop_width', type=int, default=400)
+    parser.add_argument('--scale_factor', type=int, default=1)
+    parser.add_argument('--clip_depth', type=int, default=5)
+    parser.add_argument('--frames_to_skip', type=int, default=1)
+    
+
+    args = parser.parse_args()
+    
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+
+    crop_height = args.crop_height
+    crop_width = args.crop_width
+
+    image_height = int(crop_height / args.scale_factor)
+    image_width = int(crop_width / args.scale_factor)
+    clip_depth = args.clip_depth
+
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+        
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+
+    splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), mode='all_train')
+    train_idx, test_idx = splits[0]
+    
+    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
+                                           sampler=train_sampler)
+    
+    print('Loaded', len(train_loader), 'train examples')
+
+    example_data = next(iter(train_loader))
+
+    if args.run_2d:
+        inputs = keras.Input(shape=(image_height, image_width, 5))
+    else:
+        inputs = keras.Input(shape=(5, image_height, image_width, 1))
+        
+    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+    output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+    
+    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+    
+    recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+    
+    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
+    
+    if args.save_filters:
+        if args.run_2d:
+            filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0))
+        else:
+            filters = plot_filters(recon_model.get_weights()[0])
+        filters.save(os.path.join(args.output_dir, 'filters_start.mp4'))
+
+    learning_rate = args.lr
+    if args.optimizer == 'sgd':
+        filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
+    else:
+        filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
+
+    loss_log = []
+    best_so_far = float('inf')
+
+    for epoch in range(args.epochs):
+        epoch_loss = 0
+        running_loss = 0.0
+        epoch_start = time.perf_counter()
+        
+        num_iters = 0
+
+        for labels, local_batch, vid_f in tqdm(train_loader):
+            if local_batch.size(0) != args.batch_size:
+                continue
+            if args.run_2d:
+                images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
+            else:
+                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+                
+            activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+            
+            with tf.GradientTape() as tape:
+                recon = recon_model(activations)
+                loss = sparse_loss(images, recon, activations, args.batch_size, args.lam, args.stride)
+
+            epoch_loss += loss * local_batch.size(0)
+            running_loss += loss * local_batch.size(0)
+
+            gradients = tape.gradient(loss, recon_model.trainable_weights)
+
+            filter_optimizer.apply_gradients(zip(gradients, recon_model.trainable_weights))
+            
+            if args.run_2d:
+                weights = normalize_weights(recon_model.get_weights(), args.num_kernels)
+            else:
+                weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels)
+            recon_model.set_weights(weights)
+                
+            num_iters += 1
+
+        epoch_end = time.perf_counter()
+        epoch_loss /= len(train_loader.sampler)
+        
+        if args.save_filters and epoch % 2 == 0:
+            if args.run_2d:
+                filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0))
+            else:
+                filters = plot_filters(recon_model.get_weights()[0])
+            filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4'))
+
+        if epoch_loss < best_so_far:
+            print("found better model")
+            # Save model parameters
+            recon_model.save(os.path.join(output_dir, "best_sparse.pt"))
+            best_so_far = epoch_loss
+
+        loss_log.append(epoch_loss)
+        print('epoch={}, epoch_loss={:.2f}, time={:.2f}'.format(epoch, epoch_loss, epoch_end - epoch_start))
+
+    plt.plot(loss_log)
+
+    plt.savefig(os.path.join(output_dir, 'loss_graph.png'))
diff --git a/sparse_coding_torch/onsd/video_loader.py b/sparse_coding_torch/onsd/video_loader.py
new file mode 100644
index 0000000..b644ac9
--- /dev/null
+++ b/sparse_coding_torch/onsd/video_loader.py
@@ -0,0 +1,122 @@
+from os import listdir
+from os.path import isfile
+from os.path import join
+from os.path import isdir
+from os.path import abspath
+from os.path import exists
+import json
+import glob
+
+from PIL import Image
+from torchvision.transforms import ToTensor
+from torchvision.datasets.video_utils import VideoClips
+from tqdm import tqdm
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+from torch.utils.data import DataLoader
+from torchvision.io import read_video
+import torchvision as tv
+from torch import nn
+import torchvision.transforms.functional as tv_f
+import csv
+import random
+import cv2
+from yolov4.get_bounding_boxes import YoloModel
+
+def get_participants(filenames):
+    return [f.split('/')[-2] for f in filenames]
+    
+def get_yolo_region_onsd(yolo_model, frame):
+    orig_height = frame.size(1)
+    orig_width = frame.size(2)
+    
+    bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
+    bounding_boxes = bounding_boxes.squeeze(0)
+    classes = classes.squeeze(0)
+    scores = scores.squeeze(0)
+    
+    all_frames = []
+    for bb, class_pred, score in zip(bounding_boxes, classes, scores):
+        if class_pred != 0:
+            continue
+        
+        lower_y = round((bb[0] * orig_height))
+        upper_y = round((bb[2] * orig_height))
+        lower_x = round((bb[1] * orig_width))
+        upper_x = round((bb[3] * orig_width))
+
+        trimmed_frame = frame[:, lower_y:upper_y, lower_x:upper_x]
+        
+#         cv2.imwrite('test_2.png', frame.numpy().swapaxes(0,1).swapaxes(1,2))
+#         cv2.imwrite('test_3.png', trimmed_frame.numpy().swapaxes(0,1).swapaxes(1,2))
+        
+        return trimmed_frame
+
+    return None
+    
+class ONSDLoader(Dataset):
+    
+    def __init__(self, video_path, clip_width, clip_height, transform=None, augmentation=None, yolo_model=None):
+        self.transform = transform
+        self.augmentation = augmentation
+        self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
+        
+        clip_cache_file = 'clip_cache_onsd_{}_{}_sparse.pt'.format(clip_width, clip_height)
+        
+        self.videos = []
+        for label in self.labels:
+            self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))])
+            
+        self.clips = []
+        
+        if exists(clip_cache_file):
+            self.clips = torch.load(open(clip_cache_file, 'rb'))
+        else:
+            vid_idx = 0
+            for label, path, _ in tqdm(self.videos):
+                vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
+                
+                for j in range(vc.size(1)):
+                    frame = vc[:, j, :, :]
+                    
+                    if yolo_model is not None:
+                        frame = get_yolo_region_onsd(yolo_model, frame)
+                        
+                    if frame is None:
+                        continue
+
+                    if self.transform:
+                        frame = self.transform(frame)
+
+                    self.clips.append((self.videos[vid_idx][0], frame, self.videos[vid_idx][2]))
+
+                vid_idx += 1
+                
+            torch.save(self.clips, open(clip_cache_file, 'wb+'))
+            
+        num_positive = len([clip[0] for clip in self.clips if clip[0] == 'Positives'])
+        num_negative = len([clip[0] for clip in self.clips if clip[0] == 'Negatives'])
+        
+        random.shuffle(self.clips)
+        
+        print('Loaded', num_positive, 'positive examples.')
+        print('Loaded', num_negative, 'negative examples.')
+        
+    def get_filenames(self):
+        return [self.clips[i][2] for i in range(len(self.clips))]
+        
+    def get_video_labels(self):
+        return [self.videos[i][0] for i in range(len(self.videos))]
+        
+    def get_labels(self):
+        return [self.clips[i][0] for i in range(len(self.clips))]
+    
+    def __getitem__(self, index):
+        label, frame, vid_f = self.clips[index]
+        if self.augmentation:
+            frame = self.augmentation(frame)
+        return (label, frame, vid_f)
+        
+    def __len__(self):
+        return len(self.clips)
diff --git a/sparse_coding_torch/pnb/classifier_model.py b/sparse_coding_torch/pnb/classifier_model.py
new file mode 100644
index 0000000..8f363e6
--- /dev/null
+++ b/sparse_coding_torch/pnb/classifier_model.py
@@ -0,0 +1,134 @@
+from tensorflow import keras
+import numpy as np
+import torch
+import tensorflow as tf
+import cv2
+import torchvision as tv
+import torch
+import torch.nn as nn
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
+    
+class PNBClassifier(keras.layers.Layer):
+    def __init__(self):
+        super(PNBClassifier, self).__init__()
+
+#         self.max_pool = keras.layers.MaxPooling2D(pool_size=(8, 8), strides=(2, 2))
+        self.conv_1 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=(4, 4), activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv2D(32, kernel_size=4, strides=2, activation='relu', padding='valid')
+#         self.conv_3 = keras.layers.Conv2D(12, kernel_size=4, strides=1, activation='relu', padding='valid')
+#         self.conv_4 = keras.layers.Conv2D(16, kernel_size=4, strides=2, activation='relu', padding='valid')
+
+        self.flatten = keras.layers.Flatten()
+
+#         self.dropout = keras.layers.Dropout(0.5)
+
+#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+        self.ff_2 = keras.layers.Dense(40, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_4 = keras.layers.Dense(1)
+
+#     @tf.function
+    def call(self, activations):
+        x = tf.squeeze(activations, axis=1)
+#         x = self.max_pool(x)
+#         print(x.shape)
+        x = self.conv_1(x)
+#         print(x.shape)
+        x = self.conv_2(x)
+#         print(x.shape)
+#         raise Exception
+#         x = self.conv_3(x)
+#         print(x.shape)
+#         x = self.conv_4(x)
+#         raise Exception
+        x = self.flatten(x)
+#         x = self.ff_1(x)
+#         x = self.dropout(x)
+        x = self.ff_2(x)
+#         x = self.dropout(x)
+        x = self.ff_3(x)
+#         x = self.dropout(x)
+        x = self.ff_4(x)
+
+        return x
+    
+class PNBTemporalClassifier(keras.layers.Layer):
+    def __init__(self):
+        super(PNBTemporalClassifier, self).__init__()
+        self.conv_1 = keras.layers.Conv3D(24, kernel_size=(5, 200, 50), strides=(1, 1, 10), activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv1D(48, kernel_size=8, strides=4, activation='relu', padding='valid')
+        
+        self.ff_1 = keras.layers.Dense(100, activation='relu', use_bias=True)
+        
+#         self.gru = keras.layers.GRU(25)
+
+        self.flatten = keras.layers.Flatten()
+
+        self.ff_2 = keras.layers.Dense(10, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(1)
+
+#     @tf.function
+    def call(self, clip):
+        width = clip.shape[3]
+        height = clip.shape[2]
+        depth = clip.shape[1]
+        
+        x = tf.expand_dims(clip, axis=4)
+#         x = tf.reshape(clip, (-1, height, width, 1))
+
+        x = self.conv_1(x)
+        x = tf.squeeze(x, axis=1)
+        x = tf.squeeze(x, axis=1)
+        x = self.conv_2(x)
+
+        x = self.flatten(x)
+        x = self.ff_1(x)
+
+#         x = tf.reshape(x, (-1, 5, 100))
+#         x = self.gru(x)
+        
+        x = self.ff_2(x)
+        x = self.ff_3(x)
+
+        return x
+    
+class MobileModelPNB(keras.Model):
+    def __init__(self, sparse_weights, classifier_model, batch_size, image_height, image_width, clip_depth, out_channels, kernel_size, kernel_depth, stride, lam, activation_lr, max_activation_iter, run_2d):
+        super().__init__()
+        self.sparse_code = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=out_channels, kernel_size=kernel_size, kernel_depth=kernel_depth, stride=stride, lam=lam, activation_lr=activation_lr, max_activation_iter=max_activation_iter, run_2d=run_2d, padding='VALID')
+        self.classifier = classifier_model
+
+        self.out_channels = out_channels
+        self.stride = stride
+        self.lam = lam
+        self.activation_lr = activation_lr
+        self.max_activation_iter = max_activation_iter
+        self.batch_size = batch_size
+        self.run_2d = run_2d
+        
+        if run_2d:
+            weight_list = np.split(sparse_weights, 5, axis=0)
+            self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
+            self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False)
+            self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False)
+            self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False)
+            self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False)
+        else:
+            self.filters = tf.Variable(initial_value=sparse_weights, dtype='float32', trainable=False)
+
+    @tf.function
+    def call(self, images):
+#         images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1)
+        images = tf.transpose(images, perm=[0, 2, 3, 1])
+        images = images / 255
+
+        if self.run_2d:
+            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
+            activations = tf.expand_dims(activations, axis=1)
+        else:
+            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
+
+        pred = tf.math.round(tf.math.sigmoid(self.classifier(activations)))
+#         pred = tf.math.reduce_sum(activations)
+
+        return pred
diff --git a/sparse_coding_torch/generate_tflite.py b/sparse_coding_torch/pnb/generate_tflite.py
similarity index 95%
rename from sparse_coding_torch/generate_tflite.py
rename to sparse_coding_torch/pnb/generate_tflite.py
index 2a8b2a2..7edde64 100644
--- a/sparse_coding_torch/generate_tflite.py
+++ b/sparse_coding_torch/pnb/generate_tflite.py
@@ -6,8 +6,8 @@ import cv2
 import torchvision as tv
 import torch
 import torch.nn as nn
-from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
-from sparse_coding_torch.keras_model import MobileModelPNB
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
+from sparse_coding_torch.pnb.classifier_model import MobileModelPNB
 import argparse
 
 if __name__ == "__main__":
diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/pnb/load_data.py
similarity index 54%
rename from sparse_coding_torch/load_data.py
rename to sparse_coding_torch/pnb/load_data.py
index 6dea292..3059e04 100644
--- a/sparse_coding_torch/load_data.py
+++ b/sparse_coding_torch/pnb/load_data.py
@@ -2,72 +2,12 @@ import numpy as np
 import torchvision
 import torch
 from sklearn.model_selection import train_test_split
-from sparse_coding_torch.video_loader import MinMaxScaler
-from sparse_coding_torch.video_loader import YoloClipLoader, get_ptx_participants, PNBLoader, get_participants, NeedleLoader, ONSDLoader
-from sparse_coding_torch.video_loader import VideoGrayScaler
+from sparse_coding_torch.utils import MinMaxScaler
+from sparse_coding_torch.pnb.video_loader import PNBLoader, get_participants, NeedleLoader
+from sparse_coding_torch.utils import VideoGrayScaler
 from typing import Sequence, Iterator
 import csv
 from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit
-
-def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None):   
-    video_path = "/shared_data/YOLO_Updated_PL_Model_Results/"
-
-    video_to_participant = get_ptx_participants()
-    
-    transforms = torchvision.transforms.Compose(
-    [VideoGrayScaler(),
-#      MinMaxScaler(0, 255),
-     torchvision.transforms.Normalize((0.2592,), (0.1251,)),
-    ])
-    augment_transforms = torchvision.transforms.Compose(
-    [torchvision.transforms.RandomRotation(45),
-     torchvision.transforms.RandomHorizontalFlip(),
-     torchvision.transforms.CenterCrop((100, 200))
-    ])
-    if whole_video:
-        dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
-    else:
-        dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
-    
-    targets = dataset.get_labels()
-    
-    if mode == 'leave_one_out':
-        gss = LeaveOneGroupOut()
-
-#         groups = [v for v in dataset.get_filenames()]
-        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
-        
-        return gss.split(np.arange(len(targets)), targets, groups), dataset
-    elif mode == 'all_train':
-        train_idx = np.arange(len(targets))
-#         train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-#         train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
-#                                                sampler=train_sampler)
-#         test_loader = None
-        
-        return [(train_idx, None)], dataset
-    elif mode == 'k_fold':
-        gss = StratifiedGroupKFold(n_splits=n_splits)
-
-        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
-        
-        return gss.split(np.arange(len(targets)), targets, groups), dataset
-    else:
-        gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
-
-        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
-        
-        train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0]
-        
-        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
-                                               sampler=train_sampler)
-        
-        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
-        test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
-                                               sampler=test_sampler)
-        
-        return train_loader, test_loader, dataset
     
 def get_sample_weights(train_idx, dataset):
     dataset = list(dataset)
@@ -112,7 +52,7 @@ class SubsetWeightedRandomSampler(torch.utils.data.Sampler[int]):
         return len(self.indicies)
     
 def load_pnb_videos(yolo_model, batch_size, input_size, crop_size=None, mode=None, classify_mode=False, balance_classes=False, device=None, n_splits=None, sparse_model=None, frames_to_skip=1):   
-    video_path = "/shared_data/bamc_pnb_data/full_training_data"
+    video_path = "/shared_data/bamc_pnb_data/revised_training_data"
 #     video_path = '/home/dwh48@drexel.edu/pnb_videos_for_testing/train'
 #     video_path = '/home/dwh48@drexel.edu/special_splits/train'
 
@@ -177,46 +117,6 @@ def load_pnb_videos(yolo_model, batch_size, input_size, crop_size=None, mode=Non
         
         return gss.split(np.arange(len(targets)), targets, groups), dataset
     
-def load_onsd_videos(batch_size, input_size, mode=None, n_splits=None):   
-    video_path = "/shared_data/bamc_onsd_data/preliminary_onsd_data"
-    
-    transforms = torchvision.transforms.Compose(
-    [VideoGrayScaler(),
-     MinMaxScaler(0, 255),
-     torchvision.transforms.Resize(input_size[:2])
-    ])
-    augment_transforms = torchvision.transforms.Compose(
-    [torchvision.transforms.RandomRotation(15)
-    ])
-    dataset = ONSDLoader(video_path, input_size[1], input_size[0], input_size[2], num_frames=5, transform=transforms, augmentation=augment_transforms)
-    
-    targets = dataset.get_labels()
-    
-    if mode == 'leave_one_out':
-        gss = LeaveOneGroupOut()
-
-        groups = get_participants(dataset.get_filenames())
-        
-        return gss.split(np.arange(len(targets)), targets, groups), dataset
-    elif mode == 'all_train':
-        train_idx = np.arange(len(targets))
-        test_idx = None
-        
-        return [(train_idx, test_idx)], dataset
-    elif mode == 'k_fold':
-        gss = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)
-
-        groups = get_participants(dataset.get_filenames())
-        
-        return gss.split(np.arange(len(targets)), targets, groups), dataset
-    else:
-#         gss = ShuffleSplit(n_splits=n_splits, test_size=0.2)
-        gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
-
-        groups = get_participants(dataset.get_filenames())
-        
-        return gss.split(np.arange(len(targets)), targets, groups), dataset
-    
 def load_needle_clips(batch_size, input_size):   
     video_path = "/shared_data/bamc_pnb_data/needle_data/non_needle"
     
diff --git a/sparse_coding_torch/pnb/train_classifier.py b/sparse_coding_torch/pnb/train_classifier.py
new file mode 100644
index 0000000..556bd5c
--- /dev/null
+++ b/sparse_coding_torch/pnb/train_classifier.py
@@ -0,0 +1,442 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.pnb.load_data import load_pnb_videos
+from sparse_coding_torch.utils import SubsetWeightedRandomSampler, get_sample_weights
+from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d
+from sparse_coding_torch.pnb.classifier_model import PNBClassifier, PNBTemporalClassifier
+import time
+import numpy as np
+from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
+import random
+import pickle
+import tensorflow.keras as keras
+import tensorflow as tf
+from sparse_coding_torch.pnb.train_sparse_model import sparse_loss
+from yolov4.get_bounding_boxes import YoloModel
+import torchvision
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
+import glob
+import cv2
+
+configproto = tf.compat.v1.ConfigProto()
+configproto.gpu_options.polling_inactive_delay_msecs = 5000
+configproto.gpu_options.allow_growth = True
+sess = tf.compat.v1.Session(config=configproto) 
+tf.compat.v1.keras.backend.set_session(sess)
+
+def calculate_pnb_scores(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform):
+    all_predictions = []
+    
+    numerical_labels = []
+    for label in labels:
+        if label == 'Positives':
+            numerical_labels.append(1.0)
+        else:
+            numerical_labels.append(0.0)
+
+    final_list = []
+    fp_ids = []
+    fn_ids = []
+    for v_idx, f in tqdm(enumerate(input_videos)):
+        vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2)
+        is_right = classify_nerve_is_right(yolo_model, vc)
+        needle_bb = get_needle_bb(yolo_model, vc)
+        
+        all_preds = []
+        for j in range(vc.size(1) - 5, vc.size(1) - 25, -5):
+            if j-5 < 0:
+                break
+
+            vc_sub = vc[:, j-5:j, :, :]
+            
+            if vc_sub.size(1) < 5:
+                continue
+            
+            clip = get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height)
+            
+            if not clip:
+                continue
+
+            clip = clip[0]
+            clip = transform(clip).to(torch.float32)
+            clip = tf.expand_dims(clip, axis=4) 
+            
+            if sparse_model is not None:
+                activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
+
+                pred = tf.math.round(tf.math.sigmoid(classifier_model(activations)))
+            else:
+                pred = tf.math.round(tf.math.sigmoid(classifier_model(clip)))
+
+            all_preds.append(pred)
+                
+        if all_preds:
+            final_pred = np.round(np.mean(np.array(all_preds)))
+        else:
+            final_pred = 1.0
+            
+        if final_pred != numerical_labels[v_idx]:
+            if final_pred == 0:
+                fn_ids.append(f)
+            else:
+                fp_ids.append(f)
+            
+        final_list.append(final_pred)
+        
+    return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids
+
+def calculate_pnb_scores_skipped_frames(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, frames_to_skip, image_width, image_height, transform):
+    all_predictions = []
+    
+    numerical_labels = []
+    for label in labels:
+        if label == 'Positives':
+            numerical_labels.append(1.0)
+        else:
+            numerical_labels.append(0.0)
+
+    final_list = []
+    fp_ids = []
+    fn_ids = []
+    for v_idx, f in tqdm(enumerate(input_videos)):
+        vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2)
+        is_right = classify_nerve_is_right(yolo_model, vc)
+        needle_bb = get_needle_bb(yolo_model, vc)
+        
+        all_preds = []
+        
+        frames = []
+        for k in range(vc.size(1) - 1, vc.size(1) - 5 * frames_to_skip, -frames_to_skip):
+            frames.append(vc[:, k, :, :])
+        vc_sub = torch.stack(frames, dim=1)
+            
+        if vc_sub.size(1) < 5:
+            continue
+
+        clip = get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height)
+
+        if clip:
+            clip = clip[0]
+            clip = transform(clip).to(torch.float32)
+            clip = tf.expand_dims(clip, axis=4) 
+
+            if sparse_model is not None:
+                activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
+
+                pred = tf.math.round(tf.math.sigmoid(classifier_model(activations)))
+            else:
+                pred = tf.math.round(tf.math.sigmoid(classifier_model(clip)))
+        else:
+            pred = 1.0
+            
+        if pred != numerical_labels[v_idx]:
+            if pred == 0:
+                fn_ids.append(f)
+            else:
+                fp_ids.append(f)
+            
+        final_list.append(pred)
+        
+    return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=12, type=int)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=64, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=150, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lr', default=5e-4, type=float)
+    parser.add_argument('--epochs', default=40, type=int)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--sparse_checkpoint', default=None, type=str)
+    parser.add_argument('--checkpoint', default=None, type=str)
+    parser.add_argument('--splits', default=None, type=str, help='k_fold or leave_one_out or all_train')
+    parser.add_argument('--seed', default=26, type=int)
+    parser.add_argument('--train', action='store_true')
+    parser.add_argument('--num_positives', default=100, type=int)
+    parser.add_argument('--n_splits', default=5, type=int)
+    parser.add_argument('--save_train_test_splits', action='store_true')
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--balance_classes', action='store_true')
+    parser.add_argument('--dataset', default='pnb', type=str)
+    parser.add_argument('--train_sparse', action='store_true')
+    parser.add_argument('--mixing_ratio', type=float, default=1.0)
+    parser.add_argument('--sparse_lr', type=float, default=0.003)
+    parser.add_argument('--crop_height', type=int, default=285)
+    parser.add_argument('--crop_width', type=int, default=350)
+    parser.add_argument('--scale_factor', type=int, default=1)
+    parser.add_argument('--clip_depth', type=int, default=5)
+    parser.add_argument('--frames_to_skip', type=int, default=1)
+    
+    args = parser.parse_args()
+    
+    crop_height = args.crop_height
+    crop_width = args.crop_width
+
+    image_height = int(crop_height / args.scale_factor)
+    image_width = int(crop_width / args.scale_factor)
+    clip_depth = args.clip_depth
+        
+    batch_size = args.batch_size
+    
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+        
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+    
+    yolo_model = YoloModel(args.dataset)
+
+    all_errors = []
+    
+    if args.run_2d:
+        inputs = keras.Input(shape=(image_height, image_width, clip_depth))
+    else:
+        inputs = keras.Input(shape=(clip_depth, image_height, image_width, 1))
+        
+#     filter_inputs = keras.Input(shape=(args.kernel_depth, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+#     output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+#     sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+
+#     recon_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+
+#     recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+
+#     recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
+
+#     if args.sparse_checkpoint:
+#         recon_model.set_weights(keras.models.load_model(args.sparse_checkpoint).get_weights())
+
+    sparse_model = None
+    recon_model = None
+        
+    splits, dataset = load_pnb_videos(yolo_model, args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, frames_to_skip=args.frames_to_skip)
+    positive_class = 'Positives'
+
+    overall_true = []
+    overall_pred = []
+    fn_ids = []
+    fp_ids = []
+    
+    i_fold = 0
+    for train_idx, test_idx in splits:
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+#         train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        
+        if test_idx is not None:
+            test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+            test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                                   sampler=test_sampler)
+            
+#             with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out:
+#                 test_set = set([x for tup in test_loader for x in tup[2]])
+#                 test_videos_out.writelines(test_set)
+        else:
+            test_loader = None
+        
+        if args.checkpoint:
+            classifier_model = keras.models.load_model(args.checkpoint)
+        else:
+            classifier_inputs = keras.Input(shape=(clip_depth, image_height, image_width))
+            classifier_outputs = PNBTemporalClassifier()(classifier_inputs)
+
+            classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
+
+        prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)
+        filter_optimizer = tf.keras.optimizers.SGD(learning_rate=args.sparse_lr)
+
+        best_so_far = float('-inf')
+
+        criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM)
+
+        if args.train:
+            for epoch in range(args.epochs):
+                epoch_loss = 0
+                t1 = time.perf_counter()
+
+                y_true_train = None
+                y_pred_train = None
+
+                for labels, local_batch, vid_f in tqdm(train_loader):
+                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                    if args.train_sparse:
+                        with tf.GradientTape() as tape:
+#                             activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)])
+                            pred = classifier_model(activations)
+                            loss = criterion(torch_labels, pred)
+
+                            print(loss)
+                    else:
+                        with tf.GradientTape() as tape:
+                            pred = classifier_model(images)
+                            loss = criterion(torch_labels, pred)
+
+                    epoch_loss += loss * local_batch.size(0)
+
+                    if args.train_sparse:
+                        sparse_gradients, classifier_gradients = tape.gradient(loss, [recon_model.trainable_weights, classifier_model.trainable_weights])
+
+                        prediction_optimizer.apply_gradients(zip(classifier_gradients, classifier_model.trainable_weights))
+
+                        filter_optimizer.apply_gradients(zip(sparse_gradients, recon_model.trainable_weights))
+
+                        if args.run_2d:
+                            weights = normalize_weights(recon_model.get_weights(), args.num_kernels)
+                        else:
+                            weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels)
+                        recon_model.set_weights(weights)
+                    else:
+                        gradients = tape.gradient(loss, classifier_model.trainable_weights)
+
+                        prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))
+
+
+                    if y_true_train is None:
+                        y_true_train = torch_labels
+                        y_pred_train = tf.math.round(tf.math.sigmoid(pred))
+                    else:
+                        y_true_train = tf.concat((y_true_train, torch_labels), axis=0)
+                        y_pred_train = tf.concat((y_pred_train, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+                t2 = time.perf_counter()
+
+                y_true = None
+                y_pred = None
+                test_loss = 0.0
+                
+                eval_loader = test_loader
+                if args.splits == 'all_train':
+                    eval_loader = train_loader
+                for labels, local_batch, vid_f in tqdm(eval_loader):
+                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+                    
+                    pred = classifier_model(images)
+                    loss = criterion(torch_labels, pred)
+
+                    test_loss += loss
+
+                    if y_true is None:
+                        y_true = torch_labels
+                        y_pred = tf.math.round(tf.math.sigmoid(pred))
+                    else:
+                        y_true = tf.concat((y_true, torch_labels), axis=0)
+                        y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+                t2 = time.perf_counter()
+
+                y_true = tf.cast(y_true, tf.int32)
+                y_pred = tf.cast(y_pred, tf.int32)
+
+                y_true_train = tf.cast(y_true_train, tf.int32)
+                y_pred_train = tf.cast(y_pred_train, tf.int32)
+
+                f1 = f1_score(y_true, y_pred, average='macro')
+                accuracy = accuracy_score(y_true, y_pred)
+
+                train_accuracy = accuracy_score(y_true_train, y_pred_train)
+
+                print('epoch={}, i_fold={}, time={:.2f}, train_loss={:.2f}, test_loss={:.2f}, train_acc={:.2f}, test_f1={:.2f}, test_acc={:.2f}'.format(epoch, i_fold, t2-t1, epoch_loss, test_loss, train_accuracy, f1, accuracy))
+    #             print(epoch_loss)
+                if f1 >= best_so_far:
+                    print("found better model")
+                    # Save model parameters
+                    classifier_model.save(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
+#                     recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold)))
+                    pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer_{}.pt'.format(i_fold)), 'wb+'))
+                    best_so_far = f1
+
+            classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
+#             recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold)))
+
+        epoch_loss = 0
+
+        y_true = None
+        y_pred = None
+
+        pred_dict = {}
+        gt_dict = {}
+
+        t1 = time.perf_counter()
+    #         test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader]
+        transform = torchvision.transforms.Compose(
+        [VideoGrayScaler(),
+         MinMaxScaler(0, 255),
+         torchvision.transforms.Resize((image_height, image_width))
+        ])
+
+        test_videos = set()
+        for labels, local_batch, vid_f in test_loader:
+            test_videos.update(vid_f)
+
+        test_labels = [vid_f.split('/')[-3] for vid_f in test_videos]
+
+        if args.frames_to_skip == 1:
+            y_pred, y_true, fn, fp = calculate_pnb_scores(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform)
+        else:
+            y_pred, y_true, fn, fp = calculate_pnb_scores_skipped_frames(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, args.frames_to_skip, image_width, image_height, transform)
+            
+        t2 = time.perf_counter()
+
+        print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1))
+
+        y_true = tf.cast(y_true, tf.int32)
+        y_pred = tf.cast(y_pred, tf.int32)
+
+        f1 = f1_score(y_true, y_pred, average='macro')
+        accuracy = accuracy_score(y_true, y_pred)
+
+        fn_ids.extend(fn)
+        fp_ids.extend(fp)
+
+        overall_true.extend(y_true)
+        overall_pred.extend(y_pred)
+
+        print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy))
+
+        print(confusion_matrix(y_true, y_pred))
+            
+        i_fold += 1
+
+    fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
+    with open(fp_fn_file, 'w+') as in_f:
+        in_f.write('FP:\n')
+        in_f.write(str(fp_ids) + '\n\n')
+        in_f.write('FN:\n')
+        in_f.write(str(fn_ids) + '\n\n')
+        
+    overall_true = np.array(overall_true)
+    overall_pred = np.array(overall_pred)
+            
+    final_f1 = f1_score(overall_true, overall_pred, average='macro')
+    final_acc = accuracy_score(overall_true, overall_pred)
+    final_conf = confusion_matrix(overall_true, overall_pred)
+            
+    print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1))
+    print(final_conf)
+
diff --git a/sparse_coding_torch/train_classifier_needle.py b/sparse_coding_torch/pnb/train_classifier_needle.py
similarity index 98%
rename from sparse_coding_torch/train_classifier_needle.py
rename to sparse_coding_torch/pnb/train_classifier_needle.py
index 79fa249..6f20678 100644
--- a/sparse_coding_torch/train_classifier_needle.py
+++ b/sparse_coding_torch/pnb/train_classifier_needle.py
@@ -4,8 +4,10 @@ import torch.nn.functional as F
 from tqdm import tqdm
 import argparse
 import os
-from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler, get_sample_weights
-from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d
+from sparse_coding_torch.pnb.load_data import load_pnb_videos
+from sparse_coding_torch.utils import SubsetWeightedRandomSampler, get_sample_weights
+from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d
+from sparse_coding_torch.pnb.classifier_model import PNBClassifier
 import time
 import numpy as np
 from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
@@ -14,7 +16,7 @@ import pickle
 import tensorflow.keras as keras
 import tensorflow as tf
 from sparse_coding_torch.train_sparse_model import sparse_loss
-from sparse_coding_torch.utils import calculate_pnb_scores
+from sparse_coding_torch.pnb.train_classifier import calculate_pnb_scores
 from yolov4.get_bounding_boxes import YoloModel
 
 configproto = tf.compat.v1.ConfigProto()
diff --git a/sparse_coding_torch/pnb/train_sparse_model.py b/sparse_coding_torch/pnb/train_sparse_model.py
new file mode 100644
index 0000000..698cbdc
--- /dev/null
+++ b/sparse_coding_torch/pnb/train_sparse_model.py
@@ -0,0 +1,172 @@
+import time
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from matplotlib import cm
+from matplotlib.animation import FuncAnimation
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.pnb.load_data import load_pnb_videos, load_needle_clips
+import tensorflow.keras as keras
+import tensorflow as tf
+from sparse_coding_torch.sparse_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse
+import random
+
+def sparse_loss(images, recon, activations, batch_size, lam, stride):
+    loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
+    loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
+    return loss
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=32, type=int)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=32, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=300, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lr', default=0.003, type=float)
+    parser.add_argument('--epochs', default=150, type=int)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--seed', default=42, type=int)
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--save_filters', action='store_true')
+    parser.add_argument('--optimizer', default='sgd', type=str)
+    parser.add_argument('--dataset', default='onsd', type=str)
+    parser.add_argument('--crop_height', type=int, default=400)
+    parser.add_argument('--crop_width', type=int, default=400)
+    parser.add_argument('--scale_factor', type=int, default=1)
+    parser.add_argument('--clip_depth', type=int, default=5)
+    parser.add_argument('--frames_to_skip', type=int, default=1)
+
+    args = parser.parse_args()
+    
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+
+    crop_height = args.crop_height
+    crop_width = args.crop_width
+
+    image_height = int(crop_height / args.scale_factor)
+    image_width = int(crop_width / args.scale_factor)
+    clip_depth = args.clip_depth
+
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+        
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+
+    if args.dataset == 'pnb':
+        train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=False, balance_classes=False, mode='all_train', frames_to_skip=args.frames_to_skip)
+    elif args.dataset == 'needle':
+        train_loader, test_loader, dataset = load_needle_clips(args.batch_size, input_size=(image_height, image_width, clip_depth))
+    else:
+        raise Exception('Invalid dataset')
+    
+    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
+                                           sampler=train_sampler)
+    
+    print('Loaded', len(train_loader), 'train examples')
+
+    example_data = next(iter(train_loader))
+
+    if args.run_2d:
+        inputs = keras.Input(shape=(image_height, image_width, 5))
+    else:
+        inputs = keras.Input(shape=(5, image_height, image_width, 1))
+        
+    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+    output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+    
+    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+    
+    recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+    
+    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
+    
+    if args.save_filters:
+        if args.run_2d:
+            filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0))
+        else:
+            filters = plot_filters(recon_model.get_weights()[0])
+        filters.save(os.path.join(args.output_dir, 'filters_start.mp4'))
+
+    learning_rate = args.lr
+    if args.optimizer == 'sgd':
+        filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
+    else:
+        filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
+
+    loss_log = []
+    best_so_far = float('inf')
+
+    for epoch in range(args.epochs):
+        epoch_loss = 0
+        running_loss = 0.0
+        epoch_start = time.perf_counter()
+        
+        num_iters = 0
+
+        for labels, local_batch, vid_f in tqdm(train_loader):
+            if local_batch.size(0) != args.batch_size:
+                continue
+            if args.run_2d:
+                images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
+            else:
+                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+                
+            activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+            
+            with tf.GradientTape() as tape:
+                recon = recon_model(activations)
+                loss = sparse_loss(images, recon, activations, args.batch_size, args.lam, args.stride)
+
+            epoch_loss += loss * local_batch.size(0)
+            running_loss += loss * local_batch.size(0)
+
+            gradients = tape.gradient(loss, recon_model.trainable_weights)
+
+            filter_optimizer.apply_gradients(zip(gradients, recon_model.trainable_weights))
+            
+            if args.run_2d:
+                weights = normalize_weights(recon_model.get_weights(), args.num_kernels)
+            else:
+                weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels)
+            recon_model.set_weights(weights)
+                
+            num_iters += 1
+
+        epoch_end = time.perf_counter()
+        epoch_loss /= len(train_loader.sampler)
+        
+        if args.save_filters and epoch % 2 == 0:
+            if args.run_2d:
+                filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0))
+            else:
+                filters = plot_filters(recon_model.get_weights()[0])
+            filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4'))
+
+        if epoch_loss < best_so_far:
+            print("found better model")
+            # Save model parameters
+            recon_model.save(os.path.join(output_dir, "best_sparse.pt"))
+            best_so_far = epoch_loss
+
+        loss_log.append(epoch_loss)
+        print('epoch={}, epoch_loss={:.2f}, time={:.2f}'.format(epoch, epoch_loss, epoch_end - epoch_start))
+
+    plt.plot(loss_log)
+
+    plt.savefig(os.path.join(output_dir, 'loss_graph.png'))
diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/pnb/video_loader.py
similarity index 52%
rename from sparse_coding_torch/video_loader.py
rename to sparse_coding_torch/pnb/video_loader.py
index 6621f9a..9adc4a1 100644
--- a/sparse_coding_torch/video_loader.py
+++ b/sparse_coding_torch/pnb/video_loader.py
@@ -24,44 +24,8 @@ import random
 import cv2
 from yolov4.get_bounding_boxes import YoloModel
 
-def get_ptx_participants():
-    video_to_participant = {}
-    with open('/shared_data/bamc_data/bamc_video_info.csv', 'r') as csv_in:
-        reader = csv.DictReader(csv_in)
-        for row in reader:
-            key = row['Filename'].split('.')[0].lower().replace('_clean', '')
-            if key == '37 (mislabeled as 38)':
-                key = '37'
-            video_to_participant[key] = row['Participant_id']
-            
-    return video_to_participant
-
 def get_participants(filenames):
     return [f.split('/')[-2] for f in filenames]
-
-class MinMaxScaler(object):
-    """
-    Transforms each channel to the range [0, 1].
-    """
-    def __init__(self, min_val=0, max_val=254):
-        self.min_val = min_val
-        self.max_val = max_val
-    
-    def __call__(self, tensor):
-        return (tensor - self.min_val) / (self.max_val - self.min_val)
-
-class VideoGrayScaler(nn.Module):
-    
-    def __init__(self):
-        super().__init__()
-        self.grayscale = tv.transforms.Grayscale(num_output_channels=1)
-        
-    def forward(self, video):
-        # shape = channels, time, width, height
-        video = self.grayscale(video.swapaxes(-4, -3).swapaxes(-2, -1))
-        video = video.swapaxes(-4, -3).swapaxes(-2, -1)
-        # print(video.shape)
-        return video
     
 def load_pnb_region_labels(file_path):
     all_regions = {}
@@ -76,17 +40,22 @@ def load_pnb_region_labels(file_path):
             
         return all_regions
     
-def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height):
+def get_yolo_regions(yolo_model, clip, is_right, needle_bb, crop_width, crop_height):
     orig_height = clip.size(2)
     orig_width = clip.size(3)
-    bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy())
+    bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy())
     bounding_boxes = bounding_boxes.squeeze(0)
     classes = classes.squeeze(0)
+    scores = scores.squeeze(0)
+    
+    for bb, class_pred in zip(bounding_boxes, classes):
+        if class_pred == 2:
+            needle_bb = bb
     
     rotate_box = True
     
     all_clips = []
-    for bb, class_pred in zip(bounding_boxes, classes):
+    for bb, class_pred, score in zip(bounding_boxes, classes, scores):
         if class_pred != 0:
             continue
         center_x = round((bb[3] + bb[1]) / 2 * orig_width)
@@ -97,7 +66,13 @@ def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height)
         lower_x = round((bb[1] * orig_width))
         upper_x = round((bb[3] * orig_width))
         
-        lower_y = upper_y - crop_height
+        if is_right:
+            angle = calculate_angle(needle_bb, upper_x, center_y, orig_height, orig_width)
+        else:
+            angle = calculate_angle(needle_bb, lower_x, center_y, orig_height, orig_width)
+        
+        lower_y = center_y - (crop_height // 2)
+        upper_y = center_y + (crop_height // 2) 
         
         if is_right:
             lower_x = center_x - crop_width
@@ -106,13 +81,22 @@ def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height)
             lower_x = center_x
             upper_x = center_x + crop_width
             
+        if lower_x < 0:
+            lower_x = 0
+        if upper_x < 0:
+            upper_x = 0
+        if lower_y < 0:
+            lower_y = 0
+        if upper_y < 0:
+            upper_y = 0
+            
         if rotate_box:
 #             cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
             if is_right:
-                clip = tv.transforms.functional.rotate(clip, angle=angle, center=[upper_x, upper_y])
+                clip = tv.transforms.functional.rotate(clip, angle=angle, center=[upper_x, center_y])
             else:
 #                 cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-                clip = tv.transforms.functional.rotate(clip, angle=-angle, center=[lower_x, upper_y])
+                clip = tv.transforms.functional.rotate(clip, angle=-angle, center=[lower_x, center_y])
 #                 cv2.imwrite('test_2.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
 
         trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
@@ -120,11 +104,8 @@ def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height)
 #         if orig_width - center_x >= center_x:
 #         if not is_right:
 #         print(angle)
-#         cv2.imwrite('test_2.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-#         cv2.imwrite('test_3.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-#         raise Exception
-        
-#         print(trimmed_clip.size())
+#         cv2.imwrite('test_2{}.png'.format(lower_y), clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#         cv2.imwrite('test_3{}.png'.format(lower_y), trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
 
         if not is_right:
             trimmed_clip = tv.transforms.functional.hflip(trimmed_clip)
@@ -147,7 +128,7 @@ def classify_nerve_is_right(yolo_model, video):
 
     for frame in range(0, video.size(1), round(video.size(1) / 10)):
         frame = video[:, frame, :, :]
-        bounding_boxes, classes = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
+        bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
         bounding_boxes = bounding_boxes.squeeze(0)
         classes = classes.squeeze(0)
     
@@ -181,25 +162,49 @@ def classify_nerve_is_right(yolo_model, video):
 
     return final_pred == 1
 
-def calculate_angle(yolo_model, video):
+def calculate_angle(needle_bb, vessel_x, vessel_y, orig_height, orig_width):
+    needle_x = needle_bb[1] * orig_width
+    needle_y = needle_bb[0] * orig_height
+
+    return np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x))))
+
+def get_needle_bb(yolo_model, video):
+    orig_height = video.size(2)
+    orig_width = video.size(3)
+
+    for frame in range(0, video.size(1), 1):
+        frame = video[:, frame, :, :]
+        
+        bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
+        bounding_boxes = bounding_boxes.squeeze(0)
+        classes = classes.squeeze(0)
+
+        for bb, class_pred in zip(bounding_boxes, classes):
+            if class_pred == 2:
+                return bb
+        
+    return None
+        
+def calculate_angle_video(yolo_model, video):
     orig_height = video.size(2)
     orig_width = video.size(3)
 
     all_preds = []
     if video.size(1) < 10:
-        return 45
+        return 30
 
-    for frame in range(0, video.size(1), round(video.size(1) / 10)):
+    for frame in range(0, video.size(1), video.size(1) // 10):
         frame = video[:, frame, :, :]
-        bounding_boxes, classes = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
+        
+        bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy())
         bounding_boxes = bounding_boxes.squeeze(0)
         classes = classes.squeeze(0)
-        
+
         vessel_x = 0
         vessel_y = 0
         needle_x = 0
         needle_y = 0
-    
+
         for bb, class_pred in zip(bounding_boxes, classes):
             if class_pred == 0 and vessel_x == 0:
                 vessel_x = (bb[3] + bb[1]) / 2 * orig_width
@@ -207,14 +212,14 @@ def calculate_angle(yolo_model, video):
             elif class_pred == 2 and needle_x == 0:
                 needle_x = bb[1] * orig_width
                 needle_y = bb[0] * orig_height
-                
+
             if needle_x != 0 and vessel_x != 0:
                 break
-                
-        if vessel_x == 0 or needle_x == 0:
-            return 45
-        else:
-            return np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x))))
+
+        if vessel_x > 0 and needle_x > 0:
+            all_preds.append(np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x)))))
+        
+    return np.mean(np.array(all_preds))
                 
     
 class PNBLoader(Dataset):
@@ -231,7 +236,7 @@ class PNBLoader(Dataset):
             clip_cache_file = 'clip_cache_pnb_{}_{}_sparse.pt'.format(clip_width, clip_height)
             clip_cache_final_file = 'clip_cache_pnb_{}_{}_final_sparse.pt'.format(clip_width, clip_height)
         
-        region_labels = load_pnb_region_labels('/shared_data/bamc_pnb_data/full_training_data/sme_region_labels.csv')
+        region_labels = load_pnb_region_labels('sme_region_labels.csv')
         
         self.videos = []
         for label in self.labels:
@@ -252,7 +257,7 @@ class PNBLoader(Dataset):
             for label, path, _ in tqdm(self.videos):
                 vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
                 is_right = classify_nerve_is_right(yolo_model, vc)
-                angle = calculate_angle(yolo_model, vc)
+                needle_bb = get_needle_bb(yolo_model, vc)
 
                 if classify_mode:
 #                     person_idx = path.split('/')[-2]
@@ -277,7 +282,7 @@ class PNBLoader(Dataset):
                                 if vc_sub.size(1) < clip_depth:
                                     continue
 
-                                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
+                                for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height):
                                     if self.transform:
                                         clip = self.transform(clip)
 
@@ -297,7 +302,7 @@ class PNBLoader(Dataset):
                                     if vc_sub.size(1) < clip_depth:
                                         continue
                                         
-                                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
+                                    for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height):
                                         if self.transform:
                                             clip = self.transform(clip)
 
@@ -316,7 +321,7 @@ class PNBLoader(Dataset):
 
                                         if vc_sub.size(1) < clip_depth:
                                             continue
-                                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
+                                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height):
                                             if self.transform:
                                                 clip = self.transform(clip)
 
@@ -332,7 +337,7 @@ class PNBLoader(Dataset):
                         vc_sub = torch.stack(frames, dim=1)
                         if vc_sub.size(1) < clip_depth:
                             continue
-                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height):
+                        for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height):
                             if self.transform:
                                 clip = self.transform(clip)
 
@@ -345,7 +350,7 @@ class PNBLoader(Dataset):
                             vc_sub = torch.stack(frames, dim=1)
                             if vc_sub.size(1) < clip_depth:
                                 continue
-                            for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height):
+                            for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height):
                                 if self.transform:
                                     clip = self.transform(clip)
 
@@ -419,85 +424,6 @@ class PNBLoader(Dataset):
     def __len__(self):
         return len(self.clips)
     
-class ONSDLoader(Dataset):
-    
-    def __init__(self, video_path, clip_width, clip_height, clip_depth, num_frames=5, transform=None, augmentation=None):
-        self.transform = transform
-        self.augmentation = augmentation
-        self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
-        
-        clip_cache_file = 'clip_cache_onsd_{}_{}_sparse.pt'.format(clip_width, clip_height)
-        clip_cache_final_file = 'clip_cache_onsd_{}_{}_final_sparse.pt'.format(clip_width, clip_height)
-        
-        self.videos = []
-        for label in self.labels:
-#             self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))])
-            self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))])
-        
-#         self.videos = list(filter(lambda x: x[1].split('/')[-2] in ['67', '94', '134', '193', '222', '240'], self.videos))
-#         self.videos = list(filter(lambda x: x[1].split('/')[-2] in ['67'], self.videos))
-            
-        self.clips = []
-        self.final_clips = {}
-        
-        if exists(clip_cache_file):
-            self.clips = torch.load(open(clip_cache_file, 'rb'))
-            self.final_clips = torch.load(open(clip_cache_final_file, 'rb'))
-        else:
-            vid_idx = 0
-            for label, path, _ in tqdm(self.videos):
-                vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
-                
-                for j in range(0, vc.size(1) - clip_depth, 5):
-                    frames = []
-                    for k in range(j, j + clip_depth, 1):
-                        frames.append(vc[:, k, :, :])
-                    vc_sub = torch.stack(frames, dim=1)
-
-                    if vc_sub.size(1) != clip_depth:
-                        continue
-
-                    if self.transform:
-                        vc_sub = self.transform(vc_sub)
-
-                    self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2]))
-
-                self.final_clips[self.videos[vid_idx][2]] = self.clips[-1]
-                vid_idx += 1
-                
-            torch.save(self.clips, open(clip_cache_file, 'wb+'))
-            torch.save(self.final_clips, open(clip_cache_final_file, 'wb+'))
-            
-        num_positive = len([clip[0] for clip in self.clips if clip[0] == 'Positives'])
-        num_negative = len([clip[0] for clip in self.clips if clip[0] == 'Negatives'])
-        
-        random.shuffle(self.clips)
-        
-        print('Loaded', num_positive, 'positive examples.')
-        print('Loaded', num_negative, 'negative examples.')
-        
-    def get_filenames(self):
-        return [self.clips[i][2] for i in range(len(self.clips))]
-        
-    def get_video_labels(self):
-        return [self.videos[i][0] for i in range(len(self.videos))]
-        
-    def get_labels(self):
-        return [self.clips[i][0] for i in range(len(self.clips))]
-    
-    def get_final_clips(self):
-        return self.final_clips
-    
-    def __getitem__(self, index):
-        label, clip, vid_f = self.clips[index]
-        if self.augmentation:
-            clip = clip.swapaxes(0, 1)
-            clip = self.augmentation(clip)
-            clip = clip.swapaxes(0, 1)
-        return (label, clip, vid_f)
-        
-    def __len__(self):
-        return len(self.clips)
     
 class NeedleLoader(Dataset):
     def __init__(self, video_path, transform=None, augmentation=None):
@@ -534,205 +460,4 @@ class NeedleLoader(Dataset):
         
     def __len__(self):
         return len(self.clips)
-    
-class YoloClipLoader(Dataset):
-    
-    def __init__(self, yolo_output_path, num_frames=5, frames_between_clips=None,
-                 transform=None, augment_transform=None, num_clips=1, num_positives=1, positive_videos=None, sparse_model=None, device=None):
-        if (num_frames % 2) == 0:
-            raise ValueError("Num Frames must be an odd number, so we can extract a clip centered on each detected region")
-        
-        clip_cache_file = 'clip_cache.pt'
-        
-        self.num_clips = num_clips
-        
-        self.num_frames = num_frames
-        if frames_between_clips is None:
-            self.frames_between_clips = num_frames
-        else:
-            self.frames_between_clips = frames_between_clips
-
-        self.transform = transform
-        self.augment_transform = augment_transform
-         
-        self.labels = [name for name in listdir(yolo_output_path) if isdir(join(yolo_output_path, name))]
-        self.clips = []
-        if exists(clip_cache_file):
-            self.clips = torch.load(open(clip_cache_file, 'rb'))
-        else:
-            for label in self.labels:
-                print("Processing videos in category: {}".format(label))
-                videos = list(listdir(join(yolo_output_path, label)))
-                for vi in tqdm(range(len(videos))):
-                    video = videos[vi]
-                    counter = 0
-                    all_trimmed = []
-                    with open(abspath(join(yolo_output_path, label, video, 'result.json'))) as fin:
-                        results = json.load(fin)
-                        max_frame = len(results)
-
-                        for i in range((num_frames-1)//2, max_frame - (num_frames-1)//2 - 1, self.frames_between_clips):
-                        # for frame in results:
-                            frame = results[i]
-                            # print('loading frame:', i, frame['frame_id'])
-                            frame_start = int(frame['frame_id']) - self.num_frames//2
-                            frames = [abspath(join(yolo_output_path, label, video, 'frame{}.png'.format(frame_start+fid)))
-                                      for fid in range(num_frames)]
-                            # print(frames)
-                            frames = torch.stack([ToTensor()(Image.open(f).convert("RGB")) for f in frames]).swapaxes(0, 1)
-
-                            for region in frame['objects']:
-                                # print(region)
-                                if region['name'] != "Pleural_Line":
-                                    continue
-
-                                center_x = region['relative_coordinates']["center_x"] * 1920
-                                center_y = region['relative_coordinates']['center_y'] * 1080
-
-                                # width = region['relative_coordinates']['width'] * 1920
-                                # height = region['relative_coordinates']['height'] * 1080
-                                width=200
-                                height=100
-
-                                lower_y = round(center_y - height / 2)
-                                upper_y = round(center_y + height / 2)
-                                lower_x = round(center_x - width / 2)
-                                upper_x = round(center_x + width / 2)
-
-                                final_clip = frames[:, :, lower_y:upper_y, lower_x:upper_x]
-
-                                if self.transform:
-                                    final_clip = self.transform(final_clip)
-
-                                if sparse_model:
-                                    with torch.no_grad():
-                                        final_clip = final_clip.unsqueeze(0).to(device)
-                                        final_clip = sparse_model(final_clip)
-                                        final_clip = final_clip.squeeze(0).detach().cpu()
-
-                                self.clips.append((label, final_clip, video))
-
-            torch.save(self.clips, open(clip_cache_file, 'wb+'))
-            
-            
-#         random.shuffle(self.clips)
-            
-#         video_to_clips = {}
-        if positive_videos:
-            vids_to_keep = json.load(open(positive_videos))
-            
-            self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2] in vids_to_keep or clip_tup[0] == 'Sliding']
-        else:
-            video_to_labels = {}
-
-            for lbl, clip, video in self.clips:
-                video = video.lower().replace('_clean', '')
-                if video not in video_to_labels:
-    #                 video_to_clips[video] = []
-                    video_to_labels[video] = []
-
-    #             video_to_clips[video].append(clip)
-                video_to_labels[video].append(lbl)
-
-            video_to_participants = get_ptx_participants()
-            participants_to_video = {}
-            for k, v in video_to_participants.items():
-                if video_to_labels[k][0] == 'Sliding':
-                    continue
-                if not v in participants_to_video:
-                    participants_to_video[v] = []
-
-                participants_to_video[v].append(k)
-
-            participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
-
-            num_to_remove = len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) - num_positives
-            vids_to_remove = set()
-            while num_to_remove > 0:
-                vids_to_remove.add(participants_to_video[list(participants_to_video.keys())[0]].pop())
-                participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
-                num_to_remove -= 1
-                    
-            self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2].lower().replace('_clean', '') not in vids_to_remove]
-        
-        video_to_clips = {}
-        video_to_labels = {}
-
-        for lbl, clip, video in self.clips:
-            if video not in video_to_clips:
-                video_to_clips[video] = []
-                video_to_labels[video] = []
-
-            video_to_clips[video].append(clip)
-            video_to_labels[video].append(lbl)
-            
-        print([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding'])
-            
-        print('Num positive:', len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']))
-        print('Num negative:', len([k for k,v in video_to_labels.items() if v[0] == 'Sliding']))
-
-        self.videos = None
-        self.max_video_clips = 0
-        if num_clips > 1:
-            self.videos = []
-
-            for video in video_to_clips.keys():
-                clip_list = video_to_clips[video]
-                lbl_list = video_to_labels[video]
-                
-                for i in range(0, len(clip_list) - num_clips, 1):
-                    video_stack = torch.stack(clip_list[i:i+num_clips])
-                
-                    self.videos.append((max(set(lbl_list[i:i+num_clips]), key=lbl_list[i:i+num_clips].count), video_stack, video))
-            
-            self.clips = None
-
-            
-    def get_labels(self):
-        if self.num_clips > 1:
-            return [self.videos[i][0] for i in range(len(self.videos))]
-        else:
-            return [self.clips[i][0] for i in range(len(self.clips))]
-    
-    def get_filenames(self):
-        if self.num_clips > 1:
-            return [self.videos[i][2] for i in range(len(self.videos))]
-        else:
-            return [self.clips[i][2] for i in range(len(self.clips))]
-    
-    def __getitem__(self, index): 
-        if self.num_clips > 1:
-            label = self.videos[index][0]
-            video = self.videos[index][1]
-            filename = self.videos[index][2]
-            
-            video = video.squeeze(2)
-            video = video.permute(1, 0, 2, 3)
-
-            if self.augment_transform:
-                video = self.augment_transform(video)
-                
-            video = video.unsqueeze(2)
-            video = video.permute(1, 0, 2, 3, 4)
-#             video = video.permute(4, 1, 2, 3, 0)
-#             video = torch.nn.functional.pad(video, (0), 'constant', 0)
-#             video = video.permute(4, 1, 2, 3, 0)
-
-            orig_len = video.size(0)
-
-#             if orig_len < self.max_video_clips:
-#                 video = torch.cat([video, torch.zeros(self.max_video_clips - len(video), video.size(1), video.size(2), video.size(3), video.size(4))])
-
-            return label, video, filename, orig_len
-        else:
-            label = self.clips[index][0]
-            video = self.clips[index][1]
-            filename = self.clips[index][2]
-
-            if self.augment_transform:
-                video = self.augment_transform(video)
-
-            return label, video, filename
-        
-    def __len__(self):
-        return len(self.clips)
+    
\ No newline at end of file
diff --git a/sparse_coding_torch/ptx/classifier_model.py b/sparse_coding_torch/ptx/classifier_model.py
new file mode 100644
index 0000000..fefc2fc
--- /dev/null
+++ b/sparse_coding_torch/ptx/classifier_model.py
@@ -0,0 +1,120 @@
+from tensorflow import keras
+import numpy as np
+import torch
+import tensorflow as tf
+import cv2
+import torchvision as tv
+import torch
+import torch.nn as nn
+from sparse_coding_torch.ptx.video_loader import VideoGrayScaler, MinMaxScaler
+from sparse_coding_torch.sparse_model import SparseCode
+
+class PTXClassifier(keras.layers.Layer):
+    def __init__(self):
+        super(PTXClassifier, self).__init__()
+
+        self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
+        self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
+#         self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
+
+        self.flatten = keras.layers.Flatten()
+
+        self.dropout = keras.layers.Dropout(0.5)
+
+#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+#         self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True)
+#         self.ff_2 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_4 = keras.layers.Dense(1)
+
+#     @tf.function
+    def call(self, activations):
+        activations = tf.squeeze(activations, axis=1)
+        x = self.max_pool(activations)
+        x = self.conv_1(activations)
+#         x = self.conv_2(x)
+        x = self.flatten(x)
+#         x = self.ff_1(x)
+#         x = self.dropout(x)
+#         x = self.ff_2(x)
+#         x = self.dropout(x)
+        x = self.ff_3(x)
+        x = self.dropout(x)
+        x = self.ff_4(x)
+
+        return x
+    
+class BaselinePTX(keras.layers.Layer):
+    def __init__(self):
+        super(BaselinePTX, self).__init__()
+
+        self.conv_1 = keras.layers.Conv3D(64, kernel_size=(5, 8, 8), strides=(1, 4, 4), activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
+
+        self.flatten = keras.layers.Flatten()
+
+        self.dropout = keras.layers.Dropout(0.5)
+
+#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+#         self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_4 = keras.layers.Dense(1)
+
+#     @tf.function
+    def call(self, images):
+        x = self.conv_1(images)
+        x = tf.squeeze(x, axis=1)
+        x = self.conv_2(x)
+        x = self.flatten(x)
+#         x = self.ff_1(x)
+#         x = self.dropout(x)
+#         x = self.ff_2(x)
+#         x = self.dropout(x)
+        x = self.ff_3(x)
+        x = self.dropout(x)
+        x = self.ff_4(x)
+
+        return x
+
+class MobileModelPTX(keras.Model):
+    def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
+        super().__init__()
+        self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
+        self.classifier = Classifier()
+
+        self.out_channels = out_channels
+        self.in_channels = in_channels
+        self.stride = stride
+        self.lam = lam
+        self.activation_lr = activation_lr
+        self.max_activation_iter = max_activation_iter
+        self.batch_size = batch_size
+        self.run_2d = run_2d
+
+        pytorch_weights = load_pytorch_weights(sparse_checkpoint)
+
+        if run_2d:
+            weight_list = np.split(pytorch_weights, 5, axis=0)
+            self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
+            self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False)
+            self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False)
+            self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False)
+            self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False)
+        else:
+            self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False)
+
+    @tf.function
+    def call(self, images):
+        images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1)
+        images = tf.transpose(images, perm=[0, 2, 3, 1])
+        images = images / 255
+        images = (images - 0.2592) / 0.1251
+
+        if self.run_2d:
+            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
+        else:
+            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
+
+        pred = self.classifier(activations)
+
+        return pred
\ No newline at end of file
diff --git a/sparse_coding_torch/convert_pytorch_to_keras.py b/sparse_coding_torch/ptx/convert_pytorch_to_keras.py
similarity index 85%
rename from sparse_coding_torch/convert_pytorch_to_keras.py
rename to sparse_coding_torch/ptx/convert_pytorch_to_keras.py
index 4f3e1ab..3e28b9d 100644
--- a/sparse_coding_torch/convert_pytorch_to_keras.py
+++ b/sparse_coding_torch/ptx/convert_pytorch_to_keras.py
@@ -1,7 +1,8 @@
 import argparse
 from tensorflow import keras
 import tensorflow as tf
-from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, load_pytorch_weights
+from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, load_pytorch_weights
+from sparse_coding_torch.ptx.classifier_model import PTXClassifier
 import torch
 import os
 
@@ -13,7 +14,6 @@ if __name__ == "__main__":
     parser.add_argument('--kernel_depth', default=5, type=int)
     parser.add_argument('--num_kernels', default=64, type=int)
     parser.add_argument('--stride', default=2, type=int)
-    parser.add_argument('--dataset', default='ptx', type=str)
     parser.add_argument('--input_image_height', default=100, type=int)
     parser.add_argument('--input_image_width', default=200, type=int)
     parser.add_argument('--output_dir', default='./converted_checkpoints', type=str)
@@ -30,14 +30,8 @@ if __name__ == "__main__":
     if args.classifier_checkpoint:
         classifier_inputs = keras.Input(shape=(1, args.input_image_height // args.stride, args.input_image_width // args.stride, args.num_kernels))
 
-        if args.dataset == 'pnb':
-            classifier_outputs = PNBClassifier()(classifier_inputs)
-            classifier_name = 'pnb_classifier'
-        elif args.dataset == 'ptx':
-            classifier_outputs = PTXClassifier()(classifier_inputs)
-            classifier_name = 'ptx_classifier'
-        else:
-            raise Exception('No classifier exists for that dataset')
+        classifier_outputs = PTXClassifier()(classifier_inputs)
+        classifier_name = 'ptx_classifier'
 
         classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
         
diff --git a/sparse_coding_torch/ptx/generate_tflite.py b/sparse_coding_torch/ptx/generate_tflite.py
new file mode 100644
index 0000000..5240156
--- /dev/null
+++ b/sparse_coding_torch/ptx/generate_tflite.py
@@ -0,0 +1,63 @@
+from tensorflow import keras
+import numpy as np
+import torch
+import tensorflow as tf
+import cv2
+import torchvision as tv
+import torch
+import torch.nn as nn
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
+from sparse_coding_torch.ptx.classifier_model import MobileModelPTX
+import argparse
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_dir', default='/shared_data/bamc_pnb_data/revised_training_data', type=str)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=32, type=int)
+    parser.add_argument('--stride', default=4, type=int)
+    parser.add_argument('--max_activation_iter', default=150, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--sparse_checkpoint', default='sparse_coding_torch/output/sparse_pnb_32_long_train/sparse_conv3d_model-best.pt/', type=str)
+    parser.add_argument('--checkpoint', default='sparse_coding_torch/classifier_outputs/32_filters_no_aug_3/best_classifier.pt/', type=str)
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--batch_size', default=1, type=int)
+    parser.add_argument('--image_height', type=int, default=285)
+    parser.add_argument('--image_width', type=int, default=400)
+    parser.add_argument('--clip_depth', type=int, default=5)
+    
+    args = parser.parse_args()
+    #print(args.accumulate(args.integers))
+    batch_size = args.batch_size
+
+    image_height = args.image_height
+    image_width = args.image_width
+    clip_depth = args.clip_depth
+    
+    recon_model = keras.models.load_model(args.sparse_checkpoint)
+        
+    classifier_model = keras.models.load_model(args.checkpoint)
+
+    inputs = keras.Input(shape=(5, image_height, image_width))
+
+    outputs = MobileModelPTX(sparse_weights=recon_model.weights[0], classifier_model=classifier_model, batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs)
+
+    model = keras.Model(inputs=inputs, outputs=outputs)
+
+    input_name = model.input_names[0]
+    index = model.input_names.index(input_name)
+    model.inputs[index].set_shape([1, 5, image_height, image_width])
+
+    converter = tf.lite.TFLiteConverter.from_keras_model(model)
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    converter.target_spec.supported_types = [tf.float16]
+    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+
+    tflite_model = converter.convert()
+
+    print('Converted')
+
+    with open("./sparse_coding_torch/mobile_output/ptx.tflite", "wb") as f:
+        f.write(tflite_model)
diff --git a/sparse_coding_torch/ptx/load_data.py b/sparse_coding_torch/ptx/load_data.py
new file mode 100644
index 0000000..e307a28
--- /dev/null
+++ b/sparse_coding_torch/ptx/load_data.py
@@ -0,0 +1,69 @@
+import numpy as np
+import torchvision
+import torch
+from sklearn.model_selection import train_test_split
+from sparse_coding_torch.utils import MinMaxScaler, VideoGrayScaler
+from sparse_coding_torch.ptx.video_loader import YoloClipLoader, get_ptx_participants
+import csv
+from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit
+
+def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None):   
+    video_path = "/shared_data/YOLO_Updated_PL_Model_Results/"
+
+    video_to_participant = get_ptx_participants()
+    
+    transforms = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+#      MinMaxScaler(0, 255),
+     torchvision.transforms.Normalize((0.2592,), (0.1251,)),
+    ])
+    augment_transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.RandomRotation(45),
+     torchvision.transforms.RandomHorizontalFlip(),
+     torchvision.transforms.CenterCrop((100, 200))
+    ])
+    if whole_video:
+        dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
+    else:
+        dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
+    
+    targets = dataset.get_labels()
+    
+    if mode == 'leave_one_out':
+        gss = LeaveOneGroupOut()
+
+#         groups = [v for v in dataset.get_filenames()]
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    elif mode == 'all_train':
+        train_idx = np.arange(len(targets))
+#         train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+#         train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+#                                                sampler=train_sampler)
+#         test_loader = None
+        
+        return [(train_idx, None)], dataset
+    elif mode == 'k_fold':
+        gss = StratifiedGroupKFold(n_splits=n_splits)
+
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    else:
+        gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
+
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0]
+        
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        
+        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+        test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=test_sampler)
+        
+        return train_loader, test_loader, dataset
+    
diff --git a/sparse_coding_torch/train_classifier.py b/sparse_coding_torch/ptx/train_classifier.py
similarity index 55%
rename from sparse_coding_torch/train_classifier.py
rename to sparse_coding_torch/ptx/train_classifier.py
index ca71ca2..7ed8907 100644
--- a/sparse_coding_torch/train_classifier.py
+++ b/sparse_coding_torch/ptx/train_classifier.py
@@ -4,8 +4,9 @@ import torch.nn.functional as F
 from tqdm import tqdm
 import argparse
 import os
-from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler, get_sample_weights
-from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d, PNBTemporalClassifier
+from sparse_coding_torch.ptx.load_data import load_yolo_clips
+from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d
+from sparse_coding_torch.ptx.classifier_model import PTXClassifier
 import time
 import numpy as np
 from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
@@ -13,11 +14,9 @@ import random
 import pickle
 import tensorflow.keras as keras
 import tensorflow as tf
-from sparse_coding_torch.train_sparse_model import sparse_loss
-from sparse_coding_torch.utils import calculate_pnb_scores, calculate_pnb_scores_skipped_frames
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
 from yolov4.get_bounding_boxes import YoloModel
 import torchvision
-from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
 import glob
 import cv2
 
@@ -27,6 +26,95 @@ configproto.gpu_options.allow_growth = True
 sess = tf.compat.v1.Session(config=configproto) 
 tf.compat.v1.keras.backend.set_session(sess)
 
+def calculate_ptx_scores(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform):
+    all_predictions = []
+    
+    numerical_labels = []
+    for label in labels:
+        if label == 'No_Sliding':
+            numerical_labels.append(1.0)
+        else:
+            numerical_labels.append(0.0)
+
+    final_list = []
+    fp_ids = []
+    fn_ids = []
+    for v_idx, f in tqdm(enumerate(input_videos)):
+        clipstride = 15
+        
+        vc = VideoClips([f],
+                        clip_length_in_frames=5,
+                        frame_rate=20,
+                       frames_between_clips=clipstride)
+
+        clip_predictions = []
+        i = 0
+        cliplist = []
+        countclips = 0
+        for i in range(vc.num_clips()):
+            clip, _, _, _ = vc.get_clip(i)
+            clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
+            
+            bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1))
+            bounding_boxes = bounding_boxes.squeeze(0)
+            if bounding_boxes.size == 0:
+                continue
+            #widths = []
+            countclips = countclips + len(bounding_boxes)
+            
+            widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
+            
+            #for i in range(len(bounding_boxes)):
+            #    widths.append(bounding_boxes[i][3] - bounding_boxes[i][1])
+
+            ind =  np.argmax(np.array(widths))
+            #for bb in bounding_boxes:
+            bb = bounding_boxes[ind]
+            center_x = (bb[3] + bb[1]) / 2 * 1920
+            center_y = (bb[2] + bb[0]) / 2 * 1080
+
+            width=400
+            height=400
+
+            lower_y = round(center_y - height / 2)
+            upper_y = round(center_y + height / 2)
+            lower_x = round(center_x - width / 2)
+            upper_x = round(center_x + width / 2)
+
+            trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
+
+            trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)
+
+            trimmed_clip = transform(trimmed_clip)
+            trimmed_clip.pin_memory()
+            cliplist.append(trimmed_clip)
+
+        if len(cliplist) > 0:
+            with torch.no_grad():
+                trimmed_clip = torch.stack(cliplist)
+                images = trimmed_clip.permute(0, 2, 3, 4, 1).numpy()
+                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
+
+                pred = classifier_model(activations)
+                #print(torch.nn.Sigmoid()(pred))
+                clip_predictions = tf.math.round(tf.math.sigmoid(pred))
+
+            final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
+            if len(clip_predictions) % 2 == 0 and tf.math.reduce_sum(clip_predictions) == len(clip_predictions)//2:
+                #print("I'm here")
+                final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
+        else:
+            final_pred = 1.0
+            
+        if final_pred != numerical_labels[v_idx]:
+            if final_pred == 0.0:
+                fn_ids.append(f)
+            else:
+                fp_ids.append(f)
+            
+        final_list.append(final_pred)
+        
+    return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
@@ -51,7 +139,6 @@ if __name__ == "__main__":
     parser.add_argument('--save_train_test_splits', action='store_true')
     parser.add_argument('--run_2d', action='store_true')
     parser.add_argument('--balance_classes', action='store_true')
-    parser.add_argument('--dataset', default='pnb', type=str)
     parser.add_argument('--train_sparse', action='store_true')
     parser.add_argument('--mixing_ratio', type=float, default=1.0)
     parser.add_argument('--sparse_lr', type=float, default=0.003)
@@ -63,18 +150,9 @@ if __name__ == "__main__":
     
     args = parser.parse_args()
     
-    if args.dataset == 'pnb':
-        crop_height = args.crop_height
-        crop_width = args.crop_width
-
-        image_height = int(crop_height / args.scale_factor)
-        image_width = int(crop_width / args.scale_factor)
-        clip_depth = args.clip_depth
-    elif args.dataset == 'ptx':
-        image_height = 100
-        image_width = 200
-    else:
-        raise Exception('Invalid dataset')
+    image_height = 100
+    image_width = 200
+    clip_depth = args.clip_depth
         
     batch_size = args.batch_size
     
@@ -89,7 +167,7 @@ if __name__ == "__main__":
     with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
         out_f.write(str(args))
     
-    yolo_model = YoloModel()
+    yolo_model = YoloModel('ptx')
 
     all_errors = []
     
@@ -97,31 +175,24 @@ if __name__ == "__main__":
         inputs = keras.Input(shape=(image_height, image_width, clip_depth))
     else:
         inputs = keras.Input(shape=(clip_depth, image_height, image_width, 1))
-
+        
     filter_inputs = keras.Input(shape=(args.kernel_depth, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
 
     output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
 
     sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
-    
+
     recon_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
-    
+
     recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
-    
+
     recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
 
     if args.sparse_checkpoint:
         recon_model.set_weights(keras.models.load_model(args.sparse_checkpoint).get_weights())
         
-    positive_class = None
-    if args.dataset == 'pnb':
-        splits, dataset = load_pnb_videos(yolo_model, args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, frames_to_skip=args.frames_to_skip)
-        positive_class = 'Positives'
-    elif args.dataset == 'ptx':
-        train_loader, test_loader, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
-        positive_class = 'No_Sliding'
-    else:
-        raise Exception('Invalid dataset')
+    splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+    positive_class = 'No_Sliding'
 
     overall_true = []
     overall_pred = []
@@ -131,7 +202,6 @@ if __name__ == "__main__":
     i_fold = 0
     for train_idx, test_idx in splits:
         train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-#         train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
         train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=train_sampler)
         
@@ -149,15 +219,8 @@ if __name__ == "__main__":
         if args.checkpoint:
             classifier_model = keras.models.load_model(args.checkpoint)
         else:
-#             classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernel))
-            classifier_inputs = keras.Input(shape=(clip_depth, image_height, image_width))
-
-            if args.dataset == 'pnb':
-                classifier_outputs = PNBTemporalClassifier()(classifier_inputs)
-            elif args.dataset == 'ptx':
-                classifier_outputs = PTXClassifier()(classifier_inputs)
-            else:
-                raise Exception('No classifier exists for that dataset')
+            classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+            classifier_outputs = PTXClassifier()(classifier_inputs)
 
             classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
 
@@ -178,14 +241,6 @@ if __name__ == "__main__":
 
                 for labels, local_batch, vid_f in tqdm(train_loader):
                     images = local_batch.permute(0, 2, 3, 4, 1).numpy()
-                    cv2.imwrite('example_video_2/test_{}_0.png'.format(labels[3]), images[3, 0, :, :] * 255)
-                    cv2.imwrite('example_video_2/test_{}_1.png'.format(labels[3]), images[3, 1, :, :] * 255)
-                    cv2.imwrite('example_video_2/test_{}_2.png'.format(labels[3]), images[3, 2, :, :] * 255)
-                    cv2.imwrite('example_video_2/test_{}_3.png'.format(labels[3]), images[3, 3, :, :] * 255)
-                    cv2.imwrite('example_video_2/test_{}_4.png'.format(labels[3]), images[3, 4, :, :] * 255)
-                    print(vid_f[3])
-                    raise Exception
-
                     torch_labels = np.zeros(len(labels))
                     torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
                     torch_labels = np.expand_dims(torch_labels, axis=1)
@@ -198,11 +253,10 @@ if __name__ == "__main__":
 
                             print(loss)
                     else:
-#                         activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
-    #                     raise Exception
+                        activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
 
                         with tf.GradientTape() as tape:
-                            pred = classifier_model(images)
+                            pred = classifier_model(activations)
                             loss = criterion(torch_labels, pred)
 
                     epoch_loss += loss * local_batch.size(0)
@@ -224,7 +278,6 @@ if __name__ == "__main__":
 
                         prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))
 
-
                     if y_true_train is None:
                         y_true_train = torch_labels
                         y_pred_train = tf.math.round(tf.math.sigmoid(pred))
@@ -237,16 +290,20 @@ if __name__ == "__main__":
                 y_true = None
                 y_pred = None
                 test_loss = 0.0
-                for labels, local_batch, vid_f in tqdm(test_loader):
+                
+                eval_loader = test_loader
+                if args.splits == 'all_train':
+                    eval_loader = train_loader
+                for labels, local_batch, vid_f in tqdm(eval_loader):
                     images = local_batch.permute(0, 2, 3, 4, 1).numpy()
 
                     torch_labels = np.zeros(len(labels))
                     torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
                     torch_labels = np.expand_dims(torch_labels, axis=1)
+                    
+                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
 
-#                     activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
-
-                    pred = classifier_model(images)
+                    pred = classifier_model(activations)
                     loss = criterion(torch_labels, pred)
 
                     test_loss += loss
@@ -277,158 +334,55 @@ if __name__ == "__main__":
                     print("found better model")
                     # Save model parameters
                     classifier_model.save(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
-                    recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold)))
+#                     recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold)))
                     pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer_{}.pt'.format(i_fold)), 'wb+'))
                     best_so_far = f1
 
             classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
-            recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold)))
+#             recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold)))
 
-        if args.dataset == 'pnb':
-            epoch_loss = 0
-            
-            transform = torchvision.transforms.Compose(
-            [VideoGrayScaler(),
-             MinMaxScaler(0, 255),
-             torchvision.transforms.Resize((image_height, image_width))
-            ])
-
-            y_true = None
-            y_pred = None
-
-            pred_dict = {}
-            gt_dict = {}
-
-            t1 = time.perf_counter()
-    #         test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader]
-            test_videos = set()
-            for labels, local_batch, vid_f in test_loader:
-                test_videos.update(vid_f)
-                
-            test_labels = [vid_f.split('/')[-3] for vid_f in test_videos]
-
-#             test_videos = glob.glob(pathname='/home/dwh48@drexel.edu/special_splits/test/*/*.mp4', recursive=True)
-#             test_labels = [f.split('/')[-2] for f in test_videos]
-#             print(test_videos)
-#             print(test_labels)
-            
-            if args.frames_to_skip == 1:
-                y_pred, y_true, fn, fp = calculate_pnb_scores(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform)
-            else:
-                y_pred, y_true, fn, fp = calculate_pnb_scores_skipped_frames(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, args.frames_to_skip, image_width, image_height, transform)
-
-            t2 = time.perf_counter()
+        epoch_loss = 0
 
-            print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1))
+        y_true = None
+        y_pred = None
 
-            y_true = tf.cast(y_true, tf.int32)
-            y_pred = tf.cast(y_pred, tf.int32)
+        pred_dict = {}
+        gt_dict = {}
 
-            f1 = f1_score(y_true, y_pred, average='macro')
-            accuracy = accuracy_score(y_true, y_pred)
-            
-            fn_ids.extend(fn)
-            fp_ids.extend(fp)
+        t1 = time.perf_counter()
+        
+        transform = torchvision.transforms.Compose(
+        [VideoGrayScaler(),
+         MinMaxScaler(0, 255),
+         torchvision.transforms.Normalize((0.2592,), (0.1251,)),
+         torchvision.transforms.CenterCrop((100, 200))
+        ])
+
+        test_dir = '/shared_data/bamc_ph1_test_data'
+        test_videos = glob.glob(os.path.join(test_dir, '*', '*.*'))
+        test_labels = [vid_f.split('/')[-2] for vid_f in test_videos]
+
+        y_pred, y_true, fn, fp = calculate_ptx_scores(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform)
             
-            overall_true.extend(y_true)
-            overall_pred.extend(y_pred)
-
-            print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy))
-
-            print(confusion_matrix(y_true, y_pred))
-        elif args.dataset == 'ptx':
-            epoch_loss = 0
-
-            y_true = None
-            y_pred = None
-
-            pred_dict = {}
-            gt_dict = {}
-
-            t1 = time.perf_counter()
-            for labels, local_batch, vid_f in test_loader:
-                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
-
-                torch_labels = np.zeros(len(labels))
-                torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
-                torch_labels = np.expand_dims(torch_labels, axis=1)
-
-                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
-
-                pred = classifier_model(activations)
-
-                loss = criterion(torch_labels, pred)
-                epoch_loss += loss * local_batch.size(0)
-
-                for i, v_f in enumerate(vid_f):
-                    if v_f not in pred_dict:
-                        pred_dict[v_f] = tf.math.round(tf.math.sigmoid(pred[i]))
-                    else:
-                        pred_dict[v_f] = tf.concat((pred_dict[v_f], tf.math.round(tf.math.sigmoid(pred[i]))), axis=0)
-
-                    if v_f not in gt_dict:
-                        gt_dict[v_f] = torch_labels[i]
-                    else:
-                        gt_dict[v_f] = tf.concat((gt_dict[v_f], torch_labels[i]), axis=0)
-
-                if y_true is None:
-                    y_true = torch_labels
-                    y_pred = tf.math.round(tf.math.sigmoid(pred))
-                else:
-                    y_true = tf.concat((y_true, torch_labels), axis=0)
-                    y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
-
-            t2 = time.perf_counter()
-
-            vid_acc = []
-            for k in pred_dict.keys():
-                gt_tmp = torch.tensor(gt_dict[k].numpy())
-                pred_tmp = torch.tensor(pred_dict[k].numpy())
-
-                gt_mode = torch.mode(gt_tmp)[0].item()
-    #             perm = torch.randperm(pred_tmp.size(0))
-    #             cutoff = int(pred_tmp.size(0)/4)
-    #             if cutoff < 3:
-    #                 cutoff = 3
-    #             idx = perm[:cutoff]
-    #             samples = pred_tmp[idx]
-                pred_mode = torch.mode(pred_tmp)[0].item()
-                overall_true.append(gt_mode)
-                overall_pred.append(pred_mode)
-                if pred_mode == gt_mode:
-                    vid_acc.append(1)
-                else:
-                    vid_acc.append(0)
-                    if pred_mode == 0:
-                        fn_ids.append(k)
-                    else:
-                        fp_ids.append(k)
+        t2 = time.perf_counter()
 
-            vid_acc = np.array(vid_acc)
+        print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1))
 
-            print('----------------------------------------------------------------------------')
-            for k in pred_dict.keys():
-                print(k)
-                print('Predictions:')
-                print(pred_dict[k])
-                print('Ground Truth:')
-                print(gt_dict[k])
-                print('Overall Prediction:')
-                print(torch.mode(torch.tensor(pred_dict[k].numpy()))[0].item())
-                print('----------------------------------------------------------------------------')
+        y_true = tf.cast(y_true, tf.int32)
+        y_pred = tf.cast(y_pred, tf.int32)
 
-            print('loss={:.2f}, time={:.2f}'.format(loss, t2-t1))
+        f1 = f1_score(y_true, y_pred, average='macro')
+        accuracy = accuracy_score(y_true, y_pred)
 
-            y_true = tf.cast(y_true, tf.int32)
-            y_pred = tf.cast(y_pred, tf.int32)
+        fn_ids.extend(fn)
+        fp_ids.extend(fp)
 
-            f1 = f1_score(y_true, y_pred, average='macro')
-            accuracy = accuracy_score(y_true, y_pred)
-            all_errors.append(np.sum(vid_acc) / len(vid_acc))
+        overall_true.extend(y_true)
+        overall_pred.extend(y_pred)
 
-            print("Test f1={:.2f}, clip_acc={:.2f}, vid_acc={:.2f}".format(f1, accuracy, np.sum(vid_acc) / len(vid_acc)))
+        print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy))
 
-            print(confusion_matrix(y_true, y_pred))
+        print(confusion_matrix(y_true, y_pred))
             
         i_fold += 1
 
diff --git a/sparse_coding_torch/train_sparse_model.py b/sparse_coding_torch/ptx/train_sparse_model.py
similarity index 88%
rename from sparse_coding_torch/train_sparse_model.py
rename to sparse_coding_torch/ptx/train_sparse_model.py
index a15e73d..43c03df 100644
--- a/sparse_coding_torch/train_sparse_model.py
+++ b/sparse_coding_torch/ptx/train_sparse_model.py
@@ -7,10 +7,10 @@ from matplotlib.animation import FuncAnimation
 from tqdm import tqdm
 import argparse
 import os
-from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, load_needle_clips, load_onsd_videos
+from sparse_coding_torch.ptx.load_data import load_yolo_clips
 import tensorflow.keras as keras
 import tensorflow as tf
-from sparse_coding_torch.keras_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse
+from sparse_coding_torch.sparse_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse
 import random
 
 def plot_video(video):
@@ -139,19 +139,9 @@ if __name__ == "__main__":
 
     with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
         out_f.write(str(args))
-
-    if args.dataset == 'onsd':
-        splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), mode='all_train')
-        train_idx, test_idx = splits[0]
-    elif args.dataset == 'pnb':
-        train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=False, balance_classes=False, mode='all_train', frames_to_skip=args.frames_to_skip)
-    elif args.dataset == 'ptx':
-        splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
-        train_idx, test_idx = splits[0]
-    elif args.dataset == 'needle':
-        train_loader, test_loader, dataset = load_needle_clips(args.batch_size, input_size=(image_height, image_width, clip_depth))
-    else:
-        raise Exception('Invalid dataset')
+        
+    splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+    train_idx, test_idx = splits[0]
     
     train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
     train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
diff --git a/sparse_coding_torch/ptx/video_loader.py b/sparse_coding_torch/ptx/video_loader.py
new file mode 100644
index 0000000..93ad37e
--- /dev/null
+++ b/sparse_coding_torch/ptx/video_loader.py
@@ -0,0 +1,263 @@
+from os import listdir
+from os.path import isfile
+from os.path import join
+from os.path import isdir
+from os.path import abspath
+from os.path import exists
+import json
+import glob
+
+from PIL import Image
+from torchvision.transforms import ToTensor
+from torchvision.datasets.video_utils import VideoClips
+from tqdm import tqdm
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+from torch.utils.data import DataLoader
+from torchvision.io import read_video
+import torchvision as tv
+from torch import nn
+import torchvision.transforms.functional as tv_f
+import csv
+import random
+import cv2
+from yolov4.get_bounding_boxes import YoloModel
+
+def get_ptx_participants():
+    video_to_participant = {}
+    with open('/shared_data/bamc_data/bamc_video_info.csv', 'r') as csv_in:
+        reader = csv.DictReader(csv_in)
+        for row in reader:
+            key = row['Filename'].split('.')[0].lower().replace('_clean', '')
+            if key == '37 (mislabeled as 38)':
+                key = '37'
+            video_to_participant[key] = row['Participant_id']
+            
+    return video_to_participant
+
+class MinMaxScaler(object):
+    """
+    Transforms each channel to the range [0, 1].
+    """
+    def __init__(self, min_val=0, max_val=254):
+        self.min_val = min_val
+        self.max_val = max_val
+    
+    def __call__(self, tensor):
+        return (tensor - self.min_val) / (self.max_val - self.min_val)
+
+class VideoGrayScaler(nn.Module):
+    
+    def __init__(self):
+        super().__init__()
+        self.grayscale = tv.transforms.Grayscale(num_output_channels=1)
+        
+    def forward(self, video):
+        # shape = channels, time, width, height
+        video = self.grayscale(video.swapaxes(-4, -3).swapaxes(-2, -1))
+        video = video.swapaxes(-4, -3).swapaxes(-2, -1)
+        # print(video.shape)
+        return video
+    
+class YoloClipLoader(Dataset):
+    
+    def __init__(self, yolo_output_path, num_frames=5, frames_between_clips=None,
+                 transform=None, augment_transform=None, num_clips=1, num_positives=1, positive_videos=None, sparse_model=None, device=None):
+        if (num_frames % 2) == 0:
+            raise ValueError("Num Frames must be an odd number, so we can extract a clip centered on each detected region")
+        
+        clip_cache_file = 'clip_cache.pt'
+        
+        self.num_clips = num_clips
+        
+        self.num_frames = num_frames
+        if frames_between_clips is None:
+            self.frames_between_clips = num_frames
+        else:
+            self.frames_between_clips = frames_between_clips
+
+        self.transform = transform
+        self.augment_transform = augment_transform
+         
+        self.labels = [name for name in listdir(yolo_output_path) if isdir(join(yolo_output_path, name))]
+        self.clips = []
+        if exists(clip_cache_file):
+            self.clips = torch.load(open(clip_cache_file, 'rb'))
+        else:
+            for label in self.labels:
+                print("Processing videos in category: {}".format(label))
+                videos = list(listdir(join(yolo_output_path, label)))
+                for vi in tqdm(range(len(videos))):
+                    video = videos[vi]
+                    counter = 0
+                    all_trimmed = []
+                    with open(abspath(join(yolo_output_path, label, video, 'result.json'))) as fin:
+                        results = json.load(fin)
+                        max_frame = len(results)
+
+                        for i in range((num_frames-1)//2, max_frame - (num_frames-1)//2 - 1, self.frames_between_clips):
+                        # for frame in results:
+                            frame = results[i]
+                            # print('loading frame:', i, frame['frame_id'])
+                            frame_start = int(frame['frame_id']) - self.num_frames//2
+                            frames = [abspath(join(yolo_output_path, label, video, 'frame{}.png'.format(frame_start+fid)))
+                                      for fid in range(num_frames)]
+                            # print(frames)
+                            frames = torch.stack([ToTensor()(Image.open(f).convert("RGB")) for f in frames]).swapaxes(0, 1)
+
+                            for region in frame['objects']:
+                                # print(region)
+                                if region['name'] != "Pleural_Line":
+                                    continue
+
+                                center_x = region['relative_coordinates']["center_x"] * 1920
+                                center_y = region['relative_coordinates']['center_y'] * 1080
+
+                                # width = region['relative_coordinates']['width'] * 1920
+                                # height = region['relative_coordinates']['height'] * 1080
+                                width=200
+                                height=100
+
+                                lower_y = round(center_y - height / 2)
+                                upper_y = round(center_y + height / 2)
+                                lower_x = round(center_x - width / 2)
+                                upper_x = round(center_x + width / 2)
+
+                                final_clip = frames[:, :, lower_y:upper_y, lower_x:upper_x]
+
+                                if self.transform:
+                                    final_clip = self.transform(final_clip)
+
+                                if sparse_model:
+                                    with torch.no_grad():
+                                        final_clip = final_clip.unsqueeze(0).to(device)
+                                        final_clip = sparse_model(final_clip)
+                                        final_clip = final_clip.squeeze(0).detach().cpu()
+
+                                self.clips.append((label, final_clip, video))
+
+            torch.save(self.clips, open(clip_cache_file, 'wb+'))
+            
+            
+#         random.shuffle(self.clips)
+            
+#         video_to_clips = {}
+        if positive_videos:
+            vids_to_keep = json.load(open(positive_videos))
+            
+            self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2] in vids_to_keep or clip_tup[0] == 'Sliding']
+        else:
+            video_to_labels = {}
+
+            for lbl, clip, video in self.clips:
+                video = video.lower().replace('_clean', '')
+                if video not in video_to_labels:
+    #                 video_to_clips[video] = []
+                    video_to_labels[video] = []
+
+    #             video_to_clips[video].append(clip)
+                video_to_labels[video].append(lbl)
+
+            video_to_participants = get_ptx_participants()
+            participants_to_video = {}
+            for k, v in video_to_participants.items():
+                if video_to_labels[k][0] == 'Sliding':
+                    continue
+                if not v in participants_to_video:
+                    participants_to_video[v] = []
+
+                participants_to_video[v].append(k)
+
+            participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
+
+            num_to_remove = len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) - num_positives
+            vids_to_remove = set()
+            while num_to_remove > 0:
+                vids_to_remove.add(participants_to_video[list(participants_to_video.keys())[0]].pop())
+                participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
+                num_to_remove -= 1
+                    
+            self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2].lower().replace('_clean', '') not in vids_to_remove]
+        
+        video_to_clips = {}
+        video_to_labels = {}
+
+        for lbl, clip, video in self.clips:
+            if video not in video_to_clips:
+                video_to_clips[video] = []
+                video_to_labels[video] = []
+
+            video_to_clips[video].append(clip)
+            video_to_labels[video].append(lbl)
+            
+        print([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding'])
+            
+        print('Num positive:', len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']))
+        print('Num negative:', len([k for k,v in video_to_labels.items() if v[0] == 'Sliding']))
+
+        self.videos = None
+        self.max_video_clips = 0
+        if num_clips > 1:
+            self.videos = []
+
+            for video in video_to_clips.keys():
+                clip_list = video_to_clips[video]
+                lbl_list = video_to_labels[video]
+                
+                for i in range(0, len(clip_list) - num_clips, 1):
+                    video_stack = torch.stack(clip_list[i:i+num_clips])
+                
+                    self.videos.append((max(set(lbl_list[i:i+num_clips]), key=lbl_list[i:i+num_clips].count), video_stack, video))
+            
+            self.clips = None
+
+            
+    def get_labels(self):
+        if self.num_clips > 1:
+            return [self.videos[i][0] for i in range(len(self.videos))]
+        else:
+            return [self.clips[i][0] for i in range(len(self.clips))]
+    
+    def get_filenames(self):
+        if self.num_clips > 1:
+            return [self.videos[i][2] for i in range(len(self.videos))]
+        else:
+            return [self.clips[i][2] for i in range(len(self.clips))]
+    
+    def __getitem__(self, index): 
+        if self.num_clips > 1:
+            label = self.videos[index][0]
+            video = self.videos[index][1]
+            filename = self.videos[index][2]
+            
+            video = video.squeeze(2)
+            video = video.permute(1, 0, 2, 3)
+
+            if self.augment_transform:
+                video = self.augment_transform(video)
+                
+            video = video.unsqueeze(2)
+            video = video.permute(1, 0, 2, 3, 4)
+#             video = video.permute(4, 1, 2, 3, 0)
+#             video = torch.nn.functional.pad(video, (0), 'constant', 0)
+#             video = video.permute(4, 1, 2, 3, 0)
+
+            orig_len = video.size(0)
+
+#             if orig_len < self.max_video_clips:
+#                 video = torch.cat([video, torch.zeros(self.max_video_clips - len(video), video.size(1), video.size(2), video.size(3), video.size(4))])
+
+            return label, video, filename, orig_len
+        else:
+            label = self.clips[index][0]
+            video = self.clips[index][1]
+            filename = self.clips[index][2]
+
+            if self.augment_transform:
+                video = self.augment_transform(video)
+
+            return label, video, filename
+        
+    def __len__(self):
+        return len(self.clips)
diff --git a/sparse_coding_torch/keras_model.py b/sparse_coding_torch/sparse_model.py
similarity index 54%
rename from sparse_coding_torch/keras_model.py
rename to sparse_coding_torch/sparse_model.py
index 86f015a..12ef1be 100644
--- a/sparse_coding_torch/keras_model.py
+++ b/sparse_coding_torch/sparse_model.py
@@ -6,7 +6,7 @@ import cv2
 import torchvision as tv
 import torch
 import torch.nn as nn
-from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
 
 def load_pytorch_weights(file_path):
     pytorch_checkpoint = torch.load(file_path, map_location='cpu')
@@ -220,205 +220,4 @@ class ReconSparse(keras.Model):
         else:
             recon = do_recon_3d(self.filters, activations, self.image_height, self.image_width, self.clip_depth, self.stride, self.padding)
             
-        return recon
-
-class PTXClassifier(keras.layers.Layer):
-    def __init__(self):
-        super(PTXClassifier, self).__init__()
-
-        self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
-        self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
-#         self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
-
-        self.flatten = keras.layers.Flatten()
-
-        self.dropout = keras.layers.Dropout(0.5)
-
-#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
-#         self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True)
-        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
-        self.ff_4 = keras.layers.Dense(1)
-
-#     @tf.function
-    def call(self, activations):
-        activations = tf.squeeze(activations, axis=1)
-        x = self.max_pool(activations)
-        x = self.conv_1(x)
-#         x = self.conv_2(x)
-        x = self.flatten(x)
-#         x = self.ff_1(x)
-#         x = self.dropout(x)
-#         x = self.ff_2(x)
-#         x = self.dropout(x)
-        x = self.ff_3(x)
-        x = self.dropout(x)
-        x = self.ff_4(x)
-
-        return x
-    
-class PNBClassifier(keras.layers.Layer):
-    def __init__(self):
-        super(PNBClassifier, self).__init__()
-
-#         self.max_pool = keras.layers.MaxPooling2D(pool_size=(8, 8), strides=(2, 2))
-        self.conv_1 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=(4, 4), activation='relu', padding='valid')
-        self.conv_2 = keras.layers.Conv2D(32, kernel_size=4, strides=2, activation='relu', padding='valid')
-#         self.conv_3 = keras.layers.Conv2D(12, kernel_size=4, strides=1, activation='relu', padding='valid')
-#         self.conv_4 = keras.layers.Conv2D(16, kernel_size=4, strides=2, activation='relu', padding='valid')
-
-        self.flatten = keras.layers.Flatten()
-
-#         self.dropout = keras.layers.Dropout(0.5)
-
-#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
-        self.ff_2 = keras.layers.Dense(40, activation='relu', use_bias=True)
-        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
-        self.ff_4 = keras.layers.Dense(1)
-
-#     @tf.function
-    def call(self, activations):
-        x = tf.squeeze(activations, axis=1)
-#         x = self.max_pool(x)
-#         print(x.shape)
-        x = self.conv_1(x)
-#         print(x.shape)
-        x = self.conv_2(x)
-#         print(x.shape)
-#         raise Exception
-#         x = self.conv_3(x)
-#         print(x.shape)
-#         x = self.conv_4(x)
-#         raise Exception
-        x = self.flatten(x)
-#         x = self.ff_1(x)
-#         x = self.dropout(x)
-        x = self.ff_2(x)
-#         x = self.dropout(x)
-        x = self.ff_3(x)
-#         x = self.dropout(x)
-        x = self.ff_4(x)
-
-        return x
-    
-class PNBTemporalClassifier(keras.layers.Layer):
-    def __init__(self):
-        super(PNBTemporalClassifier, self).__init__()
-        self.conv_1 = keras.layers.Conv2D(12, kernel_size=(150, 24), strides=(1, 8), activation='relu', padding='valid')
-        self.conv_2 = keras.layers.Conv1D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
-        
-        self.ff_1 = keras.layers.Dense(100, activation='relu', use_bias=True)
-        
-        self.gru = keras.layers.GRU(25)
-
-        self.flatten = keras.layers.Flatten()
-
-        self.ff_2 = keras.layers.Dense(10, activation='relu', use_bias=True)
-        self.ff_3 = keras.layers.Dense(1)
-
-#     @tf.function
-    def call(self, clip):
-        width = clip.shape[3]
-        height = clip.shape[2]
-        depth = clip.shape[1]
-        
-        x = tf.expand_dims(clip, axis=4)
-        x = tf.reshape(clip, (-1, height, width, 1))
-
-        x = self.conv_1(x)
-        x = tf.squeeze(x, axis=1)
-        x = self.conv_2(x)
-
-        x = self.flatten(x)
-        x = self.ff_1(x)
-
-        x = tf.reshape(x, (-1, 5, 100))
-        x = self.gru(x)
-        
-        x = self.ff_2(x)
-        x = self.ff_3(x)
-
-        return x
-
-class MobileModelPTX(keras.Model):
-    def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
-        super().__init__()
-        self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
-        self.classifier = Classifier()
-
-        self.out_channels = out_channels
-        self.in_channels = in_channels
-        self.stride = stride
-        self.lam = lam
-        self.activation_lr = activation_lr
-        self.max_activation_iter = max_activation_iter
-        self.batch_size = batch_size
-        self.run_2d = run_2d
-
-        pytorch_weights = load_pytorch_weights(sparse_checkpoint)
-
-        if run_2d:
-            weight_list = np.split(pytorch_weights, 5, axis=0)
-            self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
-            self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False)
-            self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False)
-            self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False)
-            self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False)
-        else:
-            self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False)
-
-    @tf.function
-    def call(self, images):
-        images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1)
-        images = tf.transpose(images, perm=[0, 2, 3, 1])
-        images = images / 255
-        images = (images - 0.2592) / 0.1251
-
-        if self.run_2d:
-            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
-        else:
-            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
-
-        pred = self.classifier(activations)
-
-        return pred
-    
-class MobileModelPNB(keras.Model):
-    def __init__(self, sparse_weights, classifier_model, batch_size, image_height, image_width, clip_depth, out_channels, kernel_size, kernel_depth, stride, lam, activation_lr, max_activation_iter, run_2d):
-        super().__init__()
-        self.sparse_code = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=out_channels, kernel_size=kernel_size, kernel_depth=kernel_depth, stride=stride, lam=lam, activation_lr=activation_lr, max_activation_iter=max_activation_iter, run_2d=run_2d, padding='VALID')
-        self.classifier = classifier_model
-
-        self.out_channels = out_channels
-        self.stride = stride
-        self.lam = lam
-        self.activation_lr = activation_lr
-        self.max_activation_iter = max_activation_iter
-        self.batch_size = batch_size
-        self.run_2d = run_2d
-        
-        if run_2d:
-            weight_list = np.split(sparse_weights, 5, axis=0)
-            self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
-            self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False)
-            self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False)
-            self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False)
-            self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False)
-        else:
-            self.filters = tf.Variable(initial_value=sparse_weights, dtype='float32', trainable=False)
-
-    @tf.function
-    def call(self, images):
-#         images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1)
-        images = tf.transpose(images, perm=[0, 2, 3, 1])
-        images = images / 255
-
-        if self.run_2d:
-            activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
-            activations = tf.expand_dims(activations, axis=1)
-        else:
-            activations = self.sparse_code(images, tf.stop_gradient(self.filters))
-
-        pred = tf.math.round(tf.math.sigmoid(self.classifier(activations)))
-#         pred = tf.math.reduce_sum(activations)
-
-        return pred
+        return recon
\ No newline at end of file
diff --git a/sparse_coding_torch/utils.py b/sparse_coding_torch/utils.py
index 98c807b..62e67d3 100644
--- a/sparse_coding_torch/utils.py
+++ b/sparse_coding_torch/utils.py
@@ -1,113 +1,149 @@
 import numpy as np
-from sparse_coding_torch.video_loader import get_yolo_regions, classify_nerve_is_right
 import torchvision as tv
 import torch
 import tensorflow as tf
 from tqdm import tqdm
+from torchvision.datasets.video_utils import VideoClips
+from typing import Sequence, Iterator
+import torch.nn as nn
 
-def calculate_pnb_scores(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform):
-    all_predictions = []
+def get_sample_weights(train_idx, dataset):
+    dataset = list(dataset)
+
+    num_positive = len([clip[0] for clip in dataset if clip[0] == 'Positives'])
+    negative_weight = num_positive / len(dataset)
+    positive_weight = 1.0 - negative_weight
     
-    numerical_labels = []
-    for label in labels:
+    weights = []
+    for idx in train_idx:
+        label = dataset[idx][0]
         if label == 'Positives':
-            numerical_labels.append(1.0)
-        else:
-            numerical_labels.append(0.0)
-
-    final_list = []
-    fp_ids = []
-    fn_ids = []
-    for v_idx, f in tqdm(enumerate(input_videos)):
-        vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2)
-        is_right = classify_nerve_is_right(yolo_model, vc)
-        
-        all_preds = []
-        for j in range(vc.size(1) - 5, vc.size(1) - 25, -5):
-            if j-5 < 0:
-                break
-
-            vc_sub = vc[:, j-5:j, :, :]
-            
-            if vc_sub.size(1) < 5:
-                continue
-            
-            clip = get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height)
-            
-            if not clip:
-                continue
-
-            clip = clip[0]
-            clip = transform(clip).to(torch.float32)
-            clip = tf.expand_dims(clip, axis=4) 
-
-            activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
-
-            pred = tf.math.round(tf.math.sigmoid(classifier_model(activations)))
-
-            all_preds.append(pred)
-                
-        if all_preds:
-            final_pred = np.round(np.mean(np.array(all_preds)))
+            weights.append(positive_weight)
+        elif label == 'Negatives':
+            weights.append(negative_weight)
         else:
-            final_pred = 1.0
-            
-        if final_pred != numerical_labels[v_idx]:
-            if final_pred == 0:
-                fn_ids.append(f)
-            else:
-                fp_ids.append(f)
-            
-        final_list.append(final_pred)
-        
-    return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids
+            raise Exception('Sampler encountered invalid label')
+    
+    return weights
+
+class SubsetWeightedRandomSampler(torch.utils.data.Sampler[int]):
+    weights: torch.Tensor
+    num_samples: int
+    replacement: bool
+
+    def __init__(self, weights: Sequence[float], indicies: Sequence[int],
+                 replacement: bool = True, generator=None) -> None:
+        if not isinstance(replacement, bool):
+            raise ValueError("replacement should be a boolean value, but got "
+                             "replacement={}".format(replacement))
+        self.weights = torch.as_tensor(weights, dtype=torch.double)
+        self.indicies = indicies
+        self.replacement = replacement
+        self.generator = generator
+
+    def __iter__(self) -> Iterator[int]:
+        rand_tensor = torch.multinomial(self.weights, len(self.indicies), self.replacement, generator=self.generator)
+        for i in rand_tensor:
+            yield self.indicies[i]
 
-def calculate_pnb_scores_skipped_frames(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, frames_to_skip, image_width, image_height, transform):
-    all_predictions = []
+    def __len__(self) -> int:
+        return len(self.indicies)
+
+class MinMaxScaler(object):
+    """
+    Transforms each channel to the range [0, 1].
+    """
+    def __init__(self, min_val=0, max_val=254):
+        self.min_val = min_val
+        self.max_val = max_val
     
-    numerical_labels = []
-    for label in labels:
-        if label == 'Positives':
-            numerical_labels.append(1.0)
-        else:
-            numerical_labels.append(0.0)
-
-    final_list = []
-    fp_ids = []
-    fn_ids = []
-    for v_idx, f in tqdm(enumerate(input_videos)):
-        vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2)
-        is_right = classify_nerve_is_right(yolo_model, vc)
-        
-        all_preds = []
+    def __call__(self, tensor):
+        return (tensor - self.min_val) / (self.max_val - self.min_val)
+
+class VideoGrayScaler(nn.Module):
+    
+    def __init__(self):
+        super().__init__()
+        self.grayscale = tv.transforms.Grayscale(num_output_channels=1)
         
-        frames = []
-        for k in range(vc.size(1) - 1, vc.size(1) - 5 * frames_to_skip, -frames_to_skip):
-            frames.append(vc[:, k, :, :])
-        vc_sub = torch.stack(frames, dim=1)
-            
-        if vc_sub.size(1) < 5:
-            continue
+    def forward(self, video):
+        # shape = channels, time, width, height
+        video = self.grayscale(video.swapaxes(-4, -3).swapaxes(-2, -1))
+        video = video.swapaxes(-4, -3).swapaxes(-2, -1)
+        # print(video.shape)
+        return video
 
-        clip = get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height)
+def plot_video(video):
 
-        if clip:
-            clip = clip[0]
-            clip = transform(clip).to(torch.float32)
-            clip = tf.expand_dims(clip, axis=4) 
+    fig = plt.gcf()
+    ax = plt.gca()
 
-            activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
+    DPI = fig.get_dpi()
+    fig.set_size_inches(video.shape[2]/float(DPI), video.shape[3]/float(DPI))
 
-            pred = tf.math.round(tf.math.sigmoid(classifier_model(activations)))
-        else:
-            pred = 1.0
-            
-        if pred != numerical_labels[v_idx]:
-            if pred == 0:
-                fn_ids.append(f)
-            else:
-                fp_ids.append(f)
-            
-        final_list.append(pred)
-        
-    return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids
\ No newline at end of file
+    ax.set_title("Video")
+
+    T = video.shape[1]
+    im = ax.imshow(video[0, 0, :, :],
+                     cmap=cm.Greys_r)
+
+    def update(i):
+        t = i % T
+        im.set_data(video[0, t, :, :])
+
+    return FuncAnimation(plt.gcf(), update, interval=1000/20)
+
+def plot_original_vs_recon(original, reconstruction, idx=0):
+
+    # create two subplots
+    ax1 = plt.subplot(1, 2, 1)
+    ax2 = plt.subplot(1, 2, 2)
+    ax1.set_title("Original")
+    ax2.set_title("Reconstruction")
+
+    T = original.shape[2]
+    im1 = ax1.imshow(original[idx, 0, 0, :, :],
+                     cmap=cm.Greys_r)
+    im2 = ax2.imshow(reconstruction[idx, 0, 0, :, :],
+                     cmap=cm.Greys_r)
+
+    def update(i):
+        t = i % T
+        im1.set_data(original[idx, 0, t, :, :])
+        im2.set_data(reconstruction[idx, 0, t, :, :])
+
+    return FuncAnimation(plt.gcf(), update, interval=1000/30)
+
+
+def plot_filters(filters):
+    filters = filters.astype('float32')
+    num_filters = filters.shape[4]
+    ncol = 3
+    # ncol = int(np.sqrt(num_filters))
+    # nrow = int(np.sqrt(num_filters))
+    T = filters.shape[0]
+
+    if num_filters // ncol == num_filters / ncol:
+        nrow = num_filters // ncol
+    else:
+        nrow = num_filters // ncol + 1
+
+    fig, axes = plt.subplots(ncols=ncol, nrows=nrow,
+                             constrained_layout=True,
+                             figsize=(ncol*2, nrow*2))
+
+    ims = {}
+    for i in range(num_filters):
+        r = i // ncol
+        c = i % ncol
+        ims[(r, c)] = axes[r, c].imshow(filters[0, :, :, 0, i],
+                                        cmap=cm.Greys_r)
+
+    def update(i):
+        t = i % T
+        for i in range(num_filters):
+            r = i // ncol
+            c = i % ncol
+            ims[(r, c)].set_data(filters[t, :, :, 0, i])
+
+    return FuncAnimation(plt.gcf(), update, interval=1000/20)
\ No newline at end of file
diff --git a/yolov4/get_bounding_boxes.py b/yolov4/get_bounding_boxes.py
index 618338f..ae1d047 100644
--- a/yolov4/get_bounding_boxes.py
+++ b/yolov4/get_bounding_boxes.py
@@ -13,12 +13,17 @@ from tensorflow.compat.v1 import InteractiveSession
 import time
 
 class YoloModel():
-    def __init__(self):
+    def __init__(self, dataset):
         flags.DEFINE_string('framework', 'tflite', '(tf, tflite, trt')
-        flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite',
-                            'path to weights file')
-#         flags.DEFINE_string('weights', 'yolov4/yolov4-416.tflite',
-#                             'path to weights file')
+        if dataset == 'pnb':
+            flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite',
+                                'path to weights file')
+        elif dataset == 'onsd':
+            flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow/onsd_prelim_yolo/yolov4-416.tflite',
+                                'path to weights file')
+        else:
+            flags.DEFINE_string('weights', 'yolov4/yolov4-416.tflite',
+                                'path to weights file')
         flags.DEFINE_integer('size', 416, 'resize images to')
         flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
         flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
@@ -111,21 +116,26 @@ class YoloModel():
         boxes = boxes.tolist()
         boxes = boxes[0]
         classes = classes[0]
+        scores = scores[0]
         boxes_list = []
         class_list = []
-        for box, class_idx in zip(boxes, classes):
+        score_list = []
+        for box, class_idx, score in zip(boxes, classes, scores):
             sum = 0
             for value in box:
                 sum += value
             if sum > 0:
                 boxes_list.append(box)
                 class_list.append(class_idx)
+                score_list.append(score)
         boxes_list = [boxes_list]
         class_list = [class_list]
+        score_list = [score_list]
         boxes = np.array(boxes_list)
         classes = np.array(class_list)
+        scores = np.array(score_list)
 
         end = time.time()
         elapsed_time = end - start
 #         print('Took %.2f seconds to run whole bounding box function\n' % (elapsed_time))
-        return boxes, classes
+        return boxes, classes, scores
-- 
GitLab