##############################################################################
# Prerequisites for all functions                                            #
##############################################################################

import cando as cnd
from compbench import *
import argparse, os, random

# If using command-line interface, set RUN_THROUGH = CLI
# If running commands written in file, set RUN_THROUGH = FILE
CLI = 0
FILE = 1
RUN_THROUGH = FILE

def generate_cando_obj(matrix, cmpds, c2i, cmpd_set, ncpus=1, val='name'):
  '''
  matrix - str, name of protein interaction signature matrix
  cmpds - str, name of drug file
  c2i - str, name of drug-indication mapping file
  cmpd_set - set (or list), set of compounds to be included
  ncpus - int, number of CPUs available for use by CANDO object

  Returns the CANDO object, an indication object, and cmpd_object dict
    Indication object can be used to rapidly create/test indications
  '''
  cando_obj = cnd.CANDO(cmpds, c2i, matrix=matrix,
                    compute_distance=True, dist_metric='cosine',
                    ncpus=ncpus)

  test_indic = cnd.Indication('test_indic', 'Temp Test Indic')
  cando_obj.indications.append(test_indic)
  cando_obj.indication_ids.append('test_indic')
  print(len(cando_obj.compounds))

  cmpd_to_obj = {}
  for cmpd in cmpd_set:
    if type(cmpd) == str and cmpd.isdigit():
      obj = cando_obj.get_compound(int(cmpd), quiet=True)
    else:
      obj = cando_obj.get_compound(cmpd, quiet=True)
    cmpd_to_obj[cmpd] = obj
    
  return cando_obj, test_indic, cmpd_to_obj

##############################################################################
# CANDO primary pipeline variants                                            #
##############################################################################


def cando_all_similar(indicated, non_indicated, cando_obj, test_indic, cmpd_to_obj,
                 filename):
  '''
  CANDO calculates ranks with similarity list cutoff (n) = -1
    (compounds sorted by overall average rank, not consensus score)

  indicated - set of str, cmpds with some commonality
    ranked output is based on similarity to the indicated group
  non_indicated - set of str, cmpds to be ranked; mutually excl w/indicated
    may include indicated cmpds not included in indicated set for benchmarking
  cando_obj - CANDO object, instantiated with correct matrix/drug file
    indication mapping file does not matter b/c will use "indicated" set
  test_indic - Indication object, part of cando_obj
    will be altered to include indicated drugs & then assessed
  cmpd_to_obj - dictionary, str:Compound object
    used to populate test_indic with Compound objects based on indicated set
  filename - str, name of temporary output file (to avoid filename clashing)

  returns ranks - list of str, compounds ordered based on similarity to
    indicated compounds by CANDO
  '''
  assert len(indicated.intersection(non_indicated)) == 0,\
         'Overlap found between indicated and non-indicated groups.'

  test_indic.compounds = [cmpd_to_obj[x] for x in list(indicated)]

  temp_file = 'temp_rank_%s_alt.tsv' % filename
  
  cando_obj.canpredict_compounds('test_indic', topX=-1, n=-1,
                                 consensus=False, save=temp_file)

  with open(temp_file, 'r') as f:
    lines = f.read().strip().split('\n')[1:]

  ranks = []
  for line in lines:
    cells = line.split('\t')
    if type(list(indicated)[0]) == str and list(indicated)[0].isdigit():
      cmpd = cells[4]
    else:
      cmpd = cells[6]

    if cmpd in non_indicated:
      ranks.append(cmpd)

  #os.remove(temp_file) # Comment this out to view the file generated
  
  return ranks

