import torch
import os
import time
import numpy as np
import torchvision
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions
from torchvision.datasets.video_utils import VideoClips
import csv
from datetime import datetime
from yolov4.get_bounding_boxes import YoloModel
import argparse
import tensorflow as tf
import scipy.stats
import cv2

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Python program for processing PNB data')
    parser.add_argument('--classifier', type=str, default='keras/mobile_output/tf_lite_model.tflite')
    parser.add_argument('--input_dir', type=str, default='input_videos')
    args = parser.parse_args()

    interpreter = tf.lite.Interpreter(args.classifier)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    yolo_model = YoloModel()

    transform = torchvision.transforms.Compose(
    [VideoGrayScaler(),
     MinMaxScaler(0, 255),
     torchvision.transforms.Resize((360, 304))
    ])

    all_predictions = []

    all_files = list(os.listdir(args.input_dir))

    for f in all_files:
        print('Processing', f)
        
        vc = tv.io.read_video(os.path.join(args.input_dir, f))[0].permute(3, 0, 1, 2)
        
        vc_sub = vc[:, -5:, :, :]
        if vc_sub.size(1) < 5:
            raise Exception(f + ' does not contain enough frames for processing')
            
        ### START time after loading video ###
        start_time = time.time()
        
        clip = get_yolo_regions(yolo_model, vc_sub)
        if clip:
            clip = clip[0]
            clip = transform(clip)

            interpreter.set_tensor(input_details[0]['index'], clip)

            interpreter.invoke()

            output_array = np.array(interpreter.get_tensor(output_details[0]['index']))

            pred = output_array[0][0]
            print(pred)

            final_pred = pred.round()

                if final_pred == 1:
                    str_pred = 'Positive'
                else:
                    str_pred = 'Negative'
        else:
            str_pred = "Positive"

        end_time = time.time()

        print(str_pred)

        all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})

    with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
        writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())

        writer.writeheader()
        writer.writerows(all_predictions)
