import os
import random
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_decomposition import CCA

cma = 'mappings/modified-drugbank-v2.5-approved.tsv'
cids = {}
f = open(cma, 'r')
lines = f.readlines()
for l in lines[1:]:
    ls = l.strip().split('\t')
    cids[ls[0]] = []
f.close()

#fg = open('/projects/academic/rams/wmangion/mappings/v2.2.1/group_disease-top_level.tsv', 'r')
fg = open('mappings/group_disease-top_level.tsv', 'r')
m2g = {}
g2ms = {}
g2cids = {}
for l in fg:
    ls = l.strip().split('\t')
    try:
        m2g[ls[0]].append(ls[1])
    except KeyError:
        m2g[ls[0]] = [ls[1]]
    try:
        g2ms[ls[1]].append(ls[0])
    except KeyError:
        g2ms[ls[1]] = [ls[0]]
    g2cids[ls[1]] = []
fg.close()

fm = open('mappings/drugbank2ctd-v2.2.tsv', 'r')
mesh2cids = {}
cid2ms = {}
cid2gs = {}
for l in fm.readlines()[1:]:
    ls = l.strip().split('\t')
    cid = ls[0]
    mesh = ls[2]
    if cid not in cids:
        continue
    try:
        mesh2cids[mesh].append(cid)
    except KeyError:
        mesh2cids[mesh] = [cid]
    try:
        cid2ms[cid].append(mesh)
    except KeyError:
        cid2ms[cid] = [mesh]
    if mesh in m2g:
        gs = m2g[mesh]
        for g in gs:
            try:
                cid2gs[cid].append(g)
            except KeyError:
                cid2gs[cid] = [g]
            if cid not in g2cids[g]:
                g2cids[g].append(cid)
fm.close()

fp2u = open('data/uniprot_to_pdb_all.txt', 'r')
pdb2u = {}
u2pdb = {}
for l in fp2u:
    ls = l.strip().split('\t')
    uni = ls[0]
    pdbs = ls[1:]
    if pdbs:
        for pdb in pdbs:
            pdb2u[pdb] = uni
    u2pdb[uni] = pdbs
fp2u.close()

fc2t = open('data/cmpds2targets-db-v2.3.tsv', 'r')
targets = {}
for l in fc2t.readlines()[1:]:
    ls = l.strip().split('\t')
    c = ls[0]
    uni = ls[3]
    targets[uni] = 1
fc2t.close()

fcp = open('data/cmpd2protein-v2.3.txt', 'r')
for l in fcp:
    ls = l.strip().split('\t')
    if ls[0] in pdb2u:
        u = pdb2u[ls[0]]
    else:
        u = ls[0]
    targets[u] = 1
fcp.close()

f = open('data/pws-targets-dists.tsv', 'r')
filt = {}
for l in f.readlines()[1:]:
    ls = l.strip().split('\t')
    pw = ls[0]
    npr = int(ls[1])
    pc = int(ls[2])
    nc = int(ls[3])
    if not 5 <= npr < 250:
        filt[pw] = 1
    if pc < 2:
        filt[pw] = 1
    if not 130 <= nc <= 4317:
        filt[pw] = 1
f.close()

print('N pathways =', 2219 - len(filt))

cid2pwds = {}
pw2direct = {}
pwc = 0
tally = {}
tally2 = {}
fop = open('pw-order-features.txt', 'w')
for dr in os.listdir('shortest-paths/pathway-compound/'):
    pw = dr[:-4]
    if pw in filt:
        continue
    f = open('shortest-paths/pathway-compound/{}'.format(dr), 'r')
    fop.write('{}\n'.format(pw))
    for l in f:
        ls = l.strip().split('\t')
        d = int(ls[1])
        cid = ls[0]
        if cid not in cids:
            continue
        if cid in mesh2cids['MESH:D009362']:
            if pw in ['R-HSA-156581', 'R-HSA-9018677', 'R-HSA-9018678', 'R-HSA-5423646', 'R-HSA-9018682', 'R-HSA-2142670',
                      'R-HSA-2142691', 'R-HSA-9027307', 'R-HSA-77289']:
                if d <= 3:
                    try:
                        tally[pw] += 1
                    except KeyError:
                        tally[pw] = 1
        if cid in mesh2cids['MESH:D009362']:
            if pw in ['R-HSA-390651', 'R-HSA-375280', 'R-HSA-1296071', 'R-HSA-390666', 'R-HSA-211958']:
                if d <= 3:
                    try:
                        tally2[pw] += 1
                    except KeyError:
                        tally2[pw] = 1
        try:
            cid2pwds[cid].append(d)
        except KeyError:
            cid2pwds[cid] = [d]
    f.close()