def cando_multiple_lists(indicated, non_indicated, cando_obj, test_indic, cmpd_to_obj,
                 filename):
  '''
  CANDO calculates ranks with similarity list cutoff (n) = {10, 20, 30...}
    sorted by consensus score & average rank above similarity list cutoff
    compounds retrieved at later/greater similarity list cutoffs ranked after
    those retrieved at earlier/smaller similarity list cutoffs

  indicated - set of str, cmpds with some commonality
    ranked output is based on similarity to the indicated group
  non_indicated - set of str, cmpds to be ranked; mutually excl w/indicated
    may include indicated cmpds not included in indicated set for benchmarking
  cando_obj - CANDO object, instantiated with correct matrix/drug file
    indication mapping file does not matter b/c will use "indicated" set
  test_indic - Indication object, part of cando_obj
    will be altered to include indicated drugs & then assessed
  cmpd_to_obj - dictionary, str:Compound object
    used to populate test_indic with Compound objects based on indicated set
  filename - str, name of temporary output file (to avoid filename clashing)

  returns ranks - list of str, compounds ordered based on similarity to
    indicated compounds by CANDO
  '''
  assert len(indicated.intersection(non_indicated)) == 0,\
         'Overlap found between indicated and non-indicated groups.'

  test_indic.compounds = [cmpd_to_obj[x] for x in list(indicated)]

  temp_file = 'temp_rank_%s.tsv' % filename
  n = 10
  ranks = []
  ranked = set()
  #print('Benchmarking %d cmpd indication...' % len(indicated))
  try:
    while len(non_indicated - ranked) > 0:
      
      cando_obj.canpredict_compounds('test_indic', topX=-1, n=n,
                                     consensus=False,
                                     save=temp_file)

      with open(temp_file, 'r') as f:
        lines = f.read().strip().split('\n')[1:]

      for line in lines:
        cells = line.split('\t')
        if type(list(indicated)[0]) == str and list(indicated)[0].isdigit():
          cmpd = cells[4]
        else:
          cmpd = cells[6]

        if cmpd in non_indicated and cmpd not in ranked:
          ranks.append(cmpd)
          ranked.add(cmpd)

      n += 10
  except IndexError:
    # If something is not captured when iterating in units of 10
    # ie only captured in last 9 n values
    unranked = list(non_indicated - ranked)
    random.shuffle(unranked)
    random.shuffle(unranked)
    random.shuffle(unranked)

    ranks += unranked

  os.remove(temp_file) # Comment this out to view the file generated
  
  return ranks

def cando_ten_similar(indicated, non_indicated, cando_obj, test_indic, cmpd_to_obj,
                 filename):
  '''
  CANDO calculates ranks with similarity list cutoff (n) = 10
    sorted by consensus score & average rank above similarity list cutoff
    compounds not retrieved w/n=10 ordered based on overall average rank

  indicated - set of str, cmpds with some commonality
    ranked output is based on similarity to the indicated group
  non_indicated - set of str, cmpds to be ranked; mutually excl w/indicated
    may include indicated cmpds not included in indicated set for benchmarking
  cando_obj - CANDO object, instantiated with correct matrix/drug file
    indication mapping file does not matter b/c will use "indicated" set
  test_indic - Indication object, part of cando_obj
    will be altered to include indicated drugs & then assessed
  cmpd_to_obj - dictionary, str:Compound object
    used to populate test_indic with Compound objects based on indicated set
  filename - str, name of temporary output file (to avoid filename clashing)

  returns ranks - list of str, compounds ordered based on similarity to
    indicated compounds by CANDO
  '''
  assert len(indicated.intersection(non_indicated)) == 0,\
         'Overlap found between indicated and non-indicated groups.'

  test_indic.compounds = [cmpd_to_obj[x] for x in list(indicated)]

  #print('x', end='') # used to track progress

  n = 10

  sorted_x = cando_obj.canpredict_compounds('test_indic', topX=-1, n=n,
                                 consensus=False, return_all=True)
  if list(indicated)[0].isdigit():
    ranks = [str(cando_obj.get_compound(x[0]).id_) for x in sorted_x]
  else:
    ranks = [cando_obj.get_compound(x[0]).name for x in sorted_x]
  ranks = [x for x in ranks if x in non_indicated]

  return ranks

##############################################################################
# Test/benchmarking code                                                     #
##############################################################################

  
fileset = 'CTD'
#FUNC = cando_all_similar
FUNC = cando_ten_similar
#FUNC = cando_multiple_lists
out_name = 'cando_ten_similar_%s' % fileset

