import numpy as np
import plots
import pickle
import torch
import gc
import pandas as pd
import tensorflow as tf
import tensorflow_text
import pprint
import spacy
from transformers import BertTokenizerFast, BertModel
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.utils import shuffle
from sklearn.metrics import make_scorer, precision_score

gc.collect()

from sklearn.model_selection import StratifiedKFold

infosDir = "/content/drive/My Drive/faculdade/fct-miei/04_ano4_(year4)/semestre1/ri/infos_projeto"
utterancesPath = {
  "elastic_search": infosDir + "/utterances" + "/normal.txt",
  "elastic_search_method1": infosDir + "/utterances" + "/utterance.txt",
  "elastic_search_method2": infosDir + "/utterances" + "/entities.txt",
  "elastic_search_method3": infosDir + "/utterances" + "/t5.txt",
  "elastic_search_methodFinal": infosDir + "/utterances" + "/final.txt"
}
turnsPerConv = 8
measuresPerTurn = 10
sendToBERT = 8

# save user utterances
def saveUserUtterances(phase, content):
  path = utterancesPath[phase]
  if path is not None:
    f = open(path, "a")
    f.write(content)
    f.close()

# save and load metrics
## save
def saveMetrics(fileName, relDocsPerTurn, metrics):
  np.savez(infosDir + "/metrics/" + str(relDocsPerTurn) + "/" + fileName, aps=metrics["aps"], ndcg5s=metrics["ndcg5s"], precisions=metrics["precisions"], recalls=metrics["recalls"])

def saveMetricsLMDAndBERT(fileName, relDocsPerTurn, metrics):
  np.savez(infosDir + "/metrics/" + str(relDocsPerTurn) + "/" + fileName,
    apsLMD=metrics["aps"]["lmd"], apsBERT=metrics["aps"]["bert"],
    ndcg5sLMD=metrics["ndcg5s"]["lmd"], ndcg5sBERT=metrics["ndcg5s"]["bert"],
    precisionsLMD=metrics["precisions"]["lmd"], precisionsBERT=metrics["precisions"]["bert"],
    recallsLMD=metrics["recalls"]["lmd"], recallsBERT=metrics["recalls"]["bert"])
## load
def loadMetrics(fileName, relDocsPerTurn):
  return np.load(infosDir + "/metrics/" + str(relDocsPerTurn) + "/" + fileName)

def loadMetricsLMDAndBERT(fileName, relDocsPerTurn):
  content = np.load(infosDir + "/metrics/" + str(relDocsPerTurn) + "/" + fileName)
  return {
    "aps": { "lmd": content["apsLMD"], "bert": content["apsBERT"] },
    "ndcg5s": { "lmd": content["ndcg5sLMD"], "bert": content["ndcg5sBERT"] },
    "recalls": { "lmd": content["recallsLMD"], "bert": content["recallsBERT"] },
    "precisions": { "lmd": content["precisionsLMD"], "bert": content["precisionsBERT"] }
  }

# save and load labels
## save
def savePlotLabels(labels):
  np.savez(infosDir + "/metrics/labels.npz", convNumbers=labels["convNumbers"], convNames=labels["convNames"])

## load
def loadPlotLabels():
  return np.load(infosDir + "/metrics/labels.npz")

# get ElasticSearch result
## normal
def getESResultNormal(updateElasticSearchResults, esResultsFolder, es, utterance, topicTurnID, numDocs, setName):
  if updateElasticSearchResults:
    esResult = es.search_body(query=utterance, numDocs=numDocs)
    pickle.dump(esResult, open(infosDir + "/" + esResultsFolder + "/" + setName + "/" + str(numDocs) + "/" + topicTurnID + ".pkl", "wb"))
  else:
    esResult = pickle.load(open(infosDir + "/" + esResultsFolder + "/" + setName + "/" + str(numDocs) + "/" + topicTurnID + ".pkl", "rb"))
  return esResult

## with entities
def getESResultEntities(updateElasticSearchResults, esResultsFolder, es, utterance, entities, topicTurnID, numDocs, setName):
  if len(entities) == 0:
    return getESResultNormal(updateElasticSearchResults, esResultsFolder, es, utterance, topicTurnID, numDocs, setName)
  if updateElasticSearchResults:
    esResult = es.search_with_boosted_entities(utterance, entities, np.ones(len(entities)), numDocs)
    pickle.dump(esResult, open(infosDir + "/" + esResultsFolder + "/" + setName + "/" + str(numDocs) + "/" + topicTurnID + ".pkl", "wb"))
  else:
    esResult = pickle.load(open(infosDir + "/" + esResultsFolder + "/" + setName + "/" + str(numDocs) + "/" + topicTurnID + ".pkl", "rb"))
  return esResult

# get metrics
def getMetrics(testBed, result, topicTurnID):
  if np.size(result) == 0:
    return [0, 0, 0, 0, np.zeros(measuresPerTurn), np.zeros(measuresPerTurn)]
  [p10, recall, ap, ndcg5, precisions, recalls] = testBed.eval(result[['_id','_score']], topicTurnID)
  return [p10, recall, ap, ndcg5, precisions, recalls]

# convert to BERT input
def convertToBERTInput(sentences=None, max_seq_length=512, tokenizer=None, add_cls=True, padding='do_not_pad', truncation=False):
  """Receive a list of [query_text, doc_text] in variable sentences
      Returns a dictionary of tensors
      It can only receive a single pair [query_text, doc_text] each time
  """
  return tokenizer.encode_plus(sentences, add_special_tokens=add_cls, padding=padding, max_length=max_seq_length, truncation=truncation,
                               return_tensors='pt', return_token_type_ids=True, return_attention_mask=True)

# get passages from ElasticSearch result
def getPassagesFromESResult(pickleFile):
  #info = relevanceJudgments.loc[relevanceJudgments['topic_turn_id'] == (topicTurnID)]
  passages = []
  for docID in pickleFile['_id']:
    #docInfo = info.loc[info['docid'] == docID]
    passages.append(pickleFile[pickleFile['_id'] == docID]['_source.body'].values[0])
  return passages

# BERT result
def getBERTResult(updateBERTResults, bertResultsForder, tokenizer, model, device, utterance, passages, topicTurnID, numDocs, setName):
  features = np.array([])
  bertInputCol = {
    'input_ids': [],
    'attention_mask': [],
    'token_type_ids': []
  }
  
  for passage in passages:
    bertInput = convertToBERTInput(sentences=[utterance, passage], max_seq_length=512, tokenizer=tokenizer, padding='max_length', truncation=False)
    bertInputCol['input_ids'].append(bertInput['input_ids'])
    bertInputCol['attention_mask'].append(bertInput['attention_mask'])
    bertInputCol['token_type_ids'].append(bertInput['token_type_ids'])
    if len(bertInputCol['input_ids']) % sendToBERT == 0:
      bertInputStack = {}
      bertInputStack['input_ids'] = torch.cat(bertInputCol['input_ids']).to(device)
      bertInputStack['attention_mask'] = torch.cat(bertInputCol['attention_mask']).to(device)
      bertInputStack['token_type_ids'] = torch.cat(bertInputCol['token_type_ids']).to(device)
      bertOutput = model(**bertInputStack)
      featuresLines = bertOutput["last_hidden_state"][:, 0].detach().clone().cpu().numpy()
      for featureLine in featuresLines:
        features = np.append(features, featureLine)

      bertInputCol = {
        'input_ids': [],
        'attention_mask': [],
        'token_type_ids': []
      }
      gc.collect()
  if len(bertInputCol['input_ids']) % sendToBERT != 0:
    bertInputStack = {}
    bertInputStack['input_ids'] = torch.cat(bertInputCol['input_ids']).to(device)
    bertInputStack['attention_mask'] = torch.cat(bertInputCol['attention_mask']).to(device)
    bertInputStack['token_type_ids'] = torch.cat(bertInputCol['token_type_ids']).to(device)
    bertOutput = model(**bertInputStack)
    featuresLines = bertOutput["last_hidden_state"][:, 0].detach().clone().cpu().numpy()
    for featureLine in featuresLines:
      features = np.append(features, featureLine)

    bertInputCol = {
      'input_ids': [],
      'attention_mask': [],
      'token_type_ids': []
    }
    gc.collect()
  features = np.reshape(features, (len(passages), 768))

  if updateBERTResults:
    pickle.dump(features, open(infosDir + "/" + bertResultsForder + "/" + setName + "/" + str(numDocs) + "/" + topicTurnID + ".pkl", "wb"))
  else:
    features = pickle.load(open(infosDir + "/" + bertResultsForder + "/" + setName + "/" + str(numDocs) + "/" + topicTurnID + ".pkl", "rb"))
  return features

