import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import argparse
import os
from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler
from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse
import time
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import random
import pickle
import tensorflow.keras as keras
import tensorflow as tf

configproto = tf.compat.v1.ConfigProto()
configproto.gpu_options.polling_inactive_delay_msecs = 5000
configproto.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=configproto) 
tf.compat.v1.keras.backend.set_session(sess)

def get_sample_weights(train_idx, dataset):
    dataset = list(dataset)

    num_positive = len([clip[0] for clip in dataset if clip[0] == 'Positives'])
    negative_weight = num_positive / len(dataset)
    positive_weight = 1.0 - negative_weight
    
    weights = []
    for idx in train_idx:
        label = dataset[idx][0]
        if label == 'Positives':
            weights.append(positive_weight)
        elif label == 'Negatives':
            weights.append(negative_weight)
        else:
            raise Exception('Sampler encountered invalid label')
    
    return weights

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', default=12, type=int)
    parser.add_argument('--kernel_size', default=15, type=int)
    parser.add_argument('--kernel_depth', default=5, type=int)
    parser.add_argument('--num_kernels', default=64, type=int)
    parser.add_argument('--stride', default=1, type=int)
    parser.add_argument('--max_activation_iter', default=150, type=int)
    parser.add_argument('--activation_lr', default=1e-2, type=float)
    parser.add_argument('--lr', default=5e-2, type=float)
    parser.add_argument('--epochs', default=40, type=int)
    parser.add_argument('--lam', default=0.05, type=float)
    parser.add_argument('--output_dir', default='./output', type=str)
    parser.add_argument('--sparse_checkpoint', default=None, type=str)
    parser.add_argument('--checkpoint', default=None, type=str)
    parser.add_argument('--splits', default=None, type=str, help='k_fold or leave_one_out or all_train')
    parser.add_argument('--seed', default=26, type=int)
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--num_positives', default=100, type=int)
    parser.add_argument('--n_splits', default=1, type=int)
    parser.add_argument('--save_train_test_splits', action='store_true')
    parser.add_argument('--run_2d', action='store_true')
    parser.add_argument('--balance_classes', action='store_true')
    parser.add_argument('--dataset', default='pnb', type=str)
    
    args = parser.parse_args()
    
    if args.dataset == 'pnb':
        image_height = 250
        image_width = 600
    elif args.dataset == 'ptx':
        image_height = 100
        image_width = 200
    else:
        raise Exception('Invalid dataset')
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
        out_f.write(str(args))

    all_errors = []
    
    if args.run_2d:
        inputs = keras.Input(shape=(image_height, image_width, 5))
    else:
        inputs = keras.Input(shape=(5, image_height, image_width, 1))

    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')

    output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)

    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
    
    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
    
    recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
    
    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)

    if args.sparse_checkpoint:
        recon_model = keras.models.load_model(args.sparse_checkpoint)
        
    positive_class = None
    if args.dataset == 'pnb':
        train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None)
        positive_class = 'Positives'
    elif args.dataset == 'ptx':
        train_loader, test_loader, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
        positive_class = 'No_Sliding'
    else:
        raise Exception('Invalid dataset')
    
    overall_true = []
    overall_pred = []
    fn_ids = []
    fp_ids = []
        
    if args.checkpoint:
        classifier_model = keras.models.load_model(args.checkpoint)
    else:
        if os.path.exists(os.path.join(args.output_dir, 'best_classifier.pt')):
            classifier_model = keras.models.load_model(os.path.join(output_dir, 'best_classifier.pt'))

        classifier_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))

        if args.dataset == 'pnb':
            classifier_outputs = PNBClassifier()(classifier_inputs)
        elif args.dataset == 'ptx':
            classifier_outputs = PTXClassifier()(classifier_inputs)
        else:
            raise Exception('No classifier exists for that dataset')

        classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)

    with open(os.path.join(output_dir, 'classifier_summary.txt'), 'w+') as out_f:
        out_f.write(str(classifier_model.summary()))
    
    prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)

    best_so_far = float('inf')

    criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM)

    if args.train:
        for epoch in range(args.epochs):
            epoch_loss = 0
            t1 = time.perf_counter()

            y_true_train = None
            y_pred_train = None

            for labels, local_batch, vid_f in tqdm(train_loader):
                images = local_batch.permute(0, 2, 3, 4, 1).numpy()

                torch_labels = np.zeros(len(labels))
                torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
                torch_labels = np.expand_dims(torch_labels, axis=1)

                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
