from rouge_score import rouge_scorer
import  nltk.translate.bleu_score as bleu
import nltk.translate.gleu_score as gleu
from rouge_score import rouge_scorer
import nltk

def critera4_5(trans1, trans2, n):
    L1= len(trans1.split())
    L2= len(trans2.split())
    #L = min(L1,L2)
    #word = trans2.split()[n]
    #print(word)
    d = 0
    D = []
    k = n+1
    x=""
    for i in range(n, L2):
      #print("word"+str(i), "\n")
      if x == "yes":
        k = a+1
      for j in range(k, L1): 
        #print(trans2.split()[i])
        #print(trans1.split()[j])
        if trans1.split()[j]==trans2.split()[i]:
          d = 0
          x = "yes"
          a = j
          break
        else: 
          x ="no"
          d = 1    
      D.append(d) 
    #print(D)   
    return sum(D)

def diff_score(trans1, trans2):
    #print("in diff_score trans1", trans1)
    #print("in diff_score trans2", trans2)
    
    if trans1 and trans2 == "":
      return 0
    
    d = 0
    substring = ","
    L1= len(trans1.split())
    L2= len(trans2.split())
    
    #criteria1
    if L1 == L2: d = d
    else: d = d+1  
    #print("con1", d)
    
    #criteria2
    #print("hello")
    #print(trans1.split()[0])
    #print(trans2.split()[0])
    if trans1.split()[0]== trans2.split()[0]: d = d
    else: d = d+1   
    #print("con2", d)
   
    #criteria3
    L = min(L1,L2)
    for i in range(L):
      if trans1.split()[i] != trans2.split()[i]: 
        x="yes"
        break
      else: x="no"
    #print(x)  
    if x=="no" : d = d
    else: d = d+1   
    #print("con3", d)
  
    #criteria4_5
    n = i
    d = critera4_5(trans1, trans2, n)
    #print("con4_5", d)

    #criteria6
    if (trans1.split()[-1]== trans2.split()[-1]) & (trans1[-1] == trans2[-1]): d = d
    else: d = d+1  
    #print("con6", d)
    
    #criteria7
    if trans1.count(substring)== trans2.count(substring): d = d
    else: d = d+1  
    #print("con7", d)
    return d

def manual_diff_score(trans, sources_name):  
    global_diff = [] 
    n=len(sources_name)
    for i in range(n):
      #print("source"+str(i+1))
      local_diff = 0
      for j in range(n):
        if i!=j:
          #print(i,j)
          #d = diff_score(trans[str(j)], trans[str(i)])
          if trans[str(i)] and trans[str(j)] == " ":  
            continue
          d = diff_score(trans[str(i)], trans[str(j)])
          #print(d)
          local_diff += d
      #print(local_diff)
      #global_diff.append(local_diff)
      global_diff.append(local_diff/(n-1))
    #print("manual_diff_score", global_diff)
    Choiced_source = global_diff.index(min(global_diff)) 
    #print("source", Choiced_source)
    #print(sources_name[str(Choiced_source)])
    #print("output", trans[str(Choiced_source)])
    return trans[str(Choiced_source)], sources_name[str(Choiced_source)]

def bleu_diff_score(trans, sources_name):  
    global_diff = [] 
    n=len(sources_name)
    for i in range(n):
      #print("source"+str(i+1))
      local_diff = 0
      for j in range(n):
        if i!=j:
          #print(i,j)
          d = 1 -bleu.sentence_bleu([trans[str(j)].split()], trans[str(i)].split())
          #d = bleu.sentence_bleu([trans[str(j)].split()], trans[str(i)].split())
          #print(d)
          local_diff += d
      #print(local_diff)
      #global_diff.append(local_diff)
      global_diff.append(local_diff/(n-1))
    #print(global_diff)
    #print("bleu_diff_score", global_diff)
    Choiced_source = global_diff.index(min(global_diff)) 
    #Choiced_source = global_diff.index(max(global_diff)) 
    #print(Choiced_source)
    return trans[str(Choiced_source)], sources_name[str(Choiced_source)]