# classifier
## train
def trainClassifier(features, classes):
  print("-- Training classifier")
  newOrder = shuffle(range(features.shape[0]))
  features = features[newOrder, :]
  classes = classes[newOrder]

  classifier = LogisticRegression(random_state=0, C=0.001, max_iter=1000, tol=1, class_weight="balanced")
  classifier.fit(features, classes)
  
  """max_iter = [10000]
  C = [10 ** c for c in range(-10, -4)]
  classWeight = ['balanced']
  tol = list(C)
  params = dict(C=C, max_iter=max_iter, class_weight=classWeight, tol=tol)

  classifier = GridSearchCV(
    LogisticRegression(), 
    params, 
    scoring=make_scorer(precision_score), 
    cv=5, 
    verbose=0)
  classifier.fit(features, classes)

  print("-- Training done: {}".format(classifier.best_estimator_.get_params()))"""

  # C = 0.001 ; tol = 1e-10 ; max_iter = 1000 ; class_weight = 'balanced'
  """classifier = LogisticRegression(max_iter=1000, C=0.001, tol=1e-5, class_weight='balanced')
  classifier.fit(features, classes)
  print("-- Training done (C: {})".format(0.001))"""

  """folds = 7
  maxIter = 500
  bestValidError = 1
  bestC = -10
  for C in range(-10, 1):
    trainError = validError = 0
    for train, valid in StratifiedKFold(n_splits=folds).split(classes, classes):
      lr = LogisticRegression(C=10**C, random_state=0, max_iter=maxIter, class_weight='balanced')
      lr.fit(features[train, :], classes[train])
      trainError += 1 - lr.score(features[train, :], classes[train])
      validError += 1 - lr.score(features[valid, :], classes[valid])
    if validError / folds < bestValidError:
      bestValidError = validError / folds
      bestC = C
  classifier = LogisticRegression(random_state=0, max_iter=maxIter, C=10**bestC, class_weight='balanced')
  classifier.fit(features, classes)"""

  #print("-- Training done (C: {} ; valid: {})".format(10 ** bestC, bestValidError))
  return classifier

## predict
def predictClassifier(classifier, features):
  return classifier.predict_proba(features)[:, 1]

## plots
def doPlots(relDocsPerTurn, setName, preName, APs, nDCGs, Recalls, Precisions, methods, convNumbers, convNames):
  preName = preName + " - " if len(preName.strip()) > 0 else preName
  plots.plotMetricAlongConversation(infosDir, relDocsPerTurn, setName, "Average Precision", APs, methods, convNumbers, preName)
  plots.plotMetricAlongConversation(infosDir, relDocsPerTurn, setName, "normalized Discounted Cumulative Gain", nDCGs, methods, convNumbers, preName)
  plots.plotMetricEachConversation(infosDir, relDocsPerTurn, setName, "Average Precision", APs, methods, convNumbers, convNames, preName)
  plots.plotMetricEachConversation(infosDir, relDocsPerTurn, setName, "normalized Discounted Cumulative Gain", nDCGs, methods, convNumbers, convNames, preName)
  plots.plotPrecisionRecall(infosDir, relDocsPerTurn, setName, Recalls, Precisions, methods, convNumbers, preName)

# reorder results
def reorderResults(classifier, features, pickleFile):
  prob = classifier.predict_proba(features)[:, 1]
  newOrder = np.lexsort(np.reshape(prob, (1, prob.shape[0])))[::-1]
  result = np.column_stack((pickleFile['_id'], prob))
  result = result[newOrder, :]
  return pd.DataFrame(data=result, columns=['_id', '_score'])






