from sys import argv

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import powerlaw

plt.style.use("ggplot")

modelfile = argv[1]
diff_genes_file = argv[2]

G = nx.read_graphml(modelfile)
print(nx.info(G))

diff_genes = pd.read_csv(diff_genes_file)
lncrna_nodes = list(diff_genes[diff_genes["group"] == "long_non_coding"]["mrna"].values)
mirna_nodes = [node for node in G.nodes() if node.startswith("hsa-")]
other_nodes = [node for node in G.nodes() if node not in lncrna_nodes + mirna_nodes]

assert len(lncrna_nodes) + len(mirna_nodes) + len(other_nodes) == len(G.nodes())

dd = G.degree()

dd_lncrna = sorted([dd[node] for node in lncrna_nodes], reverse=True)
dd_mirna = sorted([dd[node] for node in mirna_nodes], reverse=True)
dd_other = sorted([dd[node] for node in other_nodes], reverse=True)


def plot_degree(degrees):
    plt.loglog(degrees, "b-", marker="o")
    plt.ylabel("Degree")
    plt.xlabel("Rank")
    plt.show()


def powerlaw_fit(degrees, fig, ax, color, label=""):
    print(label)
    fit = powerlaw.Fit(degrees, xmin=None)
    print(fit.power_law.alpha)
    print(fit.power_law.sigma)
    print(fit.power_law.xmin)
    fit.plot_pdf(ax=ax, color=color)
    fit.power_law.plot_pdf(ax=ax, linestyle="dotted", color=color)
    fit.truncated_power_law.plot_pdf(ax=ax, linestyle="dashed", color=color)
    L, p = fit.distribution_compare("power_law", "lognormal", normalized_ratio=True)

    ax.set_xlabel("Degree k")
    ax.set_ylabel("P(x = k)")
    ax.grid(True, which="both")

    print(fit.distribution_compare("power_law", "lognormal", normalized_ratio=True))
    print(fit.distribution_compare("power_law", "exponential", normalized_ratio=True))

    return fig


def distribution_compare(dd, label):
    from itertools import combinations

    print(label)
    print("=" * 40)
    fit = powerlaw.Fit(dd)
    Ps = list(fit.supported_distributions.keys())
    Rmat = np.zeros((len(Ps), len(Ps)))
    Pmat = np.zeros((len(Ps), len(Ps)))
    for pa, pb in combinations(Ps, 2):
        R, p = fit.distribution_compare(pa, pb, normalized_ratio=True)
        Rmat[Ps.index(pa), Ps.index(pb)] = R
        Rmat[Ps.index(pb), Ps.index(pa)] = -R

        Pmat[Ps.index(pb), Ps.index(pa)] = p

    df = pd.DataFrame(Rmat, columns=Ps, index=Ps)
    pd.options.display.float_format = "{:,.4f}".format
    print(df)

    dfpval = pd.DataFrame(Pmat, columns=Ps, index=Ps)
    print(dfpval)


distribution_compare(dd_lncrna, "lncrna")
distribution_compare(dd_mirna, "mirna")
distribution_compare(dd_other, "protein coding")
