diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce10ed5d7ddcd34767702a63eb1812a649bd2c3 --- /dev/null +++ b/run.py @@ -0,0 +1,161 @@ +import torch +import os +from sparse_coding_torch.conv_sparse_model import ConvSparseLayer +from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d +import time +import numpy as np +import torchvision +from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler +from torchvision.datasets.video_utils import VideoClips +import csv +from datetime import datetime +from yolov4.get_bounding_boxes import YoloModel +import argparse + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('--fast', action='store_true', + help='optimized for runtime') + parser.add_argument('--accurate', action='store_true', + help='optimized for accuracy') + parser.add_argument('--verbose', action='store_true', + help='output verbose') + args = parser.parse_args() + #print(args.accumulate(args.integers)) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + batch_size = 1 + + frozen_sparse = ConvSparseLayer(in_channels=1, + out_channels=64, + kernel_size=(5, 15, 15), + stride=1, + padding=(0, 7, 7), + convo_dim=3, + rectifier=True, + lam=0.05, + max_activation_iter=150, + activation_lr=1e-2) + + sparse_param = torch.load('sparse.pt', map_location=device) + frozen_sparse.load_state_dict(sparse_param['model_state_dict']) + frozen_sparse.to(device) + + predictive_model = SmallDataClassifierConv3d() + predictive_model.to(device) + + checkpoint = {k.replace('module.', ''): v for k,v in torch.load('classifier.pt', map_location=device)['model_state_dict'].items()} + predictive_model.load_state_dict(checkpoint) + + yolo_model = YoloModel() + + transform = torchvision.transforms.Compose( + [VideoGrayScaler(), + MinMaxScaler(0, 255), + torchvision.transforms.Normalize((0.2592,), (0.1251,)), + torchvision.transforms.CenterCrop((100, 200)) + ]) + + frozen_sparse.eval() + predictive_model.eval() + + all_predictions = [] + + all_files = list(os.listdir('input_videos')) + + for f in all_files: + print('Processing', f) + #start_time = time.time() + + clipstride = 15 + if args.fast: + clipstride = 20 + if args.accurate: + clipstride = 10 + + vc = VideoClips([os.path.join('input_videos', f)], + clip_length_in_frames=5, + frame_rate=20, + frames_between_clips=clipstride) + + ### START time after loading video ### + start_time = time.time() + clip_predictions = [] + i = 0 + cliplist = [] + countclips = 0 + for i in range(vc.num_clips()): + + clip, _, _, _ = vc.get_clip(i) + clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy() + + bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)).squeeze(0) + if bounding_boxes.size == 0: + continue + #widths = [] + countclips = countclips + len(bounding_boxes) + + widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))] + + #for i in range(len(bounding_boxes)): + # widths.append(bounding_boxes[i][3] - bounding_boxes[i][1]) + + ind = np.argmax(np.array(widths)) + #for bb in bounding_boxes: + bb = bounding_boxes[ind] + center_x = (bb[3] + bb[1]) / 2 * 1920 + center_y = (bb[2] + bb[0]) / 2 * 1080 + + width=400 + height=400 + + lower_y = round(center_y - height / 2) + upper_y = round(center_y + height / 2) + lower_x = round(center_x - width / 2) + upper_x = round(center_x + width / 2) + + trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x] + + trimmed_clip = torch.tensor(trimmed_clip).to(torch.float) + + trimmed_clip = transform(trimmed_clip) + trimmed_clip.pin_memory() + cliplist.append(trimmed_clip) + + if len(cliplist) > 0: + with torch.no_grad(): + trimmed_clip = torch.stack(cliplist) + trimmed_clip = trimmed_clip.to(device, non_blocking=True) + activations = frozen_sparse(trimmed_clip) + + pred, activations = predictive_model(activations) + #print(torch.nn.Sigmoid()(pred)) + clip_predictions = (torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long)) + + if args.verbose: + print(clip_predictions) + print("num of clips: ", countclips) + final_pred = torch.mode(clip_predictions)[0].item() + if len(clip_predictions) % 2 == 0 and torch.sum(clip_predictions).item() == len(clip_predictions)//2: + #print("I'm here") + final_pred = (torch.nn.Sigmoid()(pred)).mean().round().detach().cpu().to(torch.long).item() + + + if final_pred == 1: + str_pred = 'No Sliding' + else: + str_pred = 'Sliding' + + else: + str_pred = "No Sliding" + + end_time = time.time() + + all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time}) + + with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out: + writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys()) + + writer.writeheader() + writer.writerows(all_predictions)