From aca5c0b7000cc9f126073492868a819ae1cb34a3 Mon Sep 17 00:00:00 2001
From: Darryl Hannan <dwh1996@gmail.com>
Date: Tue, 8 Mar 2022 18:02:24 -0600
Subject: [PATCH] added tflite testing code

---
 run_tflite.py | 158 ++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 158 insertions(+)
 create mode 100644 run_tflite.py

diff --git a/run_tflite.py b/run_tflite.py
new file mode 100644
index 0000000..d4e23ec
--- /dev/null
+++ b/run_tflite.py
@@ -0,0 +1,158 @@
+import torch
+import os
+import time
+import numpy as np
+import torchvision
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
+from torchvision.datasets.video_utils import VideoClips
+import csv
+from datetime import datetime
+from yolov4.get_bounding_boxes import YoloModel
+import argparse
+import tensorflow as tf
+import scipy.stats
+import cv2
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser(description='Process some integers.')
+    parser.add_argument('--fast', action='store_true',
+                    help='optimized for runtime')
+    parser.add_argument('--accurate', action='store_true',
+                    help='optimized for accuracy')
+    parser.add_argument('--verbose', action='store_true',
+                    help='output verbose')
+    args = parser.parse_args()
+    #print(args.accumulate(args.integers))
+    device = 'cpu'
+    batch_size = 1
+
+    interpreter = tf.lite.Interpreter("keras/mobile_output/tf_lite_model.tflite")
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    yolo_model = YoloModel()
+
+    transform = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     # MinMaxScaler(0, 255),
+     # torchvision.transforms.Normalize((0.2592,), (0.1251,)),
+     torchvision.transforms.CenterCrop((100, 200))
+    ])
+
+    all_predictions = []
+
+    all_files = list(os.listdir('input_videos'))
+
+    for f in all_files:
+        print('Processing', f)
+        #start_time = time.time()
+
+        clipstride = 15
+        if args.fast:
+            clipstride = 20
+        if args.accurate:
+            clipstride = 10
+
+        vc = VideoClips([os.path.join('input_videos', f)],
+                        clip_length_in_frames=5,
+                        frame_rate=20,
+                       frames_between_clips=clipstride)
+
+        ### START time after loading video ###
+        start_time = time.time()
+        clip_predictions = []
+        i = 0
+        cliplist = []
+        countclips = 0
+        for i in range(vc.num_clips()):
+
+            clip, _, _, _ = vc.get_clip(i)
+            clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
+
+            bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)).squeeze(0)
+            # for bb in bounding_boxes:
+            #     print(bb[1])
+            if bounding_boxes.size == 0:
+                continue
+            #widths = []
+            countclips = countclips + len(bounding_boxes)
+
+            widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
+
+            #for i in range(len(bounding_boxes)):
+            #    widths.append(bounding_boxes[i][3] - bounding_boxes[i][1])
+
+            ind =  np.argmax(np.array(widths))
+            #for bb in bounding_boxes:
+            bb = bounding_boxes[ind]
+            center_x = (bb[3] + bb[1]) / 2 * 1920
+            center_y = (bb[2] + bb[0]) / 2 * 1080
+
+            width=400
+            height=400
+
+            lower_y = round(center_y - height / 2)
+            upper_y = round(center_y + height / 2)
+            lower_x = round(center_x - width / 2)
+            upper_x = round(center_x + width / 2)
+
+            trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
+
+            trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)
+
+            trimmed_clip = transform(trimmed_clip)
+
+            # tensor_to_write = trimmed_clip.swapaxes(0, 1).swapaxes(1, 2).swapaxes(2, 3)
+            # tensor_to_write[0][0][0][0] = 100
+            # tensor_to_write[0][0][0][1] = 100
+            # tensor_to_write[0][0][0][2] = 100
+            # torchvision.io.write_video('clips_to_test_swift/' + str(countclips) + '.mp4', tensor_to_write, fps=20)
+            # countclips += 1
+            # trimmed_clip.pin_memory()
+            cliplist.append(trimmed_clip)
+
+        if len(cliplist) > 0:
+            with torch.no_grad():
+                for trimmed_clip in cliplist:
+                    interpreter.set_tensor(input_details[0]['index'], trimmed_clip)
+
+                    interpreter.invoke()
+
+                    output_array = np.array(interpreter.get_tensor(output_details[0]['index']))
+
+                    pred = output_array[0][0]
+                    print(pred)
+
+                    clip_predictions.append(pred.round())
+
+            if args.verbose:
+                print(clip_predictions)
+                print("num of clips: ", countclips)
+
+            final_pred = scipy.stats.mode(clip_predictions)[0][0]
+            # if len(clip_predictions) % 2 == 0 and torch.sum(clip_predictions).item() == len(clip_predictions)//2:
+            #     #print("I'm here")
+            #     final_pred = (torch.nn.Sigmoid()(pred)).mean().round().detach().cpu().to(torch.long).item()
+
+            if final_pred == 1:
+                str_pred = 'No Sliding'
+            else:
+                str_pred = 'Sliding'
+
+        else:
+            str_pred = "No Sliding"
+
+        end_time = time.time()
+
+        print(str_pred)
+
+        all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
+
+    with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
+        writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
+
+        writer.writeheader()
+        writer.writerows(all_predictions)
-- 
GitLab