import numpy as np
import matplotlib.pyplot as plt
import textwrap as txtwrap

NAME_LENGTH = 40


def plotLines(PROJ_DIR, REL_DOCS_PER_TURN, setName, plotTitle, xs, ys, legend, axis=None):
  # create figure
  plt.figure(figsize=(8, 8))
  # define title
  plt.title(plotTitle)
  # define axis
  if axis is not None:
    plt.axis(axis)
  # define grid
  plt.grid(True)
  # get plots
  plots = []
  for idx in range(len(legend)):
    this_plot, = plt.plot(xs[idx], ys[idx])
    plots.append(this_plot)
  # create plots with legend
  plt.legend(plots, legend)
  # include labels
  plt.tight_layout()
  plt.savefig(PROJ_DIR + "/plots/" + setName + "/" + str(REL_DOCS_PER_TURN) + "/" + plotTitle + ".png")
  plt.show()


def plotPoints(PROJ_DIR, REL_DOCS_PER_TURN, setName, plotTitle, xs, ys, legend, axis=None):
  # create figure
  plt.figure(figsize=(8, 8))
  # define title
  plt.title(plotTitle)
  # define axis
  if axis is not None:
    plt.axis(axis)
  # define grid
  plt.grid(True)
  # get plots
  plots = []
  for idx in range(len(legend)):
    plt.scatter(xs[idx], ys[idx], label=legend[idx])
  # create plots with legend
  plt.legend()
  # include labels
  plt.tight_layout()
  plt.savefig(PROJ_DIR + "/plots/" + setName + "/" + str(REL_DOCS_PER_TURN) + "/" + plotTitle + ".png")
  plt.show()


def plotMetricAlongConversation(PROJ_DIR, REL_DOCS_PER_TURN, setName, metricName, matrices, modelsNames, convNumbers, preName=""):
  # convs turns names
  convsTurnsNames = []
  # models turns mean
  modelsTurnsMean = []
  for modelID in range(len(matrices)):
    modelsTurnsMean.append(np.nanmean(matrices[modelID], axis=0).tolist())
    convsTurnsNames.append(range(1, len(matrices[0][0]) + 1))
  # plot turns mean
  plotLines(PROJ_DIR, REL_DOCS_PER_TURN, setName, preName + "Per turn score for " + metricName, convsTurnsNames, modelsTurnsMean, modelsNames)
  # each conversation
  for convID in range(len(matrices[0])):
    # models conv
    modelsConv = []
    for modelID in range(len(matrices)):
      modelsConv.append(matrices[modelID][convID])
    # plot conv
    plotLines(PROJ_DIR, REL_DOCS_PER_TURN, setName, preName + "Per turn score for " + metricName + " on conversation " + str(convNumbers[convID]), convsTurnsNames, modelsConv, modelsNames)


def plotMetricEachConversation(PROJ_DIR, REL_DOCS_PER_TURN, setName, metricName, matrices, modelsNames, convNumbers, convNames):
  # convs names
  convsNames = []
  # models convs mean
  modelsConvsMean = []
  for modelID in range(len(matrices)):
    tempConvsNames = []
    for name in convNames[:, 0]:
      tempConvsNames.append(txtwrap.shorten(name, width=NAME_LENGTH, placeholder="..."))
    convsNames.append(tempConvsNames)
    modelsConvsMean.append(np.nanmean(matrices[modelID], axis=1))
  # plot turns mean
  plotPoints(PROJ_DIR, REL_DOCS_PER_TURN, setName, metricName + " comparison", modelsConvsMean, convsNames, modelsNames)
  # each conversation
  for convID in range(len(matrices[0])):
    # models conv
    modelsConv = []
    # turns names
    convsTurnsNames = []
    for modelID in range(len(matrices)):
      modelsConv.append(matrices[modelID][convID])
      tempTurnsNames = []
      for name in convNames[convID, 1:]:
        tempTurnsNames.append(txtwrap.shorten(name, width=NAME_LENGTH, placeholder="..."))
      convsTurnsNames.append(tempTurnsNames)
    # plot conv
    plotPoints(PROJ_DIR, REL_DOCS_PER_TURN, setName, metricName + " comparison on conversation " + str(convNumbers[convID]), modelsConv, convsTurnsNames, modelsNames)


def plotPrecisionRecall(PROJ_DIR, REL_DOCS_PER_TURN, setName, recall_matrices, precision_matrices, modelsNames, convNumbers, preName=""):
  # x values
  xValues = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
  # each conversation
  for convID in range(len(recall_matrices[0])):
    # data
    precisions = []
    recalls = []
    # each model
    for modelID in range(len(recall_matrices)):
      precisions.append(np.interp(xValues, recall_matrices[modelID][convID], precision_matrices[modelID][convID]))
      recalls.append(xValues)
    plotLines(PROJ_DIR, REL_DOCS_PER_TURN, setName, preName + "Precision-Recall on conversation " + str(convNumbers[convID]), recalls, precisions, modelsNames, [0.0, 1.0, 0.0, 1.0])
  # mean
  precisions = []
  recalls = []
  for modelID in range(len(recall_matrices)):
    precision = np.nanmean(precision_matrices[modelID], axis=0)
    recall = np.nanmean(recall_matrices[modelID], axis=0)
    precisions.append(np.interp(xValues, recall, precision))
    recalls.append(xValues)
  plotLines(PROJ_DIR, REL_DOCS_PER_TURN, setName, preName + "Precision-Recall", recalls, precisions, modelsNames, [0.0, 1.0, 0.0, 1.0])
