import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import argparse
import os
from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos
from sparse_coding_torch.keras_model import SparseCode, Classifier, ReconSparse
import time
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import random
import pickle
import tensorflow.keras as keras
import tensorflow as tf

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', default=12, type=int)
    parser.add_argument('--kernel_size', default=15, type=int)
    parser.add_argument('--kernel_depth', default=5, type=int)
    parser.add_argument('--num_kernels', default=64, type=int)
    parser.add_argument('--stride', default=1, type=int)
    parser.add_argument('--max_activation_iter', default=150, type=int)
    parser.add_argument('--activation_lr', default=1e-1, type=float)
    parser.add_argument('--lr', default=5e-5, type=float)
    parser.add_argument('--epochs', default=10, type=int)
    parser.add_argument('--lam', default=0.05, type=float)
    parser.add_argument('--output_dir', default='./output', type=str)
    parser.add_argument('--sparse_checkpoint', default=None, type=str)
    parser.add_argument('--checkpoint', default=None, type=str)
    parser.add_argument('--splits', default='k_fold', type=str, help='k_fold or leave_one_out or all_train')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--num_positives', default=100, type=int)
    parser.add_argument('--n_splits', default=5, type=int)
    parser.add_argument('--save_train_test_splits', action='store_true')
    parser.add_argument('--run_2d', action='store_true')
    
    args = parser.parse_args()
    
    image_height = 360
    image_width = 304
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
        out_f.write(str(args))

    all_errors = []
    
    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
    
    recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
    
    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)

    if args.sparse_checkpoint:
        recon_model = keras.models.load_model(args.sparse_checkpoint)
        
    splits, dataset = load_pnb_videos(args.batch_size, classify_mode=True, mode='k_fold', device=None, n_splits=args.n_splits, sparse_model=None)
    i_fold = 0
    
    overall_true = []
    overall_pred = []
    fn_ids = []
    fp_ids = []
        
    for train_idx, test_idx in splits:
        
        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)

        train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
                                                   # shuffle=True,
                                                   sampler=train_sampler)

        test_loader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                                        # shuffle=True,
                                                        sampler=test_sampler)
        
        classifier_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))

        classifier_outputs = Classifier()(classifier_inputs)

        classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)


        best_so_far = float('inf')

        criterion = keras.losses.BinaryCrossentropy(from_logits=False)

        if args.checkpoint:
            classifier_model.load(args.checkpoint)

        if args.train:
            prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)

            for epoch in range(args.epochs):
                epoch_loss = 0
                t1 = time.perf_counter()
                
                if args.run_2d:
                    inputs = keras.Input(shape=(image_height, image_width, 5))
                else:
                    inputs = keras.Input(shape=(5, image_height, image_width, 1))

                filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')

                output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)

                sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)

                for labels, local_batch, vid_f in tqdm(train_loader):
                    if local_batch.size(0) != args.batch_size:
                        continue
                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()

                    torch_labels = np.zeros(len(labels))
                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
                    torch_labels = np.expand_dims(torch_labels, axis=1)

                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))

                    with tf.GradientTape() as tape:
                        pred = classifier_model(activations)
                        loss = criterion(torch_labels, pred)

                    epoch_loss += loss * local_batch.size(0)

                    gradients = tape.gradient(loss, classifier_model.trainable_weights)

                    prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))

                t2 = time.perf_counter()
                
                if args.run_2d:
                    inputs = keras.Input(shape=(image_height, image_width, 5))
                else:
                    inputs = keras.Input(shape=(5, image_height, image_width, 1))

                filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')

                output = SparseCode(batch_size=1, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)

                sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)

                y_true = None
                y_pred = None
                for labels, local_batch, vid_f in test_loader:
                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()

                    torch_labels = np.zeros(len(labels))
                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
                    torch_labels = np.expand_dims(torch_labels, axis=1)

                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))

                    pred = classifier_model(activations)

                    if y_true is None:
                        y_true = torch_labels
                        y_pred = tf.math.round(tf.math.sigmoid(pred))
                    else:
                        y_true = tf.concat((y_true, torch_labels), axis=0)
                        y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)

                t2 = time.perf_counter()
                
                y_true = tf.cast(y_true, tf.int32)
                y_pred = tf.cast(y_pred, tf.int32)

                f1 = f1_score(y_true, y_pred, average='macro')
                accuracy = accuracy_score(y_true, y_pred)

                print('fold={}, epoch={}, time={:.2f}, loss={:.2f}, f1={:.2f}, acc={:.2f}'.format(i_fold, epoch, t2-t1, epoch_loss, f1, accuracy))
    #             print(epoch_loss)
                if epoch_loss <= best_so_far:
                    print("found better model")
                    # Save model parameters
                    classifier_model.save(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
                    best_so_far = epoch_loss

            classifier_model = keras.models.load_model(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
            
            if args.run_2d:
                inputs = keras.Input(shape=(image_height, image_width, 5))
            else:
                inputs = keras.Input(shape=(5, image_height, image_width, 1))

            filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')

            output = SparseCode(batch_size=1, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)

            sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)

            epoch_loss = 0

            y_true = None
            y_pred = None

            pred_dict = {}
            gt_dict = {}

            t1 = time.perf_counter()
            for labels, local_batch, vid_f in test_loader:
                images = local_batch.permute(0, 2, 3, 4, 1).numpy()

                torch_labels = np.zeros(len(labels))
                torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
                torch_labels = np.expand_dims(torch_labels, axis=1)

                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))

                pred = classifier_model(activations)

                loss = criterion(torch_labels, pred)
                epoch_loss += loss * local_batch.size(0)

                for i, v_f in enumerate(vid_f):
                    final_pred = tf.math.round(pred[i])[0]
                    gt = torch_labels[i]
                    
                    overall_true.append(gt)
                    overall_pred.append(final_pred)
                
                    if final_pred != gt:
                        if final_pred == 0:
                            fn_ids.append(v_f)
                        else:
                            fp_ids.append(v_f)

                if y_true is None:
                    y_true = torch_labels
                    y_pred = tf.math.round(tf.math.sigmoid(pred))
                else:
                    y_true = tf.concat((y_true, torch_labels), axis=0)
                    y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)

            t2 = time.perf_counter()

            print('fold={}, loss={:.2f}, time={:.2f}'.format(i_fold, loss, t2-t1))
                
            y_true = tf.cast(y_true, tf.int32)
            y_pred = tf.cast(y_pred, tf.int32)

            f1 = f1_score(y_true, y_pred, average='macro')
            accuracy = accuracy_score(y_true, y_pred)

            print("Test f1={:.2f}, clip_acc={:.2f}, fold={}".format(f1, accuracy, i_fold))

            print(confusion_matrix(y_true, y_pred))

        i_fold = i_fold + 1

    fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
    with open(fp_fn_file, 'w+') as in_f:
        in_f.write('FP:\n')
        in_f.write(str(fp_ids) + '\n\n')
        in_f.write('FN:\n')
        in_f.write(str(fn_ids) + '\n\n')

    overall_true = np.array(overall_true)
    overall_pred = np.array(overall_pred)

    final_f1 = f1_score(overall_true, overall_pred, average='macro')
    final_acc = accuracy_score(overall_true, overall_pred)
    final_conf = confusion_matrix(overall_true, overall_pred)

    print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1))
    print(final_conf)
