Skip to content
Snippets Groups Projects
Commit d00c9911 authored by hannandarryl's avatar hannandarryl
Browse files

Pushing current model

parent e15f3e81
Branches
Tags phase-1-final
No related merge requests found
run.py 0 → 100644
import torch
import os
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d
import time
import numpy as np
import torchvision
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from torchvision.datasets.video_utils import VideoClips
import csv
from datetime import datetime
from yolov4.get_bounding_boxes import YoloModel
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--fast', action='store_true',
help='optimized for runtime')
parser.add_argument('--accurate', action='store_true',
help='optimized for accuracy')
parser.add_argument('--verbose', action='store_true',
help='output verbose')
args = parser.parse_args()
#print(args.accumulate(args.integers))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 1
frozen_sparse = ConvSparseLayer(in_channels=1,
out_channels=64,
kernel_size=(5, 15, 15),
stride=1,
padding=(0, 7, 7),
convo_dim=3,
rectifier=True,
lam=0.05,
max_activation_iter=150,
activation_lr=1e-2)
sparse_param = torch.load('sparse.pt', map_location=device)
frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
frozen_sparse.to(device)
predictive_model = SmallDataClassifierConv3d()
predictive_model.to(device)
checkpoint = {k.replace('module.', ''): v for k,v in torch.load('classifier.pt', map_location=device)['model_state_dict'].items()}
predictive_model.load_state_dict(checkpoint)
yolo_model = YoloModel()
transform = torchvision.transforms.Compose(
[VideoGrayScaler(),
MinMaxScaler(0, 255),
torchvision.transforms.Normalize((0.2592,), (0.1251,)),
torchvision.transforms.CenterCrop((100, 200))
])
frozen_sparse.eval()
predictive_model.eval()
all_predictions = []
all_files = list(os.listdir('input_videos'))
for f in all_files:
print('Processing', f)
#start_time = time.time()
clipstride = 15
if args.fast:
clipstride = 20
if args.accurate:
clipstride = 10
vc = VideoClips([os.path.join('input_videos', f)],
clip_length_in_frames=5,
frame_rate=20,
frames_between_clips=clipstride)
### START time after loading video ###
start_time = time.time()
clip_predictions = []
i = 0
cliplist = []
countclips = 0
for i in range(vc.num_clips()):
clip, _, _, _ = vc.get_clip(i)
clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)).squeeze(0)
if bounding_boxes.size == 0:
continue
#widths = []
countclips = countclips + len(bounding_boxes)
widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
#for i in range(len(bounding_boxes)):
# widths.append(bounding_boxes[i][3] - bounding_boxes[i][1])
ind = np.argmax(np.array(widths))
#for bb in bounding_boxes:
bb = bounding_boxes[ind]
center_x = (bb[3] + bb[1]) / 2 * 1920
center_y = (bb[2] + bb[0]) / 2 * 1080
width=400
height=400
lower_y = round(center_y - height / 2)
upper_y = round(center_y + height / 2)
lower_x = round(center_x - width / 2)
upper_x = round(center_x + width / 2)
trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)
trimmed_clip = transform(trimmed_clip)
trimmed_clip.pin_memory()
cliplist.append(trimmed_clip)
if len(cliplist) > 0:
with torch.no_grad():
trimmed_clip = torch.stack(cliplist)
trimmed_clip = trimmed_clip.to(device, non_blocking=True)
activations = frozen_sparse(trimmed_clip)
pred, activations = predictive_model(activations)
#print(torch.nn.Sigmoid()(pred))
clip_predictions = (torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long))
if args.verbose:
print(clip_predictions)
print("num of clips: ", countclips)
final_pred = torch.mode(clip_predictions)[0].item()
if len(clip_predictions) % 2 == 0 and torch.sum(clip_predictions).item() == len(clip_predictions)//2:
#print("I'm here")
final_pred = (torch.nn.Sigmoid()(pred)).mean().round().detach().cpu().to(torch.long).item()
if final_pred == 1:
str_pred = 'No Sliding'
else:
str_pred = 'Sliding'
else:
str_pred = "No Sliding"
end_time = time.time()
all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
writer.writeheader()
writer.writerows(all_predictions)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment