import json
import logging
import os
from collections import Counter
from sys import argv

import coloredlogs
import funcy
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import ot
import pandas as pd
import scipy as sp
from imblearn.under_sampling import ClusterCentroids
from pgmpy.models.MarkovModel import MarkovModel
from scipy import interp
from sklearn import svm
from sklearn.decomposition import PCA
from sklearn.metrics import auc, roc_curve

state = 1234

logger = logging.getLogger(__name__)

coloredlogs.install(
    logger=logger,
    stream=open(os.path.join("prob_dist", "{}.log".format("gromov-wasserstein")), "a+"),
)


def get_data(fname, mapping=None, index_col=0, exclude_col=["stage"]):
    df = pd.read_csv(fname, index_col=index_col)
    cols = [col for col in df.columns if col not in exclude_col]
    df = df[cols]

    X = df.iloc[:, :-1]
    y = df.iloc[:, -1].values
    if mapping is not None:
        y = np.array([mapping[c] for c in y])

    return X, y


class GraphicalModel:
    def __init__(self, graphml):
        self.graph = nx.read_graphml(graphml)
        self.model = MarkovModel()
        self.model.add_edges_from(self.graph.edges())

    def get_independencies(self):
        result = []
        for i in self.model.get_local_independencies().independencies:
            # a _|_ c | b (graph: a-b, b-c)
            variable, independent, given = i.get_assertion()
            variable, independent, given = (
                list(variable)[0],
                list(independent),
                list(given),
            )
            result.append(
                {"variable": variable, "independent": independent, "given": given}
            )

        return result

    def get_markov_blankets(self):
        result = []
        for i in self.model.nodes():
            neighbors = list(self.model.markov_blanket(i))
            result.append({"variable": i, "markov_blanket": neighbors})

        return result


