import _ from 'lodash'
import { Annotations, Layout, PlotData } from 'plotly.js'
import Plotly from 'plotly.js/dist/plotly-axon'
import { useMemo } from 'react'
import createPlotlyComponent from 'react-plotly.js/factory'

// Some layout constants
const LABEL_HEIGHT_PER_CHARACTER = 8
const MAX_LABEL_LENGTH = 48 // @TODO actually enforce this elsewhere?
const TITLE_HEIGHT = 40

// customizable method: use your own `Plotly` object
const Plot = createPlotlyComponent(Plotly)

export type HeatmapComparisonPlotDatum = {
  category: string // plotted on the x-axis
  value: string // plotted on the y-axis
}

export type HeatmapComparisonPlotProps = {
  data: HeatmapComparisonPlotDatum[]
  title?: string
  categoryName?: string // x-axis label
  sortValuesNumerically?: boolean // whether to sort the values numerically
  valueName: string // y-axis label
  width: number
  rowHeight: number
}

/**
 * This function calculates and prepares the data for a heatmap plot based on an array of input values.
 *
 * @internal Exported only for unit testing
 * @param data - An array of data points, each containing a category and a value.
 * @param sortValuesNumerically - An optional boolean flag indicating whether to sort the values numerically.
 *
 * @returns A partial PlotData object containing the data ready for a heatmap plot.
 */
export function getPlotData(
  data: HeatmapComparisonPlotDatum[],
  sortValuesNumerically?: boolean
): Partial<PlotData> {
  // Calculate distributions to plot
  const groupedData = _.groupBy(data, 'category')

  // Construct, for each category, a distribution of values
  //
  // i.e. categoryValueDistributionMap[category][value] \in [0,1] is within the given category
  // the fraction of rows that have the given value
  const categoryValueDistributionMap = _.mapValues(
    groupedData,
    (rows): { [value: string]: number } => {
      const valueCounts = _.countBy(rows, 'value')
      const totalCount = _.sum(Object.values(valueCounts))
      const valueDistribution = _.mapValues(valueCounts, (count) => count / totalCount)
      return valueDistribution
    }
  )

  // Get distinct values, sorted alphabetically or numerically (as requested)
  const uniqueValues = _.uniq(data.map((d) => d.value))
  const distinctValues = sortValuesNumerically
    ? uniqueValues.sort((a, b) => parseFloat(a) - parseFloat(b))
    : uniqueValues.sort()

  // Get distinct categories, sorted alphabetically
  const distinctCategories = _.uniq(data.map((d) => d.category)).sort()

  // Convert the categoryValueDistributionMap into a 2D array for the heatmap plot
  const plotDataArray = distinctValues.map((value): number[] => {
    return distinctCategories.map((category): number => {
      return categoryValueDistributionMap[category][value] ?? 0
    })
  })

  return {
    type: 'heatmap',
    x: distinctCategories.map((value) => `\u2062${value}`),
    // Plotly treats values that look "numerical" differently from strings,
    // But we want numerical values (e.g. Leiden cluster) to also lay out the same way
    // So let's add a zero-width invisible character to the beginning of each value to make Plotly believe they're strings
    // (Yes, this is hacky, but after poking around for a while, was the best solution I could find to
    // make Plotly turn off it's internal magic)
    y: distinctValues.map((value) => `\u2062${value}`),
    z: plotDataArray,

    // Use a heatmap color scale that starts at 0, and uses the max data value for the top of the scale
    colorscale: 'YlOrRd',
    reversescale: true,
    showscale: true,
    zmin: 0,

    // The default hover info is not useful and clutters up the plot
    hoverinfo: 'none',

    // Add some grid lines for better readability
    xgap: 3,
    ygap: 3,
  }
}

function getAnnotations(plotData: Partial<PlotData>): Partial<Annotations>[] {
  const annotations: Partial<Annotations>[] = []

  const zArray = plotData.z as number[][]

  plotData.y?.forEach((value, i) => {
    plotData.x?.forEach((category, j) => {
      const categoryString = category as string
      const valueString = value as string
      const density = zArray[i][j]
      const annotation: Partial<Annotations> = {
        xref: 'x',
        yref: 'y',
        x: categoryString,
        y: valueString,
        text: density?.toFixed(2),
        showarrow: false,
        font: { color: 'black' },
        xanchor: 'center',
        yanchor: 'middle',
      }
      annotations.push(annotation)
    })
  })
  return annotations
}

const HeatmapComparisonPlot = ({
  data,
  title,
  categoryName,
  valueName,
  sortValuesNumerically,
  width,
  rowHeight,
}: HeatmapComparisonPlotProps): JSX.Element => {
  const plotData = useMemo(
    () => getPlotData(data, sortValuesNumerically),
    [data, sortValuesNumerically]
  )
  const annotations = useMemo(() => getAnnotations(plotData), [plotData])

  // It took some futzing around to make layouts work with a range of different chart configurations
  // The futzing around isn't ideal...but it's hard to control the black box that is Plotly
  //
  // For the x-axis labels along the top, Plotly sometimes uses 0 and sometimes 45 degree
  // orientation on the labels.  And sometimes it puts the title over the axis labels, so we're
  // adding our own margin.
  //
  // See this issue for more details about how Plotly doesn't shift the title if the axis labels
  // need more space...but Plotly does own some pretty complex axis label logic, which makes it messy
  // https://community.plotly.com/t/overlap-of-x-axis-label-and-figure-title/12340
  //
  // We also always set the tick angle to 45 degrees to make the layout computation simpler
  // (otherwise, it's harder to guess when Plotly will switch between 0 and 45)
  //
  // And set the overall plot height based on the estimated label + title height, and then
  // the amount of space needed to give each row in the heatmap grid rowHeight pixels

  // Get the max label length, and then corresponding estimated label height of the longest label at a
  // 45 degree angle
  const maxLabelLength = (plotData.x as string[]).reduce(
    (max, label) => Math.max(max, label.length),
    0
  )
  const estimatedLabelHeight =
    Math.min(maxLabelLength, MAX_LABEL_LENGTH) * LABEL_HEIGHT_PER_CHARACTER

  // Calculate the plot height based on the number of rows and the label and title height
  const plotHeight =
    rowHeight * (plotData.y as string[]).length + estimatedLabelHeight + TITLE_HEIGHT + 100

  // Estimate the top margin height to make sure the title and labels don't overlap
  // The minumum of 80 was chosen to make sure labels of 1-5 characters long don't overlap the title
  const topMargin = Math.max(estimatedLabelHeight, 80)

  const layout: Partial<Layout> = {
    title: {
      text: title,
      font: {
        size: 16,
      },
      x: 0.5,
      y: 0.97, // this is necessary to help avoid the overlap
    },
    margin: {
      t: topMargin,
    },
    width,
    height: plotHeight,
    annotations,
    xaxis: {
      title: categoryName,
      side: 'top',
      type: 'category',
      automargin: true,
      tickangle: 45,
    },
    yaxis: {
      title: valueName,
      type: 'category',
      categoryorder: 'trace', // Preserve the order in the data
      autorange: 'reversed', // Reverse the y-axis so that the top is at the top
      automargin: true,
    },
  }

  return (
    <Plot
      data={[plotData]}
      layout={layout}
      config={{
        displaylogo: false,
        modeBarButtonsToRemove: [
          'zoom2d',
          'zoomIn2d',
          'zoomOut2d',
          'resetScale2d',
          'autoScale2d',
          'lasso2d',
          'pan2d',
          'select2d',
        ],
      }}
    />
  )
}

export default HeatmapComparisonPlot
