from sklearn.decomposition import PCA
from math import ceil
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def find_mean_img(full_mat):
    """Calculates and plots the mean of each pixel in an image matrix.

    Args:
        full_mat (np.ndarray): Vectorized array of the image matrix.
        title (String): Name of the title for the plot.

    Returns:
        matplotlib.plt: A plot of the the mean pixels for each disease category.
    """
    cols = 4
    rows = len(full_mat)//cols + 1
    
    fig = plt.figure(figsize = (12,6))

    for i, mat in zip(range(0,len(full_mat)),full_mat):
    # calculate the average
        mean_img = np.mean(full_mat[mat], axis = 0)
    # reshape it back to a matrix
        mean_img = mean_img.reshape((170, 120))
        ax = fig.add_subplot(rows, cols,i+1)
        ax.imshow(mean_img, vmin=0, vmax=255, cmap='Greys_r')
        ax.set_title('Average ' + mat)
        plt.axis('off')

    plt.tight_layout()
    
def plot_pca(pca, title, size = (170, 120)):
    """Plots each decomposed PCA image and labels the amount of variability for each image.

    Args:
        pca (sklearn PCA object): A fitted PCA object.
        title (String): Title of the plot.
        size (tuple, optional): Shape of the image matrix. Defaults to (300,225).
    """
    # plot eigen images in a grid
    n = pca.n_components_
    print('Number of PC in ' + title + ':', n)
    fig = plt.figure(figsize=(8, 8))
    fig.suptitle('PCA Components of ' + title)
    r = int(n**.5)
    c = ceil(n/ r) 
    for i in range(n):
        ax = fig.add_subplot(r, c, i + 1)
        ax.imshow(pca.components_[i].reshape(size), 
                  cmap='Greys_r')
        ax.set_title("Variance " + "{0:.2f}%".format(pca.explained_variance_ratio_[i] * 100) )
        plt.axis('off')
        plt.tight_layout()

    plt.show()


def eigenimages(full_mat,n_comp = 0.7, size = (170, 120)):
    """Creates creates and fits a PCA estimator from sklearn.

    Args:
        full_mat (np.ndarray): A vectorized array of images.
        n_comp (float, optional): Percentage of desired variability. Defaults to 0.7.
        size (tuple, optional): Shape of the image matrix. Defaults to (300,225).

    Returns:
        sklearn PCA object: Fitted PCA model.
    """
    # fit PCA to describe n_comp * variability in the class
    print(full_mat, n_comp)
    pca = PCA(n_components = n_comp, whiten = True)
    pca.fit(full_mat)

    return pca