if fileset.lower() == 'ttd':
  ctd_file = 'cando_data/drugbank2ctd-v2.5.tsv' # Req for CANDO obj, not used
  drug_file = 'cando_data/drugbank-v2.5_alt.tsv'
  matrix_file = 'cando_data/unlabelled_v2.5_all_CxP_matrix_alt.tsv'
  drug_indic_map = 'data/TTD_approved_drugs_cndonly.tsv'
  cross_val_file = 'data/name_10fold_cross_seed0_TTD.tsv'
  key,val='iname','dname'
elif fileset.lower() == 'ctd':
  ctd_file = 'cando_data/drugbank2ctd-v2.5.tsv'
  drug_file = 'cando_data/drugbank-v2.5_ctdalt.tsv' # no unapproved/unindicated drugs
  matrix_file = 'cando_data/unlabelled_v2.5_all_CxP_matrix_ctdalt.tsv'
  drug_indic_map = 'data/ctd_approved_drugs.tsv'
  #cross_val_file = 'data/id_10fold_cross_seed0_ctd_cnd.tsv'
  cross_val_file = 'data/name_10fold_cross_seed0_ctd_cnd.tsv'
  #key,val='iid','did'
  key,val='iname','dname'
else:
  raise ValueError('No fileset ' + str(fileset))

if __name__ == '__main__':
  # Test code/run in file
  if RUN_THROUGH == FILE:
    indic_mapping, cmpd_set = extract_drug_indic_mapping(drug_indic_map, min_d=2,key=key, val=val)

    cando_obj, test_indic, cmpd_to_obj = generate_cando_obj(matrix_file,
                                                            drug_file,
                                                            ctd_file, cmpd_set)

    cando_args = {'cando_obj':cando_obj, 'test_indic':test_indic,\
                    'cmpd_to_obj':cmpd_to_obj, 'filename':'0'}
    bench_args = {'filename':cross_val_file}

    benchmarking(indic_mapping, cmpd_set, read_strat_xfold_cross, \
                 FUNC, bench_args, cando_args, out_name + '.tsv')
    
  # Command line interface
  elif RUN_THROUGH == CLI:
    parser = argparse.ArgumentParser(prog='CANDO Benchmarking',
                                     description='Benchmark subsignature scoring methods')
    # set 1 split to benchmark all indications, 10 to benchmark 1/10th of indications, etc
    parser.add_argument('num_splits', metavar='number of indication splits', type=str)
    # 0-indexed, set 0 to assess 1st of num_splits indication subsets & so on
    parser.add_argument('split_index', metavar='which index split is assessed here', type=int)
    parser.add_argument('-c', '--ctd', metavar='ctd filepath', type=str)
    parser.add_argument('-d', '--drug', metavar='drug filepath', type=str)
    # Note that matrix needs to be scrubbed of cmpd labels
    # Ie can't use same matrix as subsignature method
    parser.add_argument('-m', '--matrix', metavar='prot-drug matrix filepath', type=str)
    # Other values are read from the fileset

    args = parser.parse_args()

    # Import indications
    if args.ctd != None:
      ctd_file = args.ctd
    if args.drug != None:
      drg_file = args.drug

    #indic_mapping, cmpd_set = import_indications(ctd_file, drug_file)
    num_splits = int(args.num_splits)
    split_index = int(args.split_index)
    standard, cmpd_set = extract_drug_indic_mapping(drug_indic_map, min_d=2)
    indic_splits = split_indics(standard, num_splits, by='pairs')
    this_split = indic_splits[split_index]

    cando_obj, test_indic, cmpd_to_obj = generate_cando_obj(matrix_file,
                                                            drug_file,
                                                            ctd_file, cmpd_set)
    cando_args = {'cando_obj':cando_obj, 'test_indic':test_indic,\
                  'cmpd_to_obj':cmpd_to_obj, 'filename':str(split_index)}
    bench_args = {'filename':cross_val_file}

    benchmarking(this_split, cmpd_set, read_strat_xfold_cross, \
               FUNC, bench_args, cando_args,\
               out_name + str(split_index) + '.tsv')
