diff --git a/sparse_coding_torch/pnb/classifier_model.py b/sparse_coding_torch/pnb/classifier_model.py
index 8d6f11fd2200ca104f8a8915a29d4d40dcd22028..90cc850c20d7a8a441da9e99f0c6550ad2a6bcd1 100644
--- a/sparse_coding_torch/pnb/classifier_model.py
+++ b/sparse_coding_torch/pnb/classifier_model.py
@@ -55,7 +55,7 @@ class PNBClassifier(keras.layers.Layer):
 class PNBTemporalClassifier(keras.layers.Layer):
     def __init__(self):
         super(PNBTemporalClassifier, self).__init__()
-        self.conv_1 = keras.layers.Conv3D(24, kernel_size=(1, 200, 50), strides=(1, 1, 10), activation='relu', padding='valid')
+        self.conv_1 = keras.layers.Conv3D(24, kernel_size=(1, 250, 50), strides=(1, 1, 10), activation='relu', padding='valid')
         self.conv_2 = keras.layers.Conv2D(36, kernel_size=(5, 10), strides=(1, 5), activation='relu', padding='valid')
         self.conv_3 = keras.layers.Conv1D(48, kernel_size=2, strides=2, activation='relu', padding='valid')
         
@@ -73,7 +73,7 @@ class PNBTemporalClassifier(keras.layers.Layer):
         width = clip.shape[3]
         height = clip.shape[2]
         depth = clip.shape[1]
-        
+
         x = tf.expand_dims(clip, axis=4)
 #         x = tf.reshape(clip, (-1, height, width, 1))
 
diff --git a/sparse_coding_torch/pnb/train_classifier.py b/sparse_coding_torch/pnb/train_classifier.py
index 4255b170aaf654972b5d15ac27ec966d8836d590..9acc61d795d45b48ad289f97dfd107dd5633673c 100644
--- a/sparse_coding_torch/pnb/train_classifier.py
+++ b/sparse_coding_torch/pnb/train_classifier.py
@@ -139,10 +139,8 @@ def calculate_pnb_scores_skipped_frames(input_videos, labels, yolo_model, sparse
                 fn_ids.append(f)
             else:
                 fp_ids.append(f)
-            
-        print(float(pred[0]))
-        raise Exception
-        final_list.append(pred)
+
+        final_list.append(float(pred))
         
     return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids
 
@@ -238,8 +236,8 @@ if __name__ == "__main__":
     
     i_fold = 0
     for train_idx, test_idx in splits:
-        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-#         train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
+#         train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True)
         train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=train_sampler)
         
@@ -410,9 +408,6 @@ if __name__ == "__main__":
 
         y_true = tf.cast(y_true, tf.int32)
         y_pred = tf.cast(y_pred, tf.int32)
-        
-        print(y_true)
-        print(y_pred)
 
         f1 = f1_score(y_true, y_pred, average='macro')
         accuracy = accuracy_score(y_true, y_pred)
diff --git a/sparse_coding_torch/pnb/video_loader.py b/sparse_coding_torch/pnb/video_loader.py
index 9adc4a1539776aa1db58b22ddf060b0ac85aa191..10da89dfc49b27ad4871fce24aa56fbe1139d059 100644
--- a/sparse_coding_torch/pnb/video_loader.py
+++ b/sparse_coding_torch/pnb/video_loader.py
@@ -23,6 +23,11 @@ import csv
 import random
 import cv2
 from yolov4.get_bounding_boxes import YoloModel
+from skimage.transform import warp_polar
+from skimage.io import imsave
+
+from matplotlib import pyplot as plt
+from matplotlib import cm
 
 def get_participants(filenames):
     return [f.split('/')[-2] for f in filenames]
