diff --git a/sparse_coding_torch/onsd/train_nerve_slice_all_vids.py b/sparse_coding_torch/onsd/train_nerve_slice_all_vids.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ea7696e401c76982868c5810e589c14ea6bc9c7
--- /dev/null
+++ b/sparse_coding_torch/onsd/train_nerve_slice_all_vids.py
@@ -0,0 +1,707 @@
+import os
+import tensorflow as tf
+from tensorflow.keras import layers
+import tensorflow.keras as keras
+import random
+import numpy as np
+import cv2
+import glob
+from IPython.display import Image, display
+from tensorflow.keras.preprocessing.image import load_img
+from tensorflow.keras.utils import save_img
+from PIL import ImageOps
+from matplotlib.pyplot import imshow
+from matplotlib import pyplot as plt
+from matplotlib import cm
+from unet_models import get_model
+from sklearn.model_selection import LeaveOneOut
+from sklearn.metrics import accuracy_score
+from keras_unet_collection.models import unet_2d, transunet_2d, u2net_2d, att_unet_2d, unet_plus_2d, r2_unet_2d, resunet_a_2d
+from yolov4.get_bounding_boxes import YoloModel
+import torchvision as tv
+from sparse_coding_torch.onsd.video_loader import get_yolo_region_onsd
+import torch
+import math
+import csv
+
+from scipy.ndimage import rotate, shift
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+import absl.logging
+absl.logging.set_verbosity(absl.logging.ERROR)
+
+def retrieve_outliers(file_loc, fraction_to_discard, pos_neg_cutoff):
+    all_rows = []
+    with open(file_loc, 'r') as csv_f:
+        reader = csv.DictReader(csv_f)
+        
+        for row in reader:
+            all_rows.append((row['file'], int(row['width']), row['label']))
+            
+    all_rows = [row for row in all_rows if (row[1] >= pos_neg_cutoff and row[2] == 'Negatives') or (row[1] < pos_neg_cutoff and row[2] == 'Positives')]
+            
+    all_rows_asc = sorted(all_rows, key=lambda tup: tup[1])
+    all_rows_asc = [tup[0] for tup in all_rows_asc]
+    
+    fraction_to_discard = fraction_to_discard // 2
+    
+    num_to_discard = int(len(all_rows_asc) * fraction_to_discard)
+    
+    top = all_rows_asc[-num_to_discard:]
+    bottom = all_rows_asc[:num_to_discard]
+    
+    return top + bottom
+
+def load_videos(input_dir, nerve_dir, pos_neg_cutoff, fraction_to_discard=None):
+    nerve_img_paths = sorted(
+    [
+        fname
+        for fname in os.listdir(nerve_dir)
+        if fname.endswith(".png") and not fname.startswith(".")
+    ]
+    )
+    
+    if fraction_to_discard is not None:
+        nerve_img_paths = [fname for fname in nerve_img_paths if fname not in retrieve_outliers('segmentation/more_slices/outliers.csv', fraction_to_discard, pos_neg_cutoff)]
+
+    input_img_paths = sorted(
+        [
+            os.path.join(input_dir, fname)
+            for fname in os.listdir(input_dir)
+            if fname.endswith(".png") and not fname.startswith(".") and fname in nerve_img_paths
+        ]
+    )
+    
+    nerve_img_paths = [os.path.join(nerve_dir, fname) for fname in nerve_img_paths]
+
+    assert len(input_img_paths) == len(nerve_img_paths)
+
+    print("Number of training samples:", len(input_img_paths))
+
+    input_data = []
+    for input_path, nerve_path in zip(input_img_paths, nerve_img_paths):
+        input_data.append((input_path, nerve_path))
+        
+    return input_data
+
+def get_videos(input_participant):
+    all_vids = glob.glob(os.path.join(video_path, '*', '*', '*.mp4'))
+    
+    out_vids = []
+    
+    for vid in all_vids:
+        vid_name = vid.split('/')[-1][:-4]
+        participant = vid.split('/')[-2]
+        txt_label = vid.split('/')[-3]
+        
+        if input_participant == participant:
+            out_vids.append(vid)
+            
+    return out_vids
+
+def get_test_videos():
+    all_vids = glob.glob(os.path.join('/shared_data/bamc_onsd_test_data', '*', '*.mp4'))
+    
+    out_vids = []
+    out_lbls = []
+    
+    for vid in all_vids:
+        vid_name = vid.split('/')[-1][:-4]
+        txt_label = vid.split('/')[-2]
+
+        out_vids.append(vid)
+        out_lbls.append(txt_label)
+            
+    return out_vids, out_lbls
+
+def get_participants(video_path, input_data):
+    all_vids = glob.glob(os.path.join(video_path, '*', '*', '*.mp4'))
+
+    participant_to_data = {}
+
+    for vid in all_vids:
+        vid_name = vid.split('/')[-1][:-4]
+        participant = vid.split('/')[-2]
+        txt_label = vid.split('/')[-3]
+
+        for frame in input_data:
+            frame_name = frame[0].split('/')[-1].split(' ')[1][:-2]
+
+            if frame_name != participant:
+                continue
+
+            if not participant in participant_to_data:
+                participant_to_data[participant] = []
+
+            participant_to_data[participant].append((frame[0], frame[1], txt_label))
+
+    print('{} participants.'.format(len(participant_to_data)))
+    
+    return participant_to_data
+
+def create_splits(participant_to_data):
+    participants = list(participant_to_data.keys())
+
+    random.shuffle(participants)
+    
+    gss = LeaveOneOut()
+
+    splits = gss.split(participants)
+    
+    return splits, participants
+
+def make_numpy_arrays(split_participants, participant_to_data, img_size):
+    all_x = []
+    all_yolo = []
+    all_eye = []
+    all_nerve = []
+    all_txt = []
+    
+    for participant in split_participants:
+        for x, nerve, txt_label in participant_to_data[participant]:
+            yolo = cv2.resize(cv2.imread(x), (img_size[1], img_size[0]))
+            x = cv2.resize(cv2.imread(x, cv2.IMREAD_GRAYSCALE), (img_size[1], img_size[0]))
+            nerve = cv2.resize(cv2.imread(nerve, cv2.IMREAD_GRAYSCALE), (img_size[1], img_size[0]))
+            for i in range(nerve.shape[0]):
+                for j in range(nerve.shape[1]):
+                    if nerve[i, j] == 255:
+                        nerve[i, j] = 1.0
+                    else:
+                        nerve[i, j] = 0.0
+
+            all_x.append(x)
+            all_yolo.append(yolo)
+            all_nerve.append(nerve)
+            all_txt.append(txt_label)
+            
+    return np.expand_dims(np.stack(all_x), axis=-1), np.stack(all_yolo), np.stack(all_nerve), all_txt
+
+def display_mask_test(model, input_mask):
+    """Quick utility to display a model's prediction."""
+    test_pred = model.predict(np.expand_dims(np.expand_dims(input_mask, axis=0), axis=-1), verbose=False)[0]
+    mask = np.argmax(test_pred, axis=-1)
+    mask = np.expand_dims(mask, axis=-1) * 255
+    
+    return mask
+
+
+def get_obj_coordinates_yolo(yolo_model, frame):
+    orig_height = frame.shape[0]
+    orig_width = frame.shape[1]
+    
+    bounding_boxes, classes, scores = yolo_model.get_bounding_boxes_v5(frame)
+    
+    eye_bounding_box = (None, 0.0)
+    nerve_bounding_box = (None, 0.0)
+    
+    for bb, class_pred, score in zip(bounding_boxes, classes, scores):
+        if class_pred == 0 and score > nerve_bounding_box[1]:
+            nerve_bounding_box = (bb, score)
+        elif class_pred == 1 and score > eye_bounding_box[1]:
+            eye_bounding_box = (bb, score)
+    
+    eye_bounding_box = eye_bounding_box[0]
+    nerve_bounding_box = nerve_bounding_box[0]
+    
+    if eye_bounding_box is None or nerve_bounding_box is None:
+        return None, None, None, None, None
+    
+    nerve_center_x = round((nerve_bounding_box[2] + nerve_bounding_box[0]) / 2 * orig_width)
+    nerve_center_y = round((nerve_bounding_box[3] + nerve_bounding_box[1]) / 2 * orig_height)
+    
+    eye_center_x = round((eye_bounding_box[2] + eye_bounding_box[0]) / 2 * orig_width)
+    eye_center_y = round((eye_bounding_box[3] + eye_bounding_box[1]) / 2 * orig_height)
+    
+    dist_to_bottom = round((eye_bounding_box[3] * orig_height) - eye_center_y)
+            
+    return nerve_center_x, nerve_center_y, eye_center_x, eye_center_y, dist_to_bottom
+
+def get_line(yolo_model, input_frame):
+    nerve_center_x, nerve_center_y, eye_center_x, eye_center_y, dist_to_bottom = get_obj_coordinates_yolo(yolo_model, input_frame)
+    
+    if dist_to_bottom is None:
+        return None, None, (None, None), (None, None), None
+    
+    if eye_center_x != nerve_center_x:
+        m = (eye_center_y - nerve_center_y) / (eye_center_x - nerve_center_x)
+    else:
+        return None, None, (eye_center_x, eye_center_y), (nerve_center_x, nerve_center_y), dist_to_bottom
+    
+    b = (-1 * (m * nerve_center_x)) + nerve_center_y
+    
+    assert eye_center_y == round(m*eye_center_x + b, 2)
+    assert nerve_center_y == round(m*nerve_center_x + b, 2)
+    
+    return m, b, (eye_center_x, eye_center_y), (nerve_center_x, nerve_center_y), dist_to_bottom
+
+def get_nerve_slice(yolo_model, input_frames, yolo_frames, nerve_frames, nerve_size):
+    all_nerve_slices = []
+    all_nerve_masks = []
+    
+    for input_frame, yolo_frame, nerve_frame in zip(input_frames, yolo_frames, nerve_frames):
+        m, b, (eye_center_x, eye_center_y), (nerve_center_x, nerve_center_y), dist_to_bottom = get_line(yolo_model, yolo_frame)
+
+        target_length = (65/1080)*input_frame.shape[0] + dist_to_bottom
+
+        if m is not None and b is not None:
+            nerve_start = eye_center_y
+            for i in range(round(eye_center_y) + 1, input_frame.shape[0]):
+                x_val = round((i - b) / m)
+                distance_from_eye = math.sqrt((x_val - eye_center_x)**2 + (i - eye_center_y)**2)
+                if distance_from_eye > target_length:
+                    nerve_start = i
+                    break
+
+            nerve_measure_y = nerve_start
+            nerve_measure_x = (nerve_measure_y - b) / m
+        else:   
+            nerve_measure_y = round(eye_center_y + target_length)
+            nerve_measure_x = eye_center_x
+
+        shift_y = (input_frame.shape[0] // 2) - nerve_measure_y
+        shift_x = (input_frame.shape[1] // 2) - nerve_measure_x
+
+        shifted_image = shift(input_frame, shift=(shift_y, shift_x, 0))
+        shifted_nerve = shift(nerve_frame, shift=(shift_y, shift_x))
+
+        if m is not None:
+            angle = 90 + np.degrees(np.arctan(m))
+
+            rotated_image = rotate(shifted_image, angle=angle)
+            rotated_nerve = rotate(shifted_nerve, angle=angle)
+        else:
+            rotated_image = shifted_image
+            rotated_nerve = shifted_nerve
+
+        center_y = rotated_image.shape[0] // 2
+        center_x = rotated_image.shape[1] // 2
+
+        crop_y = nerve_size[0] // 2
+        crop_x = nerve_size[1] // 2
+
+        cropped_image = rotated_image[center_y-crop_y:center_y+crop_y, center_x-crop_x:center_x+crop_x, :]
+        cropped_nerve = rotated_nerve[center_y-crop_y:center_y+crop_y, center_x-crop_x:center_x+crop_x]
+        
+        all_nerve_slices.append(cropped_image)
+        all_nerve_masks.append(cropped_nerve)
+        
+    all_nerve_slices = np.stack(all_nerve_slices)
+    all_nerve_masks = np.stack(all_nerve_masks)
+    
+    return np.expand_dims(all_nerve_slices, axis=-1), all_nerve_masks
+
+def get_nerve_slice_test_time(yolo_model, input_frames, yolo_frames, nerve_size):
+    all_nerve_slices = []
+    
+    for input_frame, yolo_frame in zip(input_frames, yolo_frames):
+        m, b, (eye_center_x, eye_center_y), (nerve_center_x, nerve_center_y), dist_to_bottom = get_line(yolo_model, yolo_frame)
+        
+        if dist_to_bottom is None:
+            continue
+
+        target_length = (65/1080)*input_frame.shape[0] + dist_to_bottom
+
+        if m is not None and b is not None:
+            nerve_start = eye_center_y
+            for i in range(round(eye_center_y) + 1, input_frame.shape[0]):
+                x_val = round((i - b) / m)
+                distance_from_eye = math.sqrt((x_val - eye_center_x)**2 + (i - eye_center_y)**2)
+                if distance_from_eye > target_length:
+                    nerve_start = i
+                    break
+
+            nerve_measure_y = nerve_start
+            nerve_measure_x = (nerve_measure_y - b) / m
+        else:   
+            nerve_measure_y = round(eye_center_y + target_length)
+            nerve_measure_x = eye_center_x
+
+        shift_y = (input_frame.shape[0] // 2) - nerve_measure_y
+        shift_x = (input_frame.shape[1] // 2) - nerve_measure_x
+
+        shifted_image = shift(input_frame, shift=(shift_y, shift_x, 0))
+
+        if m is not None:
+            angle = 90 + np.degrees(np.arctan(m))
+
+            rotated_image = rotate(shifted_image, angle=angle)
+        else:
+            rotated_image = shifted_image
+
+        center_y = rotated_image.shape[0] // 2
+        center_x = rotated_image.shape[1] // 2
+
+        crop_y = nerve_size[0] // 2
+        crop_x = nerve_size[1] // 2
+
+        cropped_image = rotated_image[center_y-crop_y:center_y+crop_y, center_x-crop_x:center_x+crop_x, :]
+        
+        all_nerve_slices.append(cropped_image)
+        
+    if not all_nerve_slices:
+        return None
+        
+    all_nerve_slices = np.stack(all_nerve_slices)
+    
+    return np.expand_dims(all_nerve_slices, axis=-1)
+
+def get_width_measurement(nerve_frame):
+    nerve_center_y = nerve_frame.shape[0] // 2
+    nerve_center_x = nerve_frame.shape[1] // 2
+    
+    left_boundary = nerve_center_x
+    for j in range(round(nerve_center_x) - 1, 0, -1):
+        if nerve_frame[nerve_center_y, j] != 1.0:
+            left_boundary = j + 1
+            break
+    
+    right_boundary = nerve_center_x
+    for j in range(round(nerve_center_x) + 1, nerve_frame.shape[1]):
+        if nerve_frame[nerve_center_y, j] != 1.0:
+            right_boundary = j - 1
+            break
+    
+    width = right_boundary - left_boundary
+
+    return width
+
+def get_width_predictions(yolo_model, nerve_model, X, yolo, nerve, lbls, pos_neg_cutoff, nerve_size):
+    all_widths = []
+    pred_widths = []
+    class_preds = []
+    gt_mask_preds = []
+    class_gt = []
+    
+    nerve_slices, nerve_masks = get_nerve_slice(yolo_model, X, yolo, nerve, nerve_size)
+    
+    nerve_pred = np.argmax(nerve_model.predict(nerve_slices, verbose=False), axis=-1)
+    
+    for nerve_p, nerve_gt, lbl in zip(nerve_pred, nerve_masks, lbls):
+        width = get_width_measurement(nerve_gt)
+
+        pred_width = get_width_measurement(nerve_p)
+
+        all_widths.append(width)
+        pred_widths.append(pred_width)
+        
+        if width >= pos_neg_cutoff:
+            gt_mask_preds.append(1)
+        else:
+            gt_mask_preds.append(0)
+        
+        if pred_width >= pos_neg_cutoff:
+            class_preds.append(1)
+        else:
+            class_preds.append(0)
+            
+        if lbl == 'Positives':
+            class_gt.append(1)
+        else:
+            class_gt.append(0)
+
+    return np.array(all_widths), np.array(pred_widths), np.array(gt_mask_preds), np.array(class_preds), np.array(class_gt)
+
+def run_full_eval(nerve_model, yolo_model, videos, lbl, pos_neg_cutoff, img_size, nerve_size):
+    pred_widths = []
+    class_preds = []
+    class_gt = []
+    
+    transforms = tv.transforms.Compose(
+    [tv.transforms.Grayscale(1)
+    ])
+    
+    resize = tv.transforms.Resize(img_size)
+    
+    all_slices = []
+    for video_path in videos:
+        vc = tv.io.read_video(video_path)[0].permute(3, 0, 1, 2)
+
+        all_frames = [resize(vc[:, j, :, :]) for j in range(0, vc.size(1), 10)]
+        
+        all_yolo = np.stack([frame.numpy().swapaxes(0,1).swapaxes(1,2) for frame in all_frames])
+        
+        all_frames = np.stack([transforms(frame).numpy().swapaxes(0,1).swapaxes(1,2) for frame in all_frames])
+        
+        slices = get_nerve_slice_test_time(yolo_model, all_frames, all_yolo, nerve_size)
+        
+        if slices is None:
+            continue
+
+        all_slices.append(slices)
+        
+    if not all_slices:
+        if lbl == 'Positives':
+            class_gt.append(1)
+        else:
+            class_gt.append(0)
+            
+        class_preds.append(0)
+    else:
+        all_slices = np.concatenate(all_slices)
+
+        pred = np.argmax(nerve_model.predict(all_slices, verbose=False), axis=-1)
+
+        for p in pred:
+            pred_width = get_width_measurement(p)
+
+            pred_widths.append(pred_width)
+
+        pred_width = np.average(pred_widths)
+
+        if pred_width >= pos_neg_cutoff:
+            class_preds.append(1)
+        else:
+            class_preds.append(0)
+
+        if lbl == 'Positives':
+            class_gt.append(1)
+        else:
+            class_gt.append(0)
+
+    return np.array(class_preds), np.array(class_gt)
+
+def run_full_eval_test_set(nerve_model, yolo_model, videos, lbls, pos_neg_cutoff, img_size, nerve_size):
+    class_preds = []
+    class_gt = []
+    
+    transforms = tv.transforms.Compose(
+    [tv.transforms.Grayscale(1)
+    ])
+    
+    resize = tv.transforms.Resize(img_size)
+
+    for video_path, lbl in zip(videos, lbls):
+        pred_widths = []
+        
+        vc = tv.io.read_video(video_path)[0].permute(3, 0, 1, 2)
+
+        all_frames = [resize(vc[:, j, :, :]) for j in range(0, vc.size(1), 10)]
+        
+        all_yolo = np.stack([frame.numpy().swapaxes(0,1).swapaxes(1,2) for frame in all_frames])
+        
+        all_frames = np.stack([transforms(frame).numpy().swapaxes(0,1).swapaxes(1,2) for frame in all_frames])
+        
+#         cv2.imwrite('onsd_validation/' + video_path.split('/')[-1][:-4] + '_frame.png', all_frames[0])
+        
+        slices = get_nerve_slice_test_time(yolo_model, all_frames, all_yolo, nerve_size)
+        
+#         cv2.imwrite('onsd_validation/' + video_path.split('/')[-1][:-4] + '_slice.png', np.squeeze(slices[0], axis=-1))
+        
+        if slices is None:
+            print('Not found')
+            class_preds.append(0)
+
+            if lbl == 'Positives':
+                class_gt.append(1)
+            else:
+                class_gt.append(0)
+            continue
+
+        slices = np.stack(slices)
+
+        pred = np.argmax(nerve_model.predict(slices, verbose=False), axis=-1)
+
+        for p in pred:
+            pred_width = get_width_measurement(p)
+            if pred_width == 0:
+                continue
+
+            pred_widths.append(pred_width)
+            
+        print(pred_widths)
+        if not pred_widths:
+            pred_widths.append(0)
+
+        pred_width = np.average(pred_widths)
+        
+        print(pred_width)
+
+        if pred_width >= pos_neg_cutoff:
+            class_preds.append(1)
+        else:
+            class_preds.append(0)
+
+        if lbl == 'Positives':
+            class_gt.append(1)
+        else:
+            class_gt.append(0)
+
+    return np.array(class_preds), np.array(class_gt)
+
+
+random.seed(321534)
+np.random.seed(321534)
+tf.random.set_seed(321534)
+
+output_dir = 'sparse_coding_torch/unet_output/unet_nerve_all_vids'
+
+if not os.path.exists(output_dir):
+    os.makedirs(output_dir)
+
+video_path = "/shared_data/bamc_onsd_data/revised_extended_onsd_data"
+
+input_dir = 'segmentation/more_slices/raw_frames'
+
+nerve_dir = 'segmentation/more_slices/nerve_segmentation'
+
+yolo_model = YoloModel('onsd')
+
+img_size = (416, 416)
+nerve_size = (16, 128)
+batch_size = 12
+# pos_neg_cutoff = (102 / 1080) * img_size[0]
+pos_neg_cutoff = 46
+
+input_data = load_videos(input_dir, nerve_dir, pos_neg_cutoff, None)
+
+participant_to_data = get_participants(video_path, input_data) 
+
+splits, participants = create_splits(participant_to_data)
+
+all_train_frame_pred = []
+all_train_frame_gt = []
+
+all_test_frame_pred = []
+all_test_frame_gt = []
+all_test_video_pred = []
+all_test_video_gt = []
+
+all_yolo_gt = []
+all_yolo_pred = []
+
+i_fold = 0
+for train_idx, test_idx in splits:
+    train_participants = [p for i, p in enumerate(participants) if i in train_idx]
+    test_participants = [p for i, p in enumerate(participants) if i in test_idx]
+    
+    assert len(set(train_participants).intersection(set(test_participants))) == 0
+
+    # Instantiate data Sequences for each split
+    train_X, train_yolo, train_nerve, train_txt = make_numpy_arrays(train_participants, participant_to_data, img_size)
+    test_X, test_yolo, test_nerve, test_txt = make_numpy_arrays(test_participants, participant_to_data, img_size)
+    print(test_txt)
+
+    keras.backend.clear_session()
+    
+    nerve_inputs = keras.Input(shape=(None, None, 1))
+    
+    data_preprocessing = keras.Sequential([#keras.layers.RandomFlip('horizontal_and_vertical'),
+                                           #keras.layers.RandomBrightness(0.10),
+#                                            keras.layers.RandomContrast(0.01),
+#                                            keras.layers.RandomRotation(0.10)
+    ])(nerve_inputs)
+
+    nerve_outputs = unet_2d((None, None, 1), [64, 128, 256], n_labels=2,
+                      stack_num_down=2, stack_num_up=1,
+                      activation='GELU', output_activation='Softmax', 
+                      batch_norm=True, pool='max', unpool='nearest', name='unet')(data_preprocessing)
+#     nerve_outputs = att_unet_2d((None, None, 1), [64, 128, 256], n_labels=2,
+#                            stack_num_down=2, stack_num_up=2,
+#                            activation='ReLU', atten_activation='ReLU', attention='add', output_activation='Softmax', 
+#                            batch_norm=True, pool=False, unpool='bilinear', name='attunet')(data_preprocessing)
+#     nerve_outputs = unet_plus_2d((None, None, 1), [64, 128, 256], n_labels=2,
+#                             stack_num_down=2, stack_num_up=2,
+#                             activation='LeakyReLU', output_activation='Softmax', 
+#                             batch_norm=False, pool='max', unpool=False, deep_supervision=True, name='xnet')(data_preprocessing)
+#     nerve_outputs = r2_unet_2d((None, None, 1), [64, 128, 256], n_labels=2,
+#                           stack_num_down=2, stack_num_up=1, recur_num=2,
+#                           activation='ReLU', output_activation='Softmax', 
+#                           batch_norm=True, pool='max', unpool='bilinear', name='r2unet')(data_preprocessing)
+#     nerve_outputs = resunet_a_2d(nerve_size + (1,), [32, 64, 128, 256], 
+#                             dilation_num=[1, 3, 15, 31], 
+#                             n_labels=2, aspp_num_down=256, aspp_num_up=128, 
+#                             activation='ReLU', output_activation='Softmax', 
+#                             batch_norm=True, pool=False, unpool='nearest', name='resunet')(data_preprocessing)
+#     nerve_outputs = u2net_2d((None, None, 1), n_labels=2, 
+#                         filter_num_down=[64, 128, 256],
+#                         activation='ReLU', output_activation='Softmax', 
+#                         batch_norm=True, pool=False, unpool=False, deep_supervision=True, name='u2net')(data_preprocessing)
+#     nerve_model = swin_unet_2d(nerve_size + (1,), filter_num_begin=64, n_labels=2, depth=4, stack_num_down=2, stack_num_up=2, 
+#                             patch_size=(2, 4), num_heads=[4, 8, 8, 8], window_size=[4, 2, 2, 2], num_mlp=512, 
+#                             output_activation='Softmax', shift_window=True, name='swin_unet')
+#     nerve_model = transunet_2d(nerve_size + (1,), filter_num=[64, 128, 256], n_labels=12, stack_num_down=2, stack_num_up=2,
+#                                 embed_dim=768, num_mlp=3072, num_heads=12, num_transformer=12,
+#                                 activation='ReLU', mlp_activation='GELU', output_activation='Softmax', 
+#                                 batch_norm=True, pool=True, unpool='bilinear', name='transunet')
+    
+    nerve_model = keras.Model(nerve_inputs, nerve_outputs)
+
+    nerve_model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4), loss="sparse_categorical_crossentropy")
+#     nerve_callbacks = [
+#         keras.callbacks.ModelCheckpoint(os.path.join(output_dir, "best_nerve_model_{}.h5".format(i_fold)), save_best_only=True, save_weights_only=True)
+#     ]
+
+    # Train the model, doing validation at the end of each epoch.
+#     if os.path.exists(os.path.join(output_dir, "best_unet_model_{}.h5".format(i_fold))):
+#         model.load_weights(os.path.join(output_dir, "best_unet_model_{}.h5".format(i_fold)))
+#     else:
+    epochs = 200
+    
+    train_slices, train_masks = get_nerve_slice(yolo_model, train_X, train_yolo, train_nerve, nerve_size)
+
+    nerve_model.fit(train_slices, train_masks, validation_split=0.2, epochs=epochs, batch_size=batch_size, verbose=0)
+    
+    nerve_weights = nerve_model.get_weights()
+    
+    nerve_model = unet_2d((None, None, 1), [64, 128, 256], n_labels=2,
+                      stack_num_down=2, stack_num_up=1,
+                      activation='GELU', output_activation='Softmax', 
+                      batch_norm=True, pool='max', unpool='nearest', name='unet')
+    nerve_model.set_weights(nerve_weights)
+
+    final_width_train, final_pred_width_train, class_gt_mask_train, class_pred_train, class_gt_train = get_width_predictions(yolo_model, nerve_model, train_X, train_yolo, train_nerve, train_txt, pos_neg_cutoff, nerve_size)
+    
+    train_average_width_difference = np.average(np.abs(np.array(final_width_train) - np.array(final_pred_width_train)))
+    
+    train_gt_mask_class_score = accuracy_score(class_gt_train, class_gt_mask_train)
+    
+    train_pred_mask_class_score = accuracy_score(class_gt_train, class_pred_train)
+    
+    print('Training results fold {}: average width difference={:.2f}, ground truth mask classification={:.2f}, predicted mask classification={:.2f}'.format(i_fold, train_average_width_difference, train_gt_mask_class_score, train_pred_mask_class_score))
+
+#     videos = get_videos(participants[i_fold])
+#     lbl = test_txt[0]
+#     test_pred, test_gt = run_full_eval(nerve_model, yolo_model, videos, lbl, pos_neg_cutoff, img_size, nerve_size)
+    
+#     test_pred_mask_class_score = accuracy_score(test_gt, test_pred)
+    
+#     pred_video_pred = np.array([np.round(np.average(test_pred))])
+    
+#     if test_txt[0] == 'Positives':
+#         video_class = np.array([1])
+#     else:
+#         video_class = np.array([0])
+    
+#     test_video_pred_mask_score = accuracy_score(video_class, pred_video_pred)
+    
+    videos, lbls = get_test_videos()
+    test_pred, test_gt = run_full_eval_test_set(nerve_model, yolo_model, videos, lbls, pos_neg_cutoff, img_size, nerve_size)
+    
+    print(test_pred)
+    print(test_gt)
+    
+    test_video_pred_mask_score = accuracy_score(test_gt, test_pred)
+    
+    print('Testing results fold {}:  predicted mask video-level classification={:.2f}'.format(i_fold, test_video_pred_mask_score))
+    
+    all_train_frame_pred.append(class_pred_train)
+    all_train_frame_gt.append(class_gt_train)
+    
+#     all_test_frame_pred.append(test_pred)
+#     all_test_frame_gt.append(test_gt)
+    all_test_video_pred.append(test_pred)
+    all_test_video_gt.append(test_gt)
+    
+    i_fold += 1
+    
+all_train_frame_pred = np.concatenate(all_train_frame_pred)
+all_train_frame_gt = np.concatenate(all_train_frame_gt)
+
+# all_test_frame_pred = np.concatenate(all_test_frame_pred)
+all_test_video_pred = np.concatenate(all_test_video_pred)
+
+final_train_frame_acc = accuracy_score(all_train_frame_gt, all_train_frame_pred)
+# final_test_frame_acc = accuracy_score(all_test_frame_gt, all_test_frame_pred)
+final_test_video_acc = accuracy_score(all_test_video_gt, all_test_video_pred)
+
+print('Final results: Train frame-level classification={:.2f}, Test frame-level classification={:.2f}, Test video-level classification={:.2f}'.format(final_train_frame_acc, final_test_frame_acc, final_test_video_acc))
\ No newline at end of file