#                     print(tf.math.reduce_sum(activations))

                with tf.GradientTape() as tape:
                    pred = classifier_model(activations)
                    loss = criterion(torch_labels, pred)

#                         print(pred)
#                         print(tf.math.sigmoid(pred))
#                         print(loss)
#                     print(torch_labels)

                epoch_loss += loss * local_batch.size(0)

                gradients = tape.gradient(loss, classifier_model.trainable_weights)

                prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))

                if y_true_train is None:
                    y_true_train = torch_labels
                    y_pred_train = tf.math.round(tf.math.sigmoid(pred))
                else:
                    y_true_train = tf.concat((y_true_train, torch_labels), axis=0)
                    y_pred_train = tf.concat((y_pred_train, tf.math.round(tf.math.sigmoid(pred))), axis=0)

            t2 = time.perf_counter()

            y_true = None
            y_pred = None
            test_loss = 0.0
            for labels, local_batch, vid_f in tqdm(test_loader):
                images = local_batch.permute(0, 2, 3, 4, 1).numpy()

                torch_labels = np.zeros(len(labels))
                torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
                torch_labels = np.expand_dims(torch_labels, axis=1)

                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))

                pred = classifier_model(activations)
                loss = criterion(torch_labels, pred)

                test_loss += loss

                if y_true is None:
                    y_true = torch_labels
                    y_pred = tf.math.round(tf.math.sigmoid(pred))
                else:
                    y_true = tf.concat((y_true, torch_labels), axis=0)
                    y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)

            t2 = time.perf_counter()

            y_true = tf.cast(y_true, tf.int32)
            y_pred = tf.cast(y_pred, tf.int32)

            y_true_train = tf.cast(y_true_train, tf.int32)
            y_pred_train = tf.cast(y_pred_train, tf.int32)

            f1 = f1_score(y_true, y_pred, average='macro')
            accuracy = accuracy_score(y_true, y_pred)

            train_accuracy = accuracy_score(y_true_train, y_pred_train)

            print('epoch={}, time={:.2f}, train_loss={:.2f}, test_loss={:.2f}, train_acc={:.2f}, test_f1={:.2f}, test_acc={:.2f}'.format(epoch, t2-t1, epoch_loss, test_loss, train_accuracy, f1, accuracy))
#             print(epoch_loss)
            if epoch_loss <= best_so_far:
                print("found better model")
                # Save model parameters
                classifier_model.save(os.path.join(output_dir, "best_classifier.pt"))
                pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer.pt'), 'wb+'))
                best_so_far = epoch_loss

        classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier.pt"))

    if args.dataset == 'pnb':
        epoch_loss = 0

        y_true = None
        y_pred = None

        pred_dict = {}
        gt_dict = {}

        t1 = time.perf_counter()
