import cando as cnd
import networkx as nx
from node2vec import Node2Vec
import os
import sys

adrcutoff = sys.argv[3]

# static files
cm = '/projects/academic/rams/zmfalls/cando/v2.2/mappings/drugbank-v2.5.tsv'
im = '/projects/academic/rams/wmangion/mappings/v2.2.1/drugbank2ctd-v2.2.tsv'

# changing files, but necessary
dgn = 'genes/dgn0.1-summed-uniprot.tsv'

# binary removable files
pw = 'pathways/human_pathways_UniProt2Reactome'
adr = 'adrs/adrs-snomed+sider+pred1295-top{}-v2.3-10perADR.tsv'.format(adrcutoff)
ppi = 'ppi/ppis-human_uniprot-700-filtered.tsv'

add_pw = 1
add_adr = 1
add_ppi = 1

prot_list = 'data/uniprot_all_STRING-Reactome.txt'
fpl = open(prot_list, 'r')
all_proteins = []
for l in fpl:
    p = l.strip()
    all_proteins.append(p)

fp2u = open('data/uniprot_to_pdb_all.txt', 'r')
pdb2u = {}
u2pdb = {}
for l in fp2u:
    ls = l.strip().split('\t')
    uni = ls[0]
    pdbs = ls[1:]
    if pdbs:
        for pdb in pdbs:
            pdb2u[pdb] = uni
    u2pdb[uni] = pdbs

f = open(ppi, 'r')
p2plist = {}
for l in f:
    ls = l.strip().split('\t')
    p1 = ls[0]
    p2 = ls[1]
    ppo = sorted([p1, p2])
    p2plist[(ppo[0], ppo[1])] = 1
ppi_pairs = list(p2plist.keys())
p2ps = {}
for p1, p2 in ppi_pairs:
    try:
        p2ps[p1].append(p2)
    except KeyError:
        p2ps[p1] = [p2]

fig = open('number-unique-i_scores-compounds-C.txt', 'r')
ignore = []
for l in fig.readlines()[1:]:
    ls = l.strip().split('\t')
    ci = ls[0]
    if ls[5] == 'allicin':
        break
    else:
        ignore.append(ci)
fig.close()

fc2t = open('data/cmpds2targets-db-v2.3.tsv', 'r')
c2t = {}
for l in fc2t.readlines()[1:]:
    ls = l.strip().split('\t')
    c = ls[0]
    uni = ls[3]
    try:
        c2t[c].append(uni)
    except KeyError:
        c2t[c] = [uni]
fc2t.close()

fcp = open('cmpd2protein-v2.3.txt', 'r')
for l in fcp:
    ls = l.strip().split('\t')
    if ls[0] in pdb2u:
        u = pdb2u[ls[0]]
    else:
        u = ls[0]
    for c in ls[1].split(';'):
        if c in c2t:
            if u not in c2t[c]:
                c2t[c].append(u)
        else:
            c2t[c] = [u]
fcp.close()

# parse pathway proteins
fpw = open('pathways/human_pathways_UniProt2Reactome', 'r')
pws = {}
for l in fpw:
    ls = l.strip().split('\t')
    pwy = ls[0]
    prots = ls[1:]
    pws[pwy] = prots

# parse dgn
dgnf = open(dgn, 'r')
ind2prots = {}
ind2name = {}
indname2prots = {}
for l in dgnf:
    ls = l.strip().split('\t')
    ind = ls[0]
    name = ls[1]
    prots = ls[2].split(';')
    ind2prots[ind] = prots
    ind2name[ind] = name
    indname2prots[name] = prots

cando = cnd.CANDO(cm, im, adr_map=adr)#, matrix=mat)

# Booleans for which to include #
compounds = 1
proteins = 1
indications = 1
pathways = add_pw
adverse_drug_reactions = add_adr
protein_protein = add_ppi
compound_indication = 0
compound_protein = 1
compound_adr = add_adr
protein_indication = 1
protein_pathway = add_pw

# START NETWORKX GRAPH #
G = nx.Graph()
node2int = {}
int2node = {}
int2type = {}
current_node_int = 0