fop.close()
print(tally)
print(tally2)

# OVERALL IDEA - train rf models predicting drug-indication associations using the "pathway distance" values
# as the features for each compound. Then, extract the feature importance for the pathways after. Finally,
# tally how often the pathway is an "important" feature for a compound based on how often that compound is
# associated to an indication with that pathway as an important feature. Interpretation: this pathway is
# important for predicting these if this drug will treat this indication

meshs50 = []
for mesh in mesh2cids:
    if len(mesh2cids[mesh]) >= 50:
        meshs50.append(mesh)
meshs50 = sorted(meshs50)

for iteration in range(10):
    for query_mesh in meshs50:
        if len(mesh2cids[query_mesh]) >= 50:
            positives = []
            chosen_positives = []
            pos = {}
            for cid in mesh2cids[query_mesh]:
                if cid in cid2pwds:
                    positives.append(cid2pwds[cid])
                    chosen_positives.append(cid)
                    pos[cid] = 1

            nall = len(cids)
            npos = len(positives)
            negatives = []
            chosen_negatives = []
            neg = {}
            cs = list(cid2pwds.keys())
            random.shuffle(cs)
            for cid in cs:
                if len(negatives) == npos:
                    break
                try:
                    y = pos[cid]
                    if query_mesh in m2g:
                        gs = m2g[query_mesh]
                        for g in gs:
                            if cid in g2cids[g]:
                                continue
                except KeyError:
                    negatives.append(cid2pwds[cid])
                    chosen_negatives.append(cid)
                    neg[cid] = 1

            #print(query_mesh, len(mesh2cids[query_mesh]), len(positives), len(negatives))

            split = int(0.9 * npos)
            random.shuffle(positives)
            random.shuffle(negatives)

            train_pos = positives[0:split]
            test_pos = positives[split:]
            train_neg = negatives[0:split]
            test_neg = negatives[split:]

            #print(len(train_pos), len(train_neg), len(test_pos), len(test_neg))

            X = train_pos + train_neg
            Y = [1] * len(train_pos) + [0] * len(train_neg)

            rf = RandomForestClassifier(n_estimators=1000)
            rf.fit(X, Y)
            preds = rf.predict_proba(test_pos+test_neg)
            #print(rf.feature_importances_)

            fo1 = open('ind-feature-importances-rf/{}-{}.txt'.format(query_mesh[5:], iteration), 'w')
            for fi in rf.feature_importances_:
                fo1.write('{}\n'.format(fi))
            fo1.close()

            fo1 = open('ind-probabilities-rf/{}-{}.txt'.format(query_mesh[5:], iteration), 'w')
            for prob in preds:
                fo1.write('{}\n'.format(prob[0]))
            fo1.close()

            tp = 0
            fn = 0
            for pr in preds[0:len(test_pos)]:
                if pr[0] >= 0.5:
                    fn += 1
                else:
                    tp += 1
            tn = 0
            fp = 0
            for pr in preds[len(test_pos):]:
                if pr[0] >= 0.5:
                    tn += 1
                else:
                    fp += 1

            fo1 = open('ind-performance-rf/{}-{}.txt'.format(query_mesh[5:], iteration), 'w')
            fo1.write('TP\t{}\n'.format(tp))
            fo1.write('FP\t{}\n'.format(fp))
            fo1.write('FN\t{}\n'.format(fn))
            fo1.write('TN\t{}\n'.format(tn))
            fo1.write('ACC\t{}\n'.format(round((tp+tn) / (tp + fp + fn + tn), 2)))
            fo1.close()