#         test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader]
        test_videos = [single_vid for labels, local_batch, vid_f in test_loader for single_vid in vid_f]

        for k, v in tqdm(dataset.get_final_clips().items()):
            if k not in test_videos:
                continue
            labels, local_batch, vid_f = v
            images = local_batch.unsqueeze(0).permute(0, 2, 3, 4, 1).numpy()
            labels = [labels]

            torch_labels = np.zeros(len(labels))
            torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
            torch_labels = np.expand_dims(torch_labels, axis=1)

            activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))

            pred = classifier_model(activations)

            loss = criterion(torch_labels, pred)
            epoch_loss += loss

            final_pred = tf.math.round(tf.math.sigmoid(pred))
            gt = torch_labels
            
            if final_pred != gt:
                if final_pred == 0:
                    fn_ids.append(k)
                else:
                    fp_ids.append(k)

            overall_true.append(gt)
            overall_pred.append(final_pred)

            if final_pred != gt:
                if final_pred == 0:
                    fn_ids.append(vid_f)
                else:
                    fp_ids.append(vid_f)

            if y_true is None:
                y_true = torch_labels
                y_pred = tf.math.round(tf.math.sigmoid(pred))
            else:
                y_true = tf.concat((y_true, torch_labels), axis=0)
                y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)

        t2 = time.perf_counter()

        print('loss={:.2f}, time={:.2f}'.format(loss, t2-t1))

        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.cast(y_pred, tf.int32)

        f1 = f1_score(y_true, y_pred, average='macro')
        accuracy = accuracy_score(y_true, y_pred)

        print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy))

        print(confusion_matrix(y_true, y_pred))
    elif args.dataset == 'ptx':
        epoch_loss = 0

        y_true = None
        y_pred = None

        pred_dict = {}
        gt_dict = {}

        t1 = time.perf_counter()
        for labels, local_batch, vid_f in test_loader:
            images = local_batch.permute(0, 2, 3, 4, 1).numpy()

            torch_labels = np.zeros(len(labels))
            torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
            torch_labels = np.expand_dims(torch_labels, axis=1)

            activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))

            pred = classifier_model(activations)

            loss = criterion(torch_labels, pred)
            epoch_loss += loss * local_batch.size(0)

            for i, v_f in enumerate(vid_f):
                if v_f not in pred_dict:
                    pred_dict[v_f] = tf.math.round(tf.math.sigmoid(pred[i]))
                else:
                    pred_dict[v_f] = tf.concat((pred_dict[v_f], tf.math.round(tf.math.sigmoid(pred[i]))), axis=0)

                if v_f not in gt_dict:
                    gt_dict[v_f] = torch_labels[i]
                else:
                    gt_dict[v_f] = tf.concat((gt_dict[v_f], torch_labels[i]), axis=0)

            if y_true is None:
                y_true = torch_labels
                y_pred = tf.math.round(tf.math.sigmoid(pred))
            else:
                y_true = tf.concat((y_true, torch_labels), axis=0)
                y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)

        t2 = time.perf_counter()

        vid_acc = []
        for k in pred_dict.keys():
            gt_tmp = torch.tensor(gt_dict[k].numpy())
            pred_tmp = torch.tensor(pred_dict[k].numpy())

            gt_mode = torch.mode(gt_tmp)[0].item()
#             perm = torch.randperm(pred_tmp.size(0))
#             cutoff = int(pred_tmp.size(0)/4)
#             if cutoff < 3:
#                 cutoff = 3
#             idx = perm[:cutoff]
#             samples = pred_tmp[idx]
            pred_mode = torch.mode(pred_tmp)[0].item()
            overall_true.append(gt_mode)
            overall_pred.append(pred_mode)
            if pred_mode == gt_mode:
                vid_acc.append(1)
            else:
                vid_acc.append(0)
                if pred_mode == 0:
                    fn_ids.append(k)
                else:
                    fp_ids.append(k)

        vid_acc = np.array(vid_acc)

        print('----------------------------------------------------------------------------')
        for k in pred_dict.keys():
            print(k)
            print('Predictions:')
            print(pred_dict[k])
            print('Ground Truth:')
            print(gt_dict[k])
            print('Overall Prediction:')
            print(torch.mode(torch.tensor(pred_dict[k].numpy()))[0].item())
            print('----------------------------------------------------------------------------')

        print('loss={:.2f}, time={:.2f}'.format(loss, t2-t1))

        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.cast(y_pred, tf.int32)

        f1 = f1_score(y_true, y_pred, average='macro')
        accuracy = accuracy_score(y_true, y_pred)
        all_errors.append(np.sum(vid_acc) / len(vid_acc))

        print("Test f1={:.2f}, clip_acc={:.2f}, vid_acc={:.2f}".format(f1, accuracy, np.sum(vid_acc) / len(vid_acc)))

        print(confusion_matrix(y_true, y_pred))

    fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
    with open(fp_fn_file, 'w+') as in_f:
        in_f.write('FP:\n')
        in_f.write(str(fp_ids) + '\n\n')
        in_f.write('FN:\n')
        in_f.write(str(fn_ids) + '\n\n')