def phase1(relDocsPerTurn, updateElasticSearchResults, updateUtterances, es, testBed, topics, relevanceJudgments, topicsIDs, setName):
  # counters
  _ntopics = 0
  _nturns = 0
  _ntotalTurns = 0

  # metrics
  _aps = np.array([])
  _ndcg5s = np.array([])
  _precisions = np.array([])
  _recalls = np.array([])

  # conv and turns numbers and names
  _convNumbers = []
  _convNames = np.array([])

  for topic in topics:
    convID = topic['number']
    if convID not in topicsIDs:
      continue
    _convNumbers.append(convID)
    _convNames = np.append(_convNames, str(convID) + " " + topic['title'])

    _turnPrecisions = np.array([])
    _turnRecalls = np.array([])
    
    firstTurn = True
    for turn in topic['turn'][:turnsPerConv]:
      turnID = turn['number']
      utterance = turn['raw_utterance']
      topicTurnID = '%d_%d'% (convID, turnID)
      
      aux = relevanceJudgments.loc[relevanceJudgments['topic_turn_id'] == (topicTurnID)]
      numRel = aux.loc[aux['rel'] != 0]['docid'].count()

      _convNames = np.append(_convNames, topicTurnID + " " + utterance)
      _ntotalTurns += 1
      
      if numRel == 0:
        _aps = np.append(_aps, np.nan)
        _ndcg5s = np.append(_ndcg5s, np.nan)
        _turnPrecisions = np.append(_turnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _turnRecalls = np.append(_turnRecalls, np.ones(measuresPerTurn) * np.nan)
        continue

      result = getESResultNormal(updateElasticSearchResults, "elastic_search", es, utterance, topicTurnID, relDocsPerTurn, setName)
      
      if np.size(result) == 0 or numRel == 0:
        _aps = np.append(_aps, 0.0)
        _ndcg5s = np.append(_ndcg5s, 0.0)
        _turnPrecisions = np.append(_turnPrecisions, np.zeros(measuresPerTurn))
        _turnRecalls = np.append(_turnRecalls, np.zeros(measuresPerTurn))
        continue

      if updateUtterances:
        if firstTurn:
          firstTurn = False
          saveUserUtterances("elastic_search", "{}: {}\n".format(convID, topic['title']))
        saveUserUtterances("elastic_search", "-- {}: {}\n".format(turnID, utterance))
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, result, topicTurnID)#testBed.eval(result[['_id','_score']], topicTurnID)

      # metrics
      _aps = np.append(_aps, ap)
      _ndcg5s = np.append(_ndcg5s, ndcg5)
      _turnPrecisions = np.append(_turnPrecisions, precisions)
      _turnRecalls = np.append(_turnRecalls, recalls)

      # counters
      _nturns = _nturns + 1
    
    while _ntotalTurns % turnsPerConv != 0:
      _ntotalTurns += 1
      _aps = np.append(_aps, np.nan)
      _ndcg5s = np.append(_ndcg5s, np.nan)
      _turnPrecisions = np.append(_turnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _turnRecalls = np.append(_turnRecalls, np.ones(measuresPerTurn) * np.nan)
      _convNames = np.append(_convNames, "NO RESULT")
    
    # compute conv means
    _turnPrecisions = np.reshape(_turnPrecisions, (turnsPerConv, measuresPerTurn))
    _turnRecalls = np.reshape(_turnRecalls, (turnsPerConv, measuresPerTurn))

    # add precisions and recalls to global matrix
    _precisions = np.append(_precisions, np.nanmean(_turnPrecisions, axis=0))
    _recalls = np.append(_recalls, np.nanmean(_turnRecalls, axis=0))

    # counters
    _ntopics += 1

  # metrics
  _aps = np.reshape(_aps, (_ntopics, turnsPerConv))
  _ndcg5s = np.reshape(_ndcg5s, (_ntopics, turnsPerConv))
  _precisions = np.reshape(_precisions, (_ntopics, measuresPerTurn))
  _recalls = np.reshape(_recalls, (_ntopics, measuresPerTurn))

  # convs and turns names
  _convNames = np.reshape(_convNames, (_ntopics, turnsPerConv + 1))

  return {
    "aps": _aps,
    "ndcg5s": _ndcg5s,
    "precisions": _precisions,
    "recalls": _recalls,
    "convNumbers": _convNumbers,
    "convNames": _convNames
  }

def phase2(relDocsPerTurn, updateElasticSearchResults, updateBERTResults, es, testBed, topicsTrain, topicsTest, relevanceJudgmentsTrain, relevanceJudgmentsTest, topicsIDsTrain, topicsIDsTest, setNameTrain, setNameTest):
  ## TRAIN

  totalDocs = 0
  questionsUtterances = np.array([])
  features = np.array([])
  classes = np.array([])

  print("Train")
  print("-- Building question-utterance pairs")
  if updateBERTResults:
    for topic in topicsTrain:
      convID = topic['number']

      if convID not in topicsIDsTrain:
        continue
      
      for turn in topic['turn']:
        turnID = turn['number']
        question = turn['raw_utterance']
        topicTurnID = '%d_%d'% (convID, turnID)
        
        info = relevanceJudgmentsTrain.loc[relevanceJudgmentsTrain['topic_turn_id'] == (topicTurnID)]
        numberRel = info.loc[info['rel'] != 0]['docid'].count()

        if numberRel == 0:
          continue
        
        pickleFile = getESResultNormal(updateElasticSearchResults, "elastic_search", es, question, topicTurnID, relDocsPerTurn, setNameTrain)
        #pickleFile = pickle.load(open(PROJ_DIR + "/pkls/" + setName + "/" + str(relDocsPerTurn) + "/" + topicTurnID + ".pkl", "rb"))

        if np.size(pickleFile) == 0:
          continue
        
        for docID in info['docid']:
          docInfo = info.loc[info['docid'] == docID]
          passageList = pickleFile[pickleFile['_id'] == docID]['_source.body'].values
          if len(passageList) > 0:
            passage = passageList[0]
            rel = (docInfo['rel'].values[0] > 0) * 1
            classes = np.append(classes, rel)
            questionsUtterances = np.append(questionsUtterances, np.array([str(question), str(passage)]))
            if rel > 0:
              classes = np.append(classes, np.array([rel, rel, rel, rel]))
              questionsUtterances = np.append(questionsUtterances, np.array([str(question), str(passage), str(question), str(passage), str(question), str(passage), str(question), str(passage)]))
              totalDocs += 4
            totalDocs += 1

    questionsUtterances = np.reshape(questionsUtterances, (totalDocs, 2))
    classes = np.reshape(classes, (totalDocs, ))
    np.savez(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/classes.npz", classes=classes)
  else:
    classes = np.load(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/classes.npz")["classes"]

  print("-- Getting BERT embeddings")
  if updateBERTResults:
    bertModelName = 'nboost/pt-bert-base-uncased-msmarco'
    tokenizer = BertTokenizerFast.from_pretrained(bertModelName)
    device = torch.device("cuda")
    model = BertModel.from_pretrained(bertModelName, return_dict=True)
    model = model.to(device)
    
    for questionUtterance in questionsUtterances:
      bertInput = convertToBERTInput(sentences=questionUtterance, max_seq_length=512, tokenizer=tokenizer, padding='max_length', truncation=False)
      bertInput['input_ids'] = bertInput['input_ids'].to(device)
      bertInput['attention_mask'] = bertInput['attention_mask'].to(device)
      bertInput['token_type_ids'] = bertInput['token_type_ids'].to(device)
      bertOutput = model(**bertInput)
      featuresLine = bertOutput["last_hidden_state"][0, 0].detach().clone().cpu().numpy()
      features = np.append(features, featuresLine)
    features = np.reshape(features, (totalDocs, 768))
    pickle.dump(features, open(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/all.pkl", "wb"))
  else:
    features = pickle.load(open(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/all.pkl", "rb"))

  print("-- Training classifier")
  classifier = trainClassifier(features, classes)
  gc.collect()

  ## TEST

  # counters
  _ntopics = 0
  _nturns = 0
  _ntotalTurns = 0

  # metrics
  _aps = np.array([])
  _ndcg5s = np.array([])
  _precisions = np.array([])
  _recalls = np.array([])
  
  print("Test")
  for topic in topicsTest:
    convID = topic['number']

    if convID not in topicsIDsTest:
      continue
    
    _turnPrecisions = np.array([])
    _turnRecalls = np.array([])
    for turn in topic['turn'][:turnsPerConv]:
      # variables
      totalDocs = 0
      features = np.array([])

      turnID = turn['number']
      question = turn['raw_utterance']
      topicTurnID = '%d_%d'% (convID, turnID)

      info = relevanceJudgmentsTest.loc[relevanceJudgmentsTest['topic_turn_id'] == (topicTurnID)]
      numberRel = info.loc[info['rel'] != 0]['docid'].count()

      _ntotalTurns += 1
      if numberRel == 0:
        _aps = np.append(_aps, np.nan)
        _ndcg5s = np.append(_ndcg5s, np.nan)
        _turnPrecisions = np.append(_turnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _turnRecalls = np.append(_turnRecalls, np.ones(measuresPerTurn) * np.nan)
        continue
      
      pickleFile = getESResultNormal(updateElasticSearchResults, "elastic_search", es, question, topicTurnID, relDocsPerTurn, setNameTest)
      #pickleFile = pickle.load(open(PROJ_DIR + "/pkls/" + setName + "/" + str(relDocsPerTurn) + "/" + topicTurnID + ".pkl", "rb"))

      if np.size(pickleFile) == 0:
        _aps = np.append(_aps, 0.0)
        _ndcg5s = np.append(_ndcg5s, 0.0)
        _turnPrecisions = np.append(_turnPrecisions, np.zeros(measuresPerTurn))
        _turnRecalls = np.append(_turnRecalls, np.zeros(measuresPerTurn))
        continue
      
      print("-- {}".format(topicTurnID))

      print("-- -- Getting BERT embeddings")
      if updateBERTResults:
        passages = []
        for docID in pickleFile['_id']:
          docInfo = info.loc[info['docid'] == docID]
          passageList = pickleFile[pickleFile['_id'] == docID]['_source.body'].values
          if len(passageList) > 0:
            passages.append(passageList[0])

        for passage in passages:
          bertInput = convertToBERTInput(sentences=[question, passage], max_seq_length=512, tokenizer=tokenizer, padding='max_length', truncation=False)
          bertInput['input_ids'] = bertInput['input_ids'].to(device)
          bertInput['attention_mask'] = bertInput['attention_mask'].to(device)
          bertInput['token_type_ids'] = bertInput['token_type_ids'].to(device)
          bertOutput = model(**bertInput)
          featuresLine = bertOutput["last_hidden_state"][0, 0].detach().clone().cpu().numpy()
          features = np.append(features, featuresLine)
        features = np.reshape(features, (len(passages), 768))
        pickle.dump(features, open(infosDir + "/bert/" + setNameTest + "/" + str(relDocsPerTurn) + "/" + topicTurnID + ".pkl", "wb"))
      else:
        features = pickle.load(open(infosDir + "/bert/" + setNameTest + "/" + str(relDocsPerTurn) + "/" + topicTurnID + ".pkl", "rb"))
      
      print("-- -- Classifying")
      prob = predictClassifier(classifier, features)
      
      print("-- -- Sorting passages")
      newOrder = np.lexsort(np.reshape(prob, (1, prob.shape[0])))[::-1]
      resultsReordered = np.column_stack((pickleFile['_id'], prob))
      resultsReordered = resultsReordered[newOrder, :]
      resultsReordered = pd.DataFrame(data=resultsReordered, columns=['_id', '_score'])
      
      print("-- -- Metrics")
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, resultsReordered, topicTurnID)#testBed.eval(resultsReordered, topicTurnID)
      
      # metrics
      _aps = np.append(_aps, ap)
      _ndcg5s = np.append(_ndcg5s, ndcg5)
      _turnPrecisions = np.append(_turnPrecisions, precisions)
      _turnRecalls = np.append(_turnRecalls, recalls)

      # counters
      _nturns += 1
    
    while _ntotalTurns % turnsPerConv != 0:
      _ntotalTurns += 1
      _aps = np.append(_aps, np.nan)
      _ndcg5s = np.append(_ndcg5s, np.nan)
      _turnPrecisions = np.append(_turnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _turnRecalls = np.append(_turnRecalls, np.ones(measuresPerTurn) * np.nan)
    
    # compute conv means
    _turnPrecisions = np.reshape(_turnPrecisions, (turnsPerConv, measuresPerTurn))
    _turnRecalls = np.reshape(_turnRecalls, (turnsPerConv, measuresPerTurn))

    # add precisions and recalls to global matrix
    _precisions = np.append(_precisions, np.nanmean(_turnPrecisions, axis=0))
    _recalls = np.append(_recalls, np.nanmean(_turnRecalls, axis=0))

    # counters
    _ntopics += 1

    gc.collect()
  
  # metrics
  _aps = np.reshape(_aps, (_ntopics, turnsPerConv))
  _ndcg5s = np.reshape(_ndcg5s, (_ntopics, turnsPerConv))
  _precisions = np.reshape(_precisions, (_ntopics, measuresPerTurn))
  _recalls = np.reshape(_recalls, (_ntopics, measuresPerTurn))
  
  return {
    "aps": _aps,
    "ndcg5s": _ndcg5s,
    "precisions": _precisions,
    "recalls": _recalls
  }

def phase3(updateUtterances, updatePhase03Metrics, updatePhase03Method01Metrics, updatePhase03Method02Metrics, updatePhase03Method03Metrics, updatePhase03MethodFinalMetrics, relDocsPerTurn, updateElasticSearchResults, updateBERTResults, es, testBed, topicsTrain, topicsTest, relevanceJudgmentsTrain, relevanceJudgmentsTest, topicsIDsTrain, topicsIDsTest, setNameTrain, setNameTest):
  if updatePhase03Metrics:
    bertModelName = 'nboost/pt-bert-base-uncased-msmarco'
    tokenizer = BertTokenizerFast.from_pretrained(bertModelName)
    device = torch.device("cuda")
    model = BertModel.from_pretrained(bertModelName, return_dict=True)
    model = model.to(device)

    class QueryRewriterT5:
      def __init__(self, model_path="/content/t5-canard-v2"):
        """
          Loads T5 model for prediction
          Returns the model
        """
        if tf.executing_eagerly():
            print("Loading SavedModel in eager mode.")
            imported = tf.saved_model.load(model_path, ["serve"])
            self.t5_model = lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
        else:
            print("Loading SavedModel in tf 1.x graph mode.")
            tf.compat.v1.reset_default_graph()
            sess = tf.compat.v1.Session()
            meta_graph_def = tf.compat.v1.saved_model.load(sess, ["serve"], model_path)
            signature_def = meta_graph_def.signature_def["serving_default"]
            self.t5_model = lambda x: sess.run(
                fetches=signature_def.outputs["outputs"].name,
                feed_dict={signature_def.inputs["input"].name: x}
            )
      
      """
        query: str - the query string to be rewritten using T5
        ctx_list: list - A list of strings containing the turns or text to give context to T5
        Returns a string with the rewritten query
      """
      def rewrite_query_with_T5(self, _curr_query, _ctx_list):
        _t5_query = '{} [CTX] '.format(_curr_query) + ' [TURN] '.join(_ctx_list)
        print("Query and context: {}".format(_t5_query))
        return self.t5_model([_t5_query])[0].decode('utf-8')

      """
        queries_list: list - A list of strings containing the raw utterances ordered from first to last
        Returns a list of strings with the rewritten queries
      """
      def rewrite_dialog_with_T5(self, _queries_list):
        _rewritten_queries_list=[]
        for i in range(len(_queries_list)):
          _current_query = _queries_list[i]
          _rewritten_query = self.rewrite_query_with_T5(_current_query, _queries_list[:i])
          print("Rewritten query: {}\n".format(_rewritten_query))
          _rewritten_queries_list.append(_rewritten_query)
        return _rewritten_queries_list

    nlp = spacy.load('en_core_web_sm')
    rewriter = QueryRewriterT5('/content/t5-canard-v2')

    ## TRAIN

    """totalDocs = 0
    questionsUtterances = np.array([])
    features = np.array([])
    classes = np.array([])

    print("Train")
    print("-- Building question-utterance pairs")
    for topic in topicsTrain:
      convID = topic['number']

      if convID not in topicsIDsTrain:
        continue
      
      for turn in topic['turn']:
        turnID = turn['number']
        question = turn['raw_utterance']
        topicTurnID = '%d_%d'% (convID, turnID)
        
        info = relevanceJudgmentsTrain.loc[relevanceJudgmentsTrain['topic_turn_id'] == (topicTurnID)]
        numberRel = info.loc[info['rel'] != 0]['docid'].count()

        if numberRel == 0:
          continue
        
        pickleFile = getESResultNormal(updateElasticSearchResults, "elastic_search", es, question, topicTurnID, relDocsPerTurn, setNameTrain)
        #pickleFile = pickle.load(open(PROJ_DIR + "/pkls/" + setName + "/" + str(relDocsPerTurn) + "/" + topicTurnID + ".pkl", "rb"))

        if np.size(pickleFile) == 0:
          continue
        
        for docID in info['docid']:
          docInfo = info.loc[info['docid'] == docID]
          passageList = pickleFile[pickleFile['_id'] == docID]['_source.body'].values
          if len(passageList) > 0:
            passage = passageList[0]
            rel = (docInfo['rel'].values[0] > 0) * 1
            classes = np.append(classes, rel)
            questionsUtterances = np.append(questionsUtterances, np.array([str(question), str(passage)]))
            if rel > 0:
              classes = np.append(classes, np.array([rel, rel, rel, rel]))
              questionsUtterances = np.append(questionsUtterances, np.array([str(question), str(passage), str(question), str(passage), str(question), str(passage), str(question), str(passage)]))
              totalDocs += 4
            totalDocs += 1

    questionsUtterances = np.reshape(questionsUtterances, (totalDocs, 2))
    classes = np.reshape(classes, (totalDocs, ))

    print("-- Getting BERT embeddings")
    if updateBERTResults:
      for questionUtterance in questionsUtterances:
        bertInput = convertToBERTInput(sentences=questionUtterance, max_seq_length=512, tokenizer=tokenizer, padding='max_length', truncation=False)
        bertInput['input_ids'] = bertInput['input_ids'].to(device)
        bertInput['attention_mask'] = bertInput['attention_mask'].to(device)
        bertInput['token_type_ids'] = bertInput['token_type_ids'].to(device)
        bertOutput = model(**bertInput)
        featuresLine = bertOutput["last_hidden_state"][0, 0].detach().clone().cpu().numpy()
        features = np.append(features, featuresLine)
      features = np.reshape(features, (totalDocs, 768))
      pickle.dump(features, open(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/all.pkl", "wb"))
    else:
      features = pickle.load(open(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/all.pkl", "rb"))
    
    classifier = trainClassifier(features, classes)
    gc.collect()"""
  features = pickle.load(open(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/all.pkl", "rb"))
  classes = np.load(infosDir + "/bert/" + setNameTrain + "/" + str(relDocsPerTurn) + "/classes.npz")["classes"]
  classifier = trainClassifier(features, classes)
  gc.collect()

  ## TEST
  # method 1
  if updatePhase03Metrics and updatePhase03Method01Metrics:
    method1Metrics = phase03Method1(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, tokenizer, model, device, classifier, es, testBed, topicsTest, relevanceJudgmentsTest, topicsIDsTest, setNameTest)
    saveMetricsLMDAndBERT("phase03Method1.npz", relDocsPerTurn, method1Metrics)

    APs = [method1Metrics["aps"]["lmd"], method1Metrics["aps"]["bert"]]
    nDCGs = [method1Metrics["ndcg5s"]["lmd"], method1Metrics["ndcg5s"]["bert"]]
    Recalls = [method1Metrics["recalls"]["lmd"], method1Metrics["recalls"]["bert"]]
    Precisions = [method1Metrics["precisions"]["lmd"], method1Metrics["precisions"]["bert"]]
    labels = loadPlotLabels()
    convNumbers = labels["convNumbers"]
    convNames = labels["convNames"]
    doPlots(relDocsPerTurn, setNameTest, "Phase 3 Method 1", APs, nDCGs, Recalls, Precisions, ["LMD", "BERT"], convNumbers, convNames)
  else:
    method1Metrics = loadMetricsLMDAndBERT("phase03Method1.npz", relDocsPerTurn)
  
  # method 2
  if updatePhase03Metrics and updatePhase03Method02Metrics:
    method2Metrics = phase03Method2(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, nlp, tokenizer, model, device, classifier, es, testBed, topicsTest, relevanceJudgmentsTest, topicsIDsTest, setNameTest)
    saveMetricsLMDAndBERT("phase03Method2.npz", relDocsPerTurn, method2Metrics)

    APs = [method2Metrics["aps"]["lmd"], method2Metrics["aps"]["bert"]]
    nDCGs = [method2Metrics["ndcg5s"]["lmd"], method2Metrics["ndcg5s"]["bert"]]
    Recalls = [method2Metrics["recalls"]["lmd"], method2Metrics["recalls"]["bert"]]
    Precisions = [method2Metrics["precisions"]["lmd"], method2Metrics["precisions"]["bert"]]
    labels = loadPlotLabels()
    convNumbers = labels["convNumbers"]
    convNames = labels["convNames"]
    doPlots(relDocsPerTurn, setNameTest, "Phase 3 Method 2", APs, nDCGs, Recalls, Precisions, ["LMD", "BERT"], convNumbers, convNames)
  else:
    method2Metrics = loadMetricsLMDAndBERT("phase03Method2.npz", relDocsPerTurn)

  # method 3
  if updatePhase03Metrics and updatePhase03Method03Metrics:
    method3Metrics = phase03Method3(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, rewriter, nlp, tokenizer, model, device, classifier, es, testBed, topicsTest, relevanceJudgmentsTest, topicsIDsTest, setNameTest)
    saveMetricsLMDAndBERT("phase03Method3.npz", relDocsPerTurn, method3Metrics)

    APs = [method3Metrics["aps"]["lmd"], method3Metrics["aps"]["bert"]]
    nDCGs = [method3Metrics["ndcg5s"]["lmd"], method3Metrics["ndcg5s"]["bert"]]
    Recalls = [method3Metrics["recalls"]["lmd"], method3Metrics["recalls"]["bert"]]
    Precisions = [method3Metrics["precisions"]["lmd"], method3Metrics["precisions"]["bert"]]
    labels = loadPlotLabels()
    convNumbers = labels["convNumbers"]
    convNames = labels["convNames"]
    doPlots(relDocsPerTurn, setNameTest, "Phase 3 Method 3", APs, nDCGs, Recalls, Precisions, ["LMD", "BERT"], convNumbers, convNames)
  else:
    method3Metrics = loadMetricsLMDAndBERT("phase03Method3.npz", relDocsPerTurn)

  # method final
  if updatePhase03Metrics and updatePhase03MethodFinalMetrics:
    methodFinalMetrics = phase03MethodFinal(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, rewriter, nlp, tokenizer, model, device, classifier, es, testBed, topicsTrain, topicsTest, relevanceJudgmentsTrain, relevanceJudgmentsTest, topicsIDsTrain, topicsIDsTest, setNameTrain, setNameTest)
    saveMetricsLMDAndBERT("phase03MethodFinal.npz", relDocsPerTurn, methodFinalMetrics)

    APs = [methodFinalMetrics["aps"]["lmd"], methodFinalMetrics["aps"]["bert"]]
    nDCGs = [methodFinalMetrics["ndcg5s"]["lmd"], methodFinalMetrics["ndcg5s"]["bert"]]
    Recalls = [methodFinalMetrics["recalls"]["lmd"], methodFinalMetrics["recalls"]["bert"]]
    Precisions = [methodFinalMetrics["precisions"]["lmd"], methodFinalMetrics["precisions"]["bert"]]
    labels = loadPlotLabels()
    convNumbers = labels["convNumbers"]
    convNames = labels["convNames"]
    doPlots(relDocsPerTurn, setNameTest, "Phase 3 Method Final", APs, nDCGs, Recalls, Precisions, ["LMD", "BERT"], convNumbers, convNames)
  else:
    methodFinalMetrics = loadMetricsLMDAndBERT("phase03MethodFinal.npz", relDocsPerTurn)

  return [method1Metrics, method2Metrics, method3Metrics, methodFinalMetrics]

def phase03Method1(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, tokenizer, model, device, classifier, es, testBed, topics, relevanceJudgments, topicsIDs, setName):
  ## METHOD 1
  print("Method 1")
  # counters
  _LMDntopics = 0
  _LMDnturns = 0
  _LMDntotalTurns = 0
  
  _BERTntopics = 0
  _BERTnturns = 0
  _BERTntotalTurns = 0

  # metrics
  _LMDp10s = np.array([])
  _LMDaps = np.array([])
  _LMDndcg5s = np.array([])
  _LMDprecisions = np.array([])
  _LMDrecalls = np.array([])

  _BERTp10s = np.array([])
  _BERTaps = np.array([])
  _BERTndcg5s = np.array([])
  _BERTprecisions = np.array([])
  _BERTrecalls = np.array([])

  for topic in topics:
    convID = topic['number']

    if convID not in topicsIDs:
      continue
    
    _LMDturnPrecisions = np.array([])
    _LMDturnRecalls = np.array([])

    _BERTturnPrecisions = np.array([])
    _BERTturnRecalls = np.array([])

    print("-- {}".format(convID))
    
    firstTurn = True
    firstUtterance = ""
    for turn in topic['turn'][:turnsPerConv]:
      turnID = turn['number']
      topicTurnID = '%d_%d'% (convID, turnID)

      info = relevanceJudgments.loc[relevanceJudgments['topic_turn_id'] == (topicTurnID)]
      numberRel = info.loc[info['rel'] != 0]['docid'].count()

      _LMDntotalTurns += 1
      if numberRel == 0:
        _LMDp10s = np.append(_LMDp10s, np.nan)
        _LMDaps = np.append(_LMDaps, np.nan)
        _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
        _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)
        
        _BERTp10s = np.append(_BERTp10s, np.nan)
        _BERTaps = np.append(_BERTaps, np.nan)
        _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
        _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
        continue

      utterance = turn['raw_utterance']
      if firstTurn:
        firstTurn = False
        firstUtterance = utterance

        if updateUtterances:
          saveUserUtterances("elastic_search_method1", "{}: {}\n".format(convID, topic['title']))
      else:
        utterance += " " + firstUtterance
      print("-- -- {}: {}".format(turnID, utterance))

      if updateUtterances:
        saveUserUtterances("elastic_search_method1", "-- {}:\n\tOriginal: {}\n\tNew: {}\n".format(turnID, turn['raw_utterance'], utterance))
      pickleFile = getESResultNormal(updateElasticSearchResults, "elastic_search_method1", es, utterance, topicTurnID, relDocsPerTurn, setName)

      # LMD
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, pickleFile, topicTurnID)
      
      _LMDp10s = np.append(_LMDp10s, p10)
      _LMDaps = np.append(_LMDaps, ap)
      _LMDndcg5s = np.append(_LMDndcg5s, ndcg5)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, precisions)
      _LMDturnRecalls = np.append(_LMDturnRecalls, recalls)
      _LMDnturns += 1

      # BERT
      passages = getPassagesFromESResult(pickleFile)
      features = getBERTResult(updateBERTResults, "bert_method1", tokenizer, model, device, utterance, passages, topicTurnID, relDocsPerTurn, setName)
      result = reorderResults(classifier, features, pickleFile)
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, result, topicTurnID)
 
      _BERTp10s = np.append(_BERTp10s, p10)
      _BERTaps = np.append(_BERTaps, ap)
      _BERTndcg5s = np.append(_BERTndcg5s, ndcg5)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, precisions)
      _BERTturnRecalls = np.append(_BERTturnRecalls, recalls)
      _BERTnturns += 1
    
    while _LMDntotalTurns % turnsPerConv != 0:
      _LMDntotalTurns += 1
      _LMDp10s = np.append(_LMDp10s, np.nan)
      _LMDaps = np.append(_LMDaps, np.nan)
      _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)

      _BERTntotalTurns += 1
      _BERTp10s = np.append(_BERTp10s, np.nan)
      _BERTaps = np.append(_BERTaps, np.nan)
      _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
    
    # compute conv means
    _LMDturnPrecisions = np.reshape(_LMDturnPrecisions, (turnsPerConv, measuresPerTurn))
    _LMDturnRecalls = np.reshape(_LMDturnRecalls, (turnsPerConv, measuresPerTurn))

    _BERTturnPrecisions = np.reshape(_BERTturnPrecisions, (turnsPerConv, measuresPerTurn))
    _BERTturnRecalls = np.reshape(_BERTturnRecalls, (turnsPerConv, measuresPerTurn))

    # add precisions and recalls to global matrix
    _LMDprecisions = np.append(_LMDprecisions, np.nanmean(_LMDturnPrecisions, axis=0))
    _LMDrecalls = np.append(_LMDrecalls, np.nanmean(_LMDturnRecalls, axis=0))

    _BERTprecisions = np.append(_BERTprecisions, np.nanmean(_BERTturnPrecisions, axis=0))
    _BERTrecalls = np.append(_BERTrecalls, np.nanmean(_BERTturnRecalls, axis=0))

    # counters
    _LMDntopics += 1

    _BERTntopics += 1
  
  # reshape
  _LMDp10s = np.reshape(_LMDp10s, (_LMDntopics, turnsPerConv))
  _LMDaps = np.reshape(_LMDaps, (_LMDntopics, turnsPerConv))
  _LMDndcg5s = np.reshape(_LMDndcg5s, (_LMDntopics, turnsPerConv))
  _LMDprecisions = np.reshape(_LMDprecisions, (_LMDntopics, measuresPerTurn))
  _LMDrecalls = np.reshape(_LMDrecalls, (_LMDntopics, measuresPerTurn))
  
  _BERTp10s = np.reshape(_BERTp10s, (_BERTntopics, turnsPerConv))
  _BERTaps = np.reshape(_BERTaps, (_BERTntopics, turnsPerConv))
  _BERTndcg5s = np.reshape(_BERTndcg5s, (_BERTntopics, turnsPerConv))
  _BERTprecisions = np.reshape(_BERTprecisions, (_BERTntopics, measuresPerTurn))
  _BERTrecalls = np.reshape(_BERTrecalls, (_BERTntopics, measuresPerTurn))

  return {
    "aps": {
      "lmd": _LMDaps,
      "bert": _BERTaps
    },
    "ndcg5s": {
      "lmd": _LMDndcg5s,
      "bert": _BERTndcg5s
    },
    "recalls": {
      "lmd": _LMDrecalls,
      "bert": _BERTrecalls
    },
    "precisions": {
      "lmd": _LMDprecisions,
      "bert": _BERTprecisions
    }
  }

def phase03Method2(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, nlp, tokenizer, model, device, classifier, es, testBed, topics, relevanceJudgments, topicsIDs, setName):
  ## METHOD 2
  print("Method 2")
  # counters
  _LMDntopics = 0
  _LMDnturns = 0
  _LMDntotalTurns = 0
  
  _BERTntopics = 0
  _BERTnturns = 0
  _BERTntotalTurns = 0

  # metrics
  _LMDp10s = np.array([])
  _LMDaps = np.array([])
  _LMDndcg5s = np.array([])
  _LMDprecisions = np.array([])
  _LMDrecalls = np.array([])

  _BERTp10s = np.array([])
  _BERTaps = np.array([])
  _BERTndcg5s = np.array([])
  _BERTprecisions = np.array([])
  _BERTrecalls = np.array([])

  for topic in topics:
    convID = topic['number']

    if convID not in topicsIDs:
      continue
    
    _LMDturnPrecisions = np.array([])
    _LMDturnRecalls = np.array([])

    _BERTturnPrecisions = np.array([])
    _BERTturnRecalls = np.array([])

    print("-- {}".format(convID))
    
    firstTurn = True
    firstEntities = []
    for turn in topic['turn'][:turnsPerConv]:
      turnID = turn['number']
      topicTurnID = '%d_%d'% (convID, turnID)

      info = relevanceJudgments.loc[relevanceJudgments['topic_turn_id'] == (topicTurnID)]
      numberRel = info.loc[info['rel'] != 0]['docid'].count()

      _LMDntotalTurns += 1
      if numberRel == 0:
        _LMDp10s = np.append(_LMDp10s, np.nan)
        _LMDaps = np.append(_LMDaps, np.nan)
        _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
        _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)
        
        _BERTp10s = np.append(_BERTp10s, np.nan)
        _BERTaps = np.append(_BERTaps, np.nan)
        _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
        _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
        continue

      utterance = turn['raw_utterance']
      if firstTurn:
        firstTurn = False
        firstEntities = [str(ent) for ent in nlp(utterance).ents]

        if updateUtterances:
          saveUserUtterances("elastic_search_method2", "{}: {}\n".format(convID, topic['title']))
      print("-- -- {}: {} ; {}".format(turnID, utterance, firstEntities))

      if updateUtterances:
        saveUserUtterances("elastic_search_method2", "-- {}:\n\tOriginal: {}\n\tNew: {} ; {}\n".format(turnID, turn['raw_utterance'], utterance, firstEntities))
      pickleFile = getESResultEntities(updateElasticSearchResults, "elastic_search_method2", es, utterance, firstEntities, topicTurnID, relDocsPerTurn, setName)

      # LMD
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, pickleFile, topicTurnID)
      
      _LMDp10s = np.append(_LMDp10s, p10)
      _LMDaps = np.append(_LMDaps, ap)
      _LMDndcg5s = np.append(_LMDndcg5s, ndcg5)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, precisions)
      _LMDturnRecalls = np.append(_LMDturnRecalls, recalls)
      _LMDnturns += 1

      # BERT
      passages = getPassagesFromESResult(pickleFile)
      features = getBERTResult(updateBERTResults, "bert_method2", tokenizer, model, device, utterance, passages, topicTurnID, relDocsPerTurn, setName)
      result = reorderResults(classifier, features, pickleFile)
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, result, topicTurnID)
 
      _BERTp10s = np.append(_BERTp10s, p10)
      _BERTaps = np.append(_BERTaps, ap)
      _BERTndcg5s = np.append(_BERTndcg5s, ndcg5)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, precisions)
      _BERTturnRecalls = np.append(_BERTturnRecalls, recalls)
      _BERTnturns += 1
    
    while _LMDntotalTurns % turnsPerConv != 0:
      _LMDntotalTurns += 1
      _LMDp10s = np.append(_LMDp10s, np.nan)
      _LMDaps = np.append(_LMDaps, np.nan)
      _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)

      _BERTntotalTurns += 1
      _BERTp10s = np.append(_BERTp10s, np.nan)
      _BERTaps = np.append(_BERTaps, np.nan)
      _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
    
    # compute conv means
    _LMDturnPrecisions = np.reshape(_LMDturnPrecisions, (turnsPerConv, measuresPerTurn))
    _LMDturnRecalls = np.reshape(_LMDturnRecalls, (turnsPerConv, measuresPerTurn))

    _BERTturnPrecisions = np.reshape(_BERTturnPrecisions, (turnsPerConv, measuresPerTurn))
    _BERTturnRecalls = np.reshape(_BERTturnRecalls, (turnsPerConv, measuresPerTurn))

    # add precisions and recalls to global matrix
    _LMDprecisions = np.append(_LMDprecisions, np.nanmean(_LMDturnPrecisions, axis=0))
    _LMDrecalls = np.append(_LMDrecalls, np.nanmean(_LMDturnRecalls, axis=0))

    _BERTprecisions = np.append(_BERTprecisions, np.nanmean(_BERTturnPrecisions, axis=0))
    _BERTrecalls = np.append(_BERTrecalls, np.nanmean(_BERTturnRecalls, axis=0))

    # counters
    _LMDntopics += 1

    _BERTntopics += 1
  
  # reshape
  _LMDp10s = np.reshape(_LMDp10s, (_LMDntopics, turnsPerConv))
  _LMDaps = np.reshape(_LMDaps, (_LMDntopics, turnsPerConv))
  _LMDndcg5s = np.reshape(_LMDndcg5s, (_LMDntopics, turnsPerConv))
  _LMDprecisions = np.reshape(_LMDprecisions, (_LMDntopics, measuresPerTurn))
  _LMDrecalls = np.reshape(_LMDrecalls, (_LMDntopics, measuresPerTurn))
  
  _BERTp10s = np.reshape(_BERTp10s, (_BERTntopics, turnsPerConv))
  _BERTaps = np.reshape(_BERTaps, (_BERTntopics, turnsPerConv))
  _BERTndcg5s = np.reshape(_BERTndcg5s, (_BERTntopics, turnsPerConv))
  _BERTprecisions = np.reshape(_BERTprecisions, (_BERTntopics, measuresPerTurn))
  _BERTrecalls = np.reshape(_BERTrecalls, (_BERTntopics, measuresPerTurn))

  return {
    "aps": {
      "lmd": _LMDaps,
      "bert": _BERTaps
    },
    "ndcg5s": {
      "lmd": _LMDndcg5s,
      "bert": _BERTndcg5s
    },
    "recalls": {
      "lmd": _LMDrecalls,
      "bert": _BERTrecalls
    },
    "precisions": {
      "lmd": _LMDprecisions,
      "bert": _BERTprecisions
    }
  }

