##############################################################################
# General benchmarking protocol                                              #
# Written February 2024, commented/updated November 2024                     #
# Important notes:                                                           #
# For provided mappings, CTD drugs/indications can be id'd by name or id     #
#   i.e. CANDO ID for drugs, MeSH code for indications                    #
# TTD IDs are non-unique so TTD drugs/indications are identified by name     #
#   i.e. "X-rays imaging" and "Herbicide" instead of "N.A." for both         #
#   i.e. "D04XVN" and "D00QDJ" are both "trametinib"
##############################################################################

import random, math

##############################################################################
# Drug-indication mapping creation                                           #
##############################################################################

def extract_drug_indic_mapping(filename, key='iname', val='dname', min_d=1):
  '''
  Input
    filename - str, path to file with drug-indication mapping
    key - str, "iid" will use indication ID, otherwise name
    val - str, "did" will use drug ID, otherwise name
    min_d - int, minimum # of drugs for included indics; default=1

  mapping file should be a tsv with 4 labelled columns:
    Drug ID  Drug Name  Indication Name  Indication ID
    no specific format for drug/indication ID, but should be unique
    drug name & indication name should also be unique, match to drug/indic id
    a drug may have multiple indications & vice versa

  Returns a tuple of 2 items:
    indic_to_drugs - dict of str:list, known drug-indication mapping
      each indication name/id (key) is mapped to 1+ drug name/ids (values)
    cmpd_set - set of str, all drug name/ids in the drug-indic mapping
  '''
  with open(filename, 'r') as f:
    lines = f.read().strip().split('\n')

  indic_to_drugs = {}
  cmpd_set = set()
  print('Creating dict translating %s to %s' % \
        ('indication ID' if key == 'iid' else 'indication name',\
         'drug IDs' if val == 'did' else 'drug names'))

  for line in lines[1:]:
    did, dname, iname, iid = line.split('\t')
    if key == 'iid':
      ikey = iid.title()
    else:
      ikey = iname.title()

    if val == 'did':
      dval = did.title()
    else:
      dval = dname.lower().replace(' ', '_')

    cmpd_set.add(dval)
      
    if ikey in indic_to_drugs:
      if dval not in indic_to_drugs[ikey]:
        indic_to_drugs[ikey].append(dval)
    else:
      indic_to_drugs[ikey] = [dval]

  if min_d > 1:
    for key in list(indic_to_drugs.keys()):
      if len(indic_to_drugs[key]) < min_d:
        del indic_to_drugs[key]

  return indic_to_drugs, cmpd_set

