diff --git a/generate_yolo_regions.py b/generate_yolo_regions.py index ceed5ca92273c08073b8d5c09b965053ec0d538e..a0351cdb1fd2125bd3cfe1947f2ff0e9a3544d9a 100644 --- a/generate_yolo_regions.py +++ b/generate_yolo_regions.py @@ -3,7 +3,7 @@ import os import time import numpy as np import torchvision -from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels +from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions, classify_nerve_is_right, load_pnb_region_labels, calculate_angle from torchvision.datasets.video_utils import VideoClips import torchvision as tv import csv @@ -16,26 +16,30 @@ import cv2 import tensorflow.keras as keras from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse import glob +from sparse_coding_torch.train_sparse_model import plot_video if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--input_video', required=True, type=str, help='Path to input video.') parser.add_argument('--output_dir', default='yolo_output', type=str, help='Location where yolo clips should be saved.') parser.add_argument('--num_frames', default=5, type=int) - parser.add_argument('--image_height', default=285, type=int) - parser.add_argument('--image_width', default=235, type=int) + parser.add_argument('--stride', default=5, type=int) + parser.add_argument('--image_height', default=150, type=int) + parser.add_argument('--image_width', default=400, type=int) args = parser.parse_args() path = args.input_video - region_labels = load_pnb_region_labels(os.path.join('/'.join(path.split('/')[:-3]), 'sme_region_labels.csv')) + region_labels = load_pnb_region_labels('sme_region_labels.csv') if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) image_height = args.image_height image_width = args.image_width + clip_depth = args.num_frames + frames_to_skip = args.stride # For some reason the size has to be even for the clips, so it will add one if the size is odd transforms = torchvision.transforms.Compose([ @@ -46,69 +50,97 @@ if __name__ == "__main__": vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2) is_right = classify_nerve_is_right(yolo_model, vc) + angle = calculate_angle(yolo_model, vc) person_idx = path.split('/')[-2] label = path.split('/')[-3] output_count = 0 - + if label == 'Positives' and person_idx in region_labels: negative_regions, positive_regions = region_labels[person_idx] for sub_region in negative_regions.split(','): sub_region = sub_region.split('-') start_loc = int(sub_region[0]) +# end_loc = int(sub_region[1]) - 50 end_loc = int(sub_region[1]) + 1 - for frame in range(start_loc, end_loc - args.num_frames, args.num_frames): - vc_sub = vc[:, frame:frame+args.num_frames, :, :] - if vc_sub.size(1) < args.num_frames: + for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth): + frames = [] + for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): + frames.append(vc[:, k, :, :]) + vc_sub = torch.stack(frames, dim=1) + + if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): clip = transforms(clip) - tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20) + ani = plot_video(clip) + ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4')) output_count += 1 if positive_regions: for sub_region in positive_regions.split(','): sub_region = sub_region.split('-') +# start_loc = int(sub_region[0]) + 15 start_loc = int(sub_region[0]) - if len(sub_region) == 1: - vc_sub = vc[:, start_loc:start_loc+args.num_frames, :, :] - if vc_sub.size(1) < args.num_frames: + if len(sub_region) == 1 and vc.size(1) >= start_loc + clip_depth * frames_to_skip: + frames = [] + for k in range(start_loc, start_loc + clip_depth * frames_to_skip, frames_to_skip): + frames.append(vc[:, k, :, :]) + vc_sub = torch.stack(frames, dim=1) + + if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): clip = transforms(clip) - tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20) + ani = plot_video(clip) + ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4')) output_count += 1 - else: + elif vc.size(1) >= start_loc + clip_depth * frames_to_skip: end_loc = sub_region[1] if end_loc.strip().lower() == 'end': end_loc = vc.size(1) else: end_loc = int(end_loc) - for frame in range(start_loc, end_loc - args.num_frames, args.num_frames): - vc_sub = vc[:, frame:frame+args.num_frames, :, :] -# cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy()) - if vc_sub.size(1) < args.num_frames: + for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth): + frames = [] + for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): + frames.append(vc[:, k, :, :]) + vc_sub = torch.stack(frames, dim=1) + + if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): clip = transforms(clip) - tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20) + ani = plot_video(clip) + ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4')) output_count += 1 + else: + continue elif label == 'Positives': - vc_sub = vc[:, -args.num_frames:, :, :] - if not vc_sub.size(1) < args.num_frames: - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height): - clip = transforms(clip) - tv.io.write_video(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20) - output_count += 1 - elif label == 'Negatives': - for j in range(0, vc.size(1) - args.num_frames, args.num_frames): - vc_sub = vc[:, j:j+args.num_frames, :, :] - if not vc_sub.size(1) < args.num_frames: + frames = [] + for k in range(j, -1 * clip_depth * frames_to_skip, frames_to_skip): + frames.append(vc[:, k, :, :]) + if frames: + vc_sub = torch.stack(frames, dim=1) + if vc_sub.size(1) >= clip_depth: for clip in get_yolo_regions(yolo_model, vc_sub, is_right, image_width, image_height): clip = transforms(clip) - tv.io.write_video(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4'), clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20) + ani = plot_video(clip) + ani.save(os.path.join(args.output_dir, 'positive_yolo' + str(output_count) + '.mp4')) + output_count += 1 + elif label == 'Negatives': + for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth): + frames = [] + for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): + frames.append(vc[:, k, :, :]) + vc_sub = torch.stack(frames, dim=1) + if vc_sub.size(1) >= clip_depth: + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, image_width, image_height): + clip = transforms(clip) + ani = plot_video(clip) + ani.save(os.path.join(args.output_dir, 'negative_yolo' + str(output_count) + '.mp4')) output_count += 1 else: raise Exception('Invalid label') \ No newline at end of file diff --git a/run_ptx.py b/run_ptx.py index 5ee5656f7e1db3638e5352dbf1e8e171c506b8ac..1b927f46a452eabe5fda728b7638d15389bfc4f9 100644 --- a/run_ptx.py +++ b/run_ptx.py @@ -24,8 +24,8 @@ if __name__ == "__main__": parser.add_argument('--max_activation_iter', default=100, type=int) parser.add_argument('--activation_lr', default=1e-2, type=float) parser.add_argument('--lam', default=0.05, type=float) - parser.add_argument('--sparse_checkpoint', default='converted_checkpoints/sparse.pt', type=str) - parser.add_argument('--checkpoint', default='converted_checkpoints/classifier.pt', type=str) + parser.add_argument('--sparse_checkpoint', default='ptx_tensorflow/sparse.pt', type=str) + parser.add_argument('--checkpoint', default='ptx_tensorflow/best_classifier.pt', type=str) parser.add_argument('--run_2d', action='store_true') parser.add_argument('--dataset', default='ptx', type=str) @@ -33,12 +33,10 @@ if __name__ == "__main__": #print(args.accumulate(args.integers)) batch_size = 1 - if args.dataset == 'pnb': - image_height = 250 - image_width = 600 - elif args.dataset == 'ptx': + if args.dataset == 'ptx': image_height = 100 image_width = 200 + clip_depth = 5 else: raise Exception('Invalid dataset') @@ -49,15 +47,9 @@ if __name__ == "__main__": filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32') - output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='SAME')(inputs, filter_inputs) + output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, clip_depth=clip_depth, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, kernel_depth=args.kernel_depth, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='SAME')(inputs, filter_inputs) sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output) - - recon_inputs = keras.Input(shape=(1, image_height // args.stride, image_width // args.stride, args.num_kernels)) - - recon_outputs = ReconSparse(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs) - - recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs) if args.sparse_checkpoint: recon_model = keras.models.load_model(args.sparse_checkpoint) diff --git a/sme_region_labels.csv b/sme_region_labels.csv index c533ec9c715bf1da2cd30d320e034c5275fce9b9..5f584839b200a189c21252e9b83051480b7e261f 100644 --- a/sme_region_labels.csv +++ b/sme_region_labels.csv @@ -1,9 +1,11 @@ idx,negative_regions,positive_regions 11,"0-121,158-516","121-157,517-end" +46,0-208,209-end 54,0-397,397 67,0-120,120-end 93,"0-60,155-256","61-154,257-end" 94,0-78,79-end +125,0-372,373-end 134,0-393,394-end 153,0-200,201-end 189,"0-122,123-184",185 @@ -12,6 +14,6 @@ idx,negative_regions,positive_regions 211,0-86,87-end 217,"0-86,87-137",138-end 222,0-112,113-end -230,0-7, -238,"0-140,141-152",153-end) -240,0-205,206-end \ No newline at end of file +230,0-75,76-end +238,"0-140,141-152",153-end +240,0-205,206-end diff --git a/sparse_coding_torch/keras_model.py b/sparse_coding_torch/keras_model.py index c73b1423eb1abb9e4b20d1cd4ffe21e88133acdd..86f015a7950113252f087f9ce4596b0a90c9cd0b 100644 --- a/sparse_coding_torch/keras_model.py +++ b/sparse_coding_torch/keras_model.py @@ -299,6 +299,45 @@ class PNBClassifier(keras.layers.Layer): x = self.ff_4(x) return x + +class PNBTemporalClassifier(keras.layers.Layer): + def __init__(self): + super(PNBTemporalClassifier, self).__init__() + self.conv_1 = keras.layers.Conv2D(12, kernel_size=(150, 24), strides=(1, 8), activation='relu', padding='valid') + self.conv_2 = keras.layers.Conv1D(24, kernel_size=8, strides=4, activation='relu', padding='valid') + + self.ff_1 = keras.layers.Dense(100, activation='relu', use_bias=True) + + self.gru = keras.layers.GRU(25) + + self.flatten = keras.layers.Flatten() + + self.ff_2 = keras.layers.Dense(10, activation='relu', use_bias=True) + self.ff_3 = keras.layers.Dense(1) + +# @tf.function + def call(self, clip): + width = clip.shape[3] + height = clip.shape[2] + depth = clip.shape[1] + + x = tf.expand_dims(clip, axis=4) + x = tf.reshape(clip, (-1, height, width, 1)) + + x = self.conv_1(x) + x = tf.squeeze(x, axis=1) + x = self.conv_2(x) + + x = self.flatten(x) + x = self.ff_1(x) + + x = tf.reshape(x, (-1, 5, 100)) + x = self.gru(x) + + x = self.ff_2(x) + x = self.ff_3(x) + + return x class MobileModelPTX(keras.Model): def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/load_data.py index 96ef3fa885a4a84dfba9f3966ca5a67b75673608..6dea292a90a127da77ec9c6ce16d8c234a2cb5da 100644 --- a/sparse_coding_torch/load_data.py +++ b/sparse_coding_torch/load_data.py @@ -40,12 +40,12 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non return gss.split(np.arange(len(targets)), targets, groups), dataset elif mode == 'all_train': train_idx = np.arange(len(targets)) - train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) - train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - sampler=train_sampler) - test_loader = None +# train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) +# train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, +# sampler=train_sampler) +# test_loader = None - return train_loader, test_loader, dataset + return [(train_idx, None)], dataset elif mode == 'k_fold': gss = StratifiedGroupKFold(n_splits=n_splits) diff --git a/sparse_coding_torch/train_classifier.py b/sparse_coding_torch/train_classifier.py index 812769c676852d472491168fca65d770e1bbb76d..ca71ca2f59f6aec2f5b0d4eebaf60c4b049be13c 100644 --- a/sparse_coding_torch/train_classifier.py +++ b/sparse_coding_torch/train_classifier.py @@ -5,7 +5,7 @@ from tqdm import tqdm import argparse import os from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler, get_sample_weights -from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d +from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, normalize_weights, normalize_weights_3d, PNBTemporalClassifier import time import numpy as np from sklearn.metrics import f1_score, accuracy_score, confusion_matrix @@ -19,6 +19,7 @@ from yolov4.get_bounding_boxes import YoloModel import torchvision from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler import glob +import cv2 configproto = tf.compat.v1.ConfigProto() configproto.gpu_options.polling_inactive_delay_msecs = 5000 @@ -129,8 +130,8 @@ if __name__ == "__main__": i_fold = 0 for train_idx, test_idx in splits: -# train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) - train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True) + train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) +# train_sampler = SubsetWeightedRandomSampler(get_sample_weights(train_idx, dataset), train_idx, replacement=True) train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) @@ -139,19 +140,20 @@ if __name__ == "__main__": test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler) - with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out: - test_set = set([x for tup in test_loader for x in tup[2]]) - test_videos_out.writelines(test_set) +# with open(os.path.join(args.output_dir, 'test_videos_{}.txt'.format(i_fold)), 'w+') as test_videos_out: +# test_set = set([x for tup in test_loader for x in tup[2]]) +# test_videos_out.writelines(test_set) else: test_loader = None if args.checkpoint: classifier_model = keras.models.load_model(args.checkpoint) else: - classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels)) +# classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernel)) + classifier_inputs = keras.Input(shape=(clip_depth, image_height, image_width)) if args.dataset == 'pnb': - classifier_outputs = PNBClassifier()(classifier_inputs) + classifier_outputs = PNBTemporalClassifier()(classifier_inputs) elif args.dataset == 'ptx': classifier_outputs = PTXClassifier()(classifier_inputs) else: @@ -176,6 +178,13 @@ if __name__ == "__main__": for labels, local_batch, vid_f in tqdm(train_loader): images = local_batch.permute(0, 2, 3, 4, 1).numpy() + cv2.imwrite('example_video_2/test_{}_0.png'.format(labels[3]), images[3, 0, :, :] * 255) + cv2.imwrite('example_video_2/test_{}_1.png'.format(labels[3]), images[3, 1, :, :] * 255) + cv2.imwrite('example_video_2/test_{}_2.png'.format(labels[3]), images[3, 2, :, :] * 255) + cv2.imwrite('example_video_2/test_{}_3.png'.format(labels[3]), images[3, 3, :, :] * 255) + cv2.imwrite('example_video_2/test_{}_4.png'.format(labels[3]), images[3, 4, :, :] * 255) + print(vid_f[3]) + raise Exception torch_labels = np.zeros(len(labels)) torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 @@ -183,17 +192,17 @@ if __name__ == "__main__": if args.train_sparse: with tf.GradientTape() as tape: - activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)]) +# activations = sparse_model([images, tf.expand_dims(recon_model.trainable_weights[0], axis=0)]) pred = classifier_model(activations) loss = criterion(torch_labels, pred) print(loss) else: - activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) +# activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) # raise Exception with tf.GradientTape() as tape: - pred = classifier_model(activations) + pred = classifier_model(images) loss = criterion(torch_labels, pred) epoch_loss += loss * local_batch.size(0) @@ -235,9 +244,9 @@ if __name__ == "__main__": torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1 torch_labels = np.expand_dims(torch_labels, axis=1) - activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) +# activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))])) - pred = classifier_model(activations) + pred = classifier_model(images) loss = criterion(torch_labels, pred) test_loss += loss diff --git a/sparse_coding_torch/train_sparse_model.py b/sparse_coding_torch/train_sparse_model.py index 8fc91ca5d995f1f68483d5c4bbf0497c742dafcd..a15e73d4fb8cfce69abe2a29e89d1f8e66667333 100644 --- a/sparse_coding_torch/train_sparse_model.py +++ b/sparse_coding_torch/train_sparse_model.py @@ -146,7 +146,8 @@ if __name__ == "__main__": elif args.dataset == 'pnb': train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, input_size=(image_height, image_width, clip_depth), crop_size=(crop_height, crop_width, clip_depth), classify_mode=False, balance_classes=False, mode='all_train', frames_to_skip=args.frames_to_skip) elif args.dataset == 'ptx': - train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json') + splits, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='positive_videos.json') + train_idx, test_idx = splits[0] elif args.dataset == 'needle': train_loader, test_loader, dataset = load_needle_clips(args.batch_size, input_size=(image_height, image_width, clip_depth)) else: diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py index fbe9e03b1fb93ef9ce670a8fecd8a8861101bd2f..6621f9a572094a43aa62fef0b74dfd8d0cc423f5 100644 --- a/sparse_coding_torch/video_loader.py +++ b/sparse_coding_torch/video_loader.py @@ -76,15 +76,14 @@ def load_pnb_region_labels(file_path): return all_regions -def get_yolo_regions(yolo_model, clip, is_right, crop_width, crop_height): +def get_yolo_regions(yolo_model, clip, is_right, angle, crop_width, crop_height): orig_height = clip.size(2) orig_width = clip.size(3) bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy()) bounding_boxes = bounding_boxes.squeeze(0) classes = classes.squeeze(0) - rotate_box = False - angle = 20 + rotate_box = True all_clips = [] for bb, class_pred in zip(bounding_boxes, classes): @@ -120,9 +119,10 @@ def get_yolo_regions(yolo_model, clip, is_right, crop_width, crop_height): # if orig_width - center_x >= center_x: # if not is_right: -# # cv2.imwrite('test.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) -# cv2.imwrite('test_3.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) -# raise Exception +# print(angle) +# cv2.imwrite('test_2.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) +# cv2.imwrite('test_3.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2)) +# raise Exception # print(trimmed_clip.size()) @@ -180,6 +180,41 @@ def classify_nerve_is_right(yolo_model, video): final_pred = round(sum(all_preds) / len(all_preds)) return final_pred == 1 + +def calculate_angle(yolo_model, video): + orig_height = video.size(2) + orig_width = video.size(3) + + all_preds = [] + if video.size(1) < 10: + return 45 + + for frame in range(0, video.size(1), round(video.size(1) / 10)): + frame = video[:, frame, :, :] + bounding_boxes, classes = yolo_model.get_bounding_boxes(frame.swapaxes(0, 2).swapaxes(0, 1).numpy()) + bounding_boxes = bounding_boxes.squeeze(0) + classes = classes.squeeze(0) + + vessel_x = 0 + vessel_y = 0 + needle_x = 0 + needle_y = 0 + + for bb, class_pred in zip(bounding_boxes, classes): + if class_pred == 0 and vessel_x == 0: + vessel_x = (bb[3] + bb[1]) / 2 * orig_width + vessel_y = (bb[2] + bb[0]) / 2 * orig_height + elif class_pred == 2 and needle_x == 0: + needle_x = bb[1] * orig_width + needle_y = bb[0] * orig_height + + if needle_x != 0 and vessel_x != 0: + break + + if vessel_x == 0 or needle_x == 0: + return 45 + else: + return np.abs(np.degrees(np.arctan((needle_y-vessel_y)/(needle_x-vessel_x)))) class PNBLoader(Dataset): @@ -217,6 +252,7 @@ class PNBLoader(Dataset): for label, path, _ in tqdm(self.videos): vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2) is_right = classify_nerve_is_right(yolo_model, vc) + angle = calculate_angle(yolo_model, vc) if classify_mode: # person_idx = path.split('/')[-2] @@ -232,7 +268,7 @@ class PNBLoader(Dataset): start_loc = int(sub_region[0]) # end_loc = int(sub_region[1]) - 50 end_loc = int(sub_region[1]) + 1 - for j in range(start_loc, end_loc - clip_depth * frames_to_skip, 5): + for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth): frames = [] for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): frames.append(vc[:, k, :, :]) @@ -241,7 +277,7 @@ class PNBLoader(Dataset): if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -261,7 +297,7 @@ class PNBLoader(Dataset): if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -272,7 +308,7 @@ class PNBLoader(Dataset): end_loc = vc.size(1) else: end_loc = int(end_loc) - for j in range(start_loc, end_loc - clip_depth * frames_to_skip, 3): + for j in range(start_loc, end_loc - clip_depth * frames_to_skip, clip_depth): frames = [] for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): frames.append(vc[:, k, :, :]) @@ -280,7 +316,7 @@ class PNBLoader(Dataset): if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -302,14 +338,14 @@ class PNBLoader(Dataset): self.clips.append((self.videos[vid_idx][0], clip, self.videos[vid_idx][2])) elif label == 'Negatives': - for j in range(0, vc.size(1) - clip_depth * frames_to_skip, 10): + for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth): frames = [] for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): frames.append(vc[:, k, :, :]) vc_sub = torch.stack(frames, dim=1) if vc_sub.size(1) < clip_depth: continue - for clip in get_yolo_regions(yolo_model, vc_sub, is_right, clip_width, clip_height): + for clip in get_yolo_regions(yolo_model, vc_sub, is_right, angle, clip_width, clip_height): if self.transform: clip = self.transform(clip) @@ -317,7 +353,7 @@ class PNBLoader(Dataset): else: raise Exception('Invalid label') else: - for j in range(0, vc.size(1) - clip_depth * frames_to_skip, 10): + for j in range(0, vc.size(1) - clip_depth * frames_to_skip, clip_depth): frames = [] for k in range(j, j + clip_depth * frames_to_skip, frames_to_skip): frames.append(vc[:, k, :, :])