edgecounts = {'P-P': 0, 'C-I': 0, 'C-P': 0, 'A-C': 0, 'I-P': 0, 'P-PW': 0, 'G-P': 0}

# proteins nodes
pnodes = []
p2p_edges = []

for p in all_proteins:
    pid = p

    node2int[pid] = current_node_int
    int2node[current_node_int] = pid
    int2type[current_node_int] = 'PROTEIN'
    current_node_int += 1

    pnodes.append(node2int[pid])

# protein-protein edges
for p in all_proteins:
    if p in p2ps:
        ppis = p2ps[p]
        for pp in ppis:
            pint1 = node2int[p]
            pint2 = node2int[pp]
            p2p_edges.append((pint1, pint2))

# add protein nodes and protein-protein edges
if proteins:
    for pn in pnodes:
        G.add_node(pn)
if protein_protein:
    for p2p in p2p_edges:
        G.add_edge(p2p[0], p2p[1])
        edgecounts['P-P'] += 1

# compound nodes and compound-protein edges
cnodes = []
c2t_edges = []
c2p_edges = []
for c in cando.compounds:
    cid = str(c.id_)

    if cid in ignore and cid not in c2t:
        continue

    node2int[cid] = current_node_int
    int2node[current_node_int] = cid
    int2type[current_node_int] = 'COMPOUND'
    current_node_int += 1

    cnodes.append(node2int[cid])

    # cmpd to target from DrugBank and pre-generated, normalized BANDOCK scores
    if cid in c2t:
        targets = c2t[cid]
        cint = node2int[cid]
        for tgt in targets:
            try:
                pint = node2int[tgt]
                if (cint, pint) not in c2t_edges:
                    c2t_edges.append((cint, pint))
            except KeyError:
                continue

# add compound nodes and compound-protein edges (from matrix and targets file)
if compounds:
    for cn in cnodes:
        G.add_node(cn)
if compound_protein:
    for c2p in c2p_edges:
        G.add_edge(c2p[0], c2p[1])
        edgecounts['C-P'] += 1
    for c2t in c2t_edges:
        G.add_edge(c2t[0], c2t[1])
        edgecounts['C-P'] += 1

# indication nodes, compound-indication edges, and indication-protein edges
inodes = []
#c2i_edges = []
p2i_edges = []
for ind in ind2prots:
    ind_proteins = ind2prots[ind]

    node2int[ind] = current_node_int
    int2node[current_node_int] = ind
    int2type[current_node_int] = 'INDICATION'
    current_node_int += 1

    inodes.append(node2int[ind])

    # no longer need c2i
    #for c in ind.compounds:
    #    if not names:
    #        iint = node2int[indid]
    #        cint = node2int[c.name]
    #    else:
    #        iint = ind.name
    #        cint = c.name
    #    c2i_edges.append((cint, iint))

    for p in ind_proteins:
        iint = node2int[ind]
        try:
            pint = node2int[p]
            p2i_edges.append((pint, iint))
        except KeyError:
            continue

# add indication nodes and compound-indication edges
if indications:
    for inode in inodes:
        G.add_node(inode)
#if compound_indication:
#    for c2i in c2i_edges:  # seems to cause errors/file size explosion if > 250
#        G.add_edge(c2i[0], c2i[1])
#        edgecounts['C-I'] += 1
if protein_indication:
    for p2i in p2i_edges:
        G.add_edge(p2i[0], p2i[1])
        edgecounts['I-P'] += 1

# adr nodes, adr-compound edges
anodes = []
c2a_edges = []
for adr in cando.adrs:
    aid = adr.id_

    if len(adr.compounds) == 0:
        continue

    node2int[aid] = current_node_int
    int2node[current_node_int] = aid
    int2type[current_node_int] = 'ADR'
    current_node_int += 1

    anodes.append(node2int[aid])

    for c in adr.compounds:
        aint = node2int[adr.id_]
        try:
            cint = node2int[str(c.id_)]
            c2a_edges.append((cint, aint))
        except KeyError:
            continue

# add adr nodes and compound-adr edges
if adverse_drug_reactions:
    for anode in anodes:
        G.add_node(anode)