def split_indics(indic_to_drugs, num_splits, by='indics'):
  '''
  indic_to_drugs - dict, indic (str) : drugs (list of str) pairs
    indication/drug pairs to be split
  num_splits - int, number of dicts to split pairs into
  by - str, indics or pairs            
    indics - even # indications, eg for Xfold validation
    pairs - even # of indic-drug pairs, eg for leave-one-out

  Returns list of num_splits indic_to_drugs dicts
    if by == indics, splits will be even # indics +- one
    if by == pairs, splits will be as even as possible given indic lens
    note that indics will NEVER be divided into multiple splits
  '''
  splits = [{} for _ in range(num_splits)]
  if len(indic_to_drugs) < num_splits:
    print('Not enough indications to split into %d' % num_splits)
    print('Splitting into %d (# indications) instead' % len(indic_to_drugs))
    num_splits = len(indic_to_drugs)
  else:
    print('Splitting %d indications into %d splits by %s' %
          (len(indic_to_drugs), num_splits, \
           'indics' if by == 'indics' else 'pairs'))
    
  if by == 'indics':
    num_per = len(indic_to_drugs) // num_splits
    extra = len(indic_to_drugs) % num_splits

    indics = tuple(indic_to_drugs.keys())
    start = 0
    for i in range(num_splits):
      num_incl = num_per + (1 if i < extra else 0)

      for j in range(start, start + num_incl):
        splits[i][indics[j]] = indic_to_drugs[indics[j]][:]

      start += num_incl

  else:
    indics = list(indic_to_drugs.keys())
    indics.sort(key=lambda x: len(indic_to_drugs[x]))
    lens = {}
    for i in range(num_splits):
      key = indics[-1]
      splits[i][key] = indic_to_drugs[key][:]

      key_len = len(indic_to_drugs[key])
      if key_len in lens:
        lens[key_len].append(i)
      else:
        lens[key_len] = [i]
      indics.pop()

    min_len = 0
    while indics:
      if min_len not in lens:
        min_len += 1
      else:
        key = indics[-1]
        i = lens[min_len][-1]

        splits[i][key] = indic_to_drugs[key][:]
        key_len = len(indic_to_drugs[key])
        new_len = min_len + key_len

        if new_len in lens:
          lens[new_len].append(i)
        else:
          lens[new_len] = [i]

        indics.pop()
        lens[min_len].pop()
        if len(lens[min_len]) == 0:
          del lens[min_len]

    print('Final split lengths range from %d to %d pairs (%d fold diff)' % \
          (min(lens.keys()), max(lens.keys()), \
           max(lens.keys())//min(lens.keys())))

  return splits
  


##############################################################################
# Data-splitting generators                                                  #
##############################################################################

def leave_one_out(drug_indic_mapping):
  '''
  drug_indic_mapping - dict, indic (str) : drugs (list of str) pairs

  Returns generator object
    Yields indic (str), indic_drugs (tuple of str), left_out (tuple of str)
    Usage: for indic, drugs, left_out in leave_one_out(standard):
    Will iterate thru every indic:drug pair in the dict
    To exlude indic:drug pairs, remove them from the dict
  '''
  for indic in drug_indic_mapping:
    for drug in drug_indic_mapping[indic]:
      #print(drug, drug_indic_mapping[indic])
      indic_drugs = drug_indic_mapping[indic][:]
      indic_drugs.remove(drug)
      yield indic, tuple(indic_drugs), (drug,)

def strat_xfold_cross(drug_indic_mapping, folds=10, seed=0):
  '''
  drug_indic_mapping - dict, indic (str) : drugs (list of str) pairs
  folds - int, # of folds the indic:drug pairs will be divided into
                  default 10 (10-fold cross validation)

  Will divide into folds stratified by indic (indic roughly evenly distrib)
    Note for indics of len < folds, some folds will be empty
    
  Returns generator object
    Yields indic (str), indic_drugs (tuple of str), left_outs (tuple of str)
    Usage: for indic, drugs, left_out in strat_xfold_cross(standard):
    Will iterate thru each fold's left out in order per indic
      ie will go indic 1 fold 1, indic 1 fold 2, indic 1 fold 3, etc
  '''
  random.seed(0)
  
  for indic in drug_indic_mapping:
    
    drug_list = drug_indic_mapping[indic][:]
    drugs_per = len(drug_list) // folds
    extra = len(drug_list) % folds
    for _ in range(3):
      random.shuffle(drug_list)

    extra_folds = set(random.sample(range(folds), extra))

    start_i = 0
    for i in range(folds):
      
      num_incl = drugs_per + (1 if i in extra_folds else 0)
      left_outs = drug_list[start_i : start_i + num_incl]
      indic_drugs = drug_list[:start_i] + drug_list[start_i + num_incl:]
      
      start_i = start_i + num_incl
      
      yield indic, tuple(indic_drugs), tuple(left_outs)

def save_strat_xfold_cross(drug_indic_mapping, folds=10, seed=0, name=''):
  '''
  drug_indic_mapping - dict, indic (str) : drugs (list of str) pairs
  folds - int, # of folds the indic:drug pairs will be divided into
                  default 10 (10-fold cross validation)

  Will divide into folds stratified by indic (indic roughly evenly distrib)
    Note for indics of len < folds, some folds will be empty
    
  Creates file saving split folds w/filename based on fold #, seed, & name
    Format: Indic  Fold  In Indic  Left Out
    Indic    - str, indication name
    Fold     - int, fold of that indication the following are included in
    In Indic - comma-separated list of str, drugs considered "in" indication
    Left Out - comma-separated list of str, "new" drugs being assessed 
  '''
  random.seed(seed)
  if not name:
    filename = '%dfold_cross_seed%d.tsv' % (folds, seed)
  else:
    filename = '%dfold_cross_seed%d_%s.tsv' % (folds, seed, name)

  with open(filename, 'w') as f:
    f.write('Indic\tFold\tIn Indic\tLeft Out\n')
  
  for indic in drug_indic_mapping:
    drug_list = drug_indic_mapping[indic][:]
    drugs_per = len(drug_list) // folds
    extra = len(drug_list) % folds
    for _ in range(3):
      random.shuffle(drug_list)

    extra_folds = set(random.sample(range(folds), extra))

    start_i = 0
    for i in range(folds):
      
      num_incl = drugs_per + (1 if i in extra_folds else 0)
      left_outs = drug_list[start_i : start_i + num_incl]
      indic_drugs = drug_list[:start_i] + drug_list[start_i + num_incl:]
      
      start_i = start_i + num_incl

      with open(filename, 'a') as f:
        f.write('%s\t%d\t%s\t%s\n' % (indic, i+1, \
                                      ','.join(indic_drugs),\
                                      ','.join(left_outs)))

def read_strat_xfold_cross(drug_indic_mapping, filename):
  '''
  drug_indic_mapping - dict, indic (str) : drugs (list of str) pairs
  filename - str, file with saved stratified folds
    
  Returns generator object
    Yields indic (str), indic_drugs (tuple of str), left_outs (tuple of str)
    Usage: for indic, drugs, left_out in strat_xfold_cross(standard):
    Will iterate thru left out cmpds in same order as input file
    Note: will skip indications not in drug_indic_mapping    
  '''
  random.seed(0)

  with open(filename, 'r') as f:
    lines = f.read().strip().split('\n')

  for line in lines:
    sections = line.split('\t')
    if len(sections) == 3:
      indic, fold, indic_drugs = sections
      left_outs = ''
    else:
      indic, fold, indic_drugs, left_outs = sections
    indic = indic.title()

    if indic not in drug_indic_mapping:
      continue

    indic_drugs = tuple(indic_drugs.split(','))
    left_outs = tuple(left_outs.split(','))
    if len(left_outs) == 1 and left_outs[0] == '':
      left_outs = tuple()
    
    yield indic, tuple(indic_drugs), tuple(left_outs)
      

##############################################################################
# Scoring functions (rank, AUROC, NDCG)                                      #
##############################################################################

def get_rank(ranked_list, true_cmpds):
  '''
  ranked_list - list of str, cmpd names/ids in a specific predicted order
  true_cmpds - str or list of str, cmpd name/id(s) to be matched to ranks
    all true_cmpds should appear at least once in ranked_list
    (in this case, the list of withheld compounds)

  Returns list of int representing ranks of true_cmpds in ranked_list
  '''
  ranks = []
  if type(true_cmpds) == str:
    true_cmpds = [true_cmpds]

  for cmpd in true_cmpds:
    ranks.append(ranked_list.index(cmpd) + 1)

  return ranks

def get_auroc(filename, max_fpr=2.0):
  '''
  filename - str, name of the file from which AUROC will be calculated
    should be tsv with 5 columns: Indic  Split  Cmpd  Rank  OutOf
    (file generated using benchmarking function)
  max_fpr (optional) - float, max FPR for AUROC to be calculated up through
    any number >= 1 or no number will calculate full AUROC

  Prints and returns area under receiver operator curve metric for given data
  '''
  with open(filename, 'r') as f:
    lines = f.read().strip().split('\n')

  if not lines[0].split('\t')[3].isdigit():
    lines = lines[1:]

  # Extract data (number of splits, # cmpds ranked per indic, ranks)
  splits = set()
  out_of_counts = {}
  ranks = []
  for line in lines:
    cells = line.split('\t')
    ranks.append(int(cells[3]))
    if (cells[0], cells[1]) in splits:
      continue
    else:
      splits.add((cells[0], cells[1]))
    key = int(cells[4])
    if key in out_of_counts:
      out_of_counts[key] += 1
    else:
      out_of_counts[key] = 1

  max_rank = max(out_of_counts.keys())
  num_splits = sum(out_of_counts.values())
  ranks.sort()

  # Count the number of approved & unapproved drugs at each rank
  pos = [0]*max_rank
  neg = [0]*max_rank
  splits_not_at_rank = 0
  j = 0
  for rank in range(1, max_rank + 1):
    i = rank - 1
    total = num_splits - splits_not_at_rank

    while j < len(ranks) and ranks[j] == rank:
      j += 1
      pos[i] += 1

    neg[i] = total - pos[i] + neg[i-1]
    pos[i] += pos[i-1]
    
    if rank in out_of_counts:
      splits_not_at_rank += out_of_counts[rank]

  neg = [x/neg[-1] for x in neg]
  pos = [x/pos[-1] for x in pos]

  # Calculate AUROC from TPR & FPR metrics at each rank
  total_area = 0
  for i in range(1,len(neg)):
    if neg[i] > max_fpr:
      run = max_fpr - neg[i-1]
      rise = ((pos[i] - pos[i-1])/(neg[i] - neg[i-1]))*run # slope times run
      total_area += (pos[i]*run) + ((rise*run)/2)
      break
    total_area += ((pos[i] + pos[i-1])/2)*(neg[i] - neg[i-1])

  print('AUROC', total_area, ('at max FPR ' + \
        str(max_fpr)) if max_fpr < 1 else '')

  return total_area

def write_auroc_graphable(filename, new_filename):
  '''
  filename - str, name of the file from which AUROC will be calculated
    should be tsv with 5 columns: Indic  Split  Cmpd  Rank  OutOf
    (file generated using benchmarking function)
  new_filename - str, name of output file to be created

  Creates TSV file named after new_filename with 3 columns:
      rank - int, rank threshold at which FPR/TPR was calculated
      FPR - float with 3 decimal places, false positive rate through rank
      TPR - float with 3 decimal places, true positive rate through rank
    Data can be used to create a ROC graph (FPR vs. TPR)
  '''
  with open(filename, 'r') as f:
    lines = f.read().strip().split('\n')

  if not lines[0].split('\t')[3].isdigit():
    lines = lines[1:]

  # Extract data (number of splits, # cmpds ranked per indic, ranks)
  splits = set()
  out_of_counts = {}
  ranks = []
  for line in lines:
    cells = line.split('\t')
    ranks.append(int(cells[3]))
    if (cells[0], cells[1]) in splits:
      continue
    else:
      splits.add((cells[0], cells[1]))
    key = int(cells[4])
    if key in out_of_counts:
      out_of_counts[key] += 1
    else:
      out_of_counts[key] = 1

  max_rank = max(out_of_counts.keys())
  num_splits = sum(out_of_counts.values())
  ranks.sort()

  # Count the number of approved & unapproved drugs at each rank
  pos = [0]*max_rank
  neg = [0]*max_rank
  splits_not_at_rank = 0
  j = 0
  for rank in range(1, max_rank + 1):
    i = rank - 1
    total = num_splits - splits_not_at_rank

    while j < len(ranks) and ranks[j] == rank:
      j += 1
      pos[i] += 1

    neg[i] = total - pos[i] + neg[i-1]
    pos[i] += pos[i-1]
    
    if rank in out_of_counts:
      splits_not_at_rank += out_of_counts[rank]

  neg = [x/neg[-1] for x in neg]
  pos = [x/pos[-1] for x in pos]

  out = ['Rank\tFPR\tTPR\n'] + \
        [('%d\t%.3f\t%.3f\n' % (i + 1, neg[i], pos[i])) for i in range(len(neg))]

  with open(new_filename, 'w') as f:
    f.writelines(out)
    

def get_ndcg(filename, rank_cutoff=None):
  '''
  filename - str, name of the file from which ndcg will be calculated
    should be tsv with 5 columns: Indic  Split  Cmpd  Rank  OutOf
    (file generated using benchmarking function)
  rank_cutoff  - int, max rank at or below which NDCG will be calculated
    None will result in no rank cutoff being used

  Prints and returns normalized discounted cumulative gain for given data
  '''
  with open(filename, 'r') as f:
    lines = f.read().strip().split('\n')

  if not lines[0].split('\t')[3].isdigit():
    lines = lines[1:]

  # Extract data (splits and ranks per split)
  splits = {}
  for line in lines:
    cells = line.split('\t')
    key = (cells[0], cells[1])
    if key in splits:
      splits[key].append(int(cells[3]))
    else:
      splits[key] = [int(cells[3])]

  if rank_cutoff == None:
    rank_cutoff = max([max(x) for x in splits.values()])

  # Calculate DCG and IDCG per split, then sum up
  dcg = 0
  idcg = 0
  for split, ranks in splits.items():
    ideals = list(range(1,len(ranks)+1))
    for rank, ideal in zip(ranks, ideals):
      if rank <= rank_cutoff:
        dcg += (1)/(math.log(rank + 1, 2))
      if ideal <= rank_cutoff:
        idcg += (1)/(math.log(ideal + 1, 2))
    
  print('NDCG', dcg/idcg, ('at cutoff ' + \
        str(rank_cutoff)) if rank_cutoff < max([max(x) for x in splits.values()]) else '')
  return dcg/idcg
  

##############################################################################
# Main benchmarking function                                                 #
##############################################################################

def benchmarking(drug_indic_mapping, cmpd_set, gen_func, bench_func,\
                 gen_func_args={}, bench_func_args={}, out_file='',
                 pass_indic_id=False):
  '''
  Primary benchmarking function; provides 

  drug_indic_mapping - dict of str:list, known drug-indication mapping
    each indication name/id (key) is mapped to 1+ drug name/ids (values)
  cmpd_set - set of str, all drug name/ids in the drug-indic mapping
    compounds not in the drug-indication mappings may also be included
  gen_func - func, splitting function (from data-splitting generators)
    must take as input the drug-indication mapping (drug_indic_mapping)
      may take additional input stored in gen_func_args dictionary
    must yield indiv. indication name/ids, associated drugs, & withheld drug(s)
  bench_func - func, the function to be benchmarked
    generally a wrapper function that interfaces with the actual platform
    must take as input a set of indicated drugs and a set of non-inidcated drugs
      if pass_indic_id is True, must also take as input the indication name/id
      may take additional arguments stored in bench_func_args dictionary
    must output a sorted list of all compounds in non-indicated drug list
      sorting order: most likely to be effective for that indication to least
  gen_func_args (optional) - dict of str:any, additional args for gen_func
    ex: number of splits, file for splits to be read from
    to use read_strat_xfold_cross, would pass gen_func_args={'filename':name}
  bench_func_args (optional) - dict of str:any, additional args for bench_func
    ex: pre-initialized objects to be used (see cando_wrapper.py)
  pass_indic_id (optional) - bool, default False
    if True, bench_func will receive indication name/id as its third argument
  
  Passes following arguments to 
  Creates results file in TSV format, 5 columns
    indic - str, indication name/id
    split - int, number of split in which association was assessed
    cmpd - str, compound name/id
    rank - int, the predicted rank of the withheld drug for the indication
    out of - int, the number of compounds the drug was ranked against
  '''
  
  print('Benchmarking...')

  if not out_file:
    out_file = 'unnamed_benchmarking.tsv'

  with open(out_file, 'w') as f:
    f.write('Indic\tSplit\tCmpd\tRank\tOut Of\n')

  indic_counters = {}
  for indic, drugs, left_outs in gen_func(drug_indic_mapping, **gen_func_args):
    # create indication based on drugs, assess ranks of left_outs
    if indic not in indic_counters:
      indic_counters[indic] = 1
    else:
      indic_counters[indic] += 1

    if len(left_outs) == 0:
      continue
      
    non_indic = cmpd_set - set(drugs)
    if pass_indic_id:
      ranked_list = bench_func(set(drugs), non_indic, indic,\
                               **bench_func_args)
    else:
      ranked_list = bench_func(set(drugs), non_indic, **bench_func_args)

    ranks = get_rank(ranked_list, left_outs)

    out = ''
    for left_drug, rank in zip(left_outs, ranks):
      out += '%s\t%d\t%s\t%d\t%d\n' % (indic, indic_counters[indic],\
                                           left_drug, rank, len(non_indic))

    with open(out_file, 'a') as f:
      f.write(out)
    
##############################################################################
# Test/sample code                                                           #
##############################################################################

def random_control(indic, non_indic, indic_id=None):
  '''
  Provide as an example of how prediction functions receive data

  indic - set of str, names or ids of all drugs associated with an indication
    (predictions should be made based on this set)
  non_indic - set of str, names/ids of all other, unassociated drugs
    (will contain 1+ withheld indication-associated drugs for assessment)
  indic_id (optional) - str, the id of the indication being assessed
    (can be used when the benchmarked function requires indication information)
    (will only be passed if pass_indic_id=True in benchmarking protocol)
  (Additional parameters, if necessary, may be passed through bench_func_args
    argument of benchmarking protocol)
  
  Returns ranks (list of str), ordered list of drug names/ids
    (Should contain all drugs from non_indic set and none from indic set)
    (Should be ordered most-least likely to be effective for this indication)
  '''
  ranks = list(non_indic)
  random.shuffle(ranks)
  return ranks

def alphabetical_test(indic, non_indic, indic_id=None):
  '''
  Serves as a deterministic test to check that this module is functioning as expecte

  indic - set of str, names or ids of all drugs associated with an indication
  non_indic - set of str, names/ids of all other, unassociated drugs
  indic_id (optional) - str, the id of the indication being assessed
  
  Returns ranks (list of str), alphabetically ordered list of drug names/ids
  '''
  ranks = list(non_indic)
  return sorted(non_indic)

if __name__ == '__main__':
  # Prepare the drug-indication mapping
  indic_drug_map, cmpd_set = \
                  extract_drug_indic_mapping('data/ctd_approved_drugs.tsv',\
                                   min_d=2,key='iid',val='did')

  # Benchmark on the given drug-indication mapping and training/testing splits
  benchmarking(indic_drug_map, cmpd_set, read_strat_xfold_cross, random_control,\
               {'filename':'data/id_10fold_cross_seed0_ctd_cnd.tsv'}, {}, 'randomized_results.tsv',\
               pass_indic_id=True)

  # Calculate AUROC and NDCG from the results; will print out to terminal
  get_auroc('randomized_results.tsv', max_fpr=0.05) # theoretically 0.00125
  get_auroc('randomized_results.tsv') # theoretically 0.5
  get_ndcg('randomized_results.tsv', rank_cutoff=10)
  get_ndcg('randomized_results.tsv')
  write_auroc_graphable('randomized_results.tsv','random_TPR_FPR_by_rank.tsv')
  print()

  benchmarking(indic_drug_map, cmpd_set, read_strat_xfold_cross, alphabetical_test,\
               {'filename':'data/id_10fold_cross_seed0_ctd_cnd.tsv'}, {}, 'alphabetical_results.tsv',\
               pass_indic_id=True)

  # Calculate AUROC and NDCG from the results; will print out to terminal
  get_auroc('deterministic_results.tsv', max_fpr=0.05) # theoretically 0.00125
  get_auroc('deterministic_results.tsv') # theoretically 0.5
  get_ndcg('deterministic_results.tsv', rank_cutoff=10)
  get_ndcg('deterministic_results.tsv')
  write_auroc_graphable('randomized_results.tsv','alphabetical_TPR_FPR_by_rank.tsv')
