diff --git a/keras/load_data.py b/keras/load_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc808f636c6862a1be1c7260393715db003b39e7
--- /dev/null
+++ b/keras/load_data.py
@@ -0,0 +1,232 @@
+import numpy as np
+import torchvision
+import torch
+from sklearn.model_selection import train_test_split
+from sparse_coding_torch.video_loader import MinMaxScaler
+from sparse_coding_torch.video_loader import VideoLoader
+from sparse_coding_torch.video_loader import VideoClipLoader, YoloClipLoader, get_video_participants, YoloVideoLoader, MobileLoader, PNBLoader
+from sparse_coding_torch.video_loader import VideoGrayScaler
+import csv
+from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold
+
+def load_balls_data(batch_size):
+    
+    with open('ball_videos.npy', 'rb') as fin:
+        ball_videos = torch.tensor(np.load(fin)).float()
+
+    batch_size = batch_size
+    train_loader = torch.utils.data.DataLoader(ball_videos,
+                                               batch_size=batch_size,
+                                               shuffle=True)
+
+    return train_loader
+
+def load_bamc_data(batch_size):   
+    video_path = "/shared_data/bamc_data"
+    
+    width = 350
+    height = 200
+
+    transforms = torchvision.transforms.Compose([VideoGrayScaler(),
+                                                 MinMaxScaler(0, 255),
+                                                 BamcPreprocessor(),
+                                                 torchvision.transforms.Resize(size=(height, width))
+                                                ])
+    dataset = VideoLoader(video_path, transform=transforms, num_frames=60)
+    
+    targets = dataset.get_labels()
+    
+    train_idx, test_idx = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets)
+    
+    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+    test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+    
+    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               # shuffle=True,
+                                               sampler=train_sampler)
+    test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                                    # shuffle=True,
+                                                    sampler=test_sampler)
+
+    return train_loader, test_loader
+
+def load_covid_data(batch_size, clip_length_in_frames=10, frame_rate=20):   
+    video_path = "/home/cm3786@drexel.edu/Projects/covid19_ultrasound/data/pocus_videos/convex"
+    # video_path = "/home/cm3786@drexel.edu/Projects/covid19_ultrasound/data/pocus_videos/pneumonia-viral"
+    
+    scale = 0.5
+    
+    base_width = 1920
+    base_height = 1080
+    
+    cropped_width = round(140/320 * base_width)
+    cropped_height = round(140/180 * base_height)
+    
+    #width = round(cropped_width * scale)
+    #height = round(cropped_height * scale)
+    
+    width = 128
+    height = 128
+    
+    transforms = torchvision.transforms.Compose([torchvision.transforms.Grayscale(num_output_channels=1),
+                                                 #torchvision.transforms.Resize(size=(base_width, base_height)),
+                                                 #torchvision.transforms.CenterCrop(size=(cropped_height, cropped_width)),
+                                                 torchvision.transforms.Resize(size=(width, height)), 
+                                                 MinMaxScaler(0, 255)])
+    dataset = VideoClipLoader(video_path, transform=transforms,
+                              clip_length_in_frames=clip_length_in_frames,
+                              frame_rate=frame_rate)
+    
+    targets = dataset.get_video_labels()
+    train_vidx, test_vidx = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets)
+    
+    train_vidx = set(train_vidx)
+    test_vidx = set(test_vidx)
+    
+    train_cidx = [i for i in range(len(dataset)) if dataset.video_idx[i] in train_vidx]
+    test_cidx = [i for i in range(len(dataset)) if dataset.video_idx[i] in test_vidx]
+    
+    train_sampler = torch.utils.data.SubsetRandomSampler(train_cidx)
+    test_sampler = torch.utils.data.SubsetRandomSampler(test_cidx)
+    
+    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+    test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                              sampler=test_sampler)
+
+    return train_loader, test_loader
+
+def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None):   
+    video_path = "/shared_data/YOLO_Updated_PL_Model_Results/"
+
+    video_to_participant = get_video_participants()
+    
+    transforms = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     torchvision.transforms.Normalize((0.2592,), (0.1251,)),
+    ])
+    augment_transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.RandomRotation(45),
+     torchvision.transforms.RandomHorizontalFlip()
+#      torchvision.transforms.CenterCrop((100, 200))
+    ])
+    if whole_video:
+        dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
+    else:
+        dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
+    
+    targets = dataset.get_labels()
+    
+    if mode == 'leave_one_out':
+        gss = LeaveOneGroupOut()
+
+#         groups = [v for v in dataset.get_filenames()]
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    elif mode == 'all_train':
+        train_idx = np.arange(len(targets))
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        test_loader = None
+        
+        return train_loader, test_loader
+    elif mode == 'k_fold':
+        gss = StratifiedGroupKFold(n_splits=n_splits)
+
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    else:
+        return None
+    
+def load_mobile_clips(batch_size, mode, num_clips=1, num_positives=100, n_splits=None):
+    video_path = "/home/dwh48@drexel.edu/clips"
+
+    video_to_participant = {}
+    with open('/shared_data/bamc_data/bamc_video_info.csv', 'r') as csv_in:
+        reader = csv.DictReader(csv_in)
+        for row in reader:
+            key = row['Filename'].split('.')[0].lower().replace('_clean', '')
+            if key == '37 (mislabeled as 38)':
+                key = '37'
+            video_to_participant[key] = row['Participant_id']
+
+    augment_transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.RandomRotation(45),
+     torchvision.transforms.RandomHorizontalFlip(),
+     torchvision.transforms.CenterCrop((100, 200))
+    ])
+    dataset = MobileLoader(video_path, transform=None, augment_transform=augment_transforms, num_positives=num_positives)
+
+    targets = dataset.get_labels()
+
+    if mode == 'leave_one_out':
+        gss = LeaveOneGroupOut()
+
+#         groups = [v for v in dataset.get_filenames()]
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    elif mode == 'all_train':
+        train_idx = np.arange(len(targets))
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        test_loader = None
+
+        return train_loader, test_loader
+    elif mode == 'k_fold':
+        gss = StratifiedGroupKFold(n_splits=n_splits)
+
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    else:
+        return None
+    
+def load_pnb_videos(batch_size, mode, device=None, n_splits=None, sparse_model=None):   
+    video_path = "/shared_data/bamc_pnb_data/full_training_data"
+    
+    transforms = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     MinMaxScaler(0, 255),
+     torchvision.transforms.Resize((360, 304))
+    ])
+    augment_transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.RandomAffine(45),
+     torchvision.transforms.RandomHorizontalFlip(),
+     torchvision.transforms.ColorJitter(brightness=0.5),
+     torchvision.transforms.RandomAdjustSharpness(0, p=0.15),
+     torchvision.transforms.RandomAffine(degrees=0, translate=(0.05, 0))
+#      torchvision.transforms.CenterCrop((100, 200))
+    ])
+    dataset = PNBLoader(video_path, num_frames=5, frame_rate=20, transform=transforms)
+    
+    targets = dataset.get_labels()
+    
+    if mode == 'leave_one_out':
+        gss = LeaveOneGroupOut()
+
+        groups = [v for v in dataset.get_filenames()]
+#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    elif mode == 'all_train':
+        train_idx = np.arange(len(targets))
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        test_loader = None
+        
+        return train_loader, test_loader
+    elif mode == 'k_fold':
+        gss = StratifiedKFold(n_splits=n_splits, shuffle=True)
+
+#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        groups = [v for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets), dataset
+    else:
+        return None
\ No newline at end of file
diff --git a/keras/train_classifier.py b/keras/train_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07d0c2e5a1d3b4ef8e344f93f45dde7b348b537
--- /dev/null
+++ b/keras/train_classifier.py
@@ -0,0 +1,325 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos
+from keras_model import SparseCode, Classifier, ReconSparse
+import time
+import numpy as np
+from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
+import random
+import pickle
+import tensorflow.keras as keras
+import tensorflow as tf
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=12, type=int)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=64, type=int)
+    parser.add_argument('--stride', default=1, type=int)
+    parser.add_argument('--max_activation_iter', default=150, type=int)
+    parser.add_argument('--activation_lr', default=1e-1, type=float)
+    parser.add_argument('--lr', default=5e-5, type=float)
+    parser.add_argument('--epochs', default=10, type=int)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--sparse_checkpoint', default=None, type=str)
+    parser.add_argument('--checkpoint', default=None, type=str)
+    parser.add_argument('--splits', default='k_fold', type=str, help='k_fold or leave_one_out or all_train')
+    parser.add_argument('--seed', default=42, type=int)
+    parser.add_argument('--train', action='store_true')
+    parser.add_argument('--num_positives', default=100, type=int)
+    parser.add_argument('--n_splits', default=5, type=int)
+    parser.add_argument('--save_train_test_splits', action='store_true')
+    parser.add_argument('--run_2d', action='store_true')
+    
+    args = parser.parse_args()
+    
+    image_height = 360
+    image_width = 304
+    
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+        
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+
+    all_errors = []
+    
+    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+    
+    recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+    
+    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
+
+    if args.sparse_checkpoint:
+        recon_model = keras.models.load_model(args.sparse_checkpoint)
+        
+    splits, dataset = load_pnb_videos(args.batch_size, mode='k_fold', device=None, n_splits=args.n_splits, sparse_model=None)
+    i_fold = 0
+        
+    for train_idx, test_idx in splits:
+        
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
+                                                   # shuffle=True,
+                                                   sampler=train_sampler)
+
+        test_loader = torch.utils.data.DataLoader(dataset, batch_size=1,
+                                                        # shuffle=True,
+                                                        sampler=test_sampler)
+        
+        classifier_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+
+        classifier_outputs = Classifier()(classifier_inputs)
+
+        classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
+
+
+        overall_true = []
+        overall_pred = []
+        fn_ids = []
+        fp_ids = []
+
+        best_so_far = float('inf')
+
+        criterion = keras.losses.BinaryCrossentropy(from_logits=False)
+
+        if args.checkpoint:
+            classifier_model.load(args.checkpoint)
+
+        if args.train:
+            prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)
+
+            for epoch in range(args.epochs):
+                epoch_loss = 0
+                t1 = time.perf_counter()
+                
+                if args.run_2d:
+                    inputs = keras.Input(shape=(image_height, image_width, 5))
+                else:
+                    inputs = keras.Input(shape=(5, image_height, image_width, 1))
+
+                filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+                output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+                sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+
+                for labels, local_batch, vid_f in tqdm(train_loader):
+                    if local_batch.size(0) != args.batch_size:
+                        continue
+                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+
+                    with tf.GradientTape() as tape:
+                        pred = classifier_model(activations)
+                        loss = criterion(torch_labels, pred)
+
+                    epoch_loss += loss * local_batch.size(0)
+
+                    gradients = tape.gradient(loss, classifier_model.trainable_weights)
+
+                    prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))
+
+                t2 = time.perf_counter()
+                
+                if args.run_2d:
+                    inputs = keras.Input(shape=(image_height, image_width, 5))
+                else:
+                    inputs = keras.Input(shape=(5, image_height, image_width, 1))
+
+                filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+                output = SparseCode(batch_size=1, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+                sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+
+                y_true = None
+                y_pred = None
+                for labels, local_batch, vid_f in test_loader:
+                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+
+                    pred = classifier_model(activations)
+
+                    if y_true is None:
+                        y_true = torch_labels
+                        y_pred = tf.math.round(tf.math.sigmoid(pred))
+                    else:
+                        y_true = tf.concat((y_true, torch_labels), axis=0)
+                        y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+                t2 = time.perf_counter()
+                
+                y_true = tf.cast(y_true, tf.int32)
+                y_pred = tf.cast(y_pred, tf.int32)
+
+                f1 = f1_score(y_true, y_pred, average='macro')
+                accuracy = accuracy_score(y_true, y_pred)
+
+                print('fold={}, epoch={}, time={:.2f}, loss={:.2f}, f1={:.2f}, acc={:.2f}'.format(i_fold, epoch, t2-t1, epoch_loss, f1, accuracy))
+    #             print(epoch_loss)
+                if epoch_loss <= best_so_far:
+                    print("found better model")
+                    # Save model parameters
+                    classifier_model.save(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
+                    best_so_far = epoch_loss
+
+            classifier_model = keras.models.load_model(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
+            
+            if args.run_2d:
+                inputs = keras.Input(shape=(image_height, image_width, 5))
+            else:
+                inputs = keras.Input(shape=(5, image_height, image_width, 1))
+
+            filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+            output = SparseCode(batch_size=1, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+            sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+
+            epoch_loss = 0
+
+            y_true = None
+            y_pred = None
+
+            pred_dict = {}
+            gt_dict = {}
+
+            t1 = time.perf_counter()
+            for labels, local_batch, vid_f in test_loader:
+                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+
+                torch_labels = np.zeros(len(labels))
+                torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
+                torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+
+                pred = classifier_model(activations)
+
+                loss = criterion(torch_labels, pred)
+                epoch_loss += loss * local_batch.size(0)
+
+                for i, v_f in enumerate(vid_f):
+                    if v_f not in pred_dict:
+                        pred_dict[v_f] = tf.math.round(tf.math.sigmoid(pred[i]))
+                    else:
+                        pred_dict[v_f] = tf.concat((pred_dict[v_f], tf.math.round(tf.math.sigmoid(pred[i]))), axis=0)
+
+                    if v_f not in gt_dict:
+                        gt_dict[v_f] = tf.constant(torch_labels[i])
+                    else:
+                        gt_dict[v_f] = tf.concat((gt_dict[v_f], torch_labels[i]), axis=0)
+
+                if y_true is None:
+                    y_true = torch_labels
+                    y_pred = tf.math.round(tf.math.sigmoid(pred))
+                else:
+                    y_true = tf.concat((y_true, torch_labels), axis=0)
+                    y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+            t2 = time.perf_counter()
+
+            vid_acc = []
+            for k in pred_dict.keys():
+                print(k)
+                print(pred_dict[k])
+                print(gt_dict[k])
+                gt_mode = torch.mode(torch.tensor(gt_dict[k]))[0].item()
+                perm = torch.randperm(torch.tensor(pred_dict[k]).size(0))
+                cutoff = int(torch.tensor(pred_dict[k]).size(0)/4)
+                if cutoff < 3:
+                    cutoff = 3
+                idx = perm[:cutoff]
+                samples = pred_dict[k][idx]
+                pred_mode = torch.mode(torch.tensor(samples))[0].item()
+                overall_true.append(gt_mode)
+                overall_pred.append(pred_mode)
+                if pred_mode == gt_mode:
+                    vid_acc.append(1)
+                else:
+                    vid_acc.append(0)
+                    if pred_mode == 0:
+                        fn_ids.append(k)
+                    else:
+                        fp_ids.append(k)
+
+            vid_acc = np.array(vid_acc)
+
+            print('----------------------------------------------------------------------------')
+            for k in pred_dict.keys():
+                print(k)
+                print('Predictions:')
+                print(pred_dict[k])
+                print('Ground Truth:')
+                print(gt_dict[k])
+                print('Overall Prediction:')
+    #                 pred_mode = 1
+    #                 contiguous_zeros = 0
+    #                 best_num = 0
+    #                 for val in pred_dict[k]:
+    #                     if val.item() == 0:
+    #                         contiguous_zeros += 1
+    #                     else:
+    #                         if contiguous_zeros > best_num:
+    #                             best_num = contiguous_zeros
+    #                             contiguous_zeros = 0
+    #                 if best_num >= 4 or contiguous_zeros >= 4:
+    #                     pred_mode = 0
+                print(torch.mode(pred_dict[k])[0].item())
+                print('----------------------------------------------------------------------------')
+
+            print('fold={}, loss={:.2f}, time={:.2f}'.format(i_fold, loss, t2-t1))
+                
+            y_true = tf.cast(y_true, tf.int32)
+            y_pred = tf.cast(y_pred, tf.int32)
+
+            f1 = f1_score(y_true, y_pred, average='macro')
+            accuracy = accuracy_score(y_true, y_pred)
+            all_errors.append(np.sum(vid_acc) / len(vid_acc))
+
+            print("Test f1={:.2f}, clip_acc={:.2f}, vid_acc={:.2f} fold={}".format(f1, accuracy, np.sum(vid_acc) / len(vid_acc), i_fold))
+
+            print(confusion_matrix(y_true, y_pred))
+
+        i_fold = i_fold + 1
+
+    fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
+    with open(fp_fn_file, 'w+') as in_f:
+        in_f.write('FP:\n')
+        in_f.write(str(fp_ids) + '\n\n')
+        in_f.write('FN:\n')
+        in_f.write(str(fn_ids) + '\n\n')
+
+    overall_true = np.array(overall_true)
+    overall_pred = np.array(overall_pred)
+
+    final_f1 = f1_score(overall_true, overall_pred, average='macro')
+    final_acc = accuracy_score(overall_true, overall_pred)
+    final_conf = confusion_matrix(overall_true, overall_pred)
+
+    print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1))
+    print(final_conf)
diff --git a/keras/video_loader.py b/keras/video_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e7188cd1d5aacdbae02be36acff80081f16d515
--- /dev/null
+++ b/keras/video_loader.py
@@ -0,0 +1,770 @@
+from os import listdir
+from os.path import isfile
+from os.path import join
+from os.path import isdir
+from os.path import abspath
+from os.path import exists
+import json
+import glob
+
+from PIL import Image
+from torchvision.transforms import ToTensor
+from torchvision.datasets.video_utils import VideoClips
+from tqdm import tqdm
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+from torch.utils.data import DataLoader
+from torchvision.io import read_video
+import torchvision as tv
+from torch import nn
+import torchvision.transforms.functional as tv_f
+import csv
+import random
+import cv2
+
+# def get_augmented_examples(clip):
+#     augmented_examples = []
+#     augmented_examples.append(tv_f.hflip(clip))
+#     augmented_examples.append(tv_f.rotate(clip, 20))
+#     augmented_examples.append(tv_f.rotate(clip, -20))
+#     augmented_examples.append(tv_f.rotate(clip, 40))
+#     augmented_examples.append(tv_f.rotate(clip, -40))
+    
+#     return augmented_examples
+def get_video_participants():
+    video_to_participant = {}
+    with open('/shared_data/bamc_data/bamc_video_info.csv', 'r') as csv_in:
+        reader = csv.DictReader(csv_in)
+        for row in reader:
+            key = row['Filename'].split('.')[0].lower().replace('_clean', '')
+            if key == '37 (mislabeled as 38)':
+                key = '37'
+            video_to_participant[key] = row['Participant_id']
+            
+    return video_to_participant
+
+class MinMaxScaler(object):
+    """
+    Transforms each channel to the range [0, 1].
+    """
+    def __init__(self, min_val=0, max_val=254):
+        self.min_val = min_val
+        self.max_val = max_val
+    
+    def __call__(self, tensor):
+        return (tensor - self.min_val) / (self.max_val - self.min_val)
+
+class VideoGrayScaler(nn.Module):
+    
+    def __init__(self):
+        super().__init__()
+        self.grayscale = tv.transforms.Grayscale(num_output_channels=1)
+        
+    def forward(self, video):
+        # shape = channels, time, width, height
+        video = self.grayscale(video.swapaxes(-4, -3).swapaxes(-2, -1))
+        video = video.swapaxes(-4, -3).swapaxes(-2, -1)
+        # print(video.shape)
+        return video
+    
+class VideoLoader(Dataset):
+    
+    def __init__(self, video_path, transform=None, num_frames=None):
+        self.num_frames = num_frames
+        self.transform = transform
+        
+        self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
+        self.videos = []
+        for label in self.labels:
+            self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in listdir(join(video_path, label)) if isfile(join(video_path, label, f))])
+            
+        self.cache = {}
+        
+    def get_labels(self):
+        return [self.videos[i][0] for i in range(len(self.videos))]
+    
+    def get_filenames(self):
+        return [self.videos[i][2] for i in range(len(self.videos))]
+    
+    def __getitem__(self, index):
+        #print('index: {}'.format(index))
+        
+        if index in self.cache:
+            return self.cache[index]
+        
+        label = self.videos[index][0]
+        filename = self.videos[index][2]
+        video, _, info = read_video(self.videos[index][1])
+        # print(info)
+        video = torch.swapaxes(video, 1, 3)
+        
+        # print('length', len(video))
+        if self.num_frames:
+            video = video[-self.num_frames:]
+            
+            if len(video) < self.num_frames:
+                padding = torch.zeros(self.num_frames - len(video), video.shape[1], video.shape[2], video.shape[3])
+                video = torch.cat((video, padding))
+            
+        video = video.swapaxes(0, 1).swapaxes(2, 3)
+        
+        if self.transform:
+            video = self.transform(video)
+        
+        self.cache[index] = (label, video, filename)
+            
+        return label, video, filename
+        
+    def __len__(self):
+        return len(self.videos)
+    
+class VideoClipLoader(Dataset):
+    
+    def __init__(self, video_path, num_frames=20, frame_rate=20, frames_between_clips=None, transform=None):
+        self.transform = transform
+        self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
+        
+        self.videos = []
+        for label in self.labels:
+            self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in listdir(join(video_path, label)) if isfile(join(video_path, label, f))])
+            
+        #for v in self.videos:
+        #    video, _, info = read_video(v[1])
+        #    print(video.shape)
+        #    print(info)
+            
+        if not frames_between_clips:
+            frames_between_clips = num_frames
+            
+        
+            
+        vc = VideoClips([path for _, path, _ in self.videos],
+                        clip_length_in_frames=num_frames,
+                        frame_rate=frame_rate,
+                       frames_between_clips=frames_between_clips)
+        self.clips = []
+                   
+        self.video_idx = []
+        for i in tqdm(range(vc.num_clips())):
+            try:
+                clip, _, _, vid_idx = vc.get_clip(i)
+                clip = clip.swapaxes(1, 3).swapaxes(2, 3)
+                clip = clip.swapaxes(0, 1)
+                if self.transform:
+                    clip = self.transform(clip)
+                self.clips.append((self.videos[vid_idx][0], clip, self.videos[vid_idx][2]))
+                self.video_idx.append(vid_idx)
+            except Exception:
+                pass   
+        
+    def get_filenames(self):
+        return [self.clips[i][2] for i in range(len(self.clips))]
+        
+    def get_video_labels(self):
+        return [self.videos[i][0] for i in range(len(self.videos))]
+        
+    def get_labels(self):
+        return [self.clips[i][0] for i in range(len(self.clips))]
+    
+    def __getitem__(self, index):
+        return self.clips[index]
+        
+    def __len__(self):
+        return len(self.clips)
+    
+class PNBLoader(Dataset):
+    
+    def __init__(self, video_path, num_frames=5, frame_rate=20, frames_between_clips=None, transform=None):
+        self.transform = transform
+        self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
+        
+        self.videos = []
+        for label in self.labels:
+            self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))])
+            
+        #for v in self.videos:
+        #    video, _, info = read_video(v[1])
+        #    print(video.shape)
+        #    print(info)
+            
+        if not frames_between_clips:
+            frames_between_clips = num_frames
+            
+        self.clips = []
+                   
+        self.video_idx = []
+        
+        vid_idx = 0
+        for _, path, _ in self.videos:
+            vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
+#             for j in range(vc.size(1), vc.size(1) - 10, -5):
+            for j in range(0, vc.size(1) - 5, 5):
+#                 if j-5 < 0:
+#                     continue
+#                 vc_sub = vc_1 = vc[:, j-5:j, :, :]
+                vc_sub = vc[:, j:j+5, :, :]
+                if self.transform:
+                    vc_sub = self.transform(vc_sub)
+                    
+                self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2]))
+                self.video_idx.append(vid_idx)
+            vid_idx += 1
+        
+    def get_filenames(self):
+        return [self.clips[i][2] for i in range(len(self.clips))]
+        
+    def get_video_labels(self):
+        return [self.videos[i][0] for i in range(len(self.videos))]
+        
+    def get_labels(self):
+        return [self.clips[i][0] for i in range(len(self.clips))]
+    
+    def __getitem__(self, index):
+        return self.clips[index]
+        
+    def __len__(self):
+        return len(self.clips)
+    
+class VideoFrameLoader(Dataset):
+    
+    def __init__(self, video_path, transform=None):
+        self.transform = transform
+        
+        self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
+        self.videos = []
+        for label in self.labels:
+            self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in listdir(join(video_path, label)) if isfile(join(video_path, label, f))])
+        
+        self.frame_labels = []
+        self.frames = []
+        for label, path, filename in self.videos:
+            video, _, info = read_video(path)
+            video = torch.swapaxes(video, 1, 3)
+            video = video.swapaxes(0, 1).swapaxes(2, 3)
+
+            if self.transform:
+                video = self.transform(video)
+                
+            for frame in range(video.shape[1]):
+                self.frames.append((label, video[:, frame, :, :].unsqueeze(0), filename))
+        
+    def get_labels(self):
+        return [self.frames[i][0] for i in range(len(self.videos))]
+    
+    def get_filenames(self):
+        return [self.frames[i][2] for i in range(len(self.videos))]
+    
+    def __getitem__(self, index):
+        return self.frames[index]
+        
+    def __len__(self):
+        return len(self.frames)
+    
+class YoloClipLoader(Dataset):
+    
+    def __init__(self, yolo_output_path, num_frames=5, frames_between_clips=None,
+                 transform=None, augment_transform=None, num_clips=1, num_positives=1, positive_videos=None, sparse_model=None, device=None):
+        if (num_frames % 2) == 0:
+            raise ValueError("Num Frames must be an odd number, so we can extract a clip centered on each detected region")
+        
+        clip_cache_file = 'clip_cache.pt'
+        
+        self.num_clips = num_clips
+        
+        self.num_frames = num_frames
+        if frames_between_clips is None:
+            self.frames_between_clips = num_frames
+        else:
+            self.frames_between_clips = frames_between_clips
+
+        self.transform = transform
+        self.augment_transform = augment_transform
+         
+        self.labels = [name for name in listdir(yolo_output_path) if isdir(join(yolo_output_path, name))]
+        self.clips = []
+        if exists(clip_cache_file):
+            self.clips = torch.load(open(clip_cache_file, 'rb'))
+        else:
+            for label in self.labels:
+                print("Processing videos in category: {}".format(label))
+                videos = list(listdir(join(yolo_output_path, label)))
+                for vi in tqdm(range(len(videos))):
+                    video = videos[vi]
+                    counter = 0
+                    all_trimmed = []
+                    with open(abspath(join(yolo_output_path, label, video, 'result.json'))) as fin:
+                        results = json.load(fin)
+                        max_frame = len(results)
+
+                        for i in range((num_frames-1)//2, max_frame - (num_frames-1)//2 - 1, self.frames_between_clips):
+                        # for frame in results:
+                            frame = results[i]
+                            # print('loading frame:', i, frame['frame_id'])
+                            frame_start = int(frame['frame_id']) - self.num_frames//2
+                            frames = [abspath(join(yolo_output_path, label, video, 'frame{}.png'.format(frame_start+fid)))
+                                      for fid in range(num_frames)]
+                            # print(frames)
+                            frames = torch.stack([ToTensor()(Image.open(f).convert("RGB")) for f in frames]).swapaxes(0, 1)
+
+                            for region in frame['objects']:
+                                # print(region)
+                                if region['name'] != "Pleural_Line":
+                                    continue
+
+                                center_x = region['relative_coordinates']["center_x"] * 1920
+                                center_y = region['relative_coordinates']['center_y'] * 1080
+
+                                # width = region['relative_coordinates']['width'] * 1920
+                                # height = region['relative_coordinates']['height'] * 1080
+                                width=200
+                                height=100
+
+                                lower_y = round(center_y - height / 2)
+                                upper_y = round(center_y + height / 2)
+                                lower_x = round(center_x - width / 2)
+                                upper_x = round(center_x + width / 2)
+
+                                final_clip = frames[:, :, lower_y:upper_y, lower_x:upper_x]
+
+                                if self.transform:
+                                    final_clip = self.transform(final_clip)
+
+                                if sparse_model:
+                                    with torch.no_grad():
+                                        final_clip = final_clip.unsqueeze(0).to(device)
+                                        final_clip = sparse_model(final_clip)
+                                        final_clip = final_clip.squeeze(0).detach().cpu()
+
+                                self.clips.append((label, final_clip, video))
+
+            torch.save(self.clips, open(clip_cache_file, 'wb+'))
+            
+            
+#         random.shuffle(self.clips)
+            
+#         video_to_clips = {}
+        if positive_videos:
+            vids_to_keep = json.load(open(positive_videos))
+            
+            self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2] in vids_to_keep or clip_tup[0] == 'Sliding']
+        else:
+            video_to_labels = {}
+
+            for lbl, clip, video in self.clips:
+                video = video.lower().replace('_clean', '')
+                if video not in video_to_labels:
+    #                 video_to_clips[video] = []
+                    video_to_labels[video] = []
+
+    #             video_to_clips[video].append(clip)
+                video_to_labels[video].append(lbl)
+
+            video_to_participants = get_video_participants()
+            participants_to_video = {}
+            for k, v in video_to_participants.items():
+                if video_to_labels[k][0] == 'Sliding':
+                    continue
+                if not v in participants_to_video:
+                    participants_to_video[v] = []
+
+                participants_to_video[v].append(k)
+
+            participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
+
+            num_to_remove = len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) - num_positives
+            vids_to_remove = set()
+            while num_to_remove > 0:
+                vids_to_remove.add(participants_to_video[list(participants_to_video.keys())[0]].pop())
+                participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
+                num_to_remove -= 1
+                    
+            self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2].lower().replace('_clean', '') not in vids_to_remove]
+        
+        video_to_clips = {}
+        video_to_labels = {}
+
+        for lbl, clip, video in self.clips:
+            if video not in video_to_clips:
+                video_to_clips[video] = []
+                video_to_labels[video] = []
+
+            video_to_clips[video].append(clip)
+            video_to_labels[video].append(lbl)
+            
+        print([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding'])
+            
+        print('Num positive:', len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']))
+        print('Num negative:', len([k for k,v in video_to_labels.items() if v[0] == 'Sliding']))
+
+        self.videos = None
+        self.max_video_clips = 0
+        if num_clips > 1:
+            self.videos = []
+
+            for video in video_to_clips.keys():
+                clip_list = video_to_clips[video]
+                lbl_list = video_to_labels[video]
+                
+                for i in range(0, len(clip_list) - num_clips, 1):
+                    video_stack = torch.stack(clip_list[i:i+num_clips])
+                
+                    self.videos.append((max(set(lbl_list[i:i+num_clips]), key=lbl_list[i:i+num_clips].count), video_stack, video))
+            
+            self.clips = None
+
+            
+    def get_labels(self):
+        if self.num_clips > 1:
+            return [self.videos[i][0] for i in range(len(self.videos))]
+        else:
+            return [self.clips[i][0] for i in range(len(self.clips))]
+    
+    def get_filenames(self):
+        if self.num_clips > 1:
+            return [self.videos[i][2] for i in range(len(self.videos))]
+        else:
+            return [self.clips[i][2] for i in range(len(self.clips))]
+    
+    def __getitem__(self, index): 
+        if self.num_clips > 1:
+            label = self.videos[index][0]
+            video = self.videos[index][1]
+            filename = self.videos[index][2]
+            
+            video = video.squeeze(2)
+            video = video.permute(1, 0, 2, 3)
+
+            if self.augment_transform:
+                video = self.augment_transform(video)
+                
+            video = video.unsqueeze(2)
+            video = video.permute(1, 0, 2, 3, 4)
+#             video = video.permute(4, 1, 2, 3, 0)
+#             video = torch.nn.functional.pad(video, (0), 'constant', 0)
+#             video = video.permute(4, 1, 2, 3, 0)
+
+            orig_len = video.size(0)
+
+#             if orig_len < self.max_video_clips:
+#                 video = torch.cat([video, torch.zeros(self.max_video_clips - len(video), video.size(1), video.size(2), video.size(3), video.size(4))])
+
+            return label, video, filename, orig_len
+        else:
+            label = self.clips[index][0]
+            video = self.clips[index][1]
+            filename = self.clips[index][2]
+
+            if self.augment_transform:
+                video = self.augment_transform(video)
+
+            return label, video, filename
+        
+    def __len__(self):
+        return len(self.clips)
+    
+class YoloVideoLoader(Dataset):
+    
+    def __init__(self, yolo_output_path, num_frames=5, frames_between_clips=None,
+                 transform=None, augment_transform=None, num_clips=1, num_positives=1, sparse_model=None, device=None):
+        if (num_frames % 2) == 0:
+            raise ValueError("Num Frames must be an odd number, so we can extract a clip centered on each detected region")
+        
+        clip_cache_file = 'video_cache.pt'
+        
+        self.num_clips = num_clips
+        
+        self.num_frames = num_frames
+        if frames_between_clips is None:
+            self.frames_between_clips = num_frames
+        else:
+            self.frames_between_clips = frames_between_clips
+
+        self.transform = transform
+        self.augment_transform = augment_transform
+         
+        self.labels = [name for name in listdir(yolo_output_path) if isdir(join(yolo_output_path, name))]
+        self.clips = []
+        if exists(clip_cache_file) and sparse_model:
+            self.clips = torch.load(open(clip_cache_file, 'rb'))
+        else:
+            for label in self.labels:
+                print("Processing videos in category: {}".format(label))
+                videos = list(listdir(join(yolo_output_path, label)))
+                for vi in tqdm(range(len(videos))):
+                    video = videos[vi]
+                    with open(abspath(join(yolo_output_path, label, video, 'result.json'))) as fin:
+                        results = json.load(fin)
+                        max_frame = len(results)
+                        
+                        regions = None
+
+                        for i in range((num_frames-1)//2, max_frame - (num_frames-1)//2 - 1, self.frames_between_clips):
+                        # for frame in results:
+                            frame = results[i]
+                            # print('loading frame:', i, frame['frame_id'])
+                            frame_start = int(frame['frame_id']) - self.num_frames//2
+                            frames = [abspath(join(yolo_output_path, label, video, 'frame{}.png'.format(frame_start+fid)))
+                                      for fid in range(num_frames)]
+                            # print(frames)
+                            frames = torch.stack([ToTensor()(Image.open(f).convert("RGB")) for f in frames]).swapaxes(0, 1)
+                            
+                            if regions is None:
+                                regions = frame['objects']
+
+                            for region_idx, region in enumerate(regions):
+                                # print(region)
+                                if region['name'] != "Pleural_Line":
+                                    continue
+
+                                center_x = region['relative_coordinates']["center_x"] * 1920
+                                center_y = region['relative_coordinates']['center_y'] * 1080
+
+                                # width = region['relative_coordinates']['width'] * 1920
+                                # height = region['relative_coordinates']['height'] * 1080
+                                width=400
+                                height=400
+
+                                lower_y = round(center_y - height / 2)
+                                upper_y = round(center_y + height / 2)
+                                lower_x = round(center_x - width / 2)
+                                upper_x = round(center_x + width / 2)
+
+                                final_clip = frames[:, :, lower_y:upper_y, lower_x:upper_x]
+
+                                if self.transform:
+                                    final_clip = self.transform(final_clip)
+
+                                if sparse_model:
+                                    with torch.no_grad():
+                                        final_clip = final_clip.unsqueeze(0).to(device)
+                                        final_clip = sparse_model(final_clip)
+                                        final_clip = final_clip.squeeze(0).detach().cpu()
+
+                                self.clips.append((label, final_clip, video, region_idx))
+
+            torch.save(self.clips, open(clip_cache_file, 'wb+'))
+            
+            
+#         random.shuffle(self.clips)
+            
+#         video_to_clips = {}
+        video_to_labels = {}
+
+        for lbl, clip, video, region_idx in self.clips:
+            video = video.lower().replace('_clean', '')
+            if video not in video_to_labels:
+#                 video_to_clips[video] = []
+                video_to_labels[video] = []
+
+#             video_to_clips[video].append(clip)
+            video_to_labels[video].append(lbl)
+            
+        video_to_participants = get_video_participants()
+        participants_to_video = {}
+        for k, v in video_to_participants.items():
+            if video_to_labels[k][0] == 'Sliding':
+                continue
+            if not v in participants_to_video:
+                participants_to_video[v] = []
+                
+            participants_to_video[v].append(k)
+
+        participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
+        
+        num_to_remove = len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) - num_positives
+        vids_to_remove = set()
+        while num_to_remove > 0:
+            vids_to_remove.add(participants_to_video[list(participants_to_video.keys())[0]].pop())
+            participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True))
+            num_to_remove -= 1
+                    
+        self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2].lower().replace('_clean', '') not in vids_to_remove]
+        
+        video_to_clips = {}
+        video_to_labels = {}
+
+        for lbl, clip, video, region_idx in self.clips:
+            if video not in video_to_clips:
+                video_to_clips[video] = []
+                video_to_labels[video] = []
+
+            video_to_clips[video].append((clip, region_idx))
+            video_to_labels[video].append(lbl)
+            
+        print([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding'])
+            
+        print('Num positive:', len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']))
+        print('Num negative:', len([k for k,v in video_to_labels.items() if v[0] == 'Sliding']))
+
+        self.videos = None
+        self.max_video_clips = 0
+        self.videos = []
+
+        for video in video_to_clips.keys():
+            clip_list = video_to_clips[video]
+            lbl_list = video_to_labels[video]
+            
+            region_sets = {}
+            
+            for clip in clip_list:
+                if clip[1] not in region_sets:
+                    region_sets[clip[1]] = []
+                    
+                region_sets[clip[1]].append(clip[0])
+                
+            for k,v in region_sets.items():
+                region_sets[k] = torch.stack(v)
+
+            video_stack = torch.stack(list(region_sets.values()))
+
+            self.videos.append((lbl_list[0], video_stack, video))
+
+        self.clips = None
+
+            
+    def get_labels(self):
+        return [self.videos[i][0] for i in range(len(self.videos))]
+    
+    def get_filenames(self):
+        return [self.videos[i][2] for i in range(len(self.videos))]
+    
+    def __getitem__(self, index):
+        label = self.videos[index][0]
+        video = self.videos[index][1]
+        filename = self.videos[index][2]
+        
+        if video.size(0) < 3:
+            video = torch.cat([video, torch.zeros(3 - video.size(0), video.size(1), video.size(2), video.size(3), video.size(4), video.size(5))])
+            
+        video = video[:3, :, :, :, :, :]
+
+        video = video.permute(1, 0, 2, 3, 4, 5)
+        
+        if video.size(0) < 10:
+            video = torch.cat([video, torch.zeros(12 - video.size(0), video.size(1), video.size(2), video.size(3), video.size(4), video.size(5))])
+            
+        video = video[:10, :, :, :, :, :]
+            
+        video = video.permute(1, 0, 2, 3, 4, 5)
+            
+        video = video.squeeze(3)
+            
+        video = video.reshape(30, video.size(2), video.size(3), video.size(4))
+                               
+        if self.augment_transform:
+            video = self.augment_transform(video)
+
+        video = video.view(3, 10, video.size(1), video.size(2), video.size(3))
+        video = video.unsqueeze(3)
+#         video = video.permute(1, 0, 2, 3, 4)
+#             video = video.permute(4, 1, 2, 3, 0)
+#             video = torch.nn.functional.pad(video, (0), 'constant', 0)
+#             video = video.permute(4, 1, 2, 3, 0)
+
+        orig_len = video.size(0)
+
+#             if orig_len < self.max_video_clips:
+#                 video = torch.cat([video, torch.zeros(self.max_video_clips - len(video), video.size(1), video.size(2), video.size(3), video.size(4))])
+
+        return label, video, filename, orig_len
+        
+    def __len__(self):
+        return len(self.videos)
+    
+class MobileLoader(Dataset):
+
+    def __init__(self, yolo_output_path, transform=None, augment_transform=None, num_positives=1):
+        self.transform = transform
+        self.augment_transform = augment_transform
+
+        self.labels = [name for name in listdir(yolo_output_path) if isdir(join(yolo_output_path, name))]
+        self.clips = []
+        for label in self.labels:
+            print("Processing videos in category: {}".format(label))
+            videos = list(listdir(join(yolo_output_path, label)))
+            for vi in tqdm(range(len(videos))):
+                video = videos[vi]
+                vid_tensor, _, _ = tv.io.read_video(join(yolo_output_path, label, video))
+                vid_tensor = vid_tensor.permute(0, 3, 1, 2)
+
+                if self.transform:
+                    vid_tensor = self.transform(vid_tensor)
+
+                self.clips.append((label, vid_tensor, video[:video.rfind('_')]))
+
+        random.shuffle(self.clips)
+
+        video_to_clips = {}
+        video_to_labels = {}
+
+        for lbl, clip, video in self.clips:
+            if video not in video_to_clips:
+                video_to_clips[video] = []
+                video_to_labels[video] = []
+
+            video_to_clips[video].append(clip)
+            video_to_labels[video].append(lbl)
+
+        pos_count = 0
+        vids_to_remove = []
+        for vid, lbl_list in video_to_labels.items():
+            if lbl_list[0] == 'No_Sliding':
+                pos_count += 1
+                if pos_count > num_positives:
+                    vids_to_remove.append(vid)
+
+        self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2] not in vids_to_remove]
+
+        video_to_clips = {}
+        video_to_labels = {}
+
+        for lbl, clip, video in self.clips:
+            if video not in video_to_clips:
+                video_to_clips[video] = []
+                video_to_labels[video] = []
+
+            video_to_clips[video].append(clip)
+            video_to_labels[video].append(lbl)
+
+        print('Num positive:', len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']))
+        print('Num negative:', len([k for k,v in video_to_labels.items() if v[0] == 'Sliding']))
+
+
+    def get_labels(self):
+        return [self.clips[i][0] for i in range(len(self.clips))]
+
+    def get_filenames(self):
+        return [self.clips[i][2] for i in range(len(self.clips))]
+
+    def __getitem__(self, index):
+        label = self.clips[index][0]
+        video = self.clips[index][1]
+        filename = self.clips[index][2]
+
+        if self.augment_transform:
+            video = self.augment_transform(video)
+
+        return label, video, filename, 0
+
+    def __len__(self):
+        return len(self.clips)
+
+if __name__ == "__main__":
+    video_path = "/shared_data/bamc_data/"
+    
+    transforms = tv.transforms.Compose([VideoGrayScaler()])
+
+    # dataset = VideoLoader(video_path, transform=transforms, num_frames=60)
+    dataset = VideoClipLoader(video_path, transform=transforms, num_frames=20)
+    #for data in dataset:
+    #    print(data[0], data[1].shape)
+
+    loader = DataLoader(
+        dataset,
+        batch_size=2,
+        shuffle=True)
+    
+    for data in loader:
+        print(data[0], data[1].shape, data[2])
+        #print(data)
\ No newline at end of file
diff --git a/notebooks/YOLO-to-CoreML.ipynb b/notebooks/YOLO-to-CoreML.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..dfac3ce88eacc6d66e5d380c8ad52a89c82fe7ad
--- /dev/null
+++ b/notebooks/YOLO-to-CoreML.ipynb
@@ -0,0 +1,776 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "83fa2e70",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import tensorflow as tf\n",
+    "from absl import app, flags, logging\n",
+    "from absl.flags import FLAGS\n",
+    "import yolov4.core.utils as utils\n",
+    "from yolov4.core.yolov4 import filter_boxes\n",
+    "from tensorflow.python.saved_model import tag_constants\n",
+    "from PIL import Image\n",
+    "import cv2\n",
+    "import numpy as np\n",
+    "from tensorflow.compat.v1 import ConfigProto\n",
+    "from tensorflow.compat.v1 import InteractiveSession\n",
+    "import coremltools as ct \n",
+    "import time\n",
+    "import os\n",
+    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   # see issue #152\n",
+    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"3\"\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "badae98a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')\n",
+    "flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow',\n",
+    "                    'path to weights file')\n",
+    "flags.DEFINE_integer('size', 416, 'resize images to')\n",
+    "flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')\n",
+    "flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')\n",
+    "flags.DEFINE_string('image', '/shared_data/YOLO_Updated_PL_Model_Results/Sliding/image_677741729740_clean/frame0.png', 'path to input image')\n",
+    "flags.DEFINE_string('output', 'result.png', 'path to output image')\n",
+    "flags.DEFINE_float('iou', 0.45, 'iou threshold')\n",
+    "flags.DEFINE_float('score', 0.25, 'score threshold')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2a6b59b5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "FLAGS(['yolov4/detect.py'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6eb766a7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "config = ConfigProto()\n",
+    "config.gpu_options.allow_growth = True\n",
+    "session = InteractiveSession(config=config)\n",
+    "STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)\n",
+    "input_size = FLAGS.size\n",
+    "image_path = FLAGS.image\n",
+    "\n",
+    "# original_image = cv2.imread(\"/shared_data/YOLO_Updated_PL_Model_Results/Sliding/image_677741729740_clean/frame0.png\")\n",
+    "original_image = cv2.imread(\"/shared_data/YOLO_Updated_PL_Model_Results/No_Sliding/Image_262499828648_clean/frame0.png\")\n",
+    "original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)\n",
+    "\n",
+    "# image_data = utils.image_preprocess(np.copy(original_image), [input_size, input_size])\n",
+    "image_data = cv2.resize(original_image, (input_size, input_size))\n",
+    "image_data = image_data / 255.\n",
+    "# image_data = image_data[np.newaxis, ...].astype(np.float32)\n",
+    "\n",
+    "images_data = []\n",
+    "for i in range(1):\n",
+    "    images_data.append(image_data)\n",
+    "images_data = np.asarray(images_data).astype(np.float32)\n",
+    "\n",
+    "if FLAGS.framework == 'tflite':\n",
+    "    interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)\n",
+    "    interpreter.allocate_tensors()\n",
+    "    input_details = interpreter.get_input_details()\n",
+    "    output_details = interpreter.get_output_details()\n",
+    "    print(input_details)\n",
+    "    print(output_details)\n",
+    "    interpreter.set_tensor(input_details[0]['index'], images_data)\n",
+    "    interpreter.invoke()\n",
+    "    pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]\n",
+    "    if FLAGS.model == 'yolov3' and FLAGS.tiny == True:\n",
+    "        boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))\n",
+    "    else:\n",
+    "        boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))\n",
+    "else:\n",
+    "    saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])\n",
+    "    infer = saved_model_loaded.signatures['serving_default']\n",
+    "    \n",
+    "    start_time = time.time()\n",
+    "   \n",
+    "    print(image_data.shape)\n",
+    "    print(image_data.min())\n",
+    "    print(image_data.max())\n",
+    "    batch_data = tf.constant(images_data)\n",
+    "    \n",
+    "    pred_bbox = infer(batch_data)\n",
+    "    print(pred_bbox)\n",
+    "    print(pred_bbox['tf_op_layer_concat_18'].shape)\n",
+    "    for key, value in pred_bbox.items():\n",
+    "        boxes = value[:, :, 0:4]\n",
+    "        pred_conf = value[:, :, 4:]\n",
+    "        print(\"VALUE\", value)\n",
+    "        \n",
+    "    print('boxes')\n",
+    "    print(tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)))\n",
+    "\n",
+    "    print('conf')\n",
+    "    print(tf.reshape(\n",
+    "        pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])))\n",
+    "\n",
+    "boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(\n",
+    "    boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),\n",
+    "    scores=tf.reshape(\n",
+    "        pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),\n",
+    "    max_output_size_per_class=50,\n",
+    "    max_total_size=50,\n",
+    "    iou_threshold=0.5,\n",
+    "    score_threshold=0.25\n",
+    ")\n",
+    "end_time = time.time()\n",
+    "\n",
+    "pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]\n",
+    "print(pred_bbox)\n",
+    "print(end_time-start_time)\n",
+    "# image = utils.draw_bbox(original_image, pred_bbox)\n",
+    "# # image = utils.draw_bbox(image_data*255, pred_bbox)\n",
+    "# image = Image.fromarray(image.astype(np.uint8))\n",
+    "# image.show()\n",
+    "# image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)\n",
+    "# cv2.imwrite(FLAGS.output, image)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6e0378a9",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# coreml_model = ct.convert(saved_model_loaded, source=\"tensorflow\")\n",
+    "coreml_model = ct.convert(\"yolov4/Pleural_Line_TensorFlow/\", source=\"tensorflow\",\n",
+    "                          inputs=[ct.ImageType(name=\"input_1\", shape=(1, 416, 416, 3), scale=1/255)],\n",
+    "                          outputs=[ct.TensorType(shape=(1, 18, 5))])\n",
+    "print(coreml_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 52,
+   "id": "4a5e0b6d",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Model: \"model_9\"\n",
+      "__________________________________________________________________________________________________\n",
+      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
+      "==================================================================================================\n",
+      " image (InputLayer)             [(None, 416, 416, 3  0           []                               \n",
+      "                                )]                                                                \n",
+      "                                                                                                  \n",
+      " functional_1 (Functional)      (None, None, 5)      64003990    ['image[0][0]']                  \n",
+      "                                                                                                  \n",
+      " max_pooling2d_9 (MaxPooling2D)  (None, 6, 6, 3)     0           ['image[0][0]']                  \n",
+      "                                                                                                  \n",
+      " tf.__operators__.getitem_20 (S  (None, None, 4)     0           ['functional_1[0][0]']           \n",
+      " licingOpLambda)                                                                                  \n",
+      "                                                                                                  \n",
+      " flatten_9 (Flatten)            (None, 108)          0           ['max_pooling2d_9[0][0]']        \n",
+      "                                                                                                  \n",
+      " tf.__operators__.getitem_21 (S  (None, None, 1)     0           ['functional_1[0][0]']           \n",
+      " licingOpLambda)                                                                                  \n",
+      "                                                                                                  \n",
+      " tf.reshape_20 (TFOpLambda)     (None, 4)            0           ['tf.__operators__.getitem_20[0][\n",
+      "                                                                 0]']                             \n",
+      "                                                                                                  \n",
+      " dense_14 (Dense)               (None, 4)            436         ['flatten_9[0][0]']              \n",
+      "                                                                                                  \n",
+      " tf.reshape_21 (TFOpLambda)     (None, 1)            0           ['tf.__operators__.getitem_21[0][\n",
+      "                                                                 0]']                             \n",
+      "                                                                                                  \n",
+      " dense_13 (Dense)               (None, 1)            109         ['flatten_9[0][0]']              \n",
+      "                                                                                                  \n",
+      " tf.concat_12 (TFOpLambda)      (None, 4)            0           ['tf.reshape_20[0][0]',          \n",
+      "                                                                  'dense_14[0][0]']               \n",
+      "                                                                                                  \n",
+      " tf.concat_13 (TFOpLambda)      (None, 1)            0           ['tf.reshape_21[0][0]',          \n",
+      "                                                                  'dense_13[0][0]']               \n",
+      "                                                                                                  \n",
+      "==================================================================================================\n",
+      "Total params: 64,004,535\n",
+      "Trainable params: 63,938,231\n",
+      "Non-trainable params: 66,304\n",
+      "__________________________________________________________________________________________________\n"
+     ]
+    }
+   ],
+   "source": [
+    "import tensorflow as tf\n",
+    "import coremltools as ct\n",
+    "keras_model = tf.keras.models.load_model(\"yolov4/Pleural_Line_TensorFlow/\", compile=False)\n",
+    "input_image = tf.keras.layers.Input(shape=(416, 416, 3), name=\"image\")\n",
+    "pred_bbox = keras_model(input_image)\n",
+    "x = tf.keras.layers.MaxPooling2D(pool_size=(64, 64), strides=64)(input_image)\n",
+    "x = tf.keras.layers.Flatten()(x)\n",
+    "x1 = tf.keras.layers.Dense(1)(x)\n",
+    "x2 = tf.keras.layers.Dense(4)(x)\n",
+    "\n",
+    "boxes = tf.reshape(pred_bbox[:, :, 0:4], (-1, 4))\n",
+    "scores = tf.reshape(pred_bbox[:, :, 4:], (-1, 1))\n",
+    "\n",
+    "coordinates = tf.concat([boxes, x2], axis=0)\n",
+    "confidence = tf.concat([scores, x1], axis=0)\n",
+    "\n",
+    "full_model = tf.keras.Model(inputs=[input_image], outputs=[coordinates, confidence])\n",
+    "# full_model.save('keras_model.tf')\n",
+    "full_model.summary()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "176c17f1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "keras_model.summary()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "id": "d8886f5d",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2022-02-25 18:00:05.628994: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.904811: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.907324: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.909025: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.910679: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 4\n",
+      "2022-02-25 18:00:05.910860: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session\n",
+      "2022-02-25 18:00:05.911372: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.911991: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.913725: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.915388: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.917040: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.917624: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.919367: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.921046: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.922705: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.923712: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.925349: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.927106: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.928933: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.929554: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.931319: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.932969: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.934676: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.935244: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43667 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:05.935301: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.936915: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 43667 MB memory:  -> device: 1, name: NVIDIA A40, pci bus id: 0000:02:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:05.936964: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.938718: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 43667 MB memory:  -> device: 2, name: NVIDIA A40, pci bus id: 0000:03:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:05.938769: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:05.940449: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 43667 MB memory:  -> device: 3, name: NVIDIA A40, pci bus id: 0000:04:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:05.979593: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1164] Optimization results for grappler item: graph_to_optimize\n",
+      "  function_optimizer: function_optimizer did nothing. time = 0.026ms.\n",
+      "  function_optimizer: function_optimizer did nothing. time = 0.001ms.\n",
+      "\n",
+      "2022-02-25 18:00:09.008014: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.327026: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.328914: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.330630: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.332335: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 4\n",
+      "2022-02-25 18:00:09.332520: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session\n",
+      "2022-02-25 18:00:09.333118: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.333729: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.335393: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.337073: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.338798: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.339376: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.341060: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.342778: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.344469: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.345047: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.346743: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.348426: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.350273: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.350885: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.352581: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.354284: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.355959: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.356532: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43667 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:09.356589: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.358218: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 43667 MB memory:  -> device: 1, name: NVIDIA A40, pci bus id: 0000:02:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:09.358269: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.359906: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 43667 MB memory:  -> device: 2, name: NVIDIA A40, pci bus id: 0000:03:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:09.359958: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:09.361590: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 43667 MB memory:  -> device: 3, name: NVIDIA A40, pci bus id: 0000:04:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:10.192380: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1164] Optimization results for grappler item: graph_to_optimize\n",
+      "  constant_folding: Graph size after: 1374 nodes (-545), 2070 edges (-543), time = 269.379ms.\n",
+      "  dependency_optimizer: Graph size after: 1369 nodes (-5), 1519 edges (-551), time = 93.216ms.\n",
+      "  debug_stripper: debug_stripper did nothing. time = 0.134ms.\n",
+      "  constant_folding: Graph size after: 1369 nodes (0), 1519 edges (0), time = 153.604ms.\n",
+      "  dependency_optimizer: Graph size after: 1369 nodes (0), 1519 edges (0), time = 35.265ms.\n",
+      "  debug_stripper: debug_stripper did nothing. time = 0.137ms.\n",
+      "\n",
+      "2022-02-25 18:00:11.323512: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.404437: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.406309: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.408134: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.409772: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 4\n",
+      "2022-02-25 18:00:11.409938: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session\n",
+      "2022-02-25 18:00:11.410517: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.411145: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.412841: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.414549: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.416219: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.416812: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.418500: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.420204: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.421909: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.422489: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.424211: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.425869: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.427777: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.428379: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.430070: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.431793: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.433444: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.434037: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43667 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:11.434098: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.435805: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 43667 MB memory:  -> device: 1, name: NVIDIA A40, pci bus id: 0000:02:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:11.435865: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.437546: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 43667 MB memory:  -> device: 2, name: NVIDIA A40, pci bus id: 0000:03:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:11.437597: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:11.439303: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 43667 MB memory:  -> device: 3, name: NVIDIA A40, pci bus id: 0000:04:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:11.480292: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1164] Optimization results for grappler item: graph_to_optimize\n",
+      "  function_optimizer: function_optimizer did nothing. time = 0.028ms.\n",
+      "  function_optimizer: function_optimizer did nothing. time = 0.001ms.\n",
+      "\n",
+      "2022-02-25 18:00:14.040474: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.165108: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.167584: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.169335: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.171051: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 4\n",
+      "2022-02-25 18:00:14.171255: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session\n",
+      "2022-02-25 18:00:14.171759: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.172405: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.174092: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.175779: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.177489: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.178119: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.179800: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.181528: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.183230: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.183833: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.185573: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.187265: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.189126: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.189757: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.191637: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.193314: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.195016: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.195609: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43667 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:14.195668: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.197317: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 43667 MB memory:  -> device: 1, name: NVIDIA A40, pci bus id: 0000:02:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:14.197367: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.199037: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 43667 MB memory:  -> device: 2, name: NVIDIA A40, pci bus id: 0000:03:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:14.199090: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:14.200730: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 43667 MB memory:  -> device: 3, name: NVIDIA A40, pci bus id: 0000:04:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:15.065896: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1164] Optimization results for grappler item: graph_to_optimize\n",
+      "  constant_folding: Graph size after: 1374 nodes (-545), 2070 edges (-543), time = 282.382ms.\n",
+      "  dependency_optimizer: Graph size after: 1369 nodes (-5), 1519 edges (-551), time = 89.576ms.\n",
+      "  debug_stripper: debug_stripper did nothing. time = 0.128ms.\n",
+      "  constant_folding: Graph size after: 1369 nodes (0), 1519 edges (0), time = 150.763ms.\n",
+      "  dependency_optimizer: Graph size after: 1369 nodes (0), 1519 edges (0), time = 34.295ms.\n",
+      "  debug_stripper: debug_stripper did nothing. time = 0.131ms.\n",
+      "\n",
+      "Running TensorFlow Graph Passes:   0%|          | 0/5 [00:00<?, ? passes/s]2022-02-25 18:00:47.229318: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.232332: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.235434: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.238343: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.241184: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.243030: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.245938: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.248798: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.251887: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.253696: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.256440: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.259340: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.262699: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.264679: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.267458: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.270192: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.273068: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.274988: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43667 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:47.275159: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.278100: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 43667 MB memory:  -> device: 1, name: NVIDIA A40, pci bus id: 0000:02:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:47.278247: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.281541: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 43667 MB memory:  -> device: 2, name: NVIDIA A40, pci bus id: 0000:03:00.0, compute capability: 8.6\n",
+      "2022-02-25 18:00:47.281697: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-02-25 18:00:47.284339: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 43667 MB memory:  -> device: 3, name: NVIDIA A40, pci bus id: 0000:04:00.0, compute capability: 8.6\n",
+      "Running TensorFlow Graph Passes: 100%|██████████| 5/5 [00:08<00:00,  1.64s/ passes]\n",
+      "Converting Frontend ==> MIL Ops: 100%|██████████| 1383/1383 [00:04<00:00, 293.82 ops/s]\n",
+      "Running MIL Common passes: 100%|██████████| 34/34 [00:06<00:00,  5.63 passes/s]\n",
+      "Running MIL Clean up passes: 100%|██████████| 9/9 [00:00<00:00, 14.32 passes/s]\n",
+      "Translating MIL ==> NeuralNetwork Ops: 100%|██████████| 1574/1574 [00:08<00:00, 179.57 ops/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = ct.convert(\n",
+    "    full_model, \n",
+    "    inputs=[ct.ImageType(scale=1/255.0)]\n",
+    ")\n",
+    "\n",
+    "output_sizes = [4, 1]\n",
+    "for i in range(2):\n",
+    "    ma_type = model._spec.description.output[i].type.multiArrayType\n",
+    "    ma_type.shapeRange.sizeRanges.add()\n",
+    "    ma_type.shapeRange.sizeRanges[0].lowerBound = 0\n",
+    "    ma_type.shapeRange.sizeRanges[0].upperBound = -1\n",
+    "    ma_type.shapeRange.sizeRanges.add()\n",
+    "    ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]\n",
+    "    ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]\n",
+    "    del ma_type.shape[:]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "63b824f7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ct.utils.rename_feature(model._spec, 'Identity', 'coordinates')\n",
+    "# ct.utils.rename_feature(model._spec, 'Identity_1', 'confidence')\n",
+    "\n",
+    "# # Add descriptions to the inputs and outputs.\n",
+    "# model._spec.description.output[1].shortDescription = u\"Boxes \\xd7 Class confidence\"\n",
+    "# model._spec.description.output[0].shortDescription = u\"Boxes \\xd7 [x, y, width, height] (relative to image size)\"\n",
+    "\n",
+    "# # Add metadata to the model.\n",
+    "# pipeline.spec.description.metadata.versionString = \"YOLO v4 - Pleural Line Detector - 11/18/21\"\n",
+    "# pipeline.spec.description.metadata.shortDescription = \"YOLOv4 - Pleural Line Detector\"\n",
+    "# pipeline.spec.description.metadata.author = \"Converted to Core ML by Chris MacLellan\"\n",
+    "# pipeline.spec.description.metadata.license = \"None\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2ded46f7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "62600f9c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# model.save(\"YoloPleuralLine.mlmodel\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "id": "b580fdab",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "nms_spec = ct.proto.Model_pb2.Model()\n",
+    "nms_spec.specificationVersion = 5\n",
+    "for i in range(2):\n",
+    "    decoder_output = model._spec.description.output[i].SerializeToString()\n",
+    "    nms_spec.description.input.add()\n",
+    "    nms_spec.description.input[i].ParseFromString(decoder_output)\n",
+    "    nms_spec.description.output.add()\n",
+    "    nms_spec.description.output[i].ParseFromString(decoder_output)\n",
+    "    \n",
+    "nms_spec.description.output[0].name = \"coordinates\"\n",
+    "nms_spec.description.output[1].name = \"confidence\"\n",
+    "\n",
+    "# output_sizes = [1, 4]\n",
+    "# for i in range(2):\n",
+    "#     ma_type = nms_spec.description.output[i].type.multiArrayType\n",
+    "#     ma_type.shapeRange.sizeRanges.add()\n",
+    "#     ma_type.shapeRange.sizeRanges[0].lowerBound = 0\n",
+    "#     ma_type.shapeRange.sizeRanges[0].upperBound = -1\n",
+    "#     ma_type.shapeRange.sizeRanges.add()\n",
+    "#     ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]\n",
+    "#     ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]\n",
+    "#     del ma_type.shape[:]\n",
+    "    \n",
+    "#     ma_type2 = nms_spec.description.input[i].type.multiArrayType\n",
+    "#     ma_type2.shapeRange.sizeRanges.add()\n",
+    "#     ma_type2.shapeRange.sizeRanges[0].lowerBound = 0\n",
+    "#     ma_type2.shapeRange.sizeRanges[0].upperBound = -1\n",
+    "#     ma_type2.shapeRange.sizeRanges.add()\n",
+    "#     ma_type2.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]\n",
+    "#     ma_type2.shapeRange.sizeRanges[1].upperBound = output_sizes[i]\n",
+    "#     del ma_type2.shape[:]\n",
+    "    \n",
+    "nms = nms_spec.nonMaximumSuppression\n",
+    "nms.confidenceInputFeatureName = \"Identity_1\"\n",
+    "nms.coordinatesInputFeatureName = \"Identity\"\n",
+    "nms.confidenceOutputFeatureName = \"confidence\"\n",
+    "nms.coordinatesOutputFeatureName = \"coordinates\"\n",
+    "nms.iouThresholdInputFeatureName = \"iouThreshold\"\n",
+    "nms.confidenceThresholdInputFeatureName = \"confidenceThreshold\"\n",
+    "\n",
+    "# print(nms)\n",
+    "\n",
+    "default_iou_threshold = 0.75\n",
+    "default_confidence_threshold = 0.25\n",
+    "nms.iouThreshold = default_iou_threshold\n",
+    "nms.confidenceThreshold = default_confidence_threshold\n",
+    "nms.pickTop.perClass = True\n",
+    "\n",
+    "#print(type(nms))\n",
+    "\n",
+    "labels = ['Pleural_Line']\n",
+    "nms.stringClassLabels.vector.extend(labels)\n",
+    "\n",
+    "nms_model = ct.models.MLModel(nms_spec)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "72598c0e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "nms_model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 55,
+   "id": "d8f9632f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from coremltools.models import datatypes\n",
+    "from coremltools.models.pipeline import *\n",
+    "input_features = [ (\"image\", datatypes.Array(416, 416, 3))]\n",
+    "#                    (\"iouThreshold\", datatypes.Double()),\n",
+    "#                    (\"confidenceThreshold\", datatypes.Double()) ]\n",
+    "\n",
+    "output_features = [ \"confidence\", \"coordinates\" ]\n",
+    "\n",
+    "pipeline = Pipeline(input_features, output_features)\n",
+    "\n",
+    "pipeline.add_model(model)\n",
+    "# pipeline.add_model(nms_model)\n",
+    "\n",
+    "pipeline.spec.description.input[0].ParseFromString(\n",
+    "    model._spec.description.input[0].SerializeToString())\n",
+    "pipeline.spec.description.output[0].ParseFromString(\n",
+    "    model._spec.description.output[0].SerializeToString())\n",
+    "pipeline.spec.description.output[1].ParseFromString(\n",
+    "    model._spec.description.output[1].SerializeToString())\n",
+    "\n",
+    "# Add descriptions to the inputs and outputs.\n",
+    "# pipeline.spec.description.input[1].shortDescription = \"(optional) IOU Threshold override\"\n",
+    "# pipeline.spec.description.input[2].shortDescription = \"(optional) Confidence Threshold override\"\n",
+    "pipeline.spec.description.output[1].shortDescription = u\"Boxes \\xd7 Class confidence\"\n",
+    "pipeline.spec.description.output[0].shortDescription = u\"Boxes \\xd7 [x, y, width, height] (relative to image size)\"\n",
+    "\n",
+    "# Add metadata to the model.\n",
+    "pipeline.spec.description.metadata.versionString = \"YOLO v4 - Pleural Line Detector - 11/15/21\"\n",
+    "pipeline.spec.description.metadata.shortDescription = \"YOLOv4 - Pleural Line Detector\"\n",
+    "pipeline.spec.description.metadata.author = \"Converted to Core ML by Chris MacLellan\"\n",
+    "pipeline.spec.description.metadata.license = \"None\"\n",
+    "\n",
+    "labels = ['Pleural_Line']\n",
+    "# Add the list of class labels and the default threshold values too.\n",
+    "user_defined_metadata = {\n",
+    "    # \"iouThreshold\": str(default_iou_threshold),\n",
+    "    # \"confidenceThreshold\": str(default_confidence_threshold),\n",
+    "    \"classes\": \",\".join(labels)\n",
+    "}\n",
+    "pipeline.spec.description.metadata.userDefined.update(user_defined_metadata)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 56,
+   "id": "75a37bb2",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "input {\n",
+       "  name: \"image\"\n",
+       "  type {\n",
+       "    imageType {\n",
+       "      width: 416\n",
+       "      height: 416\n",
+       "      colorSpace: RGB\n",
+       "      imageSizeRange {\n",
+       "        widthRange {\n",
+       "          lowerBound: 416\n",
+       "          upperBound: 416\n",
+       "        }\n",
+       "        heightRange {\n",
+       "          lowerBound: 416\n",
+       "          upperBound: 416\n",
+       "        }\n",
+       "      }\n",
+       "    }\n",
+       "  }\n",
+       "}\n",
+       "output {\n",
+       "  name: \"coordinates\"\n",
+       "  shortDescription: \"Boxes \\303\\227 [x, y, width, height] (relative to image size)\"\n",
+       "  type {\n",
+       "    multiArrayType {\n",
+       "      dataType: FLOAT32\n",
+       "      shapeRange {\n",
+       "        sizeRanges {\n",
+       "          upperBound: -1\n",
+       "        }\n",
+       "        sizeRanges {\n",
+       "          lowerBound: 4\n",
+       "          upperBound: 4\n",
+       "        }\n",
+       "      }\n",
+       "    }\n",
+       "  }\n",
+       "}\n",
+       "output {\n",
+       "  name: \"confidence\"\n",
+       "  shortDescription: \"Boxes \\303\\227 Class confidence\"\n",
+       "  type {\n",
+       "    multiArrayType {\n",
+       "      dataType: FLOAT32\n",
+       "      shapeRange {\n",
+       "        sizeRanges {\n",
+       "          upperBound: -1\n",
+       "        }\n",
+       "        sizeRanges {\n",
+       "          lowerBound: 1\n",
+       "          upperBound: 1\n",
+       "        }\n",
+       "      }\n",
+       "    }\n",
+       "  }\n",
+       "}\n",
+       "metadata {\n",
+       "  shortDescription: \"YOLOv4 - Pleural Line Detector\"\n",
+       "  versionString: \"YOLO v4 - Pleural Line Detector - 11/15/21\"\n",
+       "  author: \"Converted to Core ML by Chris MacLellan\"\n",
+       "  license: \"None\"\n",
+       "  userDefined {\n",
+       "    key: \"classes\"\n",
+       "    value: \"Pleural_Line\"\n",
+       "  }\n",
+       "}"
+      ]
+     },
+     "execution_count": 56,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "final_model = ct.models.MLModel(pipeline.spec)\n",
+    "final_model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 57,
+   "id": "cdb2fddc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "final_model.save(\"YoloPleuralLine.mlmodel\")\n",
+    "# Load a saved model\n",
+    "# loaded_model = ct.models.MLModel(\"../Models/yolo/yolov4-tiny-416.mlmodel\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "95d7588d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "final_model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7199b22a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from PIL import Image\n",
+    "import numpy as np\n",
+    "example_image = Image.open(\"/shared_data/YOLO_Updated_PL_Model_Results/Sliding/image_677741729740_clean/frame0.png\").resize((416,416)).convert('RGB')\n",
+    "example_image"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a993d8f5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "out_dict = final_model.predict({\"image\": example_image, \"iouThreshold\": 0.5, \"confidenceThreshold\": 0.5})\n",
+    "out_dict"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "14e400ce",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python (pocus_project)",
+   "language": "python",
+   "name": "darryl_pocus"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.7"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/visualize_prediction_sparse_codes.ipynb b/notebooks/visualize_prediction_sparse_codes.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d87968f7f9731623044ab808ba0f307de91dac8b
--- /dev/null
+++ b/notebooks/visualize_prediction_sparse_codes.ipynb
@@ -0,0 +1,167 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0285a1bd-83ef-44b7-9654-012294fd5653",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import time\n",
+    "from datetime import datetime\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "import torchvision\n",
+    "from matplotlib import pyplot as plt\n",
+    "from matplotlib import cm\n",
+    "from matplotlib.animation import FuncAnimation\n",
+    "\n",
+    "from sparse_coding_torch.conv_sparse_model import ConvSparseLayer\n",
+    "from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d\n",
+    "\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "\n",
+    "from sparse_coding_torch.utils import plot_filters\n",
+    "from sparse_coding_torch.utils import plot_video\n",
+    "\n",
+    "from sparse_coding_torch.load_data import load_yolo_clips\n",
+    "\n",
+    "from IPython.display import HTML\n",
+    "\n",
+    "from tqdm import tqdm\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2aba945d-1794-4d47-b6b5-3577803347d0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
+    "batch_size = 1\n",
+    "    # batch_size = 3\n",
+    "\n",
+    "# train_loader = load_balls_data(batch_size)\n",
+    "train_loader, _ = load_yolo_clips(batch_size, mode='all_train', device=device, n_splits=1, sparse_model=None)\n",
+    "print('Loaded', len(train_loader), 'train examples')\n",
+    "\n",
+    "example_data = next(iter(train_loader))\n",
+    "\n",
+    "sparse_layer = ConvSparseLayer(in_channels=1,\n",
+    "                               out_channels=64,\n",
+    "                               kernel_size=(5, 15, 15),\n",
+    "                               stride=1,\n",
+    "                               padding=(0, 7, 7),\n",
+    "                               convo_dim=3,\n",
+    "                               rectifier=True,\n",
+    "                               lam=0.05,\n",
+    "                               max_activation_iter=200,\n",
+    "                               activation_lr=1e-1)\n",
+    "sparse_layer.to(device)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2ac3f6af-1f40-47ed-bf65-ccb0253b3066",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Load models if we'd like to\n",
+    "checkpoint = torch.load(\"/home/dwh48@drexel.edu/sparse_coding_torch/sparse_conv3d_model-pleural_clips2_5x15x15-11-14-21.pt\")\n",
+    "sparse_layer.load_state_dict(checkpoint['model_state_dict'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "22ce2015-183f-4f88-98b2-7cc23790c1d6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fp_ids = ['image_24164968068436_CLEAN', 'image_73815992352100_clean', 'image_74132233134844_clean']\n",
+    "fn_ids = ['image_610066411380_CLEAN', 'image_634125159704_CLEAN', 'image_588695055398_clean', 'image_584357289931_clean', 'Image_262499828648_clean', 'image_267456908021_clean', 'image_2743083265515_CLEAN', 'image_1749559540112_clean']\n",
+    "\n",
+    "incorrect_sparsity = []\n",
+    "correct_sparsity = []\n",
+    "incorrect_filter_act = torch.zeros(64)\n",
+    "correct_filter_act = torch.zeros(64)\n",
+    "\n",
+    "for labels, local_batch, vid_f in tqdm(train_loader):\n",
+    "    activations = sparse_layer(local_batch.to(device))\n",
+    "    sparsity = torch.count_nonzero(activations) / torch.numel(activations)\n",
+    "    filter_act = torch.sum(activations.squeeze(), dim=[1, 2])\n",
+    "    filter_act = filter_act / torch.max(filter_act)\n",
+    "    filter_act = filter_act.detach().cpu()\n",
+    "    \n",
+    "    if vid_f[0] in fp_ids or vid_f[0] in fn_ids:\n",
+    "        incorrect_sparsity.append(sparsity)\n",
+    "        incorrect_filter_act += filter_act\n",
+    "    else:\n",
+    "        correct_sparsity.append(sparsity)\n",
+    "        correct_filter_act += filter_act\n",
+    "        \n",
+    "print(torch.mean(torch.tensor(correct_sparsity)))\n",
+    "print(torch.mean(torch.tensor(incorrect_sparsity)))\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "48b8dfc9-9736-4b1e-bb4b-de6b69056103",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5b202266-6b44-4c8d-9442-b86e6ad9b11b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "filters = sparse_layer.filters.cpu().detach()\n",
+    "print(filters.size())\n",
+    "\n",
+    "filters = torch.stack([filters[val] for val in incorrect_filter_act.argsort(descending=True)])\n",
+    "\n",
+    "print(filters.size())\n",
+    "\n",
+    "ani = plot_filters(filters)\n",
+    "# HTML(ani.to_html5_video())\n",
+    "ani.save(\"/home/dwh48@drexel.edu/sparse_coding_torch/incorrect_vis.mp4\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c6d1f47c-ba4c-4c3a-9f96-4901c39e16e5",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python (pocus_project)",
+   "language": "python",
+   "name": "darryl_pocus"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.7"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/yolo.tflite b/yolo.tflite
new file mode 100644
index 0000000000000000000000000000000000000000..02c95fdb5d1e0bb228ba1cadf080e9e5d8790ac0
Binary files /dev/null and b/yolo.tflite differ