import os
from sys import argv

import networkx as nx

SIGNIFICANT_MIRNA = argv[1]  # significant mirna list
with open(SIGNIFICANT_MIRNA, "r") as fname:
    add_edges = fname.readlines()

lst = [(i.split("\t")[0], i.split("\t")[1].strip()) for i in add_edges]


def add_to_network(G, lst):
    print("Before:")
    print(nx.info(G))

    src = dest = None
    nodes = {node[0]: node[1]["name"] for node in G.nodes(data=True)}

    H = nx.Graph()
    for node in G.nodes(data=True):
        label = node[1]["name"]
        H.add_node(label, name=label)

    for edge in G.edges():
        src = nodes[edge[0]]
        dest = nodes[edge[1]]
        H.add_edge(src, dest, name="{}-{}".format(src, dest))

    for e in lst:
        if e[0].startswith("hsa-"):
            src = e[0]
            dest = e[1]
        else:
            src = e[1]
            dest = e[0]

        if dest in nodes.values():
            H.add_node(src, name=src)
            H.add_edge(src, dest, name="{}-{}".format(src, dest))

    print("After:")
    print(nx.info(H))

    return H


root = argv[2]  # path where graph files resides
S1 = nx.read_graphml(os.path.join(root, "stage1_removed_intersection.graphml"))
S1 = add_to_network(S1, lst)
nx.write_graphml(
    S1,
    os.path.join(root, "stage1_removed_intersection_with_mirnas_degree_le100.graphml"),
)

S2 = nx.read_graphml(os.path.join(root, "stage2_removed_intersection.graphml"))
S2 = add_to_network(S2, lst)
nx.write_graphml(
    S2,
    os.path.join(root, "stage2_removed_intersection_with_mirnas_degree_le100.graphml"),
)

S3 = nx.read_graphml(os.path.join(root, "stage3_removed_intersection.graphml"))
S3 = add_to_network(S3, lst)
nx.write_graphml(
    S3,
    os.path.join(root, "stage3_removed_intersection_with_mirnas_degree_le100.graphml"),
)

S4 = nx.read_graphml(os.path.join(root, "stage4_removed_intersection.graphml"))
S4 = add_to_network(S4, lst)
nx.write_graphml(
    S4,
    os.path.join(root, "stage4_removed_intersection_with_mirnas_degree_le100.graphml"),
)
