import { density1d } from '@uwdata/kde'
import _ from 'lodash'
import { Layout, PlotData } from 'plotly.js'
import Plotly from 'plotly.js/dist/plotly-axon'
import { useMemo } from 'react'
import createPlotlyComponent from 'react-plotly.js/factory'
import { LabelColor } from 'redux/slices'

// customizable method: use your own `Plotly` object
const Plot = createPlotlyComponent(Plotly)

export type ContinuousComparisonPlotDatum = {
  category: string // plotted using line color
  value: number // plotted as a KDE-smoothed distribution along the x-axis
}

export type ContinuousComparisonPlotProps = {
  data: ContinuousComparisonPlotDatum[]
  title?: string
  titleFontSize?: number // optional font size for the title (defaults to 16)
  labelColors?: LabelColor[] // optional label colors to assign to the categories
  width: number
  height: number
}

/**
 * @internal (exposed only for unit testing)
 * Takes a list of category + value pairs, runs 1D KDE with automated bandwidth selection to get a PDF
 * And returns the Plotly plot of the PDFs
 *
 * @param data rows of category + value pairs to plot
 * @param labelColors Optional custom label colors to use for the categories,
 *    to match colors in use elsewhere, if needed
 * @returns An array of Plotly PlotData objects, defining one scatterplot line graph per category
 */
export function getPlotData(
  data: ContinuousComparisonPlotDatum[],
  labelColors?: LabelColor[]
): Partial<PlotData>[] {
  // Group values by category so that we get a float array for each category
  const groupedData = _.groupBy(data, 'category')
  const groupedFloatArrays = _.mapValues(groupedData, (values) => {
    return values.map((v) => v.value)
  })

  // For each category, compute the KDE of the values
  return _.map(groupedFloatArrays, (values, groupName) => {
    // Compute the KDE given the float values
    // This defaults to using NRD to pick the 'best' bandwidth to use (so the curve is smooth enough but not too smooth)
    // And also automatically selects the extents to use to capture 99% of the input data
    const density = density1d(values)

    // Convert the output of kde.density1d to x and y values for Plotly
    const points = [...density.points()]
    const x = points.map((p) => p.x)
    const y = points.map((p) => p.y)

    const lineColor = labelColors?.find((lc) => lc.name === groupName)
    const line: Partial<PlotData> = lineColor ? { line: { color: lineColor.color } } : {}

    return {
      ...line,
      type: 'scatter',
      mode: 'lines',
      x,
      y,
      name: groupName,
      showlegend: true,

      // The default hover info is not useful and clutters up the plot
      hoverinfo: 'none',
    }
  })
}

const ContinuousComparisonPlot = ({
  data,
  title,
  titleFontSize = 16,
  labelColors,
  width,
  height,
}: ContinuousComparisonPlotProps): JSX.Element => {
  const plotData = useMemo(() => getPlotData(data, labelColors), [data, labelColors])

  const layout: Partial<Layout> = {
    title:
      title === undefined || title === ''
        ? undefined
        : {
            text: title,
            font: {
              size: titleFontSize,
            },
          },
    legend: {
      orientation: 'h',
    },
    width,
    height,
    yaxis: {
      title: 'Density',
    },
    margin: {
      l: 50, // left margin
      r: 50, // right margin
      t: title ? 80 : 20, // top margin, smaller if no title
      b: 50, // bottom margin
    },
  }

  return (
    <Plot
      data={plotData}
      layout={layout}
      config={{
        displaylogo: false,
        modeBarButtonsToRemove: [
          'zoom2d',
          'zoomIn2d',
          'zoomOut2d',
          'resetScale2d',
          'autoScale2d',
          'lasso2d',
          'pan2d',
          'select2d',
        ],
      }}
    />
  )
}

export default ContinuousComparisonPlot
