import numpy as np
import torchvision
import torch
from sklearn.model_selection import train_test_split
from sparse_coding_torch.video_loader import MinMaxScaler
from sparse_coding_torch.video_loader import YoloClipLoader, get_video_participants, PNBLoader
from sparse_coding_torch.video_loader import VideoGrayScaler
import csv
from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold

def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None):   
    video_path = "/shared_data/YOLO_Updated_PL_Model_Results/"

    video_to_participant = get_video_participants()
    
    transforms = torchvision.transforms.Compose(
    [VideoGrayScaler(),
     torchvision.transforms.Normalize((0.2592,), (0.1251,)),
    ])
    augment_transforms = torchvision.transforms.Compose(
    [torchvision.transforms.RandomRotation(45),
     torchvision.transforms.RandomHorizontalFlip()
#      torchvision.transforms.CenterCrop((100, 200))
    ])
    if whole_video:
        dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
    else:
        dataset = YoloClipLoader(video_path, num_clips=num_clips, num_positives=num_positives, positive_videos=positive_videos, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
    
    targets = dataset.get_labels()
    
    if mode == 'leave_one_out':
        gss = LeaveOneGroupOut()

#         groups = [v for v in dataset.get_filenames()]
        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
        
        return gss.split(np.arange(len(targets)), targets, groups), dataset
    elif mode == 'all_train':
        train_idx = np.arange(len(targets))
        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                               sampler=train_sampler)
        test_loader = None
        
        return train_loader, test_loader
    elif mode == 'k_fold':
        gss = StratifiedGroupKFold(n_splits=n_splits)

        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
        
        return gss.split(np.arange(len(targets)), targets, groups), dataset
    else:
        return None

    
def load_pnb_videos(batch_size, mode, classify_mode=False, device=None, n_splits=None, sparse_model=None):   
    video_path = "/shared_data/bamc_pnb_data/full_training_data"
    
    transforms = torchvision.transforms.Compose(
    [VideoGrayScaler(),
     MinMaxScaler(0, 255),
     torchvision.transforms.Resize((360, 304))
    ])
    augment_transforms = torchvision.transforms.Compose(
    [torchvision.transforms.RandomAffine(45),
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ColorJitter(brightness=0.5),
     torchvision.transforms.RandomAdjustSharpness(0, p=0.15),
     torchvision.transforms.RandomAffine(degrees=0, translate=(0.05, 0))
#      torchvision.transforms.CenterCrop((100, 200))
    ])
    dataset = PNBLoader(video_path, classify_mode, num_frames=5, frame_rate=20, transform=transforms, augmentation=augment_transforms)
    
    targets = dataset.get_labels()
    
    if mode == 'leave_one_out':
        gss = LeaveOneGroupOut()

        groups = [v for v in dataset.get_filenames()]
#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
        
        return gss.split(np.arange(len(targets)), targets, groups), dataset
    elif mode == 'all_train':
        train_idx = np.arange(len(targets))
        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                               sampler=train_sampler)
        test_loader = None
        
        return train_loader, test_loader
    elif mode == 'k_fold':
        gss = StratifiedKFold(n_splits=n_splits, shuffle=True)

#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
        groups = [v for v in dataset.get_filenames()]
        
        return gss.split(np.arange(len(targets)), targets), dataset
    else:
        return None