import torch
import os
from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse
import time
import numpy as np
import torchvision
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from torchvision.datasets.video_utils import VideoClips
import csv
from datetime import datetime
from yolov4.get_bounding_boxes import YoloModel
import argparse
import tensorflow as tf
import tensorflow.keras as keras


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', default='/shared_data/bamc_data/PTX_Sliding', type=str)
    parser.add_argument('--kernel_size', default=15, type=int)
    parser.add_argument('--kernel_depth', default=5, type=int)
    parser.add_argument('--num_kernels', default=64, type=int)
    parser.add_argument('--stride', default=2, type=int)
    parser.add_argument('--max_activation_iter', default=100, type=int)
    parser.add_argument('--activation_lr', default=1e-2, type=float)
    parser.add_argument('--lam', default=0.05, type=float)
    parser.add_argument('--sparse_checkpoint', default='converted_checkpoints/sparse.pt', type=str)
    parser.add_argument('--checkpoint', default='converted_checkpoints/classifier.pt', type=str)
    parser.add_argument('--run_2d', action='store_true')
    parser.add_argument('--dataset', default='ptx', type=str)
    
    args = parser.parse_args()
    #print(args.accumulate(args.integers))
    batch_size = 1
    
    if args.dataset == 'pnb':
        image_height = 250
        image_width = 600
    elif args.dataset == 'ptx':
        image_height = 100
        image_width = 200
    else:
        raise Exception('Invalid dataset')

    if args.run_2d:
        inputs = keras.Input(shape=(image_height, image_width, 5))
    else:
        inputs = keras.Input(shape=(5, image_height, image_width, 1))

    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')

    output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d, padding='SAME')(inputs, filter_inputs)

    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
    
    recon_inputs = keras.Input(shape=(1, image_height // args.stride, image_width // args.stride, args.num_kernels))
    
    recon_outputs = ReconSparse(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
    
    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)

    if args.sparse_checkpoint:
        recon_model = keras.models.load_model(args.sparse_checkpoint)
        
    if args.checkpoint:
        classifier_model = keras.models.load_model(args.checkpoint)
    else:
        classifier_inputs = keras.Input(shape=(1, image_height // args.stride, image_width // args.stride, args.num_kernels))

        if args.dataset == 'pnb':
            classifier_outputs = PNBClassifier()(classifier_inputs)
        elif args.dataset == 'ptx':
            classifier_outputs = PTXClassifier()(classifier_inputs)
        else:
            raise Exception('No classifier exists for that dataset')

        classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
        
    yolo_model = YoloModel()
        
    transform = torchvision.transforms.Compose(
    [VideoGrayScaler(),
     MinMaxScaler(0, 255),
     torchvision.transforms.Normalize((0.2592,), (0.1251,)),
     torchvision.transforms.CenterCrop((100, 200))
    ])
    
    all_predictions = []
    
    all_files = list(os.listdir(args.input_dir))
    
    for f in all_files:
        print('Processing', f)
        #start_time = time.time()
        
        clipstride = 15
        
        vc = VideoClips([os.path.join(args.input_dir, f)],
                        clip_length_in_frames=5,
                        frame_rate=20,
                       frames_between_clips=clipstride)
    
        ### START time after loading video ###
        start_time = time.time()
        clip_predictions = []
        i = 0
        cliplist = []
        countclips = 0
        for i in range(vc.num_clips()):

            clip, _, _, _ = vc.get_clip(i)
            clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
            
            bounding_boxes, classes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1))
            bounding_boxes = bounding_boxes.squeeze(0)
            if bounding_boxes.size == 0:
                continue
            #widths = []
            countclips = countclips + len(bounding_boxes)
            
            widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
            
            #for i in range(len(bounding_boxes)):
            #    widths.append(bounding_boxes[i][3] - bounding_boxes[i][1])

            ind =  np.argmax(np.array(widths))
            #for bb in bounding_boxes:
            bb = bounding_boxes[ind]
            center_x = (bb[3] + bb[1]) / 2 * 1920
            center_y = (bb[2] + bb[0]) / 2 * 1080

            width=400
            height=400

            lower_y = round(center_y - height / 2)
            upper_y = round(center_y + height / 2)
            lower_x = round(center_x - width / 2)
            upper_x = round(center_x + width / 2)

            trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]

            trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)

            trimmed_clip = transform(trimmed_clip)
            trimmed_clip.pin_memory()
            cliplist.append(trimmed_clip)

        if len(cliplist) > 0:
            with torch.no_grad():
                trimmed_clip = torch.stack(cliplist)
                images = trimmed_clip.permute(0, 2, 3, 4, 1).numpy()
                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))

                pred = classifier_model(activations)
                #print(torch.nn.Sigmoid()(pred))
                clip_predictions = tf.math.round(tf.math.sigmoid(pred))

            final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
            if len(clip_predictions) % 2 == 0 and tf.math.reduce_sum(clip_predictions) == len(clip_predictions)//2:
                #print("I'm here")
                final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
                
            if final_pred == 1:
                str_pred = 'No Sliding'
            else:
                str_pred = 'Sliding'

        else:
            str_pred = "No Sliding"
            
        print(str_pred)
            
        end_time = time.time()
        
        all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
        
    with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
        writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
        
        writer.writeheader()
        writer.writerows(all_predictions)