@@ -52,7 +57,7 @@ def get_yolo_regions(yolo_model, clip, is_right, needle_bb, crop_width, crop_hei
         if class_pred == 2:
             needle_bb = bb
     
-    rotate_box = True
+    rotate_box = False
     
     all_clips = []
     for bb, class_pred, score in zip(bounding_boxes, classes, scores):
@@ -61,60 +66,119 @@ def get_yolo_regions(yolo_model, clip, is_right, needle_bb, crop_width, crop_hei
         center_x = round((bb[3] + bb[1]) / 2 * orig_width)
         center_y = round((bb[2] + bb[0]) / 2 * orig_height)
         
-        lower_y = round((bb[0] * orig_height))
-        upper_y = round((bb[2] * orig_height))
-        lower_x = round((bb[1] * orig_width))
-        upper_x = round((bb[3] * orig_width))
+        if not is_right:
+            clip = tv.transforms.functional.hflip(clip)
+            center_x = orig_width - center_x
+            needle_bb[1] = orig_width - needle_bb[1]
+            needle_bb[3] = orig_width - needle_bb[3]
         
-        if is_right:
-            angle = calculate_angle(needle_bb, upper_x, center_y, orig_height, orig_width)
-        else:
-            angle = calculate_angle(needle_bb, lower_x, center_y, orig_height, orig_width)
+#         lower_y = round((bb[0] * orig_height))
+#         upper_y = round((bb[2] * orig_height))
+#         lower_x = round((bb[1] * orig_width))
+#         upper_x = round((bb[3] * orig_width))
         
-        lower_y = center_y - (crop_height // 2)
-        upper_y = center_y + (crop_height // 2) 
+#         if is_right:
+        angle = calculate_angle(needle_bb, center_x, center_y, orig_height, orig_width)
+#         else:
+#             angle = calculate_angle(needle_bb, lower_x, center_y, orig_height, orig_width)
         
-        if is_right:
-            lower_x = center_x - crop_width
-            upper_x = center_x
-        else:
-            lower_x = center_x
-            upper_x = center_x + crop_width
+#         lower_y = center_y - (crop_height // 2)
+#         upper_y = center_y + (crop_height // 2) 
+        
+#         if is_right:
+#             lower_x = center_x - crop_width
+#             upper_x = center_x
+#         else:
+#             lower_x = center_x
+#             upper_x = center_x + crop_width
             
-        if lower_x < 0:
-            lower_x = 0
-        if upper_x < 0:
-            upper_x = 0
-        if lower_y < 0:
-            lower_y = 0
-        if upper_y < 0:
-            upper_y = 0
+#         if lower_x < 0:
+#             lower_x = 0
+#         if upper_x < 0:
+#             upper_x = 0
+#         if lower_y < 0:
+#             lower_y = 0
+#         if upper_y < 0:
+#             upper_y = 0
+        clip = tv.transforms.functional.rotate(clip, angle=angle, center=[center_x, center_y])
             
-        if rotate_box:
-#             cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-            if is_right:
-                clip = tv.transforms.functional.rotate(clip, angle=angle, center=[upper_x, center_y])
-            else:
-#                 cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-                clip = tv.transforms.functional.rotate(clip, angle=-angle, center=[lower_x, center_y])
+#         plt.clf()
+#         plt.imshow(clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2), cmap=cm.Greys_r)
+#         # plt.scatter([214], [214], color="red")
+#         plt.scatter([center_x, int(needle_bb[1]*orig_width)], [center_y, int(needle_bb[0] * orig_height)], color=["red", 'red'])
+# #         cv2.imwrite('test_normal.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#         plt.savefig('test_normal.png')
+            
+#         if rotate_box:
+# #             cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#             if is_right:
+#         clip = tv.transforms.functional.rotate(clip, angle=angle, center=[center_x, center_y])
+#             else:
+# #                 cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#                 clip = tv.transforms.functional.rotate(clip, angle=-angle, center=[center_x, center_y])
 #                 cv2.imwrite('test_2.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
 
-        trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
+#         plt.imshow(clip[0, 0, :, :], cmap=cm.Greys_r)
+#         # plt.annotate('25, 50', xy=(25, 50), xycoords='data',
+#         #             xytext=(0.5, 0.5), textcoords='figure fraction',
+#         #             arrowprops=dict(arrowstyle="->"))
+#         plt.scatter([center_x], [center_y], color="red")
+#         plt.savefig('red_dot.png')
+#         clip = clip[:, :, :upper_y, :]
+
+        ro,col=clip[0, 0, :, :].shape
+        max_radius = int(np.sqrt(ro**2+col**2)/2)
+#         print(upper_y)
+#         print(bb[0])
+#         print(center_x)
+#         print(center_y)
+        trimmed_clip = []
+        for i in range(clip.shape[0]):
+            sub_clip = []
+            for j in range(clip.shape[1]):
+                sub_clip.append(cv2.linearPolar(clip[i, j, :, :].numpy(), (center_x, center_y), max_radius, cv2.WARP_FILL_OUTLIERS))
+#                 sub_clip.append(warp_polar(clip[i, j, :, :].numpy(), center=(center_x, center_y), radius=max_radius, preserve_range=True))
+            trimmed_clip.append(np.stack(sub_clip))
+        trimmed_clip = np.stack(trimmed_clip)
+        
+        approximate_needle_position = int(((angle+150)/360)*orig_height)
+        
+        plt.clf()
+        plt.imshow(trimmed_clip[:, 0, :, :].swapaxes(0,1).swapaxes(1,2), cmap=cm.Greys_r)
+        # plt.scatter([214], [214], color="red")
+#         plt.scatter([center_x], [approximate_needle_position], color=["red"])
+#         cv2.imwrite('test_normal.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+        plt.savefig('test_polar.png')
+        
+        trimmed_clip = trimmed_clip[:, :, approximate_needle_position - (crop_height//2):approximate_needle_position + (crop_height//2), :]
+                
+#         trimmed_clip=cv2.linearPolar(clip[0, 0, :, :].numpy(), (center_x, center_y), max_radius, cv2.WARP_FILL_OUTLIERS)
+#         trimmed_clip = warp_polar(clip[0, 0, :, :].numpy(), center=(center_x, center_y), radius=max_radius)
+
+#         trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
         
 #         if orig_width - center_x >= center_x:
 #         if not is_right:
 #         print(angle)
-#         cv2.imwrite('test_2{}.png'.format(lower_y), clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
-#         cv2.imwrite('test_3{}.png'.format(lower_y), trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#         if not is_right:
+#         cv2.imwrite('test_polar.png', trimmed_clip[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+        plt.clf()
+        plt.imshow(trimmed_clip[:, 0, :, :].swapaxes(0,1).swapaxes(1,2), cmap=cm.Greys_r)
+        # plt.scatter([214], [214], color="red")
+#         plt.scatter([center_x], [approximate_needle_position], color=["red"])
+#         cv2.imwrite('test_normal.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+        plt.savefig('test_polar_trim.png')
+#         raise Exception
 
-        if not is_right:
-            trimmed_clip = tv.transforms.functional.hflip(trimmed_clip)
+#         if not is_right:
+#             trimmed_clip = tv.transforms.functional.hflip(trimmed_clip)
+#             cv2.imwrite('test_polar.png', trimmed_clip)
 #             cv2.imwrite('test_yolo.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
 #             raise Exception
         
         if trimmed_clip.shape[2] == 0 or trimmed_clip.shape[3] == 0:
             continue
-        all_clips.append(trimmed_clip)
+        all_clips.append(torch.tensor(trimmed_clip))
 
     return all_clips
 
diff --git a/sparse_coding_torch/ptx/classifier_model.py b/sparse_coding_torch/ptx/classifier_model.py
index 9abe6cbb41ca5dec247822017814537241477348..bd2bde38e343ddeb83943921ec49b2b722c6f0a1 100644
--- a/sparse_coding_torch/ptx/classifier_model.py
+++ b/sparse_coding_torch/ptx/classifier_model.py
@@ -44,6 +44,20 @@ class PTXClassifier(keras.layers.Layer):
 
         return x
     
+class PTXVAEClassifier(keras.layers.Layer):
+    def __init__(self):
+        super(PTXVAEClassifier, self).__init__()
+
+        self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
+        self.ff_4 = keras.layers.Dense(1)
+
+#     @tf.function
+    def call(self, z):
+        x = self.ff_3(z)
+        x = self.ff_4(x)
+
+        return x
+    
 class Sampling(keras.layers.Layer):
     """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
 
@@ -55,16 +69,17 @@ class Sampling(keras.layers.Layer):
         return z_mean + tf.exp(0.5 * z_log_var) * epsilon
     
 class VAEEncoderPTX(keras.layers.Layer):
-    def __init__(self):
+    def __init__(self, latent_dim):
         super(VAEEncoderPTX, self).__init__()
 
-        self.conv_1 = keras.layers.Conv3D(64, kernel_size=(5, 8, 8), strides=(1, 4, 4), activation='relu', padding='valid')
-        self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
+        self.conv_1 = keras.layers.Conv3D(24, kernel_size=(5, 16, 16), strides=(1, 4, 4), activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv2D(36, kernel_size=8, strides=2, activation='relu', padding='valid')
+        self.conv_3 = keras.layers.Conv2D(48, kernel_size=4, strides=2, activation='relu', padding='valid')
 
         self.flatten = keras.layers.Flatten()
 
-        self.ff_mean = keras.layers.Dense(100, activation='relu', use_bias=True)
-        self.ff_var = keras.layers.Dense(100, activation='relu', use_bias=True)
+        self.ff_mean = keras.layers.Dense(latent_dim, activation='relu', use_bias=True)
+        self.ff_var = keras.layers.Dense(latent_dim, activation='relu', use_bias=True)
         
         self.sample = Sampling()
 
@@ -73,36 +88,36 @@ class VAEEncoderPTX(keras.layers.Layer):
         x = self.conv_1(images)
         x = tf.squeeze(x, axis=1)
         x = self.conv_2(x)
+        x = self.conv_3(x)
         x = self.flatten(x)
         z_mean = self.ff_mean(x)
         z_var = self.ff_var(x)
         z = self.sample([z_mean, z_var])
-        return z
+        return z, z_mean, z_var
     
 class VAEDecoderPTX(keras.layers.Layer):
     def __init__(self):
         super(VAEDecoderPTX, self).__init__()
-
-        self.conv_1 = keras.layers.Conv3D(64, kernel_size=(5, 8, 8), strides=(1, 4, 4), activation='relu', padding='valid')
-        self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
-
-        self.flatten = keras.layers.Flatten()
-
-        self.ff_mean = keras.layers.Dense(100, activation='relu', use_bias=True)
-        self.ff_var = keras.layers.Dense(100, activation='relu', use_bias=True)
         
-        self.sample = Sampling()
+        self.ff = keras.layers.Dense(3 * 9 * 48, activation='relu', use_bias=True)
+        self.reshape = keras.layers.Reshape((3, 9, 48))
+        
+        self.deconv_1 = keras.layers.Conv2DTranspose(36, kernel_size=4, strides=2, activation='relu', padding='valid')
+        self.deconv_2 = keras.layers.Conv2DTranspose(24, kernel_size=8, strides=2, activation='relu', padding='valid')
+        self.deconv_3 = keras.layers.Conv3DTranspose(1, kernel_size=(5, 16, 16), strides=(1, 4, 4), activation='relu', padding='valid')
+        
+        self.padding = keras.layers.ZeroPadding3D((0, 0, 2))
 
 #     @tf.function
     def call(self, images):
-        x = self.conv_1(images)
-        x = tf.squeeze(x, axis=1)
-        x = self.conv_2(x)
-        x = self.flatten(x)
-        z_mean = self.ff_mean(x)
-        z_var = self.ff_var(x)
-        z = self.sample([z_mean, z_var])
-        return z
+        x = self.ff(images)
+        x = self.reshape(x)
+        x = self.deconv_1(x)
+        x = self.deconv_2(x)
+        x = tf.expand_dims(x, axis=1)
+        x = self.deconv_3(x)
+        x = self.padding(x)
+        return x
 
 class MobileModelPTX(keras.Model):
     def __init__(self, sparse_weights, classifier_model, batch_size, image_height, image_width, clip_depth, out_channels, kernel_size, kernel_depth, stride, lam, activation_lr, max_activation_iter, run_2d):
diff --git a/sparse_coding_torch/ptx/load_data.py b/sparse_coding_torch/ptx/load_data.py
index e307a281f46f7d9a54cdf3649ee710451d1aaba8..088c18d0ef3bc0927d2ecea3bc834439e4b72863 100644
--- a/sparse_coding_torch/ptx/load_data.py
+++ b/sparse_coding_torch/ptx/load_data.py
@@ -3,7 +3,7 @@ import torchvision
 import torch
 from sklearn.model_selection import train_test_split
 from sparse_coding_torch.utils import MinMaxScaler, VideoGrayScaler
-from sparse_coding_torch.ptx.video_loader import YoloClipLoader, get_ptx_participants
+from sparse_coding_torch.ptx.video_loader import YoloClipLoader, get_ptx_participants, COVID19Loader
 import csv
 from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit
 
@@ -22,10 +22,8 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non
      torchvision.transforms.RandomHorizontalFlip(),
      torchvision.transforms.CenterCrop((100, 200))
     ])
-    if whole_video:
-        dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
-    else:
-        dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
+
+    dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
     
     targets = dataset.get_labels()
     
@@ -67,3 +65,52 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non
         
         return train_loader, test_loader, dataset
     
+def load_covid_clips(batch_size, yolo_model, mode, clip_height, clip_width, clip_depth, device=None, n_splits=None, classify_mode=False):   
+    video_path = "/home/dwh48@drexel.edu/covid19_ultrasound/data/pocus_videos"
+    
+    transforms = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     MinMaxScaler(0, 255),
+    ])
+    augment_transforms = torchvision.transforms.Compose(
+    [torchvision.transforms.RandomRotation(45),
+     torchvision.transforms.RandomHorizontalFlip(),
+     torchvision.transforms.Resize((clip_height, clip_width))
+    ])
+
+    dataset = COVID19Loader(yolo_model, video_path, clip_depth, classify_mode=classify_mode, transform=transforms, augmentation=augment_transforms)
+    
+    targets = dataset.get_labels()
+    
+    if mode == 'leave_one_out':
+        gss = LeaveOneGroupOut()
+
+        groups = [v for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    elif mode == 'all_train':
+        train_idx = np.arange(len(targets))
+        
+        return [(train_idx, None)], dataset
+    elif mode == 'k_fold':
+        gss = StratifiedGroupKFold(n_splits=n_splits)
+
+        groups = [v for v in dataset.get_filenames()]
+        
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
+    else:
+        gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
+
+        groups = [v for v in dataset.get_filenames()]
+        
+        train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0]
+        
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        
+        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+        test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=test_sampler)
+        
+        return train_loader, test_loader, dataset
\ No newline at end of file
diff --git a/sparse_coding_torch/ptx/train_classifier.py b/sparse_coding_torch/ptx/train_classifier.py
index 8e3c5c59d1db86c6f0b77e42ae5a21c72d39bcca..4152869aed14b87275b95667a876b9251e54e0db 100644
--- a/sparse_coding_torch/ptx/train_classifier.py
+++ b/sparse_coding_torch/ptx/train_classifier.py
@@ -6,7 +6,7 @@ import argparse
 import os
 from sparse_coding_torch.ptx.load_data import load_yolo_clips
 from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d
-from sparse_coding_torch.ptx.classifier_model import PTXClassifier
+from sparse_coding_torch.ptx.classifier_model import PTXClassifier, VAEEncoderPTX, PTXVAEClassifier
 import time
 import numpy as np
 from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
diff --git a/sparse_coding_torch/ptx/train_classifier_vae.py b/sparse_coding_torch/ptx/train_classifier_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d884b8607bd09d481d0960263f84718fcc915d7
--- /dev/null
+++ b/sparse_coding_torch/ptx/train_classifier_vae.py
@@ -0,0 +1,356 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.ptx.load_data import load_yolo_clips
+from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d
+from sparse_coding_torch.ptx.classifier_model import PTXClassifier, VAEEncoderPTX, PTXVAEClassifier
+import time
+import numpy as np
+from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
+import random
+import pickle
+import tensorflow.keras as keras
+import tensorflow as tf
+from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler
+from yolov4.get_bounding_boxes import YoloModel
+import torchvision
+import glob
+from torchvision.datasets.video_utils import VideoClips
+import cv2
+
+configproto = tf.compat.v1.ConfigProto()
+configproto.gpu_options.polling_inactive_delay_msecs = 5000
+configproto.gpu_options.allow_growth = True
+sess = tf.compat.v1.Session(config=configproto) 
+tf.compat.v1.keras.backend.set_session(sess)
+
+
+def calculate_ptx_scores(input_videos, labels, yolo_model, encoder_model, classifier_model, image_width, image_height, transform):
+    all_predictions = []
+    
+    numerical_labels = []
+    for label in labels:
+        if label == 'No_Sliding':
+            numerical_labels.append(1.0)
+        else:
+            numerical_labels.append(0.0)
+
+    final_list = []
+    clip_correct = []
+    fp_ids = []
+    fn_ids = []
+    for v_idx, f in tqdm(enumerate(input_videos)):
+        clipstride = 15
+        
+        vc = VideoClips([f],
+                        clip_length_in_frames=5,
+                        frame_rate=20,
+                       frames_between_clips=clipstride)
+
+        clip_predictions = []
+        i = 0
+        cliplist = []
+        countclips = 0
+        for i in range(vc.num_clips()):
+            clip, _, _, _ = vc.get_clip(i)
+            clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
+            
+            bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1))
+            bounding_boxes = bounding_boxes.squeeze(0)
+            if bounding_boxes.size == 0:
+                continue
+            #widths = []
+            countclips = countclips + len(bounding_boxes)
+            
+            widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
+
+            ind =  np.argmax(np.array(widths))
+
+            bb = bounding_boxes[ind]
+            center_x = (bb[3] + bb[1]) / 2 * 1920
+            center_y = (bb[2] + bb[0]) / 2 * 1080
+
+            width=400
+            height=400
+
+            lower_y = round(center_y - height / 2)
+            upper_y = round(center_y + height / 2)
+            lower_x = round(center_x - width / 2)
+            upper_x = round(center_x + width / 2)
+
+            trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
+
+            trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)
+
+            trimmed_clip = transform(trimmed_clip)
+            trimmed_clip.pin_memory()
+            cliplist.append(trimmed_clip)
+
+        if len(cliplist) > 0:
+            with torch.no_grad():
+                trimmed_clip = torch.stack(cliplist)
+                images = trimmed_clip.permute(0, 2, 3, 4, 1).numpy()
+                z, _, _ = tf.stop_gradient(encoder_model(images))
+
+                pred = classifier_model(z)
+
+                clip_predictions = tf.math.round(tf.math.sigmoid(pred))
+
+            final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
+            if len(clip_predictions) % 2 == 0 and tf.math.reduce_sum(clip_predictions) == len(clip_predictions)//2:
+                #print("I'm here")
+                final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
+        else:
+            final_pred = 1.0
+            
+        if final_pred != numerical_labels[v_idx]:
+            if final_pred == 0.0:
+                fn_ids.append(f)
+            else:
+                fp_ids.append(f)
+            
+        final_list.append(final_pred)
+        
+        clip_correct.extend([1 if clip_pred == numerical_labels[v_idx] else 0 for clip_pred in clip_predictions])
+        
+    return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids, sum(clip_correct) / len(clip_correct)
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=12, type=int)
+    parser.add_argument('--lr', default=5e-4, type=float)
+    parser.add_argument('--epochs', default=40, type=int)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--vae_checkpoint', default=None, type=str)
+    parser.add_argument('--checkpoint', default=None, type=str)
+    parser.add_argument('--splits', default=None, type=str, help='k_fold or leave_one_out or all_train')
+    parser.add_argument('--seed', default=26, type=int)
+    parser.add_argument('--train', action='store_true')
+    parser.add_argument('--num_positives', default=100, type=int)
+    parser.add_argument('--n_splits', default=5, type=int)
+    parser.add_argument('--save_train_test_splits', action='store_true')
+    parser.add_argument('--balance_classes', action='store_true')
+    parser.add_argument('--crop_height', type=int, default=100)
+    parser.add_argument('--crop_width', type=int, default=200)
+    parser.add_argument('--scale_factor', type=int, default=1)
+    parser.add_argument('--clip_depth', type=int, default=5)
+    parser.add_argument('--frames_to_skip', type=int, default=1)
+    parser.add_argument('--latent_dim', type=int, default=1000)
+    
+    args = parser.parse_args()
+    
+    image_height = 100
+    image_width = 200
+    clip_depth = args.clip_depth
+        
+    batch_size = args.batch_size
+    
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+        
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+    
+    yolo_model = YoloModel('ptx')
+
+    all_errors = []
+    
+    encoder_inputs = keras.Input(shape=(5, image_height, image_width, 1))
+        
+    encoder_outputs = VAEEncoderPTX(args.latent_dim)(encoder_inputs)
+    
+    encoder_model = keras.Model(inputs=encoder_inputs, outputs=encoder_outputs)
+
+    if args.vae_checkpoint:
+        encoder_model.set_weights(keras.models.load_model(args.vae_checkpoint).get_weights())
+        
+    splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+    positive_class = 'No_Sliding'
+
+    overall_true = []
+    overall_pred = []
+    fn_ids = []
+    fp_ids = []
+    
+    i_fold = 0
+    for train_idx, test_idx in splits:
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        
+        if test_idx is not None:
+            test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+            test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                                   sampler=test_sampler)
+        else:
+            test_loader = None
+        
+        if args.checkpoint:
+            classifier_model = keras.models.load_model(args.checkpoint)
+        else:
+            classifier_inputs = keras.Input(shape=(args.latent_dim))
+            classifier_outputs = PTXVAEClassifier()(classifier_inputs)
+
+            classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
+
+        prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)
+
+        best_so_far = float('-inf')
+
+        criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM)
+
+        if args.train:
+            for epoch in range(args.epochs):
+                epoch_loss = 0
+                t1 = time.perf_counter()
+
+                y_true_train = None
+                y_pred_train = None
+
+                for labels, local_batch, vid_f in tqdm(train_loader):
+                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                    z, _, _ = tf.stop_gradient(encoder_model(images))
+
+                    with tf.GradientTape() as tape:
+                        pred = classifier_model(z)
+                        loss = criterion(torch_labels, pred)
+
+                    epoch_loss += loss * local_batch.size(0)
+
+                    gradients = tape.gradient(loss, classifier_model.trainable_weights)
+
+                    prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))
+
+                    if y_true_train is None:
+                        y_true_train = torch_labels
+                        y_pred_train = tf.math.round(tf.math.sigmoid(pred))
+                    else:
+                        y_true_train = tf.concat((y_true_train, torch_labels), axis=0)
+                        y_pred_train = tf.concat((y_pred_train, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+                t2 = time.perf_counter()
+
+                y_true = None
+                y_pred = None
+                test_loss = 0.0
+                
+                eval_loader = test_loader
+                if args.splits == 'all_train':
+                    eval_loader = train_loader
+                for labels, local_batch, vid_f in tqdm(eval_loader):
+                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+
+                    torch_labels = np.zeros(len(labels))
+                    torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                    torch_labels = np.expand_dims(torch_labels, axis=1)
+                    
+                    z, _, _ = tf.stop_gradient(encoder_model(images))
+
+                    pred = classifier_model(z)
+                    loss = criterion(torch_labels, pred)
+
+                    test_loss += loss
+
+                    if y_true is None:
+                        y_true = torch_labels
+                        y_pred = tf.math.round(tf.math.sigmoid(pred))
+                    else:
+                        y_true = tf.concat((y_true, torch_labels), axis=0)
+                        y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+                t2 = time.perf_counter()
+
+                y_true = tf.cast(y_true, tf.int32)
+                y_pred = tf.cast(y_pred, tf.int32)
+
+                y_true_train = tf.cast(y_true_train, tf.int32)
+                y_pred_train = tf.cast(y_pred_train, tf.int32)
+
+                f1 = f1_score(y_true, y_pred, average='macro')
+                accuracy = accuracy_score(y_true, y_pred)
+
+                train_accuracy = accuracy_score(y_true_train, y_pred_train)
+
+                print('epoch={}, i_fold={}, time={:.2f}, train_loss={:.2f}, test_loss={:.2f}, train_acc={:.2f}, test_f1={:.2f}, test_acc={:.2f}'.format(epoch, i_fold, t2-t1, epoch_loss, test_loss, train_accuracy, f1, accuracy))
+    #             print(epoch_loss)
+                if f1 >= best_so_far:
+                    print("found better model")
+                    # Save model parameters
+                    classifier_model.save(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
+#                     recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold)))
+                    pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer_{}.pt'.format(i_fold)), 'wb+'))
+                    best_so_far = f1
+
+            classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold)))
+#             recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold)))
+
+        epoch_loss = 0
+
+        y_true = None
+        y_pred = None
+
+        pred_dict = {}
+        gt_dict = {}
+
+        t1 = time.perf_counter()
+        
+        transform = torchvision.transforms.Compose(
+        [VideoGrayScaler(),
+         MinMaxScaler(0, 255),
+         torchvision.transforms.Normalize((0.2592,), (0.1251,)),
+         torchvision.transforms.CenterCrop((100, 200))
+        ])
+
+        test_dir = '/shared_data/bamc_ph1_test_data'
+        test_videos = glob.glob(os.path.join(test_dir, '*', '*.*'))
+        test_labels = [vid_f.split('/')[-2] for vid_f in test_videos]
+
+        y_pred, y_true, fn, fp, clip_acc = calculate_ptx_scores(test_videos, test_labels, yolo_model, encoder_model, classifier_model, image_width, image_height, transform)
+            
+        t2 = time.perf_counter()
+
+        print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1))
+
+        y_true = tf.cast(y_true, tf.int32)
+        y_pred = tf.cast(y_pred, tf.int32)
+
+        f1 = f1_score(y_true, y_pred, average='macro')
+        accuracy = accuracy_score(y_true, y_pred)
+
+        fn_ids.extend(fn)
+        fp_ids.extend(fp)
+
+        overall_true.extend(y_true)
+        overall_pred.extend(y_pred)
+
+        print("Test f1={:.2f}, vid_acc={:.2f}, clip_acc={:.2f}".format(f1, accuracy, clip_acc))
+
+        print(confusion_matrix(y_true, y_pred))
+            
+        i_fold += 1
+
+    fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
+    with open(fp_fn_file, 'w+') as in_f:
+        in_f.write('FP:\n')
+        in_f.write(str(fp_ids) + '\n\n')
+        in_f.write('FN:\n')
+        in_f.write(str(fn_ids) + '\n\n')
+        
+    overall_true = np.array(overall_true)
+    overall_pred = np.array(overall_pred)
+            
+    final_f1 = f1_score(overall_true, overall_pred, average='macro')
+    final_acc = accuracy_score(overall_true, overall_pred)
+    final_conf = confusion_matrix(overall_true, overall_pred)
+            
+    print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1))
+    print(final_conf)
+
diff --git a/sparse_coding_torch/ptx/train_sparse_model.py b/sparse_coding_torch/ptx/train_sparse_model.py
index 43c03df937acf96291b000eb75edbf6adf4c1a88..042c55b92d2de3ada8a610d02fe478d343efce6a 100644
--- a/sparse_coding_torch/ptx/train_sparse_model.py
+++ b/sparse_coding_torch/ptx/train_sparse_model.py
@@ -7,7 +7,7 @@ from matplotlib.animation import FuncAnimation
 from tqdm import tqdm
 import argparse
 import os
-from sparse_coding_torch.ptx.load_data import load_yolo_clips
+from sparse_coding_torch.ptx.load_data import load_yolo_clips, load_covid_clips
 import tensorflow.keras as keras
 import tensorflow as tf
 from sparse_coding_torch.sparse_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse
@@ -110,7 +110,7 @@ if __name__ == "__main__":
     parser.add_argument('--run_2d', action='store_true')
     parser.add_argument('--save_filters', action='store_true')
     parser.add_argument('--optimizer', default='sgd', type=str)
-    parser.add_argument('--dataset', default='onsd', type=str)
+    parser.add_argument('--dataset', default='ptx', type=str)
     parser.add_argument('--crop_height', type=int, default=400)
     parser.add_argument('--crop_width', type=int, default=400)
     parser.add_argument('--scale_factor', type=int, default=1)
@@ -139,8 +139,11 @@ if __name__ == "__main__":
 
     with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
         out_f.write(str(args))
-        
-    splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+    
+    if args.dataset == 'ptx':
+        splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+    elif args.dataset == 'covid':
+        splits, dataset = load_covid_clips(batch_size=args.batch_size, yolo_model=None, mode='all_train', clip_height=crop_height, clip_width=crop_width, clip_depth=clip_depth, device=device, n_splits=1, classify_mode=False)
     train_idx, test_idx = splits[0]
     
     train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
@@ -192,8 +195,6 @@ if __name__ == "__main__":
         num_iters = 0
 
         for labels, local_batch, vid_f in tqdm(train_loader):
-            if local_batch.size(0) != args.batch_size:
-                continue
             if args.run_2d:
                 images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
             else:
diff --git a/sparse_coding_torch/ptx/train_vae.py b/sparse_coding_torch/ptx/train_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..3abe8ade70e5897af9d0a05040784aa2f3b64e46
--- /dev/null
+++ b/sparse_coding_torch/ptx/train_vae.py
@@ -0,0 +1,129 @@
+import time
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from matplotlib import cm
+from matplotlib.animation import FuncAnimation
+from tqdm import tqdm
+import argparse
+import os
+from sparse_coding_torch.ptx.load_data import load_yolo_clips
+import tensorflow.keras as keras
+import tensorflow as tf
+from sparse_coding_torch.ptx.classifier_model import VAEEncoderPTX, VAEDecoderPTX
+import random
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=32, type=int)
+    parser.add_argument('--lr', default=0.003, type=float)
+    parser.add_argument('--epochs', default=50, type=int)
+    parser.add_argument('--output_dir', default='./output', type=str)
+    parser.add_argument('--seed', default=42, type=int)
+    parser.add_argument('--optimizer', default='adam', type=str)
+    parser.add_argument('--dataset', default='ptx', type=str)
+    parser.add_argument('--crop_height', type=int, default=100)
+    parser.add_argument('--crop_width', type=int, default=200)
+    parser.add_argument('--scale_factor', type=int, default=1)
+    parser.add_argument('--clip_depth', type=int, default=5)
+    parser.add_argument('--frames_to_skip', type=int, default=1)
+    parser.add_argument('--latent_dim', type=int, default=1000)
+    
+
+    args = parser.parse_args()
+    
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+
+    crop_height = args.crop_height
+    crop_width = args.crop_width
+
+    image_height = int(crop_height / args.scale_factor)
+    image_width = int(crop_width / args.scale_factor)
+    clip_depth = args.clip_depth
+
+    output_dir = args.output_dir
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+        
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
+        out_f.write(str(args))
+        
+    splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+    train_idx, test_idx = splits[0]
+    
+    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
+                                           sampler=train_sampler)
+    
+    print('Loaded', len(train_loader), 'train examples')
+
+    example_data = next(iter(train_loader))
+
+    encoder_inputs = keras.Input(shape=(5, image_height, image_width, 1))
+        
+    encoder_outputs = VAEEncoderPTX(args.latent_dim)(encoder_inputs)
+    
+    encoder_model = keras.Model(inputs=encoder_inputs, outputs=encoder_outputs)
+    
+    decoder_inputs = keras.Input(shape=(args.latent_dim))
+    
+    decoder_outputs = VAEDecoderPTX()(decoder_inputs)
+    
+    decoder_model = keras.Model(inputs=decoder_inputs, outputs=decoder_outputs)
+
+    optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
+
+    loss_log = []
+    best_so_far = float('inf')
+
+    for epoch in range(args.epochs):
+        epoch_loss = 0
+        running_loss = 0.0
+        epoch_start = time.perf_counter()
+        
+        num_iters = 0
+
+        for labels, local_batch, vid_f in tqdm(train_loader):
+            images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+            
+            with tf.GradientTape() as tape:
+                z, z_mean, z_var = encoder_model(images)
+                recon = decoder_model(z)
+                reconstruction_loss = tf.reduce_mean(
+                tf.reduce_sum(
+                    keras.losses.binary_crossentropy(images, recon), axis=(1, 2)
+                )
+                )
+                kl_loss = -0.5 * (1 + z_var - tf.square(z_mean) - tf.exp(z_var))
+                kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
+                loss = reconstruction_loss + kl_loss
+
+            epoch_loss += loss * local_batch.size(0)
+            running_loss += loss * local_batch.size(0)
+
+            gradients = tape.gradient(loss, encoder_model.trainable_weights + decoder_model.trainable_weights)
+
+            optimizer.apply_gradients(zip(gradients, encoder_model.trainable_weights + decoder_model.trainable_weights))
+                
+            num_iters += 1
+
+        epoch_end = time.perf_counter()
+        epoch_loss /= len(train_loader.sampler)
+
+        if epoch_loss < best_so_far:
+            print("found better model")
+            # Save model parameters
+            encoder_model.save(os.path.join(output_dir, "best_encoder.pt"))
+            decoder_model.save(os.path.join(output_dir, "best_decoder.pt"))
+            best_so_far = epoch_loss
+
+        loss_log.append(epoch_loss)
+        print('epoch={}, epoch_loss={:.2f}, time={:.2f}'.format(epoch, epoch_loss, epoch_end - epoch_start))
+
+    plt.plot(loss_log)
+
+    plt.savefig(os.path.join(output_dir, 'loss_graph.png'))
diff --git a/sparse_coding_torch/ptx/video_loader.py b/sparse_coding_torch/ptx/video_loader.py
index 93ad37eb5b9dd71fd64986d9984a3202e7e1db2d..63fc7b3a204406e3d4311a8863b79825b39fdf37 100644
--- a/sparse_coding_torch/ptx/video_loader.py
+++ b/sparse_coding_torch/ptx/video_loader.py
@@ -261,3 +261,87 @@ class YoloClipLoader(Dataset):
         
     def __len__(self):
         return len(self.clips)
+    
+def get_yolo_regions(yolo_model, clip):
+    orig_height = clip.size(2)
+    orig_width = clip.size(3)
+    bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy())
+    bounding_boxes = bounding_boxes.squeeze(0)
+    classes = classes.squeeze(0)
+    scores = scores.squeeze(0)
+    
+    all_clips = []
+    for bb, class_pred, score in zip(bounding_boxes, classes, scores):
+        lower_y = round((bb[0] * orig_height))
+        upper_y = round((bb[2] * orig_height))
+        lower_x = round((bb[1] * orig_width))
+        upper_x = round((bb[3] * orig_width))
+
+        trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
+        
+        if trimmed_clip.shape[2] == 0 or trimmed_clip.shape[3] == 0:
+            continue
+        all_clips.append(torch.tensor(trimmed_clip))
+
+    return all_clips
+
+class COVID19Loader(Dataset):
+    def __init__(self, yolo_model, video_path, clip_depth, classify_mode=False, transform=None, augmentation=None):
+        self.transform = transform
+        self.augmentation = augmentation
+        
+        self.videos = glob.glob(join(video_path, '*', '*.*'))
+        
+        vid_to_label = {}
+        with open('/home/dwh48@drexel.edu/covid19_ultrasound/data/dataset_metadata.csv') as csv_in:
+            reader = csv.DictReader(csv_in)
+            for row in reader:
+                vid_to_label[row['Filename']] = row['Label']
+            
+        self.clips = []
+        
+        vid_idx = 0
+        for path in tqdm(self.videos):
+            vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
+            label = vid_to_label[path.split('/')[-1].split('.')[0]]
+            
+            if classify_mode:
+                for j in range(0, vc.size(1) - clip_depth, clip_depth):
+                    vc_sub = vc[:, j:j+clip_depth, :, :]
+                    if vc_sub.size(1) < clip_depth:
+                        continue
+                    for clip in get_yolo_regions(yolo_model, vc_sub):
+                        if self.transform:
+                            clip = self.transform(clip)
+
+                        self.clips.append((label, clip, path))
+            else:
+                for j in range(0, vc.size(1) - clip_depth, clip_depth):
+                    vc_sub = vc[:, j:j+clip_depth, :, :]
+                    if vc_sub.size(1) != clip_depth:
+                        continue
+                    if self.transform:
+                        vc_sub = self.transform(vc_sub)
+
+                    self.clips.append((label, vc_sub, path))
+
+            vid_idx += 1
+        
+        random.shuffle(self.clips)
+        
+    def get_filenames(self):
+        return [self.clips[i][2] for i in range(len(self.clips))]
+        
+    def get_labels(self):
+        return [self.clips[i][0] for i in range(len(self.clips))]
+    
+    def __getitem__(self, index):
+        label, clip, vid_f = self.clips[index]
+        if self.augmentation:
+            clip = clip.swapaxes(0, 1)
+            clip = self.augmentation(clip)
+            clip = clip.swapaxes(0, 1)
+        return (label, clip, vid_f)
+        
+    def __len__(self):
+        return len(self.clips)
\ No newline at end of file
diff --git a/sparse_coding_torch/utils.py b/sparse_coding_torch/utils.py
index 62e67d34f3d586efff1851e41bb84aad01dc30f8..e2398c3646eaa6afabdfc9ae1cd76761fb0432f6 100644
--- a/sparse_coding_torch/utils.py
+++ b/sparse_coding_torch/utils.py
@@ -6,6 +6,9 @@ from tqdm import tqdm
 from torchvision.datasets.video_utils import VideoClips
 from typing import Sequence, Iterator
 import torch.nn as nn
+from matplotlib import pyplot as plt
+from matplotlib import cm
+from matplotlib.animation import FuncAnimation
 
 def get_sample_weights(train_idx, dataset):
     dataset = list(dataset)
@@ -146,4 +149,28 @@ def plot_filters(filters):
             c = i % ncol
             ims[(r, c)].set_data(filters[t, :, :, 0, i])
 
-    return FuncAnimation(plt.gcf(), update, interval=1000/20)
\ No newline at end of file
+    return FuncAnimation(plt.gcf(), update, interval=1000/20)
+
+def plot_filters_image(filters):
+    filters = filters.astype('float32')
+    num_filters = filters.shape[4]
+    ncol = 12
+    T = filters.shape[0]
+
+    if num_filters // ncol == num_filters / ncol:
+        nrow = num_filters // ncol
+    else:
+        nrow = num_filters // ncol + 1
+
+    fig, axes = plt.subplots(ncols=ncol, nrows=nrow,
+                             constrained_layout=True,
+                             figsize=(ncol*2, nrow*2))
+
+    ims = {}
+    for i in range(num_filters):
+        r = i // ncol
+        c = i % ncol
+        ims[(r, c)] = axes[r, c].imshow(filters[0, :, :, 0, i],
+                                        cmap=cm.Greys_r)
+
+    return plt.gcf()
\ No newline at end of file
diff --git a/yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite b/yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite
index cad9fc040b9d61ae4eb1ed18854ad48460e5a59d..a1e1bb296c6bd6d1d5c8df4c3f1cb8b5ebca6fc4 100644
Binary files a/yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite and b/yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite differ