def gleu_diff_score(trans, sources_name):  
    global_diff = [] 
    n=len(sources_name)
    for i in range(n):
      #print("source"+str(i+1))
      local_diff = 0
      for j in range(n):
        if i!=j:
          #print(i,j)
          d = 1 -gleu.sentence_gleu([trans[str(j)].split()], trans[str(i)].split())
          #d = gleu.sentence_gleu([trans[str(j)].split()], trans[str(i)].split())
          #print(d)
          local_diff += d
      #print(local_diff)
      #global_diff.append(local_diff)
      global_diff.append(local_diff/(n-1))
    #print(global_diff)
    #print("gleu_diff_score", global_diff)
    Choiced_source = global_diff.index(min(global_diff)) 
    #Choiced_source = global_diff.index(max(global_diff)) 
    #print("source", Choiced_source)
    #print(sources_name[str(Choiced_source)])
    #print("output", trans[str(Choiced_source)])
    return trans[str(Choiced_source)], sources_name[str(Choiced_source)]    

def meteor_diff_score(trans, sources_name):  
    global_diff = [] 
    n=len(sources_name)
    for i in range(n):
      #print("source"+str(i+1))
      local_diff = 0
      for j in range(n):
        if i!=j:
          #print(i,j)
          d = 1 -(nltk.translate.meteor_score.meteor_score([trans[str(j)]], trans[str(i)]))
          #d = (nltk.translate.meteor_score.meteor_score([trans[str(j)]], trans[str(i)]))
          #print(d)
          local_diff += d
      #print(local_diff)
      #global_diff.append(local_diff)
      global_diff.append(local_diff/(n-1))
    #print(global_diff)
    #Choiced_source = global_diff.index(min(global_diff)) 
    #print("meteor_diff_score", global_diff)
    Choiced_source = global_diff.index(max(global_diff)) 
    #print("source", Choiced_source)
    #print(sources_name[str(Choiced_source)])
    #print("output", trans[str(Choiced_source)])
    return trans[str(Choiced_source)], sources_name[str(Choiced_source)]

scorer = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)
def rouge_diff_score(trans, sources_name):
    global_diff1 = [] 
    global_diff2 = [] 
    n=len(sources_name)
    for i in range(n):
      #print("source"+str(i+1))
      local_diff1 = 0
      local_diff2 = 0
      for j in range(n):
        if i!=j:
          #print(i,j)
          #scorer = rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
          scores = scorer.score(trans[str(j)], trans[str(i)])
          d1 = 1 - scores['rouge2'][2]
          d2 = 1 - scores['rougeL'][2]  
          # d1 = scores['rouge2'][2]
          # d2 = scores['rougeL'][2]          
          #print("rouge2", d1)
          #print("rougel", d2)
          local_diff1 += d1
          local_diff2 += d2
      #print(local_diff)
      #global_diff.append(local_diff)
      global_diff1.append(local_diff1/(n-1))
      global_diff2.append(local_diff2/(n-1))
    #print(global_diff1)
    #print(global_diff2)
    #print("rouge1_diff_score", global_diff1)
    Choiced_source1 = global_diff1.index(min(global_diff1)) 
    #print("source", Choiced_source1)
    #print(sources_name[str(Choiced_source1)])
    #print("output", trans[str(Choiced_source1)])


    #print("rougeL_diff_score", global_diff2)
    Choiced_source2 = global_diff2.index(min(global_diff2)) 
    #print("source", Choiced_source2)
    #print(sources_name[str(Choiced_source2)])
    #print("output", trans[str(Choiced_source2)])

    # Choiced_source1 = global_diff1.index(min(global_diff1)) 
    # Choiced_source2 = global_diff2.index(min(global_diff2)) 
    #Choiced_source1 = global_diff1.index(max(global_diff1)) 
    #Choiced_source2 = global_diff2.index(max(global_diff2)) 
    #print(Choiced_source1)
    #print(Choiced_source2)
    
    
    return trans[str(Choiced_source1)], sources_name[str(Choiced_source1)], trans[str(Choiced_source2)], sources_name[str(Choiced_source2)]