def phase03Method3(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, rewriter, nlp, tokenizer, model, device, classifier, es, testBed, topics, relevanceJudgments, topicsIDs, setName):
  ## METHOD 3
  print("Method 3")
  # counters
  _LMDntopics = 0
  _LMDnturns = 0
  _LMDntotalTurns = 0
  
  _BERTntopics = 0
  _BERTnturns = 0
  _BERTntotalTurns = 0

  # metrics
  _LMDp10s = np.array([])
  _LMDaps = np.array([])
  _LMDndcg5s = np.array([])
  _LMDprecisions = np.array([])
  _LMDrecalls = np.array([])

  _BERTp10s = np.array([])
  _BERTaps = np.array([])
  _BERTndcg5s = np.array([])
  _BERTprecisions = np.array([])
  _BERTrecalls = np.array([])

  for topic in topics:
    convID = topic['number']

    if convID not in topicsIDs:
      continue
    
    _LMDturnPrecisions = np.array([])
    _LMDturnRecalls = np.array([])

    _BERTturnPrecisions = np.array([])
    _BERTturnRecalls = np.array([])

    print("-- {}".format(convID))
    
    firstTurn = True
    convUterrances = []
    for turn in topic['turn'][:turnsPerConv]:
      turnID = turn['number']
      topicTurnID = '%d_%d'% (convID, turnID)

      info = relevanceJudgments.loc[relevanceJudgments['topic_turn_id'] == (topicTurnID)]
      numberRel = info.loc[info['rel'] != 0]['docid'].count()

      _LMDntotalTurns += 1
      if numberRel == 0:
        _LMDp10s = np.append(_LMDp10s, np.nan)
        _LMDaps = np.append(_LMDaps, np.nan)
        _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
        _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)
        
        _BERTp10s = np.append(_BERTp10s, np.nan)
        _BERTaps = np.append(_BERTaps, np.nan)
        _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
        _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
        continue

      utterance = rewriter.rewrite_query_with_T5(turn['raw_utterance'], convUterrances[:len(convUterrances)])
      convUterrances.append(turn['raw_utterance'])
      print("-- -- {}: {}".format(turnID, utterance))

      if updateUtterances:
        if firstTurn:
          firstTurn = False
          saveUserUtterances("elastic_search_method3", "{}: {}\n".format(convID, topic['title']))
        saveUserUtterances("elastic_search_method3", "-- {}:\n\tOriginal: {}\n\tNew: {}\n".format(turnID, turn['raw_utterance'], utterance))
      pickleFile = getESResultNormal(updateElasticSearchResults, "elastic_search_method3", es, utterance, topicTurnID, relDocsPerTurn, setName)

      # LMD
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, pickleFile, topicTurnID)
      
      _LMDp10s = np.append(_LMDp10s, p10)
      _LMDaps = np.append(_LMDaps, ap)
      _LMDndcg5s = np.append(_LMDndcg5s, ndcg5)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, precisions)
      _LMDturnRecalls = np.append(_LMDturnRecalls, recalls)
      _LMDnturns += 1

      # BERT
      passages = getPassagesFromESResult(pickleFile)
      features = getBERTResult(updateBERTResults, "bert_method3", tokenizer, model, device, utterance, passages, topicTurnID, relDocsPerTurn, setName)
      result = reorderResults(classifier, features, pickleFile)
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, result, topicTurnID)
 
      _BERTp10s = np.append(_BERTp10s, p10)
      _BERTaps = np.append(_BERTaps, ap)
      _BERTndcg5s = np.append(_BERTndcg5s, ndcg5)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, precisions)
      _BERTturnRecalls = np.append(_BERTturnRecalls, recalls)
      _BERTnturns += 1
    
    while _LMDntotalTurns % turnsPerConv != 0:
      _LMDntotalTurns += 1
      _LMDp10s = np.append(_LMDp10s, np.nan)
      _LMDaps = np.append(_LMDaps, np.nan)
      _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)

      _BERTntotalTurns += 1
      _BERTp10s = np.append(_BERTp10s, np.nan)
      _BERTaps = np.append(_BERTaps, np.nan)
      _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
    
    # compute conv means
    _LMDturnPrecisions = np.reshape(_LMDturnPrecisions, (turnsPerConv, measuresPerTurn))
    _LMDturnRecalls = np.reshape(_LMDturnRecalls, (turnsPerConv, measuresPerTurn))

    _BERTturnPrecisions = np.reshape(_BERTturnPrecisions, (turnsPerConv, measuresPerTurn))
    _BERTturnRecalls = np.reshape(_BERTturnRecalls, (turnsPerConv, measuresPerTurn))

    # add precisions and recalls to global matrix
    _LMDprecisions = np.append(_LMDprecisions, np.nanmean(_LMDturnPrecisions, axis=0))
    _LMDrecalls = np.append(_LMDrecalls, np.nanmean(_LMDturnRecalls, axis=0))

    _BERTprecisions = np.append(_BERTprecisions, np.nanmean(_BERTturnPrecisions, axis=0))
    _BERTrecalls = np.append(_BERTrecalls, np.nanmean(_BERTturnRecalls, axis=0))

    # counters
    _LMDntopics += 1

    _BERTntopics += 1
  
  # reshape
  _LMDp10s = np.reshape(_LMDp10s, (_LMDntopics, turnsPerConv))
  _LMDaps = np.reshape(_LMDaps, (_LMDntopics, turnsPerConv))
  _LMDndcg5s = np.reshape(_LMDndcg5s, (_LMDntopics, turnsPerConv))
  _LMDprecisions = np.reshape(_LMDprecisions, (_LMDntopics, measuresPerTurn))
  _LMDrecalls = np.reshape(_LMDrecalls, (_LMDntopics, measuresPerTurn))
  
  _BERTp10s = np.reshape(_BERTp10s, (_BERTntopics, turnsPerConv))
  _BERTaps = np.reshape(_BERTaps, (_BERTntopics, turnsPerConv))
  _BERTndcg5s = np.reshape(_BERTndcg5s, (_BERTntopics, turnsPerConv))
  _BERTprecisions = np.reshape(_BERTprecisions, (_BERTntopics, measuresPerTurn))
  _BERTrecalls = np.reshape(_BERTrecalls, (_BERTntopics, measuresPerTurn))

  return {
    "aps": {
      "lmd": _LMDaps,
      "bert": _BERTaps
    },
    "ndcg5s": {
      "lmd": _LMDndcg5s,
      "bert": _BERTndcg5s
    },
    "recalls": {
      "lmd": _LMDrecalls,
      "bert": _BERTrecalls
    },
    "precisions": {
      "lmd": _LMDprecisions,
      "bert": _BERTprecisions
    }
  }

