import itertools
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.mixture


def cluster(expr_data, n_components_range, plot=True):
    X = expr_data.values
    X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)

    expr_data_norm = X.T

    lowest_bic = np.infty
    bic = []
    cv_types = ["spherical", "tied", "diag", "full"]
    best_cv_type = None
    best_n_components = None

    for cv_type in cv_types:
        print(cv_type)
        print("=" * 10)
        for n_components in n_components_range:
            print(n_components)
            # Fit a Gaussian mixture with EM
            gmm = sklearn.mixture.GaussianMixture(
                n_components=n_components, covariance_type=cv_type
            )
            gmm.fit(expr_data_norm)
            bic.append(gmm.bic(expr_data_norm))
            if bic[-1] < lowest_bic:
                lowest_bic = bic[-1]
                best_cv_type = cv_type
                best_n_components = n_components
                best_gmm = gmm
        print()

    if plot:
        bic = np.array(bic)
        color_iter = itertools.cycle(
            ["navy", "turquoise", "cornflowerblue", "darkorange"]
        )
        bars = []

        # Plot the BIC scores
        plt.figure(figsize=(8, 6))
        spl = plt.subplot(2, 1, 1)
        for i, (cv_type, color) in enumerate(zip(cv_types, color_iter)):
            xpos = np.array(n_components_range) + 0.2 * (i - 2)
            bars.append(
                plt.bar(
                    xpos,
                    bic[
                        i * len(n_components_range) : (i + 1) * len(n_components_range)
                    ],
                    width=0.2,
                    color=color,
                )
            )

        plt.xticks(n_components_range)
        plt.ylim([bic.min() * 1.01 - 0.01 * bic.max(), bic.max()])
        plt.title("BIC score per model")
        xpos = (
            np.mod(bic.argmin(), len(n_components_range))
            + 0.65
            + 0.2 * np.floor(bic.argmin() / len(n_components_range))
        )
        plt.text(xpos, bic.min() * 0.97 + 0.03 * bic.max(), "*", fontsize=14)
        spl.set_xlabel("Number of components")
        spl.legend([b[0] for b in bars], cv_types)

        plt.show()

    print("Best CV Type: {}".format(best_cv_type))
    print("Best Number of Components: {}".format(best_n_components))

    best_gmm.fit(expr_data_norm)
    y = best_gmm.predict(expr_data_norm)

    for i in set(y):
        cluster_i = expr_data.T.iloc[np.where(y == i)[0], :]
        cluster_i.to_csv("cluster/cluster{}.csv".format(i + 1))


if __name__ == "__main__":
    data = pd.read_csv(sys.argv[1], index_col=0)
    data_only_expr = data.T

    num_clusters = 50
    clusters = np.arange(1, num_clusters + 1)
    cluster(data_only_expr, clusters, plot=True)
