diff --git a/generate_yolo_regions.py b/generate_yolo_regions.py index a0351cdb1fd2125bd3cfe1947f2ff0e9a3544d9a..c39eb2d364950a2445f0e2f109bb1a3da784bca3 100644 --- a/generate_yolo_regions.py +++ b/generate_yolo_regions.py @@ -3,7 +3,7 @@ import os import time import numpy as np import torchvision -from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels, calculate_angle +from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels, calculate_angle, calculate_angle_video, get_needle_bb from torchvision.datasets.video_utils import VideoClips import torchvision as tv import csv @@ -24,7 +24,7 @@ if __name__ == "__main__": parser.add_argument('--output_dir', default='yolo_output', type=str, help='Location where yolo clips should be saved.') parser.add_argument('--num_frames', default=5, type=int) parser.add_argument('--stride', default=5, type=int) - parser.add_argument('--image_height', default=150, type=int) + parser.add_argument('--image_height', default=300, type=int) parser.add_argument('--image_width', default=400, type=int) args = parser.parse_args() @@ -50,7 +50,8 @@ if __name__ == "__main__": vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2) is_right = classify_nerve_is_right(yolo_model, vc) - angle = calculate_angle(yolo_model, vc) +# video_angle = calculate_angle_video(yolo_model, vc) + needle_bb = get_needle_bb(yolo_model, vc) person_idx = path.split('/')[-2] label = path.split('/')[-3] @@ -72,10 +73,11 @@ if __name__ == "__main__": if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height): clip = transforms(clip) ani = plot_video(clip) ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4')) + print(output_count) output_count += 1 if positive_regions: @@ -92,11 +94,13 @@ if __name__ == "__main__": if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height): clip = transforms(clip) ani = plot_video(clip) ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4')) + print(output_count) output_count += 1 + elif vc.size(1) >= start_loc + clip_depth * frames_to_skip: end_loc = sub_region[1] if end_loc.strip().lower() == 'end': @@ -111,10 +115,12 @@ if __name__ == "__main__": if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): + + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height): clip = transforms(clip) ani = plot_video(clip) ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4')) + print(output_count) output_count += 1 else: continue @@ -125,10 +131,11 @@ if __name__ == "__main__": if frames: vc_sub = torch.stack(frames, dim=1) if vc_sub.size(1) >= clip_depth: - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height): clip = transforms(clip) ani = plot_video(clip) ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4')) + print(output_count) output_count += 1 elif label == 'Negatives': for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth): @@ -136,11 +143,13 @@ if __name__ == "__main__": for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): frames.append(vc[:, k, :, :]) vc_sub = torch.stack(frames, dim=1) + if vc_sub.size(1) >= clip_depth: - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height): clip = transforms(clip) ani = plot_video(clip) ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4')) + print(output_count) output_count += 1 else: raise Exception('Invalid label') \ No newline at end of file diff --git a/sparse_coding_torch/onsd/classifier_model.py b/sparse_coding_torch/onsd/classifier_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1162dc1f27ad3b8e28835b0b3086052f8d28f609 --- /dev/null +++ b/sparse_coding_torch/onsd/classifier_model.py @@ -0,0 +1,42 @@ +from tensorflow import keras +import numpy as np +import torch +import tensorflow as tf +import cv2 +import torchvision as tv +import torch +import torch.nn as nn +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler + +class ONSDClassifier(keras.layers.Layer): + def __init__(self): + super(ONSDClassifier, self).__init__() + + self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid') + self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid') + + self.flatten = keras.layers.Flatten() + + self.dropout = keras.layers.Dropout(0.5) + +# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) +# self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True) +# self.ff_2 = keras.layers.Dense(20, activation='relu', use_bias=True) + self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True) + self.ff_4 = keras.layers.Dense(1) + +# @tf.function + def call(self, activations): + x = self.conv_1(activations) + x = self.conv_2(x) + x = self.flatten(x) +# x = self.ff_1(x) +# x = self.dropout(x) +# x = self.ff_2(x) +# x = self.dropout(x) + x = self.ff_3(x) + x = self.dropout(x) + x = self.ff_4(x) + + return x + diff --git a/sparse_coding_torch/onsd/load_data.py b/sparse_coding_torch/onsd/load_data.py new file mode 100644 index 0000000000000000000000000000000000000000..de53a528ea30657fc574fa6ea84a59b9d8fa78b5 --- /dev/null +++ b/sparse_coding_torch/onsd/load_data.py @@ -0,0 +1,50 @@ +import numpy as np +import torchvision +import torch +from sklearn.model_selection import train_test_split +from sparse_coding_torch.utils import MinMaxScaler +from sparse_coding_torch.onsd.video_loader import get_participants, ONSDLoader +from sparse_coding_torch.utils import VideoGrayScaler +from typing import Sequence, Iterator +import csv +from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit + +def load_onsd_videos(batch_size, input_size, yolo_model=None, mode=None, n_splits=None): + video_path = "/shared_data/bamc_onsd_data/preliminary_onsd_data" + + transforms = torchvision.transforms.Compose( + [torchvision.transforms.Grayscale(1), + MinMaxScaler(0, 255), + torchvision.transforms.Resize(input_size[:2]) + ]) + augment_transforms = torchvision.transforms.Compose( + [torchvision.transforms.RandomRotation(15) + ]) + dataset = ONSDLoader(video_path, input_size[1], input_size[0], transform=transforms, augmentation=augment_transforms, yolo_model=yolo_model) + + targets = dataset.get_labels() + + if mode == 'leave_one_out': + gss = LeaveOneGroupOut() + + groups = get_participants(dataset.get_filenames()) + + return gss.split(np.arange(len(targets)), targets, groups), dataset + elif mode == 'all_train': + train_idx = np.arange(len(targets)) + test_idx = None + + return [(train_idx, test_idx)], dataset + elif mode == 'k_fold': + gss = StratifiedGroupKFold(n_splits=n_splits, shuffle=True) + + groups = get_participants(dataset.get_filenames()) + + return gss.split(np.arange(len(targets)), targets, groups), dataset + else: +# gss = ShuffleSplit(n_splits=n_splits, test_size=0.2) + gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2) + + groups = get_participants(dataset.get_filenames()) + + return gss.split(np.arange(len(targets)), targets, groups), dataset \ No newline at end of file diff --git a/sparse_coding_torch/onsd/train_classifier.py b/sparse_coding_torch/onsd/train_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a879bf032944b5bec5f8a1b33a1bc907a5a888d3 --- /dev/null +++ b/sparse_coding_torch/onsd/train_classifier.py @@ -0,0 +1,299 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +import argparse +import os +from sparse_coding_torch.onsd.load_data import load_onsd_videos +from sparse_coding_torch.utils import SubsetWeightedRandomSampler, get_sample_weights +from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d +from sparse_coding_torch.onsd.classifier_model import ONSDClassifier +import time +import numpy as np +from sklearn.metrics import f1_score, accuracy_score, confusion_matrix +import random +import pickle +import tensorflow.keras as keras +import tensorflow as tf +from sparse_coding_torch.onsd.train_sparse_model import sparse_loss +from yolov4.get_bounding_boxes import YoloModel +import torchvision +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler +import glob +import cv2 + +configproto = tf.compat.v1.ConfigProto() +configproto.gpu_options.polling_inactive_delay_msecs = 5000 +configproto.gpu_options.allow_growth = True +sess = tf.compat.v1.Session(config=configproto) +tf.compat.v1.keras.backend.set_session(sess) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=12, type=int) + parser.add_argument('--kernel_size', default=15, type=int) + parser.add_argument('--kernel_depth', default=5, type=int) + parser.add_argument('--num_kernels', default=64, type=int) + parser.add_argument('--stride', default=1, type=int) + parser.add_argument('--max_activation_iter', default=150, type=int) + parser.add_argument('--activation_lr', default=1e-2, type=float) + parser.add_argument('--lr', default=5e-4, type=float) + parser.add_argument('--epochs', default=40, type=int) + parser.add_argument('--lam', default=0.05, type=float) + parser.add_argument('--output_dir', default='./output', type=str) + parser.add_argument('--sparse_checkpoint', default=None, type=str) + parser.add_argument('--checkpoint', default=None, type=str) + parser.add_argument('--splits', default=None, type=str, help='k_fold or leave_one_out or all_train') + parser.add_argument('--seed', default=26, type=int) + parser.add_argument('--train', action='store_true') + parser.add_argument('--num_positives', default=100, type=int) + parser.add_argument('--n_splits', default=5, type=int) + parser.add_argument('--save_train_test_splits', action='store_true') + parser.add_argument('--run_2d', action='store_true') + parser.add_argument('--balance_classes', action='store_true') + parser.add_argument('--dataset', default='pnb', type=str) + parser.add_argument('--train_sparse', action='store_true') + parser.add_argument('--mixing_ratio', type=float, default=1.0) + parser.add_argument('--sparse_lr', type=float, default=0.003) + parser.add_argument('--crop_height', type=int, default=285) + parser.add_argument('--crop_width', type=int, default=350) + parser.add_argument('--scale_factor', type=int, default=1) + parser.add_argument('--clip_depth', type=int, default=5) + parser.add_argument('--frames_to_skip', type=int, default=1) + + args = parser.parse_args() + + crop_height = args.crop_height + crop_width = args.crop_width + + image_height = int(crop_height / args.scale_factor) + image_width = int(crop_width / args.scale_factor) + clip_depth = args.clip_depth + + batch_size = args.batch_size + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: + out_f.write(str(args)) + + yolo_model = YoloModel(args.dataset) + + all_errors = [] + + if args.run_2d: + inputs = keras.Input(shape=(image_height, image_width, clip_depth)) + else: + inputs = keras.Input(shape=(clip_depth, image_height, image_width, 1)) + + sparse_model = None + recon_model = None + + + splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width), yolo_model=yolo_model, mode=args.splits, n_splits=args.n_splits) + positive_class = 'Positives' + + overall_true = [] + overall_pred = [] + fn_ids = [] + fp_ids = [] + + i_fold = 0 + for train_idx, test_idx in splits: + train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) +# train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True) + train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + sampler=train_sampler) + + if test_idx is not None: + test_sampler = torch.utils.data.SubsetRandomSampler(test_idx) + test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + sampler=test_sampler) + +# with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out: +# test_set = set([x for tup in test_loader for x in tup[2]]) +# test_videos_out.writelines(test_set) + else: + test_loader = None + + if args.checkpoint: + classifier_model = keras.models.load_model(args.checkpoint) + else: + classifier_inputs = keras.Input(shape=(image_height, image_width, 1)) + classifier_outputs = ONSDClassifier()(classifier_inputs) + + classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs) + + prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr) + filter_optimizer = tf.keras.optimizers.SGD(learning_rate=args.sparse_lr) + + best_so_far = float('-inf') + + criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM) + + if args.train: + for epoch in range(args.epochs): + epoch_loss = 0 + t1 = time.perf_counter() + + y_true_train = None + y_pred_train = None + + for labels, local_batch, vid_f in tqdm(train_loader): + images = local_batch.permute(0, 2, 3, 1).numpy() + + torch_labels = np.zeros(len(labels)) + torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 + torch_labels = np.expand_dims(torch_labels, axis=1) + + if args.train_sparse: + with tf.GradientTape() as tape: +# activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)]) + pred = classifier_model(activations) + loss = criterion(torch_labels, pred) + + print(loss) + else: + with tf.GradientTape() as tape: + pred = classifier_model(images) + loss = criterion(torch_labels, pred) + + epoch_loss += loss * local_batch.size(0) + + if args.train_sparse: + sparse_gradients, classifier_gradients = tape.gradient(loss, [recon_model.trainable_weights, classifier_model.trainable_weights]) + + prediction_optimizer.apply_gradients(zip(classifier_gradients, classifier_model.trainable_weights)) + + filter_optimizer.apply_gradients(zip(sparse_gradients, recon_model.trainable_weights)) + + if args.run_2d: + weights = normalize_weights(recon_model.get_weights(), args.num_kernels) + else: + weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels) + recon_model.set_weights(weights) + else: + gradients = tape.gradient(loss, classifier_model.trainable_weights) + + prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights)) + + + if y_true_train is None: + y_true_train = torch_labels + y_pred_train = tf.math.round(tf.math.sigmoid(pred)) + else: + y_true_train = tf.concat((y_true_train, torch_labels), axis=0) + y_pred_train = tf.concat((y_pred_train, tf.math.round(tf.math.sigmoid(pred))), axis=0) + + t2 = time.perf_counter() + + y_true = None + y_pred = None + test_loss = 0.0 + + eval_loader = test_loader + if args.splits == 'all_train': + eval_loader = train_loader + for labels, local_batch, vid_f in tqdm(eval_loader): + images = local_batch.permute(0, 2, 3, 1).numpy() + + torch_labels = np.zeros(len(labels)) + torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 + torch_labels = np.expand_dims(torch_labels, axis=1) + + pred = classifier_model(images) + loss = criterion(torch_labels, pred) + + test_loss += loss + + if y_true is None: + y_true = torch_labels + y_pred = tf.math.round(tf.math.sigmoid(pred)) + else: + y_true = tf.concat((y_true, torch_labels), axis=0) + y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0) + + t2 = time.perf_counter() + + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + y_true_train = tf.cast(y_true_train, tf.int32) + y_pred_train = tf.cast(y_pred_train, tf.int32) + + f1 = f1_score(y_true, y_pred, average='macro') + accuracy = accuracy_score(y_true, y_pred) + + train_accuracy = accuracy_score(y_true_train, y_pred_train) + + print('epoch={}, i_fold={}, time={:.2f}, train_loss={:.2f}, test_loss={:.2f}, train_acc={:.2f}, test_f1={:.2f}, test_acc={:.2f}'.format(epoch, i_fold, t2-t1, epoch_loss, test_loss, train_accuracy, f1, accuracy)) + # print(epoch_loss) + if f1 >= best_so_far: + print("found better model") + # Save model parameters + classifier_model.save(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold))) +# recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold))) + pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer_{}.pt'.format(i_fold)), 'wb+')) + best_so_far = f1 + + classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold))) +# recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold))) + + epoch_loss = 0 + + y_true = None + y_pred = None + + pred_dict = {} + gt_dict = {} + + t1 = time.perf_counter() + # test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader] + raise Exception('Not yet implemented') + + t2 = time.perf_counter() + + print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1)) + + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + f1 = f1_score(y_true, y_pred, average='macro') + accuracy = accuracy_score(y_true, y_pred) + + fn_ids.extend(fn) + fp_ids.extend(fp) + + overall_true.extend(y_true) + overall_pred.extend(y_pred) + + print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy)) + + print(confusion_matrix(y_true, y_pred)) + + i_fold += 1 + + fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt') + with open(fp_fn_file, 'w+') as in_f: + in_f.write('FP:\n') + in_f.write(str(fp_ids) + '\n\n') + in_f.write('FN:\n') + in_f.write(str(fn_ids) + '\n\n') + + overall_true = np.array(overall_true) + overall_pred = np.array(overall_pred) + + final_f1 = f1_score(overall_true, overall_pred, average='macro') + final_acc = accuracy_score(overall_true, overall_pred) + final_conf = confusion_matrix(overall_true, overall_pred) + + print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1)) + print(final_conf) + diff --git a/sparse_coding_torch/onsd/train_sparse_model.py b/sparse_coding_torch/onsd/train_sparse_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0352d530d505f175e1e6cc3f3c5323d1be64a4c4 --- /dev/null +++ b/sparse_coding_torch/onsd/train_sparse_model.py @@ -0,0 +1,169 @@ +import time +import numpy as np +import torch +from matplotlib import pyplot as plt +from matplotlib import cm +from matplotlib.animation import FuncAnimation +from tqdm import tqdm +import argparse +import os +from sparse_coding_torch.onsd.load_data import load_onsd_videos +import tensorflow.keras as keras +import tensorflow as tf +from sparse_coding_torch.sparse_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse +import random + +def sparse_loss(images, recon, activations, batch_size, lam, stride): + loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2)) + loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1)) + return loss + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser.add_argument('--kernel_size', default=15, type=int) + parser.add_argument('--kernel_depth', default=5, type=int) + parser.add_argument('--num_kernels', default=32, type=int) + parser.add_argument('--stride', default=1, type=int) + parser.add_argument('--max_activation_iter', default=300, type=int) + parser.add_argument('--activation_lr', default=1e-2, type=float) + parser.add_argument('--lr', default=0.003, type=float) + parser.add_argument('--epochs', default=150, type=int) + parser.add_argument('--lam', default=0.05, type=float) + parser.add_argument('--output_dir', default='./output', type=str) + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--run_2d', action='store_true') + parser.add_argument('--save_filters', action='store_true') + parser.add_argument('--optimizer', default='sgd', type=str) + parser.add_argument('--dataset', default='onsd', type=str) + parser.add_argument('--crop_height', type=int, default=400) + parser.add_argument('--crop_width', type=int, default=400) + parser.add_argument('--scale_factor', type=int, default=1) + parser.add_argument('--clip_depth', type=int, default=5) + parser.add_argument('--frames_to_skip', type=int, default=1) + + + args = parser.parse_args() + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + crop_height = args.crop_height + crop_width = args.crop_width + + image_height = int(crop_height / args.scale_factor) + image_width = int(crop_width / args.scale_factor) + clip_depth = args.clip_depth + + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: + out_f.write(str(args)) + + splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), mode='all_train') + train_idx, test_idx = splits[0] + + train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) + train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, + sampler=train_sampler) + + print('Loaded', len(train_loader), 'train examples') + + example_data = next(iter(train_loader)) + + if args.run_2d: + inputs = keras.Input(shape=(image_height, image_width, 5)) + else: + inputs = keras.Input(shape=(5, image_height, image_width, 1)) + + filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32') + + output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs) + + sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output) + + recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels)) + + recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs) + + recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs) + + if args.save_filters: + if args.run_2d: + filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0)) + else: + filters = plot_filters(recon_model.get_weights()[0]) + filters.save(os.path.join(args.output_dir, 'filters_start.mp4')) + + learning_rate = args.lr + if args.optimizer == 'sgd': + filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + else: + filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) + + loss_log = [] + best_so_far = float('inf') + + for epoch in range(args.epochs): + epoch_loss = 0 + running_loss = 0.0 + epoch_start = time.perf_counter() + + num_iters = 0 + + for labels, local_batch, vid_f in tqdm(train_loader): + if local_batch.size(0) != args.batch_size: + continue + if args.run_2d: + images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy() + else: + images = local_batch.permute(0, 2, 3, 4, 1).numpy() + + activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) + + with tf.GradientTape() as tape: + recon = recon_model(activations) + loss = sparse_loss(images, recon, activations, args.batch_size, args.lam, args.stride) + + epoch_loss += loss * local_batch.size(0) + running_loss += loss * local_batch.size(0) + + gradients = tape.gradient(loss, recon_model.trainable_weights) + + filter_optimizer.apply_gradients(zip(gradients, recon_model.trainable_weights)) + + if args.run_2d: + weights = normalize_weights(recon_model.get_weights(), args.num_kernels) + else: + weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels) + recon_model.set_weights(weights) + + num_iters += 1 + + epoch_end = time.perf_counter() + epoch_loss /= len(train_loader.sampler) + + if args.save_filters and epoch % 2 == 0: + if args.run_2d: + filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0)) + else: + filters = plot_filters(recon_model.get_weights()[0]) + filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4')) + + if epoch_loss < best_so_far: + print("found better model") + # Save model parameters + recon_model.save(os.path.join(output_dir, "best_sparse.pt")) + best_so_far = epoch_loss + + loss_log.append(epoch_loss) + print('epoch={}, epoch_loss={:.2f}, time={:.2f}'.format(epoch, epoch_loss, epoch_end - epoch_start)) + + plt.plot(loss_log) + + plt.savefig(os.path.join(output_dir, 'loss_graph.png')) diff --git a/sparse_coding_torch/onsd/video_loader.py b/sparse_coding_torch/onsd/video_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..b644ac915d39f8de86085d250edf97b419afb2da --- /dev/null +++ b/sparse_coding_torch/onsd/video_loader.py @@ -0,0 +1,122 @@ +from os import listdir +from os.path import isfile +from os.path import join +from os.path import isdir +from os.path import abspath +from os.path import exists +import json +import glob + +from PIL import Image +from torchvision.transforms import ToTensor +from torchvision.datasets.video_utils import VideoClips +from tqdm import tqdm +import torch +import numpy as np +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torchvision.io import read_video +import torchvision as tv +from torch import nn +import torchvision.transforms.functional as tv_f +import csv +import random +import cv2 +from yolov4.get_bounding_boxes import YoloModel + +def get_participants(filenames): + return [f.split('/')[-2] for f in filenames] + +def get_yolo_region_onsd(yolo_model, frame): + orig_height = frame.size(1) + orig_width = frame.size(2) + + bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy()) + bounding_boxes = bounding_boxes.squeeze(0) + classes = classes.squeeze(0) + scores = scores.squeeze(0) + + all_frames = [] + for bb, class_pred, score in zip(bounding_boxes, classes, scores): + if class_pred != 0: + continue + + lower_y = round((bb[0] * orig_height)) + upper_y = round((bb[2] * orig_height)) + lower_x = round((bb[1] * orig_width)) + upper_x = round((bb[3] * orig_width)) + + trimmed_frame = frame[:, lower_y:upper_y, lower_x:upper_x] + +# cv2.imwrite('test_2.png', frame.numpy().swapaxes(0,1).swapaxes(1,2)) +# cv2.imwrite('test_3.png', trimmed_frame.numpy().swapaxes(0,1).swapaxes(1,2)) + + return trimmed_frame + + return None + +class ONSDLoader(Dataset): + + def __init__(self, video_path, clip_width, clip_height, transform=None, augmentation=None, yolo_model=None): + self.transform = transform + self.augmentation = augmentation + self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))] + + clip_cache_file = 'clip_cache_onsd_{}_{}_sparse.pt'.format(clip_width, clip_height) + + self.videos = [] + for label in self.labels: + self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))]) + + self.clips = [] + + if exists(clip_cache_file): + self.clips = torch.load(open(clip_cache_file, 'rb')) + else: + vid_idx = 0 + for label, path, _ in tqdm(self.videos): + vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2) + + for j in range(vc.size(1)): + frame = vc[:, j, :, :] + + if yolo_model is not None: + frame = get_yolo_region_onsd(yolo_model, frame) + + if frame is None: + continue + + if self.transform: + frame = self.transform(frame) + + self.clips.append((self.videos[vid_idx][0], frame, self.videos[vid_idx][2])) + + vid_idx += 1 + + torch.save(self.clips, open(clip_cache_file, 'wb+')) + + num_positive = len([clip[0] for clip in self.clips if clip[0] == 'Positives']) + num_negative = len([clip[0] for clip in self.clips if clip[0] == 'Negatives']) + + random.shuffle(self.clips) + + print('Loaded', num_positive, 'positive examples.') + print('Loaded', num_negative, 'negative examples.') + + def get_filenames(self): + return [self.clips[i][2] for i in range(len(self.clips))] + + def get_video_labels(self): + return [self.videos[i][0] for i in range(len(self.videos))] + + def get_labels(self): + return [self.clips[i][0] for i in range(len(self.clips))] + + def __getitem__(self, index): + label, frame, vid_f = self.clips[index] + if self.augmentation: + frame = self.augmentation(frame) + return (label, frame, vid_f) + + def __len__(self): + return len(self.clips) diff --git a/sparse_coding_torch/pnb/classifier_model.py b/sparse_coding_torch/pnb/classifier_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8f363e6e623f719dc27e005f77238ae6d1daae67 --- /dev/null +++ b/sparse_coding_torch/pnb/classifier_model.py @@ -0,0 +1,134 @@ +from tensorflow import keras +import numpy as np +import torch +import tensorflow as tf +import cv2 +import torchvision as tv +import torch +import torch.nn as nn +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler + +class PNBClassifier(keras.layers.Layer): + def __init__(self): + super(PNBClassifier, self).__init__() + +# self.max_pool = keras.layers.MaxPooling2D(pool_size=(8, 8), strides=(2, 2)) + self.conv_1 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=(4, 4), activation='relu', padding='valid') + self.conv_2 = keras.layers.Conv2D(32, kernel_size=4, strides=2, activation='relu', padding='valid') +# self.conv_3 = keras.layers.Conv2D(12, kernel_size=4, strides=1, activation='relu', padding='valid') +# self.conv_4 = keras.layers.Conv2D(16, kernel_size=4, strides=2, activation='relu', padding='valid') + + self.flatten = keras.layers.Flatten() + +# self.dropout = keras.layers.Dropout(0.5) + +# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) + self.ff_2 = keras.layers.Dense(40, activation='relu', use_bias=True) + self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True) + self.ff_4 = keras.layers.Dense(1) + +# @tf.function + def call(self, activations): + x = tf.squeeze(activations, axis=1) +# x = self.max_pool(x) +# print(x.shape) + x = self.conv_1(x) +# print(x.shape) + x = self.conv_2(x) +# print(x.shape) +# raise Exception +# x = self.conv_3(x) +# print(x.shape) +# x = self.conv_4(x) +# raise Exception + x = self.flatten(x) +# x = self.ff_1(x) +# x = self.dropout(x) + x = self.ff_2(x) +# x = self.dropout(x) + x = self.ff_3(x) +# x = self.dropout(x) + x = self.ff_4(x) + + return x + +class PNBTemporalClassifier(keras.layers.Layer): + def __init__(self): + super(PNBTemporalClassifier, self).__init__() + self.conv_1 = keras.layers.Conv3D(24, kernel_size=(5, 200, 50), strides=(1, 1, 10), activation='relu', padding='valid') + self.conv_2 = keras.layers.Conv1D(48, kernel_size=8, strides=4, activation='relu', padding='valid') + + self.ff_1 = keras.layers.Dense(100, activation='relu', use_bias=True) + +# self.gru = keras.layers.GRU(25) + + self.flatten = keras.layers.Flatten() + + self.ff_2 = keras.layers.Dense(10, activation='relu', use_bias=True) + self.ff_3 = keras.layers.Dense(1) + +# @tf.function + def call(self, clip): + width = clip.shape[3] + height = clip.shape[2] + depth = clip.shape[1] + + x = tf.expand_dims(clip, axis=4) +# x = tf.reshape(clip, (-1, height, width, 1)) + + x = self.conv_1(x) + x = tf.squeeze(x, axis=1) + x = tf.squeeze(x, axis=1) + x = self.conv_2(x) + + x = self.flatten(x) + x = self.ff_1(x) + +# x = tf.reshape(x, (-1, 5, 100)) +# x = self.gru(x) + + x = self.ff_2(x) + x = self.ff_3(x) + + return x + +class MobileModelPNB(keras.Model): + def __init__(self, sparse_weights, classifier_model, batch_size, image_height, image_width, clip_depth, out_channels, kernel_size, kernel_depth, stride, lam, activation_lr, max_activation_iter, run_2d): + super().__init__() + self.sparse_code = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=out_channels, kernel_size=kernel_size, kernel_depth=kernel_depth, stride=stride, lam=lam, activation_lr=activation_lr, max_activation_iter=max_activation_iter, run_2d=run_2d, padding='VALID') + self.classifier = classifier_model + + self.out_channels = out_channels + self.stride = stride + self.lam = lam + self.activation_lr = activation_lr + self.max_activation_iter = max_activation_iter + self.batch_size = batch_size + self.run_2d = run_2d + + if run_2d: + weight_list = np.split(sparse_weights, 5, axis=0) + self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False) + self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False) + self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False) + self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False) + self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False) + else: + self.filters = tf.Variable(initial_value=sparse_weights, dtype='float32', trainable=False) + + @tf.function + def call(self, images): +# images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1) + images = tf.transpose(images, perm=[0, 2, 3, 1]) + images = images / 255 + + if self.run_2d: + activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)]) + activations = tf.expand_dims(activations, axis=1) + else: + activations = self.sparse_code(images, tf.stop_gradient(self.filters)) + + pred = tf.math.round(tf.math.sigmoid(self.classifier(activations))) +# pred = tf.math.reduce_sum(activations) + + return pred diff --git a/sparse_coding_torch/generate_tflite.py b/sparse_coding_torch/pnb/generate_tflite.py similarity index 95% rename from sparse_coding_torch/generate_tflite.py rename to sparse_coding_torch/pnb/generate_tflite.py index 2a8b2a245fd602208b954873d0c9a7701307993b..7edde64bca9eac02cab3aaafd19b905aa4356299 100644 --- a/sparse_coding_torch/generate_tflite.py +++ b/sparse_coding_torch/pnb/generate_tflite.py @@ -6,8 +6,8 @@ import cv2 import torchvision as tv import torch import torch.nn as nn -from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler -from sparse_coding_torch.keras_model import MobileModelPNB +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler +from sparse_coding_torch.pnb.classifier_model import MobileModelPNB import argparse if __name__ == "__main__": diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/pnb/load_data.py similarity index 54% rename from sparse_coding_torch/load_data.py rename to sparse_coding_torch/pnb/load_data.py index 6dea292a90a127da77ec9c6ce16d8c234a2cb5da..3059e04cb4ad8daef9123631e7827c0cff5a3568 100644 --- a/sparse_coding_torch/load_data.py +++ b/sparse_coding_torch/pnb/load_data.py @@ -2,72 +2,12 @@ import numpy as np import torchvision import torch from sklearn.model_selection import train_test_split -from sparse_coding_torch.video_loader import MinMaxScaler -from sparse_coding_torch.video_loader import YoloClipLoader, get_ptx_participants, PNBLoader, get_participants, NeedleLoader, ONSDLoader -from sparse_coding_torch.video_loader import VideoGrayScaler +from sparse_coding_torch.utils import MinMaxScaler +from sparse_coding_torch.pnb.video_loader import PNBLoader, get_participants, NeedleLoader +from sparse_coding_torch.utils import VideoGrayScaler from typing import Sequence, Iterator import csv from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit - -def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None): - video_path = "/shared_data/YOLO_Updated_PL_Model_Results/" - - video_to_participant = get_ptx_participants() - - transforms = torchvision.transforms.Compose( - [VideoGrayScaler(), -# MinMaxScaler(0, 255), - torchvision.transforms.Normalize((0.2592,), (0.1251,)), - ]) - augment_transforms = torchvision.transforms.Compose( - [torchvision.transforms.RandomRotation(45), - torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.CenterCrop((100, 200)) - ]) - if whole_video: - dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device) - else: - dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device) - - targets = dataset.get_labels() - - if mode == 'leave_one_out': - gss = LeaveOneGroupOut() - -# groups = [v for v in dataset.get_filenames()] - groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] - - return gss.split(np.arange(len(targets)), targets, groups), dataset - elif mode == 'all_train': - train_idx = np.arange(len(targets)) -# train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) -# train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, -# sampler=train_sampler) -# test_loader = None - - return [(train_idx, None)], dataset - elif mode == 'k_fold': - gss = StratifiedGroupKFold(n_splits=n_splits) - - groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] - - return gss.split(np.arange(len(targets)), targets, groups), dataset - else: - gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2) - - groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] - - train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0] - - train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) - train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - sampler=train_sampler) - - test_sampler = torch.utils.data.SubsetRandomSampler(test_idx) - test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - sampler=test_sampler) - - return train_loader, test_loader, dataset def get_sample_weights(train_idx, dataset): dataset = list(dataset) @@ -112,7 +52,7 @@ class SubsetWeightedRandomSampler(torch.utils.data.Sampler[int]): return len(self.indicies) def load_pnb_videos(yolo_model, batch_size, input_size, crop_size=None, mode=None, classify_mode=False, balance_classes=False, device=None, n_splits=None, sparse_model=None, frames_to_skip=1): - video_path = "/shared_data/bamc_pnb_data/full_training_data" + video_path = "/shared_data/bamc_pnb_data/revised_training_data" # video_path = '/home/dwh48@drexel.edu/pnb_videos_for_testing/train' # video_path = '/home/dwh48@drexel.edu/special_splits/train' @@ -177,46 +117,6 @@ def load_pnb_videos(yolo_model, batch_size, input_size, crop_size=None, mode=Non return gss.split(np.arange(len(targets)), targets, groups), dataset -def load_onsd_videos(batch_size, input_size, mode=None, n_splits=None): - video_path = "/shared_data/bamc_onsd_data/preliminary_onsd_data" - - transforms = torchvision.transforms.Compose( - [VideoGrayScaler(), - MinMaxScaler(0, 255), - torchvision.transforms.Resize(input_size[:2]) - ]) - augment_transforms = torchvision.transforms.Compose( - [torchvision.transforms.RandomRotation(15) - ]) - dataset = ONSDLoader(video_path, input_size[1], input_size[0], input_size[2], num_frames=5, transform=transforms, augmentation=augment_transforms) - - targets = dataset.get_labels() - - if mode == 'leave_one_out': - gss = LeaveOneGroupOut() - - groups = get_participants(dataset.get_filenames()) - - return gss.split(np.arange(len(targets)), targets, groups), dataset - elif mode == 'all_train': - train_idx = np.arange(len(targets)) - test_idx = None - - return [(train_idx, test_idx)], dataset - elif mode == 'k_fold': - gss = StratifiedGroupKFold(n_splits=n_splits, shuffle=True) - - groups = get_participants(dataset.get_filenames()) - - return gss.split(np.arange(len(targets)), targets, groups), dataset - else: -# gss = ShuffleSplit(n_splits=n_splits, test_size=0.2) - gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2) - - groups = get_participants(dataset.get_filenames()) - - return gss.split(np.arange(len(targets)), targets, groups), dataset - def load_needle_clips(batch_size, input_size): video_path = "/shared_data/bamc_pnb_data/needle_data/non_needle" diff --git a/sparse_coding_torch/pnb/train_classifier.py b/sparse_coding_torch/pnb/train_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..556bd5c991b52eda1930384f61a24454860ff420 --- /dev/null +++ b/sparse_coding_torch/pnb/train_classifier.py @@ -0,0 +1,442 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +import argparse +import os +from sparse_coding_torch.pnb.load_data import load_pnb_videos +from sparse_coding_torch.utils import SubsetWeightedRandomSampler, get_sample_weights +from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d +from sparse_coding_torch.pnb.classifier_model import PNBClassifier, PNBTemporalClassifier +import time +import numpy as np +from sklearn.metrics import f1_score, accuracy_score, confusion_matrix +import random +import pickle +import tensorflow.keras as keras +import tensorflow as tf +from sparse_coding_torch.pnb.train_sparse_model import sparse_loss +from yolov4.get_bounding_boxes import YoloModel +import torchvision +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler +import glob +import cv2 + +configproto = tf.compat.v1.ConfigProto() +configproto.gpu_options.polling_inactive_delay_msecs = 5000 +configproto.gpu_options.allow_growth = True +sess = tf.compat.v1.Session(config=configproto) +tf.compat.v1.keras.backend.set_session(sess) + +def calculate_pnb_scores(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform): + all_predictions = [] + + numerical_labels = [] + for label in labels: + if label == 'Positives': + numerical_labels.append(1.0) + else: + numerical_labels.append(0.0) + + final_list = [] + fp_ids = [] + fn_ids = [] + for v_idx, f in tqdm(enumerate(input_videos)): + vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2) + is_right = classify_nerve_is_right(yolo_model, vc) + needle_bb = get_needle_bb(yolo_model, vc) + + all_preds = [] + for j in range(vc.size(1) - 5, vc.size(1) - 25, -5): + if j-5 < 0: + break + + vc_sub = vc[:, j-5:j, :, :] + + if vc_sub.size(1) < 5: + continue + + clip = get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height) + + if not clip: + continue + + clip = clip[0] + clip = transform(clip).to(torch.float32) + clip = tf.expand_dims(clip, axis=4) + + if sparse_model is not None: + activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))])) + + pred = tf.math.round(tf.math.sigmoid(classifier_model(activations))) + else: + pred = tf.math.round(tf.math.sigmoid(classifier_model(clip))) + + all_preds.append(pred) + + if all_preds: + final_pred = np.round(np.mean(np.array(all_preds))) + else: + final_pred = 1.0 + + if final_pred != numerical_labels[v_idx]: + if final_pred == 0: + fn_ids.append(f) + else: + fp_ids.append(f) + + final_list.append(final_pred) + + return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids + +def calculate_pnb_scores_skipped_frames(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, frames_to_skip, image_width, image_height, transform): + all_predictions = [] + + numerical_labels = [] + for label in labels: + if label == 'Positives': + numerical_labels.append(1.0) + else: + numerical_labels.append(0.0) + + final_list = [] + fp_ids = [] + fn_ids = [] + for v_idx, f in tqdm(enumerate(input_videos)): + vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2) + is_right = classify_nerve_is_right(yolo_model, vc) + needle_bb = get_needle_bb(yolo_model, vc) + + all_preds = [] + + frames = [] + for k in range(vc.size(1) - 1, vc.size(1) - 5 * frames_to_skip, -frames_to_skip): + frames.append(vc[:, k, :, :]) + vc_sub = torch.stack(frames, dim=1) + + if vc_sub.size(1) < 5: + continue + + clip = get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, image_width, image_height) + + if clip: + clip = clip[0] + clip = transform(clip).to(torch.float32) + clip = tf.expand_dims(clip, axis=4) + + if sparse_model is not None: + activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))])) + + pred = tf.math.round(tf.math.sigmoid(classifier_model(activations))) + else: + pred = tf.math.round(tf.math.sigmoid(classifier_model(clip))) + else: + pred = 1.0 + + if pred != numerical_labels[v_idx]: + if pred == 0: + fn_ids.append(f) + else: + fp_ids.append(f) + + final_list.append(pred) + + return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=12, type=int) + parser.add_argument('--kernel_size', default=15, type=int) + parser.add_argument('--kernel_depth', default=5, type=int) + parser.add_argument('--num_kernels', default=64, type=int) + parser.add_argument('--stride', default=1, type=int) + parser.add_argument('--max_activation_iter', default=150, type=int) + parser.add_argument('--activation_lr', default=1e-2, type=float) + parser.add_argument('--lr', default=5e-4, type=float) + parser.add_argument('--epochs', default=40, type=int) + parser.add_argument('--lam', default=0.05, type=float) + parser.add_argument('--output_dir', default='./output', type=str) + parser.add_argument('--sparse_checkpoint', default=None, type=str) + parser.add_argument('--checkpoint', default=None, type=str) + parser.add_argument('--splits', default=None, type=str, help='k_fold or leave_one_out or all_train') + parser.add_argument('--seed', default=26, type=int) + parser.add_argument('--train', action='store_true') + parser.add_argument('--num_positives', default=100, type=int) + parser.add_argument('--n_splits', default=5, type=int) + parser.add_argument('--save_train_test_splits', action='store_true') + parser.add_argument('--run_2d', action='store_true') + parser.add_argument('--balance_classes', action='store_true') + parser.add_argument('--dataset', default='pnb', type=str) + parser.add_argument('--train_sparse', action='store_true') + parser.add_argument('--mixing_ratio', type=float, default=1.0) + parser.add_argument('--sparse_lr', type=float, default=0.003) + parser.add_argument('--crop_height', type=int, default=285) + parser.add_argument('--crop_width', type=int, default=350) + parser.add_argument('--scale_factor', type=int, default=1) + parser.add_argument('--clip_depth', type=int, default=5) + parser.add_argument('--frames_to_skip', type=int, default=1) + + args = parser.parse_args() + + crop_height = args.crop_height + crop_width = args.crop_width + + image_height = int(crop_height / args.scale_factor) + image_width = int(crop_width / args.scale_factor) + clip_depth = args.clip_depth + + batch_size = args.batch_size + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: + out_f.write(str(args)) + + yolo_model = YoloModel(args.dataset) + + all_errors = [] + + if args.run_2d: + inputs = keras.Input(shape=(image_height, image_width, clip_depth)) + else: + inputs = keras.Input(shape=(clip_depth, image_height, image_width, 1)) + +# filter_inputs = keras.Input(shape=(args.kernel_depth, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32') + +# output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs) + +# sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output) + +# recon_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels)) + +# recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs) + +# recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs) + +# if args.sparse_checkpoint: +# recon_model.set_weights(keras.models.load_model(args.sparse_checkpoint).get_weights()) + + sparse_model = None + recon_model = None + + splits, dataset = load_pnb_videos(yolo_model, args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, frames_to_skip=args.frames_to_skip) + positive_class = 'Positives' + + overall_true = [] + overall_pred = [] + fn_ids = [] + fp_ids = [] + + i_fold = 0 + for train_idx, test_idx in splits: + train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) +# train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True) + train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + sampler=train_sampler) + + if test_idx is not None: + test_sampler = torch.utils.data.SubsetRandomSampler(test_idx) + test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + sampler=test_sampler) + +# with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out: +# test_set = set([x for tup in test_loader for x in tup[2]]) +# test_videos_out.writelines(test_set) + else: + test_loader = None + + if args.checkpoint: + classifier_model = keras.models.load_model(args.checkpoint) + else: + classifier_inputs = keras.Input(shape=(clip_depth, image_height, image_width)) + classifier_outputs = PNBTemporalClassifier()(classifier_inputs) + + classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs) + + prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr) + filter_optimizer = tf.keras.optimizers.SGD(learning_rate=args.sparse_lr) + + best_so_far = float('-inf') + + criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM) + + if args.train: + for epoch in range(args.epochs): + epoch_loss = 0 + t1 = time.perf_counter() + + y_true_train = None + y_pred_train = None + + for labels, local_batch, vid_f in tqdm(train_loader): + images = local_batch.permute(0, 2, 3, 4, 1).numpy() + + torch_labels = np.zeros(len(labels)) + torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 + torch_labels = np.expand_dims(torch_labels, axis=1) + + if args.train_sparse: + with tf.GradientTape() as tape: +# activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)]) + pred = classifier_model(activations) + loss = criterion(torch_labels, pred) + + print(loss) + else: + with tf.GradientTape() as tape: + pred = classifier_model(images) + loss = criterion(torch_labels, pred) + + epoch_loss += loss * local_batch.size(0) + + if args.train_sparse: + sparse_gradients, classifier_gradients = tape.gradient(loss, [recon_model.trainable_weights, classifier_model.trainable_weights]) + + prediction_optimizer.apply_gradients(zip(classifier_gradients, classifier_model.trainable_weights)) + + filter_optimizer.apply_gradients(zip(sparse_gradients, recon_model.trainable_weights)) + + if args.run_2d: + weights = normalize_weights(recon_model.get_weights(), args.num_kernels) + else: + weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels) + recon_model.set_weights(weights) + else: + gradients = tape.gradient(loss, classifier_model.trainable_weights) + + prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights)) + + + if y_true_train is None: + y_true_train = torch_labels + y_pred_train = tf.math.round(tf.math.sigmoid(pred)) + else: + y_true_train = tf.concat((y_true_train, torch_labels), axis=0) + y_pred_train = tf.concat((y_pred_train, tf.math.round(tf.math.sigmoid(pred))), axis=0) + + t2 = time.perf_counter() + + y_true = None + y_pred = None + test_loss = 0.0 + + eval_loader = test_loader + if args.splits == 'all_train': + eval_loader = train_loader + for labels, local_batch, vid_f in tqdm(eval_loader): + images = local_batch.permute(0, 2, 3, 4, 1).numpy() + + torch_labels = np.zeros(len(labels)) + torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 + torch_labels = np.expand_dims(torch_labels, axis=1) + + pred = classifier_model(images) + loss = criterion(torch_labels, pred) + + test_loss += loss + + if y_true is None: + y_true = torch_labels + y_pred = tf.math.round(tf.math.sigmoid(pred)) + else: + y_true = tf.concat((y_true, torch_labels), axis=0) + y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0) + + t2 = time.perf_counter() + + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + y_true_train = tf.cast(y_true_train, tf.int32) + y_pred_train = tf.cast(y_pred_train, tf.int32) + + f1 = f1_score(y_true, y_pred, average='macro') + accuracy = accuracy_score(y_true, y_pred) + + train_accuracy = accuracy_score(y_true_train, y_pred_train) + + print('epoch={}, i_fold={}, time={:.2f}, train_loss={:.2f}, test_loss={:.2f}, train_acc={:.2f}, test_f1={:.2f}, test_acc={:.2f}'.format(epoch, i_fold, t2-t1, epoch_loss, test_loss, train_accuracy, f1, accuracy)) + # print(epoch_loss) + if f1 >= best_so_far: + print("found better model") + # Save model parameters + classifier_model.save(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold))) +# recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold))) + pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer_{}.pt'.format(i_fold)), 'wb+')) + best_so_far = f1 + + classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold))) +# recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold))) + + epoch_loss = 0 + + y_true = None + y_pred = None + + pred_dict = {} + gt_dict = {} + + t1 = time.perf_counter() + # test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader] + transform = torchvision.transforms.Compose( + [VideoGrayScaler(), + MinMaxScaler(0, 255), + torchvision.transforms.Resize((image_height, image_width)) + ]) + + test_videos = set() + for labels, local_batch, vid_f in test_loader: + test_videos.update(vid_f) + + test_labels = [vid_f.split('/')[-3] for vid_f in test_videos] + + if args.frames_to_skip == 1: + y_pred, y_true, fn, fp = calculate_pnb_scores(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform) + else: + y_pred, y_true, fn, fp = calculate_pnb_scores_skipped_frames(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, args.frames_to_skip, image_width, image_height, transform) + + t2 = time.perf_counter() + + print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1)) + + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + f1 = f1_score(y_true, y_pred, average='macro') + accuracy = accuracy_score(y_true, y_pred) + + fn_ids.extend(fn) + fp_ids.extend(fp) + + overall_true.extend(y_true) + overall_pred.extend(y_pred) + + print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy)) + + print(confusion_matrix(y_true, y_pred)) + + i_fold += 1 + + fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt') + with open(fp_fn_file, 'w+') as in_f: + in_f.write('FP:\n') + in_f.write(str(fp_ids) + '\n\n') + in_f.write('FN:\n') + in_f.write(str(fn_ids) + '\n\n') + + overall_true = np.array(overall_true) + overall_pred = np.array(overall_pred) + + final_f1 = f1_score(overall_true, overall_pred, average='macro') + final_acc = accuracy_score(overall_true, overall_pred) + final_conf = confusion_matrix(overall_true, overall_pred) + + print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1)) + print(final_conf) + diff --git a/sparse_coding_torch/train_classifier_needle.py b/sparse_coding_torch/pnb/train_classifier_needle.py similarity index 98% rename from sparse_coding_torch/train_classifier_needle.py rename to sparse_coding_torch/pnb/train_classifier_needle.py index 79fa2494dcabd3049d7ecf67105bffaeef177ee2..6f20678dcaa3fcc34f436705fffa9135507a8815 100644 --- a/sparse_coding_torch/train_classifier_needle.py +++ b/sparse_coding_torch/pnb/train_classifier_needle.py @@ -4,8 +4,10 @@ import torch.nn.functional as F from tqdm import tqdm import argparse import os -from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler, get_sample_weights -from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d +from sparse_coding_torch.pnb.load_data import load_pnb_videos +from sparse_coding_torch.utils import SubsetWeightedRandomSampler, get_sample_weights +from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d +from sparse_coding_torch.pnb.classifier_model import PNBClassifier import time import numpy as np from sklearn.metrics import f1_score, accuracy_score, confusion_matrix @@ -14,7 +16,7 @@ import pickle import tensorflow.keras as keras import tensorflow as tf from sparse_coding_torch.train_sparse_model import sparse_loss -from sparse_coding_torch.utils import calculate_pnb_scores +from sparse_coding_torch.pnb.train_classifier import calculate_pnb_scores from yolov4.get_bounding_boxes import YoloModel configproto = tf.compat.v1.ConfigProto() diff --git a/sparse_coding_torch/pnb/train_sparse_model.py b/sparse_coding_torch/pnb/train_sparse_model.py new file mode 100644 index 0000000000000000000000000000000000000000..698cbdc80a43a4887f8097f4ad1be3a8fdd35c9e --- /dev/null +++ b/sparse_coding_torch/pnb/train_sparse_model.py @@ -0,0 +1,172 @@ +import time +import numpy as np +import torch +from matplotlib import pyplot as plt +from matplotlib import cm +from matplotlib.animation import FuncAnimation +from tqdm import tqdm +import argparse +import os +from sparse_coding_torch.pnb.load_data import load_pnb_videos, load_needle_clips +import tensorflow.keras as keras +import tensorflow as tf +from sparse_coding_torch.sparse_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse +import random + +def sparse_loss(images, recon, activations, batch_size, lam, stride): + loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2)) + loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1)) + return loss + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser.add_argument('--kernel_size', default=15, type=int) + parser.add_argument('--kernel_depth', default=5, type=int) + parser.add_argument('--num_kernels', default=32, type=int) + parser.add_argument('--stride', default=1, type=int) + parser.add_argument('--max_activation_iter', default=300, type=int) + parser.add_argument('--activation_lr', default=1e-2, type=float) + parser.add_argument('--lr', default=0.003, type=float) + parser.add_argument('--epochs', default=150, type=int) + parser.add_argument('--lam', default=0.05, type=float) + parser.add_argument('--output_dir', default='./output', type=str) + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--run_2d', action='store_true') + parser.add_argument('--save_filters', action='store_true') + parser.add_argument('--optimizer', default='sgd', type=str) + parser.add_argument('--dataset', default='onsd', type=str) + parser.add_argument('--crop_height', type=int, default=400) + parser.add_argument('--crop_width', type=int, default=400) + parser.add_argument('--scale_factor', type=int, default=1) + parser.add_argument('--clip_depth', type=int, default=5) + parser.add_argument('--frames_to_skip', type=int, default=1) + + args = parser.parse_args() + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + crop_height = args.crop_height + crop_width = args.crop_width + + image_height = int(crop_height / args.scale_factor) + image_width = int(crop_width / args.scale_factor) + clip_depth = args.clip_depth + + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: + out_f.write(str(args)) + + if args.dataset == 'pnb': + train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=False, balance_classes=False, mode='all_train', frames_to_skip=args.frames_to_skip) + elif args.dataset == 'needle': + train_loader, test_loader, dataset = load_needle_clips(args.batch_size, input_size=(image_height, image_width, clip_depth)) + else: + raise Exception('Invalid dataset') + + train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) + train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, + sampler=train_sampler) + + print('Loaded', len(train_loader), 'train examples') + + example_data = next(iter(train_loader)) + + if args.run_2d: + inputs = keras.Input(shape=(image_height, image_width, 5)) + else: + inputs = keras.Input(shape=(5, image_height, image_width, 1)) + + filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32') + + output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs) + + sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output) + + recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels)) + + recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs) + + recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs) + + if args.save_filters: + if args.run_2d: + filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0)) + else: + filters = plot_filters(recon_model.get_weights()[0]) + filters.save(os.path.join(args.output_dir, 'filters_start.mp4')) + + learning_rate = args.lr + if args.optimizer == 'sgd': + filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + else: + filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) + + loss_log = [] + best_so_far = float('inf') + + for epoch in range(args.epochs): + epoch_loss = 0 + running_loss = 0.0 + epoch_start = time.perf_counter() + + num_iters = 0 + + for labels, local_batch, vid_f in tqdm(train_loader): + if local_batch.size(0) != args.batch_size: + continue + if args.run_2d: + images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy() + else: + images = local_batch.permute(0, 2, 3, 4, 1).numpy() + + activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) + + with tf.GradientTape() as tape: + recon = recon_model(activations) + loss = sparse_loss(images, recon, activations, args.batch_size, args.lam, args.stride) + + epoch_loss += loss * local_batch.size(0) + running_loss += loss * local_batch.size(0) + + gradients = tape.gradient(loss, recon_model.trainable_weights) + + filter_optimizer.apply_gradients(zip(gradients, recon_model.trainable_weights)) + + if args.run_2d: + weights = normalize_weights(recon_model.get_weights(), args.num_kernels) + else: + weights = normalize_weights_3d(recon_model.get_weights(), args.num_kernels) + recon_model.set_weights(weights) + + num_iters += 1 + + epoch_end = time.perf_counter() + epoch_loss /= len(train_loader.sampler) + + if args.save_filters and epoch % 2 == 0: + if args.run_2d: + filters = plot_filters(tf.stack(recon_model.get_weights(), axis=0)) + else: + filters = plot_filters(recon_model.get_weights()[0]) + filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4')) + + if epoch_loss < best_so_far: + print("found better model") + # Save model parameters + recon_model.save(os.path.join(output_dir, "best_sparse.pt")) + best_so_far = epoch_loss + + loss_log.append(epoch_loss) + print('epoch={}, epoch_loss={:.2f}, time={:.2f}'.format(epoch, epoch_loss, epoch_end - epoch_start)) + + plt.plot(loss_log) + + plt.savefig(os.path.join(output_dir, 'loss_graph.png')) diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/pnb/video_loader.py similarity index 52% rename from sparse_coding_torch/video_loader.py rename to sparse_coding_torch/pnb/video_loader.py index 6621f9a572094a43aa62fef0b74dfd8d0cc423f5..9adc4a1539776aa1db58b22ddf060b0ac85aa191 100644 --- a/sparse_coding_torch/video_loader.py +++ b/sparse_coding_torch/pnb/video_loader.py @@ -24,44 +24,8 @@ import random import cv2 from yolov4.get_bounding_boxes import YoloModel -def get_ptx_participants(): - video_to_participant = {} - with open('/shared_data/bamc_data/bamc_video_info.csv', 'r') as csv_in: - reader = csv.DictReader(csv_in) - for row in reader: - key = row['Filename'].split('.')[0].lower().replace('_clean', '') - if key == '37 (mislabeled as 38)': - key = '37' - video_to_participant[key] = row['Participant_id'] - - return video_to_participant - def get_participants(filenames): return [f.split('/')[-2] for f in filenames] - -class MinMaxScaler(object): - """ - Transforms each channel to the range [0, 1]. - """ - def __init__(self, min_val=0, max_val=254): - self.min_val = min_val - self.max_val = max_val - - def __call__(self, tensor): - return (tensor - self.min_val) / (self.max_val - self.min_val) - -class VideoGrayScaler(nn.Module): - - def __init__(self): - super().__init__() - self.grayscale = tv.transforms.Grayscale(num_output_channels=1) - - def forward(self, video): - # shape = channels, time, width, height - video = self.grayscale(video.swapaxes(-4, -3).swapaxes(-2, -1)) - video = video.swapaxes(-4, -3).swapaxes(-2, -1) - # print(video.shape) - return video def load_pnb_region_labels(file_path): all_regions = {} @@ -76,17 +40,22 @@ def load_pnb_region_labels(file_path): return all_regions -def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height): +def get_yolo_regions(yolo_model, clip, is_right, needle_bb, crop_width, crop_height): orig_height = clip.size(2) orig_width = clip.size(3) - bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy()) + bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy()) bounding_boxes = bounding_boxes.squeeze(0) classes = classes.squeeze(0) + scores = scores.squeeze(0) + + for bb, class_pred in zip(bounding_boxes, classes): + if class_pred == 2: + needle_bb = bb rotate_box = True all_clips = [] - for bb, class_pred in zip(bounding_boxes, classes): + for bb, class_pred, score in zip(bounding_boxes, classes, scores): if class_pred != 0: continue center_x = round((bb[3] + bb[1]) / 2 * orig_width) @@ -97,7 +66,13 @@ def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height) lower_x = round((bb[1] * orig_width)) upper_x = round((bb[3] * orig_width)) - lower_y = upper_y - crop_height + if is_right: + angle = calculate_angle(needle_bb, upper_x, center_y, orig_height, orig_width) + else: + angle = calculate_angle(needle_bb, lower_x, center_y, orig_height, orig_width) + + lower_y = center_y - (crop_height // 2) + upper_y = center_y + (crop_height // 2) if is_right: lower_x = center_x - crop_width @@ -106,13 +81,22 @@ def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height) lower_x = center_x upper_x = center_x + crop_width + if lower_x < 0: + lower_x = 0 + if upper_x < 0: + upper_x = 0 + if lower_y < 0: + lower_y = 0 + if upper_y < 0: + upper_y = 0 + if rotate_box: # cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) if is_right: - clip = tv.transforms.functional.rotate(clip, angle=angle, center=[upper_x, upper_y]) + clip = tv.transforms.functional.rotate(clip, angle=angle, center=[upper_x, center_y]) else: # cv2.imwrite('test_1.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) - clip = tv.transforms.functional.rotate(clip, angle=-angle, center=[lower_x, upper_y]) + clip = tv.transforms.functional.rotate(clip, angle=-angle, center=[lower_x, center_y]) # cv2.imwrite('test_2.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x] @@ -120,11 +104,8 @@ def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height) # if orig_width - center_x >= center_x: # if not is_right: # print(angle) -# cv2.imwrite('test_2.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) -# cv2.imwrite('test_3.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) -# raise Exception - -# print(trimmed_clip.size()) +# cv2.imwrite('test_2{}.png'.format(lower_y), clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) +# cv2.imwrite('test_3{}.png'.format(lower_y), trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) if not is_right: trimmed_clip = tv.transforms.functional.hflip(trimmed_clip) @@ -147,7 +128,7 @@ def classify_nerve_is_right(yolo_model, video): for frame in range(0, video.size(1), round(video.size(1) / 10)): frame = video[:, frame, :, :] - bounding_boxes, classes = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy()) + bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy()) bounding_boxes = bounding_boxes.squeeze(0) classes = classes.squeeze(0) @@ -181,25 +162,49 @@ def classify_nerve_is_right(yolo_model, video): return final_pred == 1 -def calculate_angle(yolo_model, video): +def calculate_angle(needle_bb, vessel_x, vessel_y, orig_height, orig_width): + needle_x = needle_bb[1] * orig_width + needle_y = needle_bb[0] * orig_height + + return np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x)))) + +def get_needle_bb(yolo_model, video): + orig_height = video.size(2) + orig_width = video.size(3) + + for frame in range(0, video.size(1), 1): + frame = video[:, frame, :, :] + + bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy()) + bounding_boxes = bounding_boxes.squeeze(0) + classes = classes.squeeze(0) + + for bb, class_pred in zip(bounding_boxes, classes): + if class_pred == 2: + return bb + + return None + +def calculate_angle_video(yolo_model, video): orig_height = video.size(2) orig_width = video.size(3) all_preds = [] if video.size(1) < 10: - return 45 + return 30 - for frame in range(0, video.size(1), round(video.size(1) / 10)): + for frame in range(0, video.size(1), video.size(1) // 10): frame = video[:, frame, :, :] - bounding_boxes, classes = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy()) + + bounding_boxes, classes, scores = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy()) bounding_boxes = bounding_boxes.squeeze(0) classes = classes.squeeze(0) - + vessel_x = 0 vessel_y = 0 needle_x = 0 needle_y = 0 - + for bb, class_pred in zip(bounding_boxes, classes): if class_pred == 0 and vessel_x == 0: vessel_x = (bb[3] + bb[1]) / 2 * orig_width @@ -207,14 +212,14 @@ def calculate_angle(yolo_model, video): elif class_pred == 2 and needle_x == 0: needle_x = bb[1] * orig_width needle_y = bb[0] * orig_height - + if needle_x != 0 and vessel_x != 0: break - - if vessel_x == 0 or needle_x == 0: - return 45 - else: - return np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x)))) + + if vessel_x > 0 and needle_x > 0: + all_preds.append(np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x))))) + + return np.mean(np.array(all_preds)) class PNBLoader(Dataset): @@ -231,7 +236,7 @@ class PNBLoader(Dataset): clip_cache_file = 'clip_cache_pnb_{}_{}_sparse.pt'.format(clip_width, clip_height) clip_cache_final_file = 'clip_cache_pnb_{}_{}_final_sparse.pt'.format(clip_width, clip_height) - region_labels = load_pnb_region_labels('/shared_data/bamc_pnb_data/full_training_data/sme_region_labels.csv') + region_labels = load_pnb_region_labels('sme_region_labels.csv') self.videos = [] for label in self.labels: @@ -252,7 +257,7 @@ class PNBLoader(Dataset): for label, path, _ in tqdm(self.videos): vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2) is_right = classify_nerve_is_right(yolo_model, vc) - angle = calculate_angle(yolo_model, vc) + needle_bb = get_needle_bb(yolo_model, vc) if classify_mode: # person_idx = path.split('/')[-2] @@ -277,7 +282,7 @@ class PNBLoader(Dataset): if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -297,7 +302,7 @@ class PNBLoader(Dataset): if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -316,7 +321,7 @@ class PNBLoader(Dataset): if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -332,7 +337,7 @@ class PNBLoader(Dataset): vc_sub = torch.stack(frames, dim=1) if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -345,7 +350,7 @@ class PNBLoader(Dataset): vc_sub = torch.stack(frames, dim=1) if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, needle_bb, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -419,85 +424,6 @@ class PNBLoader(Dataset): def __len__(self): return len(self.clips) -class ONSDLoader(Dataset): - - def __init__(self, video_path, clip_width, clip_height, clip_depth, num_frames=5, transform=None, augmentation=None): - self.transform = transform - self.augmentation = augmentation - self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))] - - clip_cache_file = 'clip_cache_onsd_{}_{}_sparse.pt'.format(clip_width, clip_height) - clip_cache_final_file = 'clip_cache_onsd_{}_{}_final_sparse.pt'.format(clip_width, clip_height) - - self.videos = [] - for label in self.labels: -# self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))]) - self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))]) - -# self.videos = list(filter(lambda x: x[1].split('/')[-2] in ['67', '94', '134', '193', '222', '240'], self.videos)) -# self.videos = list(filter(lambda x: x[1].split('/')[-2] in ['67'], self.videos)) - - self.clips = [] - self.final_clips = {} - - if exists(clip_cache_file): - self.clips = torch.load(open(clip_cache_file, 'rb')) - self.final_clips = torch.load(open(clip_cache_final_file, 'rb')) - else: - vid_idx = 0 - for label, path, _ in tqdm(self.videos): - vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2) - - for j in range(0, vc.size(1) - clip_depth, 5): - frames = [] - for k in range(j, j + clip_depth, 1): - frames.append(vc[:, k, :, :]) - vc_sub = torch.stack(frames, dim=1) - - if vc_sub.size(1) != clip_depth: - continue - - if self.transform: - vc_sub = self.transform(vc_sub) - - self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2])) - - self.final_clips[self.videos[vid_idx][2]] = self.clips[-1] - vid_idx += 1 - - torch.save(self.clips, open(clip_cache_file, 'wb+')) - torch.save(self.final_clips, open(clip_cache_final_file, 'wb+')) - - num_positive = len([clip[0] for clip in self.clips if clip[0] == 'Positives']) - num_negative = len([clip[0] for clip in self.clips if clip[0] == 'Negatives']) - - random.shuffle(self.clips) - - print('Loaded', num_positive, 'positive examples.') - print('Loaded', num_negative, 'negative examples.') - - def get_filenames(self): - return [self.clips[i][2] for i in range(len(self.clips))] - - def get_video_labels(self): - return [self.videos[i][0] for i in range(len(self.videos))] - - def get_labels(self): - return [self.clips[i][0] for i in range(len(self.clips))] - - def get_final_clips(self): - return self.final_clips - - def __getitem__(self, index): - label, clip, vid_f = self.clips[index] - if self.augmentation: - clip = clip.swapaxes(0, 1) - clip = self.augmentation(clip) - clip = clip.swapaxes(0, 1) - return (label, clip, vid_f) - - def __len__(self): - return len(self.clips) class NeedleLoader(Dataset): def __init__(self, video_path, transform=None, augmentation=None): @@ -534,205 +460,4 @@ class NeedleLoader(Dataset): def __len__(self): return len(self.clips) - -class YoloClipLoader(Dataset): - - def __init__(self, yolo_output_path, num_frames=5, frames_between_clips=None, - transform=None, augment_transform=None, num_clips=1, num_positives=1, positive_videos=None, sparse_model=None, device=None): - if (num_frames % 2) == 0: - raise ValueError("Num Frames must be an odd number, so we can extract a clip centered on each detected region") - - clip_cache_file = 'clip_cache.pt' - - self.num_clips = num_clips - - self.num_frames = num_frames - if frames_between_clips is None: - self.frames_between_clips = num_frames - else: - self.frames_between_clips = frames_between_clips - - self.transform = transform - self.augment_transform = augment_transform - - self.labels = [name for name in listdir(yolo_output_path) if isdir(join(yolo_output_path, name))] - self.clips = [] - if exists(clip_cache_file): - self.clips = torch.load(open(clip_cache_file, 'rb')) - else: - for label in self.labels: - print("Processing videos in category: {}".format(label)) - videos = list(listdir(join(yolo_output_path, label))) - for vi in tqdm(range(len(videos))): - video = videos[vi] - counter = 0 - all_trimmed = [] - with open(abspath(join(yolo_output_path, label, video, 'result.json'))) as fin: - results = json.load(fin) - max_frame = len(results) - - for i in range((num_frames-1)//2, max_frame - (num_frames-1)//2 - 1, self.frames_between_clips): - # for frame in results: - frame = results[i] - # print('loading frame:', i, frame['frame_id']) - frame_start = int(frame['frame_id']) - self.num_frames//2 - frames = [abspath(join(yolo_output_path, label, video, 'frame{}.png'.format(frame_start+fid))) - for fid in range(num_frames)] - # print(frames) - frames = torch.stack([ToTensor()(Image.open(f).convert("RGB")) for f in frames]).swapaxes(0, 1) - - for region in frame['objects']: - # print(region) - if region['name'] != "Pleural_Line": - continue - - center_x = region['relative_coordinates']["center_x"] * 1920 - center_y = region['relative_coordinates']['center_y'] * 1080 - - # width = region['relative_coordinates']['width'] * 1920 - # height = region['relative_coordinates']['height'] * 1080 - width=200 - height=100 - - lower_y = round(center_y - height / 2) - upper_y = round(center_y + height / 2) - lower_x = round(center_x - width / 2) - upper_x = round(center_x + width / 2) - - final_clip = frames[:, :, lower_y:upper_y, lower_x:upper_x] - - if self.transform: - final_clip = self.transform(final_clip) - - if sparse_model: - with torch.no_grad(): - final_clip = final_clip.unsqueeze(0).to(device) - final_clip = sparse_model(final_clip) - final_clip = final_clip.squeeze(0).detach().cpu() - - self.clips.append((label, final_clip, video)) - - torch.save(self.clips, open(clip_cache_file, 'wb+')) - - -# random.shuffle(self.clips) - -# video_to_clips = {} - if positive_videos: - vids_to_keep = json.load(open(positive_videos)) - - self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2] in vids_to_keep or clip_tup[0] == 'Sliding'] - else: - video_to_labels = {} - - for lbl, clip, video in self.clips: - video = video.lower().replace('_clean', '') - if video not in video_to_labels: - # video_to_clips[video] = [] - video_to_labels[video] = [] - - # video_to_clips[video].append(clip) - video_to_labels[video].append(lbl) - - video_to_participants = get_ptx_participants() - participants_to_video = {} - for k, v in video_to_participants.items(): - if video_to_labels[k][0] == 'Sliding': - continue - if not v in participants_to_video: - participants_to_video[v] = [] - - participants_to_video[v].append(k) - - participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True)) - - num_to_remove = len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) - num_positives - vids_to_remove = set() - while num_to_remove > 0: - vids_to_remove.add(participants_to_video[list(participants_to_video.keys())[0]].pop()) - participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True)) - num_to_remove -= 1 - - self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2].lower().replace('_clean', '') not in vids_to_remove] - - video_to_clips = {} - video_to_labels = {} - - for lbl, clip, video in self.clips: - if video not in video_to_clips: - video_to_clips[video] = [] - video_to_labels[video] = [] - - video_to_clips[video].append(clip) - video_to_labels[video].append(lbl) - - print([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) - - print('Num positive:', len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding'])) - print('Num negative:', len([k for k,v in video_to_labels.items() if v[0] == 'Sliding'])) - - self.videos = None - self.max_video_clips = 0 - if num_clips > 1: - self.videos = [] - - for video in video_to_clips.keys(): - clip_list = video_to_clips[video] - lbl_list = video_to_labels[video] - - for i in range(0, len(clip_list) - num_clips, 1): - video_stack = torch.stack(clip_list[i:i+num_clips]) - - self.videos.append((max(set(lbl_list[i:i+num_clips]), key=lbl_list[i:i+num_clips].count), video_stack, video)) - - self.clips = None - - - def get_labels(self): - if self.num_clips > 1: - return [self.videos[i][0] for i in range(len(self.videos))] - else: - return [self.clips[i][0] for i in range(len(self.clips))] - - def get_filenames(self): - if self.num_clips > 1: - return [self.videos[i][2] for i in range(len(self.videos))] - else: - return [self.clips[i][2] for i in range(len(self.clips))] - - def __getitem__(self, index): - if self.num_clips > 1: - label = self.videos[index][0] - video = self.videos[index][1] - filename = self.videos[index][2] - - video = video.squeeze(2) - video = video.permute(1, 0, 2, 3) - - if self.augment_transform: - video = self.augment_transform(video) - - video = video.unsqueeze(2) - video = video.permute(1, 0, 2, 3, 4) -# video = video.permute(4, 1, 2, 3, 0) -# video = torch.nn.functional.pad(video, (0), 'constant', 0) -# video = video.permute(4, 1, 2, 3, 0) - - orig_len = video.size(0) - -# if orig_len < self.max_video_clips: -# video = torch.cat([video, torch.zeros(self.max_video_clips - len(video), video.size(1), video.size(2), video.size(3), video.size(4))]) - - return label, video, filename, orig_len - else: - label = self.clips[index][0] - video = self.clips[index][1] - filename = self.clips[index][2] - - if self.augment_transform: - video = self.augment_transform(video) - - return label, video, filename - - def __len__(self): - return len(self.clips) + \ No newline at end of file diff --git a/sparse_coding_torch/ptx/classifier_model.py b/sparse_coding_torch/ptx/classifier_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fefc2fcc6c63e72d14517cec5d0fe3cefefa8769 --- /dev/null +++ b/sparse_coding_torch/ptx/classifier_model.py @@ -0,0 +1,120 @@ +from tensorflow import keras +import numpy as np +import torch +import tensorflow as tf +import cv2 +import torchvision as tv +import torch +import torch.nn as nn +from sparse_coding_torch.ptx.video_loader import VideoGrayScaler, MinMaxScaler +from sparse_coding_torch.sparse_model import SparseCode + +class PTXClassifier(keras.layers.Layer): + def __init__(self): + super(PTXClassifier, self).__init__() + + self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4) + self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid') +# self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid') + + self.flatten = keras.layers.Flatten() + + self.dropout = keras.layers.Dropout(0.5) + +# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) +# self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True) +# self.ff_2 = keras.layers.Dense(20, activation='relu', use_bias=True) + self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True) + self.ff_4 = keras.layers.Dense(1) + +# @tf.function + def call(self, activations): + activations = tf.squeeze(activations, axis=1) + x = self.max_pool(activations) + x = self.conv_1(activations) +# x = self.conv_2(x) + x = self.flatten(x) +# x = self.ff_1(x) +# x = self.dropout(x) +# x = self.ff_2(x) +# x = self.dropout(x) + x = self.ff_3(x) + x = self.dropout(x) + x = self.ff_4(x) + + return x + +class BaselinePTX(keras.layers.Layer): + def __init__(self): + super(BaselinePTX, self).__init__() + + self.conv_1 = keras.layers.Conv3D(64, kernel_size=(5, 8, 8), strides=(1, 4, 4), activation='relu', padding='valid') + self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid') + + self.flatten = keras.layers.Flatten() + + self.dropout = keras.layers.Dropout(0.5) + +# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) +# self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True) + self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True) + self.ff_4 = keras.layers.Dense(1) + +# @tf.function + def call(self, images): + x = self.conv_1(images) + x = tf.squeeze(x, axis=1) + x = self.conv_2(x) + x = self.flatten(x) +# x = self.ff_1(x) +# x = self.dropout(x) +# x = self.ff_2(x) +# x = self.dropout(x) + x = self.ff_3(x) + x = self.dropout(x) + x = self.ff_4(x) + + return x + +class MobileModelPTX(keras.Model): + def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): + super().__init__() + self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d) + self.classifier = Classifier() + + self.out_channels = out_channels + self.in_channels = in_channels + self.stride = stride + self.lam = lam + self.activation_lr = activation_lr + self.max_activation_iter = max_activation_iter + self.batch_size = batch_size + self.run_2d = run_2d + + pytorch_weights = load_pytorch_weights(sparse_checkpoint) + + if run_2d: + weight_list = np.split(pytorch_weights, 5, axis=0) + self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False) + self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False) + self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False) + self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False) + self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False) + else: + self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False) + + @tf.function + def call(self, images): + images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1) + images = tf.transpose(images, perm=[0, 2, 3, 1]) + images = images / 255 + images = (images - 0.2592) / 0.1251 + + if self.run_2d: + activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)]) + else: + activations = self.sparse_code(images, tf.stop_gradient(self.filters)) + + pred = self.classifier(activations) + + return pred \ No newline at end of file diff --git a/sparse_coding_torch/convert_pytorch_to_keras.py b/sparse_coding_torch/ptx/convert_pytorch_to_keras.py similarity index 85% rename from sparse_coding_torch/convert_pytorch_to_keras.py rename to sparse_coding_torch/ptx/convert_pytorch_to_keras.py index 4f3e1abbd830c01acca62f924e118c4864ab7478..3e28b9d5e2a63a8f70d0f98736a34f6837eb6cac 100644 --- a/sparse_coding_torch/convert_pytorch_to_keras.py +++ b/sparse_coding_torch/ptx/convert_pytorch_to_keras.py @@ -1,7 +1,8 @@ import argparse from tensorflow import keras import tensorflow as tf -from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, load_pytorch_weights +from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, load_pytorch_weights +from sparse_coding_torch.ptx.classifier_model import PTXClassifier import torch import os @@ -13,7 +14,6 @@ if __name__ == "__main__": parser.add_argument('--kernel_depth', default=5, type=int) parser.add_argument('--num_kernels', default=64, type=int) parser.add_argument('--stride', default=2, type=int) - parser.add_argument('--dataset', default='ptx', type=str) parser.add_argument('--input_image_height', default=100, type=int) parser.add_argument('--input_image_width', default=200, type=int) parser.add_argument('--output_dir', default='./converted_checkpoints', type=str) @@ -30,14 +30,8 @@ if __name__ == "__main__": if args.classifier_checkpoint: classifier_inputs = keras.Input(shape=(1, args.input_image_height // args.stride, args.input_image_width // args.stride, args.num_kernels)) - if args.dataset == 'pnb': - classifier_outputs = PNBClassifier()(classifier_inputs) - classifier_name = 'pnb_classifier' - elif args.dataset == 'ptx': - classifier_outputs = PTXClassifier()(classifier_inputs) - classifier_name = 'ptx_classifier' - else: - raise Exception('No classifier exists for that dataset') + classifier_outputs = PTXClassifier()(classifier_inputs) + classifier_name = 'ptx_classifier' classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs) diff --git a/sparse_coding_torch/ptx/generate_tflite.py b/sparse_coding_torch/ptx/generate_tflite.py new file mode 100644 index 0000000000000000000000000000000000000000..5240156eb831b845761a1a2815380dbe8ca3f215 --- /dev/null +++ b/sparse_coding_torch/ptx/generate_tflite.py @@ -0,0 +1,63 @@ +from tensorflow import keras +import numpy as np +import torch +import tensorflow as tf +import cv2 +import torchvision as tv +import torch +import torch.nn as nn +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler +from sparse_coding_torch.ptx.classifier_model import MobileModelPTX +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', default='/shared_data/bamc_pnb_data/revised_training_data', type=str) + parser.add_argument('--kernel_size', default=15, type=int) + parser.add_argument('--kernel_depth', default=5, type=int) + parser.add_argument('--num_kernels', default=32, type=int) + parser.add_argument('--stride', default=4, type=int) + parser.add_argument('--max_activation_iter', default=150, type=int) + parser.add_argument('--activation_lr', default=1e-2, type=float) + parser.add_argument('--lam', default=0.05, type=float) + parser.add_argument('--sparse_checkpoint', default='sparse_coding_torch/output/sparse_pnb_32_long_train/sparse_conv3d_model-best.pt/', type=str) + parser.add_argument('--checkpoint', default='sparse_coding_torch/classifier_outputs/32_filters_no_aug_3/best_classifier.pt/', type=str) + parser.add_argument('--run_2d', action='store_true') + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--image_height', type=int, default=285) + parser.add_argument('--image_width', type=int, default=400) + parser.add_argument('--clip_depth', type=int, default=5) + + args = parser.parse_args() + #print(args.accumulate(args.integers)) + batch_size = args.batch_size + + image_height = args.image_height + image_width = args.image_width + clip_depth = args.clip_depth + + recon_model = keras.models.load_model(args.sparse_checkpoint) + + classifier_model = keras.models.load_model(args.checkpoint) + + inputs = keras.Input(shape=(5, image_height, image_width)) + + outputs = MobileModelPTX(sparse_weights=recon_model.weights[0], classifier_model=classifier_model, batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs) + + model = keras.Model(inputs=inputs, outputs=outputs) + + input_name = model.input_names[0] + index = model.input_names.index(input_name) + model.inputs[index].set_shape([1, 5, image_height, image_width]) + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_types = [tf.float16] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + + tflite_model = converter.convert() + + print('Converted') + + with open("./sparse_coding_torch/mobile_output/ptx.tflite", "wb") as f: + f.write(tflite_model) diff --git a/sparse_coding_torch/ptx/load_data.py b/sparse_coding_torch/ptx/load_data.py new file mode 100644 index 0000000000000000000000000000000000000000..e307a281f46f7d9a54cdf3649ee710451d1aaba8 --- /dev/null +++ b/sparse_coding_torch/ptx/load_data.py @@ -0,0 +1,69 @@ +import numpy as np +import torchvision +import torch +from sklearn.model_selection import train_test_split +from sparse_coding_torch.utils import MinMaxScaler, VideoGrayScaler +from sparse_coding_torch.ptx.video_loader import YoloClipLoader, get_ptx_participants +import csv +from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold, ShuffleSplit + +def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None): + video_path = "/shared_data/YOLO_Updated_PL_Model_Results/" + + video_to_participant = get_ptx_participants() + + transforms = torchvision.transforms.Compose( + [VideoGrayScaler(), +# MinMaxScaler(0, 255), + torchvision.transforms.Normalize((0.2592,), (0.1251,)), + ]) + augment_transforms = torchvision.transforms.Compose( + [torchvision.transforms.RandomRotation(45), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.CenterCrop((100, 200)) + ]) + if whole_video: + dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device) + else: + dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device) + + targets = dataset.get_labels() + + if mode == 'leave_one_out': + gss = LeaveOneGroupOut() + +# groups = [v for v in dataset.get_filenames()] + groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] + + return gss.split(np.arange(len(targets)), targets, groups), dataset + elif mode == 'all_train': + train_idx = np.arange(len(targets)) +# train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) +# train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, +# sampler=train_sampler) +# test_loader = None + + return [(train_idx, None)], dataset + elif mode == 'k_fold': + gss = StratifiedGroupKFold(n_splits=n_splits) + + groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] + + return gss.split(np.arange(len(targets)), targets, groups), dataset + else: + gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2) + + groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()] + + train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0] + + train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) + train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + sampler=train_sampler) + + test_sampler = torch.utils.data.SubsetRandomSampler(test_idx) + test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + sampler=test_sampler) + + return train_loader, test_loader, dataset + diff --git a/sparse_coding_torch/train_classifier.py b/sparse_coding_torch/ptx/train_classifier.py similarity index 55% rename from sparse_coding_torch/train_classifier.py rename to sparse_coding_torch/ptx/train_classifier.py index ca71ca2f59f6aec2f5b0d4eebaf60c4b049be13c..7ed890767f4a77cb505932e4f28fb4d4e8386d38 100644 --- a/sparse_coding_torch/train_classifier.py +++ b/sparse_coding_torch/ptx/train_classifier.py @@ -4,8 +4,9 @@ import torch.nn.functional as F from tqdm import tqdm import argparse import os -from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler, get_sample_weights -from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d, PNBTemporalClassifier +from sparse_coding_torch.ptx.load_data import load_yolo_clips +from sparse_coding_torch.sparse_model import SparseCode, ReconSparse, normalize_weights, normalize_weights_3d +from sparse_coding_torch.ptx.classifier_model import PTXClassifier import time import numpy as np from sklearn.metrics import f1_score, accuracy_score, confusion_matrix @@ -13,11 +14,9 @@ import random import pickle import tensorflow.keras as keras import tensorflow as tf -from sparse_coding_torch.train_sparse_model import sparse_loss -from sparse_coding_torch.utils import calculate_pnb_scores, calculate_pnb_scores_skipped_frames +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler from yolov4.get_bounding_boxes import YoloModel import torchvision -from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler import glob import cv2 @@ -27,6 +26,95 @@ configproto.gpu_options.allow_growth = True sess = tf.compat.v1.Session(config=configproto) tf.compat.v1.keras.backend.set_session(sess) +def calculate_ptx_scores(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform): + all_predictions = [] + + numerical_labels = [] + for label in labels: + if label == 'No_Sliding': + numerical_labels.append(1.0) + else: + numerical_labels.append(0.0) + + final_list = [] + fp_ids = [] + fn_ids = [] + for v_idx, f in tqdm(enumerate(input_videos)): + clipstride = 15 + + vc = VideoClips([f], + clip_length_in_frames=5, + frame_rate=20, + frames_between_clips=clipstride) + + clip_predictions = [] + i = 0 + cliplist = [] + countclips = 0 + for i in range(vc.num_clips()): + clip, _, _, _ = vc.get_clip(i) + clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy() + + bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)) + bounding_boxes = bounding_boxes.squeeze(0) + if bounding_boxes.size == 0: + continue + #widths = [] + countclips = countclips + len(bounding_boxes) + + widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))] + + #for i in range(len(bounding_boxes)): + # widths.append(bounding_boxes[i][3] - bounding_boxes[i][1]) + + ind = np.argmax(np.array(widths)) + #for bb in bounding_boxes: + bb = bounding_boxes[ind] + center_x = (bb[3] + bb[1]) / 2 * 1920 + center_y = (bb[2] + bb[0]) / 2 * 1080 + + width=400 + height=400 + + lower_y = round(center_y - height / 2) + upper_y = round(center_y + height / 2) + lower_x = round(center_x - width / 2) + upper_x = round(center_x + width / 2) + + trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x] + + trimmed_clip = torch.tensor(trimmed_clip).to(torch.float) + + trimmed_clip = transform(trimmed_clip) + trimmed_clip.pin_memory() + cliplist.append(trimmed_clip) + + if len(cliplist) > 0: + with torch.no_grad(): + trimmed_clip = torch.stack(cliplist) + images = trimmed_clip.permute(0, 2, 3, 4, 1).numpy() + activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))])) + + pred = classifier_model(activations) + #print(torch.nn.Sigmoid()(pred)) + clip_predictions = tf.math.round(tf.math.sigmoid(pred)) + + final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item() + if len(clip_predictions) % 2 == 0 and tf.math.reduce_sum(clip_predictions) == len(clip_predictions)//2: + #print("I'm here") + final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item() + else: + final_pred = 1.0 + + if final_pred != numerical_labels[v_idx]: + if final_pred == 0.0: + fn_ids.append(f) + else: + fp_ids.append(f) + + final_list.append(final_pred) + + return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -51,7 +139,6 @@ if __name__ == "__main__": parser.add_argument('--save_train_test_splits', action='store_true') parser.add_argument('--run_2d', action='store_true') parser.add_argument('--balance_classes', action='store_true') - parser.add_argument('--dataset', default='pnb', type=str) parser.add_argument('--train_sparse', action='store_true') parser.add_argument('--mixing_ratio', type=float, default=1.0) parser.add_argument('--sparse_lr', type=float, default=0.003) @@ -63,18 +150,9 @@ if __name__ == "__main__": args = parser.parse_args() - if args.dataset == 'pnb': - crop_height = args.crop_height - crop_width = args.crop_width - - image_height = int(crop_height / args.scale_factor) - image_width = int(crop_width / args.scale_factor) - clip_depth = args.clip_depth - elif args.dataset == 'ptx': - image_height = 100 - image_width = 200 - else: - raise Exception('Invalid dataset') + image_height = 100 + image_width = 200 + clip_depth = args.clip_depth batch_size = args.batch_size @@ -89,7 +167,7 @@ if __name__ == "__main__": with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: out_f.write(str(args)) - yolo_model = YoloModel() + yolo_model = YoloModel('ptx') all_errors = [] @@ -97,31 +175,24 @@ if __name__ == "__main__": inputs = keras.Input(shape=(image_height, image_width, clip_depth)) else: inputs = keras.Input(shape=(clip_depth, image_height, image_width, 1)) - + filter_inputs = keras.Input(shape=(args.kernel_depth, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32') output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs) sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output) - + recon_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels)) - + recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs) - + recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs) if args.sparse_checkpoint: recon_model.set_weights(keras.models.load_model(args.sparse_checkpoint).get_weights()) - positive_class = None - if args.dataset == 'pnb': - splits, dataset = load_pnb_videos(yolo_model, args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, frames_to_skip=args.frames_to_skip) - positive_class = 'Positives' - elif args.dataset == 'ptx': - train_loader, test_loader, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json') - positive_class = 'No_Sliding' - else: - raise Exception('Invalid dataset') + splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json') + positive_class = 'No_Sliding' overall_true = [] overall_pred = [] @@ -131,7 +202,6 @@ if __name__ == "__main__": i_fold = 0 for train_idx, test_idx in splits: train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) -# train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True) train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) @@ -149,15 +219,8 @@ if __name__ == "__main__": if args.checkpoint: classifier_model = keras.models.load_model(args.checkpoint) else: -# classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernel)) - classifier_inputs = keras.Input(shape=(clip_depth, image_height, image_width)) - - if args.dataset == 'pnb': - classifier_outputs = PNBTemporalClassifier()(classifier_inputs) - elif args.dataset == 'ptx': - classifier_outputs = PTXClassifier()(classifier_inputs) - else: - raise Exception('No classifier exists for that dataset') + classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels)) + classifier_outputs = PTXClassifier()(classifier_inputs) classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs) @@ -178,14 +241,6 @@ if __name__ == "__main__": for labels, local_batch, vid_f in tqdm(train_loader): images = local_batch.permute(0, 2, 3, 4, 1).numpy() - cv2.imwrite('example_video_2/test_{}_0.png'.format(labels[3]), images[3, 0, :, :] * 255) - cv2.imwrite('example_video_2/test_{}_1.png'.format(labels[3]), images[3, 1, :, :] * 255) - cv2.imwrite('example_video_2/test_{}_2.png'.format(labels[3]), images[3, 2, :, :] * 255) - cv2.imwrite('example_video_2/test_{}_3.png'.format(labels[3]), images[3, 3, :, :] * 255) - cv2.imwrite('example_video_2/test_{}_4.png'.format(labels[3]), images[3, 4, :, :] * 255) - print(vid_f[3]) - raise Exception - torch_labels = np.zeros(len(labels)) torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 torch_labels = np.expand_dims(torch_labels, axis=1) @@ -198,11 +253,10 @@ if __name__ == "__main__": print(loss) else: -# activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) - # raise Exception + activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) with tf.GradientTape() as tape: - pred = classifier_model(images) + pred = classifier_model(activations) loss = criterion(torch_labels, pred) epoch_loss += loss * local_batch.size(0) @@ -224,7 +278,6 @@ if __name__ == "__main__": prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights)) - if y_true_train is None: y_true_train = torch_labels y_pred_train = tf.math.round(tf.math.sigmoid(pred)) @@ -237,16 +290,20 @@ if __name__ == "__main__": y_true = None y_pred = None test_loss = 0.0 - for labels, local_batch, vid_f in tqdm(test_loader): + + eval_loader = test_loader + if args.splits == 'all_train': + eval_loader = train_loader + for labels, local_batch, vid_f in tqdm(eval_loader): images = local_batch.permute(0, 2, 3, 4, 1).numpy() torch_labels = np.zeros(len(labels)) torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 torch_labels = np.expand_dims(torch_labels, axis=1) + + activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) -# activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) - - pred = classifier_model(images) + pred = classifier_model(activations) loss = criterion(torch_labels, pred) test_loss += loss @@ -277,158 +334,55 @@ if __name__ == "__main__": print("found better model") # Save model parameters classifier_model.save(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold))) - recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold))) +# recon_model.save(os.path.join(output_dir, "best_sparse_model_{}.pt".format(i_fold))) pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer_{}.pt'.format(i_fold)), 'wb+')) best_so_far = f1 classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier_{}.pt".format(i_fold))) - recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold))) +# recon_model = keras.models.load_model(os.path.join(output_dir, 'best_sparse_model_{}.pt'.format(i_fold))) - if args.dataset == 'pnb': - epoch_loss = 0 - - transform = torchvision.transforms.Compose( - [VideoGrayScaler(), - MinMaxScaler(0, 255), - torchvision.transforms.Resize((image_height, image_width)) - ]) - - y_true = None - y_pred = None - - pred_dict = {} - gt_dict = {} - - t1 = time.perf_counter() - # test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader] - test_videos = set() - for labels, local_batch, vid_f in test_loader: - test_videos.update(vid_f) - - test_labels = [vid_f.split('/')[-3] for vid_f in test_videos] - -# test_videos = glob.glob(pathname='/home/dwh48@drexel.edu/special_splits/test/*/*.mp4', recursive=True) -# test_labels = [f.split('/')[-2] for f in test_videos] -# print(test_videos) -# print(test_labels) - - if args.frames_to_skip == 1: - y_pred, y_true, fn, fp = calculate_pnb_scores(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform) - else: - y_pred, y_true, fn, fp = calculate_pnb_scores_skipped_frames(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, args.frames_to_skip, image_width, image_height, transform) - - t2 = time.perf_counter() + epoch_loss = 0 - print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1)) + y_true = None + y_pred = None - y_true = tf.cast(y_true, tf.int32) - y_pred = tf.cast(y_pred, tf.int32) + pred_dict = {} + gt_dict = {} - f1 = f1_score(y_true, y_pred, average='macro') - accuracy = accuracy_score(y_true, y_pred) - - fn_ids.extend(fn) - fp_ids.extend(fp) + t1 = time.perf_counter() + + transform = torchvision.transforms.Compose( + [VideoGrayScaler(), + MinMaxScaler(0, 255), + torchvision.transforms.Normalize((0.2592,), (0.1251,)), + torchvision.transforms.CenterCrop((100, 200)) + ]) + + test_dir = '/shared_data/bamc_ph1_test_data' + test_videos = glob.glob(os.path.join(test_dir, '*', '*.*')) + test_labels = [vid_f.split('/')[-2] for vid_f in test_videos] + + y_pred, y_true, fn, fp = calculate_ptx_scores(test_videos, test_labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform) - overall_true.extend(y_true) - overall_pred.extend(y_pred) - - print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy)) - - print(confusion_matrix(y_true, y_pred)) - elif args.dataset == 'ptx': - epoch_loss = 0 - - y_true = None - y_pred = None - - pred_dict = {} - gt_dict = {} - - t1 = time.perf_counter() - for labels, local_batch, vid_f in test_loader: - images = local_batch.permute(0, 2, 3, 4, 1).numpy() - - torch_labels = np.zeros(len(labels)) - torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 - torch_labels = np.expand_dims(torch_labels, axis=1) - - activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))])) - - pred = classifier_model(activations) - - loss = criterion(torch_labels, pred) - epoch_loss += loss * local_batch.size(0) - - for i, v_f in enumerate(vid_f): - if v_f not in pred_dict: - pred_dict[v_f] = tf.math.round(tf.math.sigmoid(pred[i])) - else: - pred_dict[v_f] = tf.concat((pred_dict[v_f], tf.math.round(tf.math.sigmoid(pred[i]))), axis=0) - - if v_f not in gt_dict: - gt_dict[v_f] = torch_labels[i] - else: - gt_dict[v_f] = tf.concat((gt_dict[v_f], torch_labels[i]), axis=0) - - if y_true is None: - y_true = torch_labels - y_pred = tf.math.round(tf.math.sigmoid(pred)) - else: - y_true = tf.concat((y_true, torch_labels), axis=0) - y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0) - - t2 = time.perf_counter() - - vid_acc = [] - for k in pred_dict.keys(): - gt_tmp = torch.tensor(gt_dict[k].numpy()) - pred_tmp = torch.tensor(pred_dict[k].numpy()) - - gt_mode = torch.mode(gt_tmp)[0].item() - # perm = torch.randperm(pred_tmp.size(0)) - # cutoff = int(pred_tmp.size(0)/4) - # if cutoff < 3: - # cutoff = 3 - # idx = perm[:cutoff] - # samples = pred_tmp[idx] - pred_mode = torch.mode(pred_tmp)[0].item() - overall_true.append(gt_mode) - overall_pred.append(pred_mode) - if pred_mode == gt_mode: - vid_acc.append(1) - else: - vid_acc.append(0) - if pred_mode == 0: - fn_ids.append(k) - else: - fp_ids.append(k) + t2 = time.perf_counter() - vid_acc = np.array(vid_acc) + print('i_fold={}, time={:.2f}'.format(i_fold, t2-t1)) - print('----------------------------------------------------------------------------') - for k in pred_dict.keys(): - print(k) - print('Predictions:') - print(pred_dict[k]) - print('Ground Truth:') - print(gt_dict[k]) - print('Overall Prediction:') - print(torch.mode(torch.tensor(pred_dict[k].numpy()))[0].item()) - print('----------------------------------------------------------------------------') + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) - print('loss={:.2f}, time={:.2f}'.format(loss, t2-t1)) + f1 = f1_score(y_true, y_pred, average='macro') + accuracy = accuracy_score(y_true, y_pred) - y_true = tf.cast(y_true, tf.int32) - y_pred = tf.cast(y_pred, tf.int32) + fn_ids.extend(fn) + fp_ids.extend(fp) - f1 = f1_score(y_true, y_pred, average='macro') - accuracy = accuracy_score(y_true, y_pred) - all_errors.append(np.sum(vid_acc) / len(vid_acc)) + overall_true.extend(y_true) + overall_pred.extend(y_pred) - print("Test f1={:.2f}, clip_acc={:.2f}, vid_acc={:.2f}".format(f1, accuracy, np.sum(vid_acc) / len(vid_acc))) + print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy)) - print(confusion_matrix(y_true, y_pred)) + print(confusion_matrix(y_true, y_pred)) i_fold += 1 diff --git a/sparse_coding_torch/train_sparse_model.py b/sparse_coding_torch/ptx/train_sparse_model.py similarity index 88% rename from sparse_coding_torch/train_sparse_model.py rename to sparse_coding_torch/ptx/train_sparse_model.py index a15e73d4fb8cfce69abe2a29e89d1f8e66667333..43c03df937acf96291b000eb75edbf6adf4c1a88 100644 --- a/sparse_coding_torch/train_sparse_model.py +++ b/sparse_coding_torch/ptx/train_sparse_model.py @@ -7,10 +7,10 @@ from matplotlib.animation import FuncAnimation from tqdm import tqdm import argparse import os -from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, load_needle_clips, load_onsd_videos +from sparse_coding_torch.ptx.load_data import load_yolo_clips import tensorflow.keras as keras import tensorflow as tf -from sparse_coding_torch.keras_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse +from sparse_coding_torch.sparse_model import normalize_weights_3d, normalize_weights, SparseCode, load_pytorch_weights, ReconSparse import random def plot_video(video): @@ -139,19 +139,9 @@ if __name__ == "__main__": with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: out_f.write(str(args)) - - if args.dataset == 'onsd': - splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), mode='all_train') - train_idx, test_idx = splits[0] - elif args.dataset == 'pnb': - train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=False, balance_classes=False, mode='all_train', frames_to_skip=args.frames_to_skip) - elif args.dataset == 'ptx': - splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json') - train_idx, test_idx = splits[0] - elif args.dataset == 'needle': - train_loader, test_loader, dataset = load_needle_clips(args.batch_size, input_size=(image_height, image_width, clip_depth)) - else: - raise Exception('Invalid dataset') + + splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json') + train_idx, test_idx = splits[0] train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, diff --git a/sparse_coding_torch/ptx/video_loader.py b/sparse_coding_torch/ptx/video_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..93ad37eb5b9dd71fd64986d9984a3202e7e1db2d --- /dev/null +++ b/sparse_coding_torch/ptx/video_loader.py @@ -0,0 +1,263 @@ +from os import listdir +from os.path import isfile +from os.path import join +from os.path import isdir +from os.path import abspath +from os.path import exists +import json +import glob + +from PIL import Image +from torchvision.transforms import ToTensor +from torchvision.datasets.video_utils import VideoClips +from tqdm import tqdm +import torch +import numpy as np +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torchvision.io import read_video +import torchvision as tv +from torch import nn +import torchvision.transforms.functional as tv_f +import csv +import random +import cv2 +from yolov4.get_bounding_boxes import YoloModel + +def get_ptx_participants(): + video_to_participant = {} + with open('/shared_data/bamc_data/bamc_video_info.csv', 'r') as csv_in: + reader = csv.DictReader(csv_in) + for row in reader: + key = row['Filename'].split('.')[0].lower().replace('_clean', '') + if key == '37 (mislabeled as 38)': + key = '37' + video_to_participant[key] = row['Participant_id'] + + return video_to_participant + +class MinMaxScaler(object): + """ + Transforms each channel to the range [0, 1]. + """ + def __init__(self, min_val=0, max_val=254): + self.min_val = min_val + self.max_val = max_val + + def __call__(self, tensor): + return (tensor - self.min_val) / (self.max_val - self.min_val) + +class VideoGrayScaler(nn.Module): + + def __init__(self): + super().__init__() + self.grayscale = tv.transforms.Grayscale(num_output_channels=1) + + def forward(self, video): + # shape = channels, time, width, height + video = self.grayscale(video.swapaxes(-4, -3).swapaxes(-2, -1)) + video = video.swapaxes(-4, -3).swapaxes(-2, -1) + # print(video.shape) + return video + +class YoloClipLoader(Dataset): + + def __init__(self, yolo_output_path, num_frames=5, frames_between_clips=None, + transform=None, augment_transform=None, num_clips=1, num_positives=1, positive_videos=None, sparse_model=None, device=None): + if (num_frames % 2) == 0: + raise ValueError("Num Frames must be an odd number, so we can extract a clip centered on each detected region") + + clip_cache_file = 'clip_cache.pt' + + self.num_clips = num_clips + + self.num_frames = num_frames + if frames_between_clips is None: + self.frames_between_clips = num_frames + else: + self.frames_between_clips = frames_between_clips + + self.transform = transform + self.augment_transform = augment_transform + + self.labels = [name for name in listdir(yolo_output_path) if isdir(join(yolo_output_path, name))] + self.clips = [] + if exists(clip_cache_file): + self.clips = torch.load(open(clip_cache_file, 'rb')) + else: + for label in self.labels: + print("Processing videos in category: {}".format(label)) + videos = list(listdir(join(yolo_output_path, label))) + for vi in tqdm(range(len(videos))): + video = videos[vi] + counter = 0 + all_trimmed = [] + with open(abspath(join(yolo_output_path, label, video, 'result.json'))) as fin: + results = json.load(fin) + max_frame = len(results) + + for i in range((num_frames-1)//2, max_frame - (num_frames-1)//2 - 1, self.frames_between_clips): + # for frame in results: + frame = results[i] + # print('loading frame:', i, frame['frame_id']) + frame_start = int(frame['frame_id']) - self.num_frames//2 + frames = [abspath(join(yolo_output_path, label, video, 'frame{}.png'.format(frame_start+fid))) + for fid in range(num_frames)] + # print(frames) + frames = torch.stack([ToTensor()(Image.open(f).convert("RGB")) for f in frames]).swapaxes(0, 1) + + for region in frame['objects']: + # print(region) + if region['name'] != "Pleural_Line": + continue + + center_x = region['relative_coordinates']["center_x"] * 1920 + center_y = region['relative_coordinates']['center_y'] * 1080 + + # width = region['relative_coordinates']['width'] * 1920 + # height = region['relative_coordinates']['height'] * 1080 + width=200 + height=100 + + lower_y = round(center_y - height / 2) + upper_y = round(center_y + height / 2) + lower_x = round(center_x - width / 2) + upper_x = round(center_x + width / 2) + + final_clip = frames[:, :, lower_y:upper_y, lower_x:upper_x] + + if self.transform: + final_clip = self.transform(final_clip) + + if sparse_model: + with torch.no_grad(): + final_clip = final_clip.unsqueeze(0).to(device) + final_clip = sparse_model(final_clip) + final_clip = final_clip.squeeze(0).detach().cpu() + + self.clips.append((label, final_clip, video)) + + torch.save(self.clips, open(clip_cache_file, 'wb+')) + + +# random.shuffle(self.clips) + +# video_to_clips = {} + if positive_videos: + vids_to_keep = json.load(open(positive_videos)) + + self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2] in vids_to_keep or clip_tup[0] == 'Sliding'] + else: + video_to_labels = {} + + for lbl, clip, video in self.clips: + video = video.lower().replace('_clean', '') + if video not in video_to_labels: + # video_to_clips[video] = [] + video_to_labels[video] = [] + + # video_to_clips[video].append(clip) + video_to_labels[video].append(lbl) + + video_to_participants = get_ptx_participants() + participants_to_video = {} + for k, v in video_to_participants.items(): + if video_to_labels[k][0] == 'Sliding': + continue + if not v in participants_to_video: + participants_to_video[v] = [] + + participants_to_video[v].append(k) + + participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True)) + + num_to_remove = len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) - num_positives + vids_to_remove = set() + while num_to_remove > 0: + vids_to_remove.add(participants_to_video[list(participants_to_video.keys())[0]].pop()) + participants_to_video = dict(sorted(participants_to_video.items(), key=lambda x: len(x[1]), reverse=True)) + num_to_remove -= 1 + + self.clips = [clip_tup for clip_tup in self.clips if clip_tup[2].lower().replace('_clean', '') not in vids_to_remove] + + video_to_clips = {} + video_to_labels = {} + + for lbl, clip, video in self.clips: + if video not in video_to_clips: + video_to_clips[video] = [] + video_to_labels[video] = [] + + video_to_clips[video].append(clip) + video_to_labels[video].append(lbl) + + print([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding']) + + print('Num positive:', len([k for k,v in video_to_labels.items() if v[0] == 'No_Sliding'])) + print('Num negative:', len([k for k,v in video_to_labels.items() if v[0] == 'Sliding'])) + + self.videos = None + self.max_video_clips = 0 + if num_clips > 1: + self.videos = [] + + for video in video_to_clips.keys(): + clip_list = video_to_clips[video] + lbl_list = video_to_labels[video] + + for i in range(0, len(clip_list) - num_clips, 1): + video_stack = torch.stack(clip_list[i:i+num_clips]) + + self.videos.append((max(set(lbl_list[i:i+num_clips]), key=lbl_list[i:i+num_clips].count), video_stack, video)) + + self.clips = None + + + def get_labels(self): + if self.num_clips > 1: + return [self.videos[i][0] for i in range(len(self.videos))] + else: + return [self.clips[i][0] for i in range(len(self.clips))] + + def get_filenames(self): + if self.num_clips > 1: + return [self.videos[i][2] for i in range(len(self.videos))] + else: + return [self.clips[i][2] for i in range(len(self.clips))] + + def __getitem__(self, index): + if self.num_clips > 1: + label = self.videos[index][0] + video = self.videos[index][1] + filename = self.videos[index][2] + + video = video.squeeze(2) + video = video.permute(1, 0, 2, 3) + + if self.augment_transform: + video = self.augment_transform(video) + + video = video.unsqueeze(2) + video = video.permute(1, 0, 2, 3, 4) +# video = video.permute(4, 1, 2, 3, 0) +# video = torch.nn.functional.pad(video, (0), 'constant', 0) +# video = video.permute(4, 1, 2, 3, 0) + + orig_len = video.size(0) + +# if orig_len < self.max_video_clips: +# video = torch.cat([video, torch.zeros(self.max_video_clips - len(video), video.size(1), video.size(2), video.size(3), video.size(4))]) + + return label, video, filename, orig_len + else: + label = self.clips[index][0] + video = self.clips[index][1] + filename = self.clips[index][2] + + if self.augment_transform: + video = self.augment_transform(video) + + return label, video, filename + + def __len__(self): + return len(self.clips) diff --git a/sparse_coding_torch/keras_model.py b/sparse_coding_torch/sparse_model.py similarity index 54% rename from sparse_coding_torch/keras_model.py rename to sparse_coding_torch/sparse_model.py index 86f015a7950113252f087f9ce4596b0a90c9cd0b..12ef1bed39321958c65cc4f1fef8f3572cc0b690 100644 --- a/sparse_coding_torch/keras_model.py +++ b/sparse_coding_torch/sparse_model.py @@ -6,7 +6,7 @@ import cv2 import torchvision as tv import torch import torch.nn as nn -from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler +from sparse_coding_torch.utils import VideoGrayScaler, MinMaxScaler def load_pytorch_weights(file_path): pytorch_checkpoint = torch.load(file_path, map_location='cpu') @@ -220,205 +220,4 @@ class ReconSparse(keras.Model): else: recon = do_recon_3d(self.filters, activations, self.image_height, self.image_width, self.clip_depth, self.stride, self.padding) - return recon - -class PTXClassifier(keras.layers.Layer): - def __init__(self): - super(PTXClassifier, self).__init__() - - self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4) - self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid') -# self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid') - - self.flatten = keras.layers.Flatten() - - self.dropout = keras.layers.Dropout(0.5) - -# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) -# self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True) - self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True) - self.ff_4 = keras.layers.Dense(1) - -# @tf.function - def call(self, activations): - activations = tf.squeeze(activations, axis=1) - x = self.max_pool(activations) - x = self.conv_1(x) -# x = self.conv_2(x) - x = self.flatten(x) -# x = self.ff_1(x) -# x = self.dropout(x) -# x = self.ff_2(x) -# x = self.dropout(x) - x = self.ff_3(x) - x = self.dropout(x) - x = self.ff_4(x) - - return x - -class PNBClassifier(keras.layers.Layer): - def __init__(self): - super(PNBClassifier, self).__init__() - -# self.max_pool = keras.layers.MaxPooling2D(pool_size=(8, 8), strides=(2, 2)) - self.conv_1 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=(4, 4), activation='relu', padding='valid') - self.conv_2 = keras.layers.Conv2D(32, kernel_size=4, strides=2, activation='relu', padding='valid') -# self.conv_3 = keras.layers.Conv2D(12, kernel_size=4, strides=1, activation='relu', padding='valid') -# self.conv_4 = keras.layers.Conv2D(16, kernel_size=4, strides=2, activation='relu', padding='valid') - - self.flatten = keras.layers.Flatten() - -# self.dropout = keras.layers.Dropout(0.5) - -# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) - self.ff_2 = keras.layers.Dense(40, activation='relu', use_bias=True) - self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True) - self.ff_4 = keras.layers.Dense(1) - -# @tf.function - def call(self, activations): - x = tf.squeeze(activations, axis=1) -# x = self.max_pool(x) -# print(x.shape) - x = self.conv_1(x) -# print(x.shape) - x = self.conv_2(x) -# print(x.shape) -# raise Exception -# x = self.conv_3(x) -# print(x.shape) -# x = self.conv_4(x) -# raise Exception - x = self.flatten(x) -# x = self.ff_1(x) -# x = self.dropout(x) - x = self.ff_2(x) -# x = self.dropout(x) - x = self.ff_3(x) -# x = self.dropout(x) - x = self.ff_4(x) - - return x - -class PNBTemporalClassifier(keras.layers.Layer): - def __init__(self): - super(PNBTemporalClassifier, self).__init__() - self.conv_1 = keras.layers.Conv2D(12, kernel_size=(150, 24), strides=(1, 8), activation='relu', padding='valid') - self.conv_2 = keras.layers.Conv1D(24, kernel_size=8, strides=4, activation='relu', padding='valid') - - self.ff_1 = keras.layers.Dense(100, activation='relu', use_bias=True) - - self.gru = keras.layers.GRU(25) - - self.flatten = keras.layers.Flatten() - - self.ff_2 = keras.layers.Dense(10, activation='relu', use_bias=True) - self.ff_3 = keras.layers.Dense(1) - -# @tf.function - def call(self, clip): - width = clip.shape[3] - height = clip.shape[2] - depth = clip.shape[1] - - x = tf.expand_dims(clip, axis=4) - x = tf.reshape(clip, (-1, height, width, 1)) - - x = self.conv_1(x) - x = tf.squeeze(x, axis=1) - x = self.conv_2(x) - - x = self.flatten(x) - x = self.ff_1(x) - - x = tf.reshape(x, (-1, 5, 100)) - x = self.gru(x) - - x = self.ff_2(x) - x = self.ff_3(x) - - return x - -class MobileModelPTX(keras.Model): - def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): - super().__init__() - self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d) - self.classifier = Classifier() - - self.out_channels = out_channels - self.in_channels = in_channels - self.stride = stride - self.lam = lam - self.activation_lr = activation_lr - self.max_activation_iter = max_activation_iter - self.batch_size = batch_size - self.run_2d = run_2d - - pytorch_weights = load_pytorch_weights(sparse_checkpoint) - - if run_2d: - weight_list = np.split(pytorch_weights, 5, axis=0) - self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False) - self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False) - self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False) - self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False) - self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False) - else: - self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False) - - @tf.function - def call(self, images): - images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1) - images = tf.transpose(images, perm=[0, 2, 3, 1]) - images = images / 255 - images = (images - 0.2592) / 0.1251 - - if self.run_2d: - activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)]) - else: - activations = self.sparse_code(images, tf.stop_gradient(self.filters)) - - pred = self.classifier(activations) - - return pred - -class MobileModelPNB(keras.Model): - def __init__(self, sparse_weights, classifier_model, batch_size, image_height, image_width, clip_depth, out_channels, kernel_size, kernel_depth, stride, lam, activation_lr, max_activation_iter, run_2d): - super().__init__() - self.sparse_code = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=out_channels, kernel_size=kernel_size, kernel_depth=kernel_depth, stride=stride, lam=lam, activation_lr=activation_lr, max_activation_iter=max_activation_iter, run_2d=run_2d, padding='VALID') - self.classifier = classifier_model - - self.out_channels = out_channels - self.stride = stride - self.lam = lam - self.activation_lr = activation_lr - self.max_activation_iter = max_activation_iter - self.batch_size = batch_size - self.run_2d = run_2d - - if run_2d: - weight_list = np.split(sparse_weights, 5, axis=0) - self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False) - self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False) - self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False) - self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False) - self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False) - else: - self.filters = tf.Variable(initial_value=sparse_weights, dtype='float32', trainable=False) - - @tf.function - def call(self, images): -# images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1) - images = tf.transpose(images, perm=[0, 2, 3, 1]) - images = images / 255 - - if self.run_2d: - activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)]) - activations = tf.expand_dims(activations, axis=1) - else: - activations = self.sparse_code(images, tf.stop_gradient(self.filters)) - - pred = tf.math.round(tf.math.sigmoid(self.classifier(activations))) -# pred = tf.math.reduce_sum(activations) - - return pred + return recon \ No newline at end of file diff --git a/sparse_coding_torch/utils.py b/sparse_coding_torch/utils.py index 98c807bb4be5f285ed70a35fcab2bdfb1af5de6d..62e67d34f3d586efff1851e41bb84aad01dc30f8 100644 --- a/sparse_coding_torch/utils.py +++ b/sparse_coding_torch/utils.py @@ -1,113 +1,149 @@ import numpy as np -from sparse_coding_torch.video_loader import get_yolo_regions, classify_nerve_is_right import torchvision as tv import torch import tensorflow as tf from tqdm import tqdm +from torchvision.datasets.video_utils import VideoClips +from typing import Sequence, Iterator +import torch.nn as nn -def calculate_pnb_scores(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, image_width, image_height, transform): - all_predictions = [] +def get_sample_weights(train_idx, dataset): + dataset = list(dataset) + + num_positive = len([clip[0] for clip in dataset if clip[0] == 'Positives']) + negative_weight = num_positive / len(dataset) + positive_weight = 1.0 - negative_weight - numerical_labels = [] - for label in labels: + weights = [] + for idx in train_idx: + label = dataset[idx][0] if label == 'Positives': - numerical_labels.append(1.0) - else: - numerical_labels.append(0.0) - - final_list = [] - fp_ids = [] - fn_ids = [] - for v_idx, f in tqdm(enumerate(input_videos)): - vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2) - is_right = classify_nerve_is_right(yolo_model, vc) - - all_preds = [] - for j in range(vc.size(1) - 5, vc.size(1) - 25, -5): - if j-5 < 0: - break - - vc_sub = vc[:, j-5:j, :, :] - - if vc_sub.size(1) < 5: - continue - - clip = get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height) - - if not clip: - continue - - clip = clip[0] - clip = transform(clip).to(torch.float32) - clip = tf.expand_dims(clip, axis=4) - - activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))])) - - pred = tf.math.round(tf.math.sigmoid(classifier_model(activations))) - - all_preds.append(pred) - - if all_preds: - final_pred = np.round(np.mean(np.array(all_preds))) + weights.append(positive_weight) + elif label == 'Negatives': + weights.append(negative_weight) else: - final_pred = 1.0 - - if final_pred != numerical_labels[v_idx]: - if final_pred == 0: - fn_ids.append(f) - else: - fp_ids.append(f) - - final_list.append(final_pred) - - return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids + raise Exception('Sampler encountered invalid label') + + return weights + +class SubsetWeightedRandomSampler(torch.utils.data.Sampler[int]): + weights: torch.Tensor + num_samples: int + replacement: bool + + def __init__(self, weights: Sequence[float], indicies: Sequence[int], + replacement: bool = True, generator=None) -> None: + if not isinstance(replacement, bool): + raise ValueError("replacement should be a boolean value, but got " + "replacement={}".format(replacement)) + self.weights = torch.as_tensor(weights, dtype=torch.double) + self.indicies = indicies + self.replacement = replacement + self.generator = generator + + def __iter__(self) -> Iterator[int]: + rand_tensor = torch.multinomial(self.weights, len(self.indicies), self.replacement, generator=self.generator) + for i in rand_tensor: + yield self.indicies[i] -def calculate_pnb_scores_skipped_frames(input_videos, labels, yolo_model, sparse_model, recon_model, classifier_model, frames_to_skip, image_width, image_height, transform): - all_predictions = [] + def __len__(self) -> int: + return len(self.indicies) + +class MinMaxScaler(object): + """ + Transforms each channel to the range [0, 1]. + """ + def __init__(self, min_val=0, max_val=254): + self.min_val = min_val + self.max_val = max_val - numerical_labels = [] - for label in labels: - if label == 'Positives': - numerical_labels.append(1.0) - else: - numerical_labels.append(0.0) - - final_list = [] - fp_ids = [] - fn_ids = [] - for v_idx, f in tqdm(enumerate(input_videos)): - vc = tv.io.read_video(f)[0].permute(3, 0, 1, 2) - is_right = classify_nerve_is_right(yolo_model, vc) - - all_preds = [] + def __call__(self, tensor): + return (tensor - self.min_val) / (self.max_val - self.min_val) + +class VideoGrayScaler(nn.Module): + + def __init__(self): + super().__init__() + self.grayscale = tv.transforms.Grayscale(num_output_channels=1) - frames = [] - for k in range(vc.size(1) - 1, vc.size(1) - 5 * frames_to_skip, -frames_to_skip): - frames.append(vc[:, k, :, :]) - vc_sub = torch.stack(frames, dim=1) - - if vc_sub.size(1) < 5: - continue + def forward(self, video): + # shape = channels, time, width, height + video = self.grayscale(video.swapaxes(-4, -3).swapaxes(-2, -1)) + video = video.swapaxes(-4, -3).swapaxes(-2, -1) + # print(video.shape) + return video - clip = get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height) +def plot_video(video): - if clip: - clip = clip[0] - clip = transform(clip).to(torch.float32) - clip = tf.expand_dims(clip, axis=4) + fig = plt.gcf() + ax = plt.gca() - activations = tf.stop_gradient(sparse_model([clip, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))])) + DPI = fig.get_dpi() + fig.set_size_inches(video.shape[2]/float(DPI), video.shape[3]/float(DPI)) - pred = tf.math.round(tf.math.sigmoid(classifier_model(activations))) - else: - pred = 1.0 - - if pred != numerical_labels[v_idx]: - if pred == 0: - fn_ids.append(f) - else: - fp_ids.append(f) - - final_list.append(pred) - - return np.array(final_list), np.array(numerical_labels), fn_ids, fp_ids \ No newline at end of file + ax.set_title("Video") + + T = video.shape[1] + im = ax.imshow(video[0, 0, :, :], + cmap=cm.Greys_r) + + def update(i): + t = i % T + im.set_data(video[0, t, :, :]) + + return FuncAnimation(plt.gcf(), update, interval=1000/20) + +def plot_original_vs_recon(original, reconstruction, idx=0): + + # create two subplots + ax1 = plt.subplot(1, 2, 1) + ax2 = plt.subplot(1, 2, 2) + ax1.set_title("Original") + ax2.set_title("Reconstruction") + + T = original.shape[2] + im1 = ax1.imshow(original[idx, 0, 0, :, :], + cmap=cm.Greys_r) + im2 = ax2.imshow(reconstruction[idx, 0, 0, :, :], + cmap=cm.Greys_r) + + def update(i): + t = i % T + im1.set_data(original[idx, 0, t, :, :]) + im2.set_data(reconstruction[idx, 0, t, :, :]) + + return FuncAnimation(plt.gcf(), update, interval=1000/30) + + +def plot_filters(filters): + filters = filters.astype('float32') + num_filters = filters.shape[4] + ncol = 3 + # ncol = int(np.sqrt(num_filters)) + # nrow = int(np.sqrt(num_filters)) + T = filters.shape[0] + + if num_filters // ncol == num_filters / ncol: + nrow = num_filters // ncol + else: + nrow = num_filters // ncol + 1 + + fig, axes = plt.subplots(ncols=ncol, nrows=nrow, + constrained_layout=True, + figsize=(ncol*2, nrow*2)) + + ims = {} + for i in range(num_filters): + r = i // ncol + c = i % ncol + ims[(r, c)] = axes[r, c].imshow(filters[0, :, :, 0, i], + cmap=cm.Greys_r) + + def update(i): + t = i % T + for i in range(num_filters): + r = i // ncol + c = i % ncol + ims[(r, c)].set_data(filters[t, :, :, 0, i]) + + return FuncAnimation(plt.gcf(), update, interval=1000/20) \ No newline at end of file diff --git a/yolov4/get_bounding_boxes.py b/yolov4/get_bounding_boxes.py index 618338f63477a35a65e0fe6c528217d40db8dc2b..ae1d0477c171b3b4f7989f3e789d11f71c053f84 100644 --- a/yolov4/get_bounding_boxes.py +++ b/yolov4/get_bounding_boxes.py @@ -13,12 +13,17 @@ from tensorflow.compat.v1 import InteractiveSession import time class YoloModel(): - def __init__(self): + def __init__(self, dataset): flags.DEFINE_string('framework', 'tflite', '(tf, tflite, trt') - flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite', - 'path to weights file') -# flags.DEFINE_string('weights', 'yolov4/yolov4-416.tflite', -# 'path to weights file') + if dataset == 'pnb': + flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow/pnb_prelim_yolo/yolov4-416.tflite', + 'path to weights file') + elif dataset == 'onsd': + flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow/onsd_prelim_yolo/yolov4-416.tflite', + 'path to weights file') + else: + flags.DEFINE_string('weights', 'yolov4/yolov4-416.tflite', + 'path to weights file') flags.DEFINE_integer('size', 416, 'resize images to') flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny') flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4') @@ -111,21 +116,26 @@ class YoloModel(): boxes = boxes.tolist() boxes = boxes[0] classes = classes[0] + scores = scores[0] boxes_list = [] class_list = [] - for box, class_idx in zip(boxes, classes): + score_list = [] + for box, class_idx, score in zip(boxes, classes, scores): sum = 0 for value in box: sum += value if sum > 0: boxes_list.append(box) class_list.append(class_idx) + score_list.append(score) boxes_list = [boxes_list] class_list = [class_list] + score_list = [score_list] boxes = np.array(boxes_list) classes = np.array(class_list) + scores = np.array(score_list) end = time.time() elapsed_time = end - start # print('Took %.2f seconds to run whole bounding box function\n' % (elapsed_time)) - return boxes, classes + return boxes, classes, scores