Skip to content
Snippets Groups Projects
Commit d00c9911 authored by hannandarryl's avatar hannandarryl
Browse files

Pushing current model

parent e15f3e81
No related branches found
No related tags found
No related merge requests found
run.py 0 → 100644
import torch
import os
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d
import time
import numpy as np
import torchvision
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from torchvision.datasets.video_utils import VideoClips
import csv
from datetime import datetime
from yolov4.get_bounding_boxes import YoloModel
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--fast', action='store_true',
help='optimized for runtime')
parser.add_argument('--accurate', action='store_true',
help='optimized for accuracy')
parser.add_argument('--verbose', action='store_true',
help='output verbose')
args = parser.parse_args()
#print(args.accumulate(args.integers))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 1
frozen_sparse = ConvSparseLayer(in_channels=1,
out_channels=64,
kernel_size=(5, 15, 15),
stride=1,
padding=(0, 7, 7),
convo_dim=3,
rectifier=True,
lam=0.05,
max_activation_iter=150,
activation_lr=1e-2)
sparse_param = torch.load('sparse.pt', map_location=device)
frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
frozen_sparse.to(device)
predictive_model = SmallDataClassifierConv3d()
predictive_model.to(device)
checkpoint = {k.replace('module.', ''): v for k,v in torch.load('classifier.pt', map_location=device)['model_state_dict'].items()}
predictive_model.load_state_dict(checkpoint)
yolo_model = YoloModel()
transform = torchvision.transforms.Compose(
[VideoGrayScaler(),
MinMaxScaler(0, 255),
torchvision.transforms.Normalize((0.2592,), (0.1251,)),
torchvision.transforms.CenterCrop((100, 200))
])
frozen_sparse.eval()
predictive_model.eval()
all_predictions = []
all_files = list(os.listdir('input_videos'))
for f in all_files:
print('Processing', f)
#start_time = time.time()
clipstride = 15
if args.fast:
clipstride = 20
if args.accurate:
clipstride = 10
vc = VideoClips([os.path.join('input_videos', f)],
clip_length_in_frames=5,
frame_rate=20,
frames_between_clips=clipstride)
### START time after loading video ###
start_time = time.time()
clip_predictions = []
i = 0
cliplist = []
countclips = 0
for i in range(vc.num_clips()):
clip, _, _, _ = vc.get_clip(i)
clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)).squeeze(0)
if bounding_boxes.size == 0:
continue
#widths = []
countclips = countclips + len(bounding_boxes)
widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
#for i in range(len(bounding_boxes)):
# widths.append(bounding_boxes[i][3] - bounding_boxes[i][1])
ind = np.argmax(np.array(widths))
#for bb in bounding_boxes:
bb = bounding_boxes[ind]
center_x = (bb[3] + bb[1]) / 2 * 1920
center_y = (bb[2] + bb[0]) / 2 * 1080
width=400
height=400
lower_y = round(center_y - height / 2)
upper_y = round(center_y + height / 2)
lower_x = round(center_x - width / 2)
upper_x = round(center_x + width / 2)
trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)
trimmed_clip = transform(trimmed_clip)
trimmed_clip.pin_memory()
cliplist.append(trimmed_clip)
if len(cliplist) > 0:
with torch.no_grad():
trimmed_clip = torch.stack(cliplist)
trimmed_clip = trimmed_clip.to(device, non_blocking=True)
activations = frozen_sparse(trimmed_clip)
pred, activations = predictive_model(activations)
#print(torch.nn.Sigmoid()(pred))
clip_predictions = (torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long))
if args.verbose:
print(clip_predictions)
print("num of clips: ", countclips)
final_pred = torch.mode(clip_predictions)[0].item()
if len(clip_predictions) % 2 == 0 and torch.sum(clip_predictions).item() == len(clip_predictions)//2:
#print("I'm here")
final_pred = (torch.nn.Sigmoid()(pred)).mean().round().detach().cpu().to(torch.long).item()
if final_pred == 1:
str_pred = 'No Sliding'
else:
str_pred = 'Sliding'
else:
str_pred = "No Sliding"
end_time = time.time()
all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
writer.writeheader()
writer.writerows(all_predictions)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment