##############################################################################
# Prerequisites for all functions                                            #
##############################################################################

from subsignature import *
from compbench import *
import argparse, math

# If using command-line interface, set RUN_THROUGH = CLI
# If running commands written in file, set RUN_THROUGH = FILE
CLI = 0
FILE = 1
RUN_THROUGH = CLI

##############################################################################
# Subsignature functions                                                     #
##############################################################################

def subsig_log_weighted(indicated, non_indicated, cmpd_scores,
                        GO_to_index, cmpd_to_index):
  '''
  CANDO subsignature calculates ranks based on distance to indicated cluster
    distances are weighted by the negative log of cluster centrality

  indicated - set of str, cmpds with some commonality
    ranked output is based on similarity to the indicated group
  non_indicated - set of str, cmpds to be ranked; mutually excl w/indicated
    may include indicated cmpds not included in indicated set for benchmarking
  cmpd_scores - matrix (list of list) containing all subsignatures (Subsig objects)
    subsignatures contain all protein interaction scores corresponding to that
    compound & GO combination
  GO_to_index - dict of str:int, maps GO term to index in cmpd_scores matrix
  cmpd_to_index - dict of str:int, maps compound name to index in cmpd_scores matrix

  Returns ranked list of compounds passed in non_indicated set
  '''

  assert len(indicated.intersection(non_indicated)) == 0,\
         'Overlap found between indicated and non-indicated groups.'

  clusts = generate_clusters(cmpd_scores, GO_to_index,\
                             cmpd_to_index, indicated)

  #print('Ranking compounds...')
  ranks = []

  non_indicated_order = list(non_indicated)
  non_indicated_scores = [0]*len(non_indicated_order)
  
  max_score = max([x.get_score()/x.get_num_prots() for x in clusts])
  num_cmpds = clusts[0].get_num_cmpds()
  
  for GO in GO_to_index:
    subsigs = []
    j = GO_to_index[GO]

    for cmpd in non_indicated_order:
      i = cmpd_to_index[cmpd]
      subsigs.append(cmpd_scores[i][j])

    scores = clusts[j].get_dists(subsigs)
    clust_score = clusts[j].get_score()
    if max_score > 0:
      num_prots = clusts[j].get_num_prots()
      calc = max((clust_score/num_prots)/max_score, 0.000000001)
      clust_weight = -math.log(calc,2)
    else:
      clust_weight = 1

    for k in range(len(non_indicated_scores)):
      non_indicated_scores[k] += (scores[k]*clust_weight)
      
  ranks = [(non_indicated_order[k], non_indicated_scores[k])
            for k in range(len(non_indicated_order))]

  for i in range(len(clusts)-1,-1,-1):
    del clusts[i]

  ranks.sort(key=lambda x: x[1])
  ranks = [x[0] for x in ranks]
  
  return ranks

def subsig_unweighted(indicated, non_indicated, cmpd_scores,
                        GO_to_index, cmpd_to_index):
  '''
  CANDO subsignature calculates ranks based on distance to indicated cluster
    distances are summed with equal weight

  indicated - set of str, cmpds with some commonality
    ranked output is based on similarity to the indicated group
  non_indicated - set of str, cmpds to be ranked; mutually excl w/indicated
    may include indicated cmpds not included in indicated set for benchmarking
  cmpd_scores - matrix (list of list) containing all subsignatures (Subsig objects)
    subsignatures contain all protein interaction scores corresponding to that
    compound & GO combination
  GO_to_index - dict of str:int, maps GO term to index in cmpd_scores matrix
  cmpd_to_index - dict of str:int, maps compound name to index in cmpd_scores matrix

  Returns ranked list of compounds passed in non_indicated set
  '''
  assert len(indicated.intersection(non_indicated)) == 0,\
         'Overlap found between indicated and non-indicated groups.'

  clusts = generate_clusters(cmpd_scores, GO_to_index,\
                             cmpd_to_index, indicated)

  #print('Ranking compounds...')
  ranks = []

  non_indicated_order = list(non_indicated)
  non_indicated_scores = [0]*len(non_indicated_order)
  
  max_score = max([x.get_score()/x.get_num_prots() for x in clusts])
  num_cmpds = clusts[0].get_num_cmpds()
  
  for GO in GO_to_index:
    subsigs = []
    j = GO_to_index[GO]

    for cmpd in non_indicated_order:
      i = cmpd_to_index[cmpd]
      subsigs.append(cmpd_scores[i][j])

    scores = clusts[j].get_dists(subsigs)
    clust_score = clusts[j].get_score()

    for k in range(len(non_indicated_scores)):
      non_indicated_scores[k] += (scores[k])
      
  ranks = [(non_indicated_order[k], non_indicated_scores[k])
            for k in range(len(non_indicated_order))]

  for i in range(len(clusts)-1,-1,-1):
    del clusts[i]

  ranks.sort(key=lambda x: x[1])
  ranks = [x[0] for x in ranks]
  
  return ranks

def subsig_25_weighted(indicated, non_indicated, cmpd_scores,
                        GO_to_index, cmpd_to_index):
  '''
  CANDO subsignature calculates ranks based on distance to indicated cluster
    only the 25 most central clusters are included in summed distance

  indicated - set of str, cmpds with some commonality
    ranked output is based on similarity to the indicated group
  non_indicated - set of str, cmpds to be ranked; mutually excl w/indicated
    may include indicated cmpds not included in indicated set for benchmarking
  cmpd_scores - matrix (list of list) containing all subsignatures (Subsig objects)
    subsignatures contain all protein interaction scores corresponding to that
    compound & GO combination
  GO_to_index - dict of str:int, maps GO term to index in cmpd_scores matrix
  cmpd_to_index - dict of str:int, maps compound name to index in cmpd_scores matrix

  Returns ranked list of compounds passed in non_indicated set
  '''
  assert len(indicated.intersection(non_indicated)) == 0,\
         'Overlap found between indicated and non-indicated groups.'

  clusts = generate_clusters(cmpd_scores, GO_to_index,\
                             cmpd_to_index, indicated)

  #print('Ranking compounds...')
  ranks = []

  non_indicated_order = list(non_indicated)
  non_indicated_scores = [0]*len(non_indicated_order)
  
  t25_score = sorted([x.get_score() for x in clusts])[25]
  num_cmpds = clusts[0].get_num_cmpds()
  
  for GO in GO_to_index:
    subsigs = []
    j = GO_to_index[GO]

    for cmpd in non_indicated_order:
      i = cmpd_to_index[cmpd]

      subsigs.append(cmpd_scores[i][j])

    scores = clusts[j].get_dists(subsigs)
    
    num_cmpds = clusts[j].get_num_cmpds()
    clust_score = clusts[j].get_score()

    for k in range(len(non_indicated_scores)):
      if num_cmpds == 1 or clust_score < t25_score:
        non_indicated_scores[k] += scores[k]
      
  ranks = [(non_indicated_order[k], non_indicated_scores[k])
            for k in range(len(non_indicated_order))]

  for i in range(len(clusts)-1,-1,-1):
    del clusts[i]

  ranks.sort(key=lambda x: x[1])
  ranks = [x[0] for x in ranks]
  
  return ranks

##############################################################################
# Test/benchmarking code                                                     #
##############################################################################

fileset = 'TTD'
#FUNC = subsig_log_weighted
#FUNC = subsig_unweighted
FUNC = subsig_25_weighted
out_file = 'subsig_25_weighted_%s' % fileset


if fileset.lower() == 'ttd':
  drug_indic_map = 'data/TTD_approved_drugs_cndonly.tsv'
  prot_scores = 'subsig_data/labelled_v2.5_all_CxP_matrix_alt.tsv'
  cross_val_file = 'data/name_10fold_cross_seed0_TTD.tsv'
elif fileset.lower() == 'ctd':
  drug_indic_map = 'data/ctd_approved_drugs.tsv'
  prot_scores = 'subsig_data/labelled_v2.5_all_CxP_matrix_ctdalt.tsv'
  cross_val_file = 'data/name_10fold_cross_seed0_ctd_cnd.tsv'
else:
  raise ValueError('No fileset ' + str(fileset))

GO_term_file = 'subsig_data/GO_term_levels.tsv'
prot_GO_file = 'subsig_data/prot2GO_uniprot.tsv'
GO_rules = {'level':2}

if __name__ == '__main__':
  # Test code/run in file
  if RUN_THROUGH == FILE:
    
    indic_mapping, cmpd_set = extract_drug_indic_mapping(drug_indic_map, min_d=2)
    cmpd_scores, GO_to_index, cmpd_to_index = \
                 generate_data_structs(prot_scores, prot_GO_file,\
                                  GO_term_file, **GO_rules)
    
    subsig_args = {'cmpd_scores':cmpd_scores, 'GO_to_index':GO_to_index,
          'cmpd_to_index':cmpd_to_index}
    bench_args = {'filename':cross_val_file}

    benchmarking(indic_mapping, cmpd_set, read_strat_xfold_cross, \
               FUNC, bench_args, subsig_args,\
               out_file + '.tsv')

  # Command line interface
  elif RUN_THROUGH == CLI:
    parser = argparse.ArgumentParser(prog='SubSignature Benchmarking',
                                     description='Benchmark subsignature scoring methods')
    # set 1 split to benchmark all indications, 10 to benchmark 1/10th of indications, etc
    parser.add_argument('num_splits', metavar='number of indication splits', type=str)
    # 0-indexed, set 0 to assess 1st of num_splits indication subsets & so on
    parser.add_argument('split_index', metavar='which index split is assessed here', type=int)
     # Note that matrix needs to contain compound labels
    # Ie can't use same matrix as CANDO primary pipeline
    parser.add_argument('-m', '--matrix', metavar='prot-drug matrix filepath', type=str)
    parser.add_argument('-g', '--go', metavar='GO filepath', type=str)
    parser.add_argument('-t', '--trans', metavar='prot2GO filepath', type=str)
    # Other values are read from the fileset
    
    args = parser.parse_args()

    num_splits = int(args.num_splits)
    split_index = int(args.split_index)
    standard, cmpd_set = extract_drug_indic_mapping(drug_indic_map, min_d=2)
    indic_splits = split_indics(standard, num_splits, by='pairs')
    this_split = indic_splits[split_index]

    # Generate data structures and begin benchmarking
    if args.matrix != None:
      prot_scores = args.matrix
    if args.go != None:
      GO_term_file = args.go
    if args.trans != None:
      prot_GO_file = args.trans
      
    cmpd_scores, GO_to_index, cmpd_to_index = \
                 generate_data_structs(prot_scores, prot_GO_file,\
                                  GO_term_file, **GO_rules)
    
    subsig_args = {'cmpd_scores':cmpd_scores, 'GO_to_index':GO_to_index,
          'cmpd_to_index':cmpd_to_index}
    bench_args = {'filename':cross_val_file}

    benchmarking(this_split, cmpd_set, read_strat_xfold_cross, \
               FUNC, bench_args, subsig_args,\
               out_file + str(split_index) + '.tsv')