if compound_adr:
    for c2a in c2a_edges:
        G.add_edge(c2a[0], c2a[1])
        edgecounts['A-C'] += 1

# pathway nodes and pathway-protein edges
pwnodes = []
pw2p_edges = []
# structure/model proteins
for pw in pws:
    pwid = pw

    node2int[pwid] = current_node_int
    int2node[current_node_int] = pwid
    int2type[current_node_int] = 'PATHWAY'
    current_node_int += 1

    pwnodes.append(node2int[pwid])

    for p in pws[pw]:
        pint = node2int[p]
        pwint = node2int[pwid]

        pw2p_edges.append((pwint, pint))

# add pathway nodes and pathway-protein edges
if pathways:
    for pwn in pwnodes:
        G.add_node(pwn)
if protein_pathway:
    for pw2p in pw2p_edges:
        G.add_edge(pw2p[0], pw2p[1])
        edgecounts['P-PW'] += 1


stats = 1
if stats:
    fo = open('network-stats-new/complete-adrs.txt', 'w')

    fo.write('# nodes: {}\n'.format(G.number_of_nodes()))
    dd = {'PROTEIN': 0, 'GO-TERM': 0, 'COMPOUND': 0, 'INDICATION': 0, 'ADR': 0,
          'PATHWAY': 0, }

    for i in int2type.keys():
        typ = int2type[i]
        dd[typ] += 1

    fo.write('\tproteins\t{}\n'.format(dd['PROTEIN']))
    fo.write('\tcompounds\t{}\n'.format(dd['COMPOUND']))
    fo.write('\tindications\t{}\n'.format(dd['INDICATION']))
    fo.write('\tADRs\t{}\n'.format(dd['ADR']))
    fo.write('\tpathways\t{}\n'.format(dd['PATHWAY']))
    fo.write('\tGO-terms\t{}\n\n'.format(dd['GO-TERM']))

    fo.write('# edges: {}\n'.format(G.number_of_edges()))
    fo.write('\tprotein-protein\t{}\n'.format(edgecounts['P-P']))
    fo.write('\tcompound-protein\t{}\n'.format(edgecounts['C-P']))
    fo.write('\tcompound-indication\t{}\n'.format(edgecounts['C-I']))
    fo.write('\tindication-protein\t{}\n'.format(edgecounts['I-P']))
    fo.write('\tcompound-ADR\t{}\n'.format(edgecounts['A-C']))
    fo.write('\tprotein-pathway\t{}\n'.format(edgecounts['P-PW']))
    fo.write('\tprotein-GO\t{}\n'.format(edgecounts['G-P']))
    fo.close()

    fo2 = open('graphs-new/nodes-complete-adrs.tsv', 'w')
    fo2.write('n_int\tn_id\tn_type\n')
    for ni in G.nodes():
        name = int2node[ni]
        typ = int2type[ni]
        fo2.write('{}\t{}\t{}\n'.format(ni, name, typ))
    fo2.close()

    fo3 = open('graphs-new/edges-complete-adrs.tsv', 'w')
    fo3.write('n_int1\tn_int2\tn_id1\tn_id2\tn_type1\tn_type2\n')
    for n1, n2 in G.edges():
        name1 = int2node[n1]
        name2 = int2node[n2]
        typ1 = int2type[n1]
        typ2 = int2type[n2]
        fo3.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(n1, n2, name1, name2, typ1, typ2))
    fo3.close()

# start node2vec embeddings #
num_walks = int(sys.argv[1])
wl = int(sys.argv[2])

node2vec = Node2Vec(G, dimensions=128, walk_length=wl, num_walks=num_walks, workers=16, quiet=True, temp_folder='tmp-n2v/')
model = node2vec.fit(window=10, min_count=1, batch_words=4)

if not os.path.exists("embeddings-complete-adrs{}/".format(adrcutoff)):
    os.mkdir("embeddings-complete-adrs{}/".format(adrcutoff))
model.wv.save_word2vec_format("embeddings-complete-adrs{}/embed-{}-{}".format(adrcutoff, num_walks, wl))