def selection_source_transliteration(sources_name, O, priority_list):
  seq = list(Counter(O).values())
  seq.sort(reverse = True)
  print(seq)
  check=[]
  temp="y"
  for i in range(len(seq)-1):
    if seq[0]>seq[i+1]:
      check.append(i)
  #print(check)
  if len(check)==(len(seq)-1):
    temp = "yes" 

  #print("strating", temp)

  if temp=="yes":
    #print(" in if")
    (o1, s1), (o2, s2) = two_sources_two_outputs(sources_name, O)
    output1 = o1
    source1 = s1
    print(output1, source1)
    if len(seq)==2:
      output2 = o2
      source2 = s2
      print("1", output2, source2)
    else:
      temp1="y"
      check1=[]
      for i in range(len(seq)-2):
        if seq[1]>seq[i+2]:
          check.append(i)
      #print(check1)
      if len(check1)==(len(seq)-2):
        temp1 = "yes"  
      if temp1=="yes":
        output2 = o2
        source2 = s2
        #print("2", output2, source2)
      else:
        for i in priority_list:
         temp_source="test"
         if i not in source1: 
           temp_source = i
           break
        #print(temp_source) 
        if temp_source=="test":
          output2 = o2
          source2 = s2
        else:    
          if temp_source != priority_list[1]:
            output2= O[priority_list.index(temp_source)]
            source2=temp_source
            #print("3", output2, source2)
          else:
            output2= O[priority_list.index(priority_list[1])]
            source2=priority_list[1]
            #print("4", output2, source2)

  else:
    #print("in else")
    (o1, s1), (o2, s2) = two_sources_two_outputs(sources_name, O)
    #print(o1, s1)
    #print(o2, s2)
    if priority_list[0] in s1:
      #print("1")
      output1= o1
      source1= s1
      print(output1, source1)
    elif priority_list[0] in s2:
      #print("2")
      output1= o2
      source1= s2
      print(output1, source1)
    else:
      #print("3")
      output1=O[0]
      source1= priority_list[0]
      #print(output1, source1)
    temp_source = "test"  
    for i in priority_list:
      #print(i)
      if i not in source1: 
        temp_source = i
        break
    #print(temp_source)
    if temp_source=="test": 
        output2 = o2
        source2 = s2
    else:    
      if temp_source != priority_list[1]:
        output2= O[priority_list.index(temp_source)]
        source2=temp_source
        #print("4", output2, source2)
      else:
        output2= O[priority_list.index(priority_list[1])]
        source2=priority_list[1]
        #print("5", output2, source2)   
  return(output1, source1), (output2, source2)

def two_sources_two_outputs(sources_name, O):
  dict1 = Counter(O)
  #print(dict1)
  sorted_values = sorted(dict1.values(), reverse=True) # Sort the values
  sorted_dict = {}
  for i in sorted_values:
      for k in dict1.keys():
          if dict1[k] == i:
              sorted_dict[k] = dict1[k]       
  sources = list(sorted_dict.keys())
  #print(sources)
  rm =[]
  for r in Counter(O).keys():
    temp = [i for i in range(len(O)) if O[i] == r] 
    rm.append(temp)
  #print(rm)  
  resANDmethods_indexes={}
  fs = list(Counter(O).keys())
  for t in range(len(fs)):
    resANDmethods_indexes.update({fs[t]: rm[t]}) 
  #print(resANDmethods_indexes)
  out1 = sources[0]
  source1 = [sources_name[str(i)] for i in resANDmethods_indexes[out1]]
  if len(sources)==1:
    return (out1, source1), ("", "")
  else:  
    out2 = sources[1]
    source2 = [sources_name[str(i)] for i in resANDmethods_indexes[out2]]
    return (out1, source1), (out2, source2)