class Classification:
    def __init__(self, X, y, minority, majority):
        self.X = X
        self.y = y
        self.minority = minority
        self.majority = majority

    def under_sample(self):
        label_minority, num_minority = self.minority
        label_majority, num_majority = self.majority

        cc = ClusterCentroids(ratio={label_majority: num_minority})

        X_cc, y_cc = cc.fit_sample(self.X, self.y)

        print("Original Data...")
        print("X: {} || y: {}".format(self.X.shape, self.y.shape))

        print("Under Sampling... (Cluster Centroids)")
        print("X_cc: {} || y_cc: {}".format(X_cc.shape, y_cc.shape))
        return X_cc, y_cc

    def classify(self, classifier, data, label, k=5):
        cv = KFold(n_splits=k, shuffle=True, random_state=state)

        tprs = []
        aucs = []
        mean_fpr = np.linspace(0, 1, 100)

        i = 0
        for train, test in cv.split(data, label):
            probas_ = classifier.fit(data[train], label[train]).predict_proba(
                data[test]
            )
            # Compute ROC curve and area the curve
            fpr, tpr, thresholds = roc_curve(label[test], probas_[:, 1])
            tprs.append(interp(mean_fpr, fpr, tpr))
            tprs[-1][0] = 0.0
            roc_auc = auc(fpr, tpr)
            aucs.append(roc_auc)
            plt.plot(
                fpr,
                tpr,
                lw=1,
                alpha=0.3,
                label="ROC fold %d (AUC = %0.2f)" % (i + 1, roc_auc),
            )

            i += 1

        plt.plot(
            [0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8
        )

        mean_tpr = np.mean(tprs, axis=0)
        mean_tpr[-1] = 1.0
        mean_auc = auc(mean_fpr, mean_tpr)
        std_auc = np.std(aucs)
        plt.plot(
            mean_fpr,
            mean_tpr,
            color="b",
            label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
            lw=2,
            alpha=0.8,
        )

        std_tpr = np.std(tprs, axis=0)
        tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
        tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
        plt.fill_between(
            mean_fpr,
            tprs_lower,
            tprs_upper,
            color="grey",
            alpha=0.2,
            label=r"$\pm$ 1 std. dev.",
        )

        plt.xlim([-0.05, 1.05])
        plt.ylim([-0.05, 1.05])
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("Receiver Operating Characteristic")
        plt.legend(loc="lower right")
        plt.show()

    def plot_2d_space(self, X, y, label="Classes"):
        colors = ["#1F77B4", "#FF7F0E"]
        markers = ["o", "s"]
        for l, c, m in zip(np.unique(y), colors, markers):
            plt.scatter(X[y == l, 0], X[y == l, 1], c=c, label=l, marker=m)
        plt.title(label)
        plt.legend(loc="upper right")
        plt.show()

    def calc_EMD(self, X, y, fname, plot=True):
        X_minority = X[y == self.minority[0]]
        X_majority = X[y == self.majority[0]]

        C_minority = sp.spatial.distance.cdist(
            X_minority, X_minority, metric="seuclidean"
        )
        C_majority = sp.spatial.distance.cdist(
            X_majority, X_majority, metric="seuclidean"
        )

        C_minority /= C_minority.max()
        C_majority /= C_majority.max()

        if plot:
            plt.figure()
            plt.subplot(121)
            plt.imshow(C_minority)
            plt.subplot(122)
            plt.imshow(C_majority)
            plt.savefig("{}_distance_kernels.png".format(fname))

        p = ot.unif(X_minority.shape[0])
        q = ot.unif(X_majority.shape[0])

        gw0, log0 = ot.gromov.gromov_wasserstein(
            C_minority, C_majority, p, q, "square_loss", verbose=False, log=True
        )

        gw, log = ot.gromov.entropic_gromov_wasserstein(
            C_minority,
            C_majority,
            p,
            q,
            "square_loss",
            epsilon=5e-4,
            log=True,
            verbose=False,
        )

        if plot:
            plt.figure(1, (10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(gw0, cmap="jet")
            plt.title("Gromov Wasserstein")
            plt.subplot(1, 2, 2)
            plt.imshow(gw, cmap="jet")
            plt.title("Entropic Gromov Wasserstein")
            plt.savefig("{}_wasserstein.png".format(fname))

        logger.info("Gromov-Wasserstein distances: {}".format(log0["gw_dist"]))
        logger.info("Entropic Gromov-Wasserstein distances: {}".format(log["gw_dist"]))

        return log0["gw_dist"], log["gw_dist"]


def prep_data(fname, mapping):
    X, y = get_data(fname, mapping)

    (label_class_0, num_class_0), (label_class_1, num_class_1) = Counter(y).items()

    minority = None
    majority = None
    if num_class_0 < num_class_1:
        minority = (label_class_0, num_class_0)
        majority = (label_class_1, num_class_1)
    else:
        minority = (label_class_1, num_class_1)
        majority = (label_class_0, num_class_0)

    return X, y, minority, majority


def do_classification(args):
    classification = Classification(X, y, minority, majority)
    X_cc, y_cc = classification.under_sample()

    classifier = svm.SVC(kernel="linear", probability=True, random_state=state)
    classification.classify(classifier, X_cc, y_cc)

    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(X_cc)
    classification.plot_2d_space(X_pca, y_cc, "Balanced dataset (2 PCA components)")


def do_prob_similarity(args, mapping=None):
    def calc_local_indep_bindings(X, y):
        raise NotImplementedError

    def calc_markov_blanket_bindings(X, y):
        raise NotImplementedError

    def calc_only_mirna_binding(X, y):
        raise NotImplementedError

    X, y, minority, majority = prep_data(args[0], mapping)
    classification = Classification(X, y, minority, majority)

    with open(args[1]) as gene_json:
        gene_indep = json.load(gene_json)

    with open(args[2]) as mirna_json:
        mirna_indep = json.load(mirna_json)

    evidence = pd.read_csv(args[3], index_col=0)
    report = []
    for i, mirna in enumerate(mirna_indep):
        print("{}/{}".format(i + 1, len(mirna_indep)))

        given = mirna["given"]
        mirnas = [mirna["variable"]] + given
        mirnas_genes = np.array([])
        for mirna in evidence[mirnas].columns:
            binding = evidence[mirna][np.where(evidence[mirna] == 1)[0]].index
            mirnas_genes = np.concatenate((mirnas_genes, binding))
        mirnas_genes = list(set(mirnas_genes))
        mirnas_genes = [
            mirna_genes for mirna_genes in mirnas_genes if mirna_genes in X.columns
        ]

        wasserstein, ent_wasserstein = classification.calc_EMD(
            X[mirnas_genes], y, "prob_dist/{}".format(mirna)
        )

        report.append(
            {
                "mirnas": mirnas,
                "genes": mirnas_genes,
                "wassertein": wasserstein,
                "ent_wassertein": ent_wasserstein,
            }
        )

    export_json(report, args[4])


def do_get_independencies(args):
    gm = GraphicalModel(args[0])
    lst = gm.get_independencies()

    export_json(lst, args[1])


def do_get_markov_blankets(args):
    gm = GraphicalModel(args[0])
    lst = gm.get_markov_blankets()

    export_json(lst, args[1])


def add_links(ranks, cutoff, model, evidence):
    """
    :param ranks: json file
    :param cutoff: threshold
    :param model: graph model
    :returns: graph with new links
    :rtype: nx.Graph
    """
    with open(ranks) as fname:
        ranks = json.load(fname)

    ranks_filtered = funcy.select(lambda p: p["ent_wassertein"] >= cutoff, ranks)
    ranks_filtered_sorted = sorted(
        ranks_filtered, key=lambda i: i["ent_wassertein"], reverse=True
    )
    top20 = ranks_filtered_sorted[: int(len(ranks_filtered_sorted) * 0.2)]
    top20 = set(funcy.join([i["mirnas"] for i in top20]))

    model = nx.read_graphml(model)
    genes = [n for n in model.nodes() if n[:4] == "ENSG"]
    print(
        "Network model has {} nodes and {} edges".format(
            len(model.nodes()), len(model.edges())
        )
    )
    evidence = pd.read_csv(evidence, index_col=0)
    evidence = evidence[top20]
    mask = [i for i, gene in enumerate(evidence.index) if gene in genes]
    evidence = evidence.iloc[mask, :]

    result = {}
    for mirna in evidence.columns:
        binding = evidence[mirna][np.where(evidence[mirna] == 1)[0]].index
        result[mirna] = binding.values

    for k, v in result.items():
        for gene in v:
            model.add_edge(k, gene)

    print(
        "New network model has {} nodes and {} edges".format(
            len(model.nodes()), len(model.edges())
        )
    )

    return model


def export_json(lst, fname):
    with open(fname, "w") as fname:
        json.dump(lst, fname, indent=4, sort_keys=False)


def main(args):
    mapping = {"PrimaryTumor": 0, "SolidTissueNormal": 1}
    do_prob_similarity(args, mapping)

    do_get_independencies(args)
    do_get_markov_blankets(args)

    X, y, minority, majority = prep_data(args[0], mapping)
    classification = Classification(X, y, minority, majority)
    wasserstein, ent_wasserstein = classification.calc_EMD(
        X, y, "prob_dist/{}".format("all")
    )
    new_network = add_links(args[1], ent_wasserstein, args[2], args[3])
    nx.write_graphml(new_network, args[4])


if __name__ == "__main__":
    args = argv[1:]
    main(args)