def phase03MethodFinal(relDocsPerTurn, updateUtterances, updateBERTResults, updateElasticSearchResults, rewriter, nlp, tokenizer, model, device, classifier, es, testBed, topicsTrain, topicsTest, relevanceJudgmentsTrain, relevanceJudgmentsTest, topicsIDsTrain, topicsIDsTest, setNameTrain, setNameTest):
  ## METHOD Final
  print("Method Final")

  ## TRAIN

  totalDocs = 0
  questionsUtterances = np.array([])
  features = np.array([])
  classes = np.array([])

  print("Train")
  print("-- Building question-utterance pairs")
  for topic in topicsTrain:
    convID = topic['number']

    if convID not in topicsIDsTrain:
      continue
    
    convUtterances = []
    for turn in topic['turn']:
      turnID = turn['number']
      question = turn['raw_utterance']
      topicTurnID = '%d_%d'% (convID, turnID)
      
      info = relevanceJudgmentsTrain.loc[relevanceJudgmentsTrain['topic_turn_id'] == (topicTurnID)]
      numberRel = info.loc[info['rel'] != 0]['docid'].count()

      if numberRel == 0:
        continue
      
      utterance = rewriter.rewrite_query_with_T5(turn['raw_utterance'], convUtterances[:len(convUtterances)])
      convUtterances.append(turn['raw_utterance'])
      entities = [str(ent) for ent in nlp(utterance).ents]

      pickleFile = getESResultEntities(updateElasticSearchResults, "elastic_search_methodFinal", es, utterance, entities, topicTurnID, relDocsPerTurn, setNameTrain)
      
      if np.size(pickleFile) == 0:
        continue
      
      for docID in info['docid']:
        docInfo = info.loc[info['docid'] == docID]
        passageList = pickleFile[pickleFile['_id'] == docID]['_source.body'].values
        if len(passageList) > 0:
          passage = passageList[0]
          rel = (docInfo['rel'].values[0] > 0) * 1
          classes = np.append(classes, rel)
          questionsUtterances = np.append(questionsUtterances, np.array([str(utterance), str(passage)]))
          if rel > 0:
            classes = np.append(classes, np.array([rel, rel, rel, rel]))
            questionsUtterances = np.append(questionsUtterances, np.array([str(utterance), str(passage), str(utterance), str(passage), str(utterance), str(passage), str(utterance), str(passage)]))
            totalDocs += 4
          totalDocs += 1

  questionsUtterances = np.reshape(questionsUtterances, (totalDocs, 2))
  classes = np.reshape(classes, (totalDocs, ))

  print("-- Getting BERT embeddings")
  if updateBERTResults:
    for questionUtterance in questionsUtterances:
      bertInput = convertToBERTInput(sentences=questionUtterance, max_seq_length=512, tokenizer=tokenizer, padding='max_length', truncation=False)
      bertInput['input_ids'] = bertInput['input_ids'].to(device)
      bertInput['attention_mask'] = bertInput['attention_mask'].to(device)
      bertInput['token_type_ids'] = bertInput['token_type_ids'].to(device)
      bertOutput = model(**bertInput)
      featuresLine = bertOutput["last_hidden_state"][0, 0].detach().clone().cpu().numpy()
      features = np.append(features, featuresLine)
    features = np.reshape(features, (totalDocs, 768))
    pickle.dump(features, open(infosDir + "/bert_methodFinal/" + setNameTrain + "/" + str(relDocsPerTurn) + "/all.pkl", "wb"))
  else:
    features = pickle.load(open(infosDir + "/bert_methodFinal/" + setNameTrain + "/" + str(relDocsPerTurn) + "/all.pkl", "rb"))
  
  classifier = trainClassifier(features, classes)
  gc.collect()
  
  ## TEST

  # counters
  _LMDntopics = 0
  _LMDnturns = 0
  _LMDntotalTurns = 0
  
  _BERTntopics = 0
  _BERTnturns = 0
  _BERTntotalTurns = 0

  # metrics
  _LMDp10s = np.array([])
  _LMDaps = np.array([])
  _LMDndcg5s = np.array([])
  _LMDprecisions = np.array([])
  _LMDrecalls = np.array([])

  _BERTp10s = np.array([])
  _BERTaps = np.array([])
  _BERTndcg5s = np.array([])
  _BERTprecisions = np.array([])
  _BERTrecalls = np.array([])

  for topic in topicsTest:
    convID = topic['number']

    if convID not in topicsIDsTest:
      continue
    
    _LMDturnPrecisions = np.array([])
    _LMDturnRecalls = np.array([])

    _BERTturnPrecisions = np.array([])
    _BERTturnRecalls = np.array([])

    print("-- {}".format(convID))
    
    firstTurn = True
    convUterrances = []
    for turn in topic['turn'][:turnsPerConv]:
      turnID = turn['number']
      topicTurnID = '%d_%d'% (convID, turnID)

      info = relevanceJudgmentsTest.loc[relevanceJudgmentsTest['topic_turn_id'] == (topicTurnID)]
      numberRel = info.loc[info['rel'] != 0]['docid'].count()

      _LMDntotalTurns += 1
      if numberRel == 0:
        _LMDp10s = np.append(_LMDp10s, np.nan)
        _LMDaps = np.append(_LMDaps, np.nan)
        _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
        _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)
        
        _BERTp10s = np.append(_BERTp10s, np.nan)
        _BERTaps = np.append(_BERTaps, np.nan)
        _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
        _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
        _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
        continue

      utterance = rewriter.rewrite_query_with_T5(turn['raw_utterance'], convUterrances[:len(convUterrances)])
      convUterrances.append(turn['raw_utterance'])
      entities = [str(ent) for ent in nlp(utterance).ents]
      print("-- -- {}: {} ; {}".format(turnID, utterance, entities))

      if updateUtterances:
        if firstTurn:
          firstTurn = False
          saveUserUtterances("elastic_search_methodFinal", "{}: {}\n".format(convID, topic['title']))
        saveUserUtterances("elastic_search_methodFinal", "-- {}:\n\tOriginal: {}\n\tNew: {} ; {}\n".format(turnID, turn['raw_utterance'], utterance, entities))
      pickleFile = getESResultEntities(updateElasticSearchResults, "elastic_search_methodFinal", es, utterance, entities, topicTurnID, relDocsPerTurn, setNameTest)

      # LMD
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, pickleFile, topicTurnID)
      
      _LMDp10s = np.append(_LMDp10s, p10)
      _LMDaps = np.append(_LMDaps, ap)
      _LMDndcg5s = np.append(_LMDndcg5s, ndcg5)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, precisions)
      _LMDturnRecalls = np.append(_LMDturnRecalls, recalls)
      _LMDnturns += 1

      # BERT
      passages = getPassagesFromESResult(pickleFile)
      features = getBERTResult(updateBERTResults, "bert_methodFinal", tokenizer, model, device, utterance, passages, topicTurnID, relDocsPerTurn, setNameTest)
      result = reorderResults(classifier, features, pickleFile)
      [p10, recall, ap, ndcg5, precisions, recalls] = getMetrics(testBed, result, topicTurnID)
 
      _BERTp10s = np.append(_BERTp10s, p10)
      _BERTaps = np.append(_BERTaps, ap)
      _BERTndcg5s = np.append(_BERTndcg5s, ndcg5)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, precisions)
      _BERTturnRecalls = np.append(_BERTturnRecalls, recalls)
      _BERTnturns += 1
    
    while _LMDntotalTurns % turnsPerConv != 0:
      _LMDntotalTurns += 1
      _LMDp10s = np.append(_LMDp10s, np.nan)
      _LMDaps = np.append(_LMDaps, np.nan)
      _LMDndcg5s = np.append(_LMDndcg5s, np.nan)
      _LMDturnPrecisions = np.append(_LMDturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _LMDturnRecalls = np.append(_LMDturnRecalls, np.ones(measuresPerTurn) * np.nan)

      _BERTntotalTurns += 1
      _BERTp10s = np.append(_BERTp10s, np.nan)
      _BERTaps = np.append(_BERTaps, np.nan)
      _BERTndcg5s = np.append(_BERTndcg5s, np.nan)
      _BERTturnPrecisions = np.append(_BERTturnPrecisions, np.ones(measuresPerTurn) * np.nan)
      _BERTturnRecalls = np.append(_BERTturnRecalls, np.ones(measuresPerTurn) * np.nan)
    
    # compute conv means
    _LMDturnPrecisions = np.reshape(_LMDturnPrecisions, (turnsPerConv, measuresPerTurn))
    _LMDturnRecalls = np.reshape(_LMDturnRecalls, (turnsPerConv, measuresPerTurn))

    _BERTturnPrecisions = np.reshape(_BERTturnPrecisions, (turnsPerConv, measuresPerTurn))
    _BERTturnRecalls = np.reshape(_BERTturnRecalls, (turnsPerConv, measuresPerTurn))

    # add precisions and recalls to global matrix
    _LMDprecisions = np.append(_LMDprecisions, np.nanmean(_LMDturnPrecisions, axis=0))
    _LMDrecalls = np.append(_LMDrecalls, np.nanmean(_LMDturnRecalls, axis=0))

    _BERTprecisions = np.append(_BERTprecisions, np.nanmean(_BERTturnPrecisions, axis=0))
    _BERTrecalls = np.append(_BERTrecalls, np.nanmean(_BERTturnRecalls, axis=0))

    # counters
    _LMDntopics += 1

    _BERTntopics += 1
  
  # reshape
  _LMDp10s = np.reshape(_LMDp10s, (_LMDntopics, turnsPerConv))
  _LMDaps = np.reshape(_LMDaps, (_LMDntopics, turnsPerConv))
  _LMDndcg5s = np.reshape(_LMDndcg5s, (_LMDntopics, turnsPerConv))
  _LMDprecisions = np.reshape(_LMDprecisions, (_LMDntopics, measuresPerTurn))
  _LMDrecalls = np.reshape(_LMDrecalls, (_LMDntopics, measuresPerTurn))
  
  _BERTp10s = np.reshape(_BERTp10s, (_BERTntopics, turnsPerConv))
  _BERTaps = np.reshape(_BERTaps, (_BERTntopics, turnsPerConv))
  _BERTndcg5s = np.reshape(_BERTndcg5s, (_BERTntopics, turnsPerConv))
  _BERTprecisions = np.reshape(_BERTprecisions, (_BERTntopics, measuresPerTurn))
  _BERTrecalls = np.reshape(_BERTrecalls, (_BERTntopics, measuresPerTurn))

  return {
    "aps": {
      "lmd": _LMDaps,
      "bert": _BERTaps
    },
    "ndcg5s": {
      "lmd": _LMDndcg5s,
      "bert": _BERTndcg5s
    },
    "recalls": {
      "lmd": _LMDrecalls,
      "bert": _BERTrecalls
    },
    "precisions": {
      "lmd": _LMDprecisions,
      "bert": _BERTprecisions
    }
  }