import { density1d } from '@uwdata/kde'
import { group, InternMap, rollup } from 'd3-array'
import Plotly from 'plotly.js/dist/plotly-axon'
import { Data, Layout, PlotData } from 'plotly.js'
import createPlotlyComponent from 'react-plotly.js/factory';
import { LabelColor  } from 'redux/slices'

// customizable method: use your own `Plotly` object
const Plot = createPlotlyComponent(Plotly);

/** We're using a more limited definition than Plotly's allowable data types */
export type DeepcellDatum = string | number | undefined
export type DefinedDeepcellDatum = Exclude<DeepcellDatum, undefined>

/**
 * Abstract base class of a factory that generates plots.
 * Currently, this interface is tied to building only Plotly plots, but we could abstract that away
 *
 * Input data is in the form of arrays of objects of type TDataObj
 * The values in the object are always of type DeepcellDatum
 *
 * @TODO Not sure if this is a reasonable assumption for a basic data type, but it's a start
 * It may be cleaner to accept arrays of DeepcellDatum[] instead of TDataObj[]
 *
 * (and put slightly more responsibility on the caller)
 */
export abstract class PlotFactory<
  TDataObj extends Partial<{ [key in TKey]: DeepcellDatum }>,
  TKey extends keyof TDataObj
> {
  data: TDataObj[]

  layout?: Partial<Layout>

  /** Creates a plot type with the given data and optional layout override */
  constructor(data: TDataObj[]) {
    this.data = data
  }

  setLayout(layout?: Partial<Layout>): void {
    this.layout = layout
  }

  /**
   * Gets values for the given field from the data
   * @param field Field to access from the data
   * @returns An array of values
   */
  getField(field: TKey): DeepcellDatum[] {
    return this.data.map((item) => item[field])
  }

  /**
   *
   * @returns A Plotly Plot element based on this plot type, using the given data
   */
  getPlot(): JSX.Element {
    return (
      <Plot
        data={this.getData()}
        layout={{ ...this.layout, ...this.getLayout() }}
        config={{
          displaylogo: false,
          modeBarButtonsToRemove: [
            'zoom2d',
            'zoomIn2d',
            'zoomOut2d',
            'resetScale2d',
            'autoScale2d',
            'lasso2d',
            'pan2d',
            'select2d',
          ],
        }}
      />
    )
  }

  abstract getData(): Data[]
  abstract getLayout(): Partial<Layout>
}

/** Plots distribution of one quantiative variable, optionally grouped by a categorical field
 * Using a smoothed kernel density estimator for the values
 * (similar to seaborn.kdeplot, but only 1D)
 */
export class KDEPlotFactory<
  TDataObj extends Partial<{ [key in TKey]: DeepcellDatum }>,
  TKey extends keyof TDataObj
> extends PlotFactory<TDataObj, TKey> {
  valueField: TKey

  categoryField?: TKey

  labelColors?: LabelColor[]

  constructor(
    data: TDataObj[],
    valueField: TKey,
    categoryField?: TKey,
    labelColors?: LabelColor[]
  ) {
    super(data)
    this.valueField = valueField
    this.categoryField = categoryField
    this.labelColors = labelColors
  }

  /**
   * Generates one or more KDE PlotData configurations for the given valueField
   * @returns One or more KDE plots based on input data
   */
  getData(): Partial<PlotData>[] {
    if (this.categoryField !== undefined) {
      const categories = [...new Set(this.getField(this.categoryField))]
      const categoryData: InternMap<DeepcellDatum, TDataObj[]> = group(
        this.data,
        (item) => item[this.categoryField as TKey]
      )

      return categories.map((category) => {
        const categoryRows = categoryData.get(category) as TDataObj[]
        const values = categoryRows.map((x) => x[this.valueField])
        return this.getDataForValues(values, category)
      })
    }

    // otherwise
    const values = this.getField(this.valueField)
    return [this.getDataForValues(values)]
  }

  getDataForValues(values: DeepcellDatum[], groupName?: DeepcellDatum): Partial<PlotData> {
    const floatValues = values
      .filter((value) => value !== undefined)
      .map((value) => (typeof value === 'number' ? value : parseFloat(value as string)))

    // Compute the KDE given the float values
    // This defaults to using NRD to pick the 'best' bandwidth to use (so the curve is smooth enough but not too smooth)
    // And also automatically selects the extents to use to capture 99% of the input data
    const density = density1d(floatValues)

    // Convert the output of kde.density1d to x and y values for Plotly
    const points = [...density.points()]
    const x = points.map((p) => p.x)
    const y = points.map((p) => p.y)

    const name = `${groupName}`
    const lineColor = this.labelColors?.find((lc) => `${lc.name}` === name)
    const line: Partial<PlotData> = lineColor ? { line: { color: lineColor.color } } : {}

    return {
      ...line,
      type: 'scatter',
      mode: 'lines',
      x,
      y,
      name,
    }
  }

  getLayout(): Partial<Layout> {
    return {
      legend: { x: 0, y: -0.1, orientation: 'h' },
      yaxis: { title: 'Density', automargin: true, fixedrange: true },
      xaxis: { automargin: true, fixedrange: true, autorange: true },
    }
  }
}

/** Plots distribution of one categorical variable (as a horizontal bar chart, with probability values),
 * optionally grouped by another categorical field (as a grouped bar chart, with probability values that sum to 1 per group)
 */
export class CategoricalDistributionPlotFactory<
  TDataObj extends Partial<{ [key in TKey]: DeepcellDatum }>,
  TKey extends keyof TDataObj
> extends PlotFactory<TDataObj, TKey> {
  data: TDataObj[]

  valueField: TKey

  categoryField?: TKey

  labelColors?: LabelColor[]

  constructor(
    data: TDataObj[],
    valueField: TKey,
    categoryField?: TKey,
    labelColors?: LabelColor[]
  ) {
    super(data)
    this.data = data
    this.valueField = valueField
    this.categoryField = categoryField
    this.labelColors = labelColors
  }

  /**
   * Generates one or more bar chart PlotData configurations,
   * summarizing the distribution of values in the valueField.
   *
   * @returns One or more bar chart configs
   */
  getData(): Partial<PlotData>[] {
    if (this.categoryField !== undefined) {
      const categoryData = group(this.data, (x) => x[this.categoryField as TKey])
      const categories = [...categoryData.keys()]

      return categories.map((category) => {
        const categoryRows = categoryData.get(category) as TDataObj[]
        return this.getDataForRows(categoryRows, category)
      })
    }
    return [this.getDataForRows(this.data)]
  }

  /**
   * Returns a single plot Data for an array of data rows
   *
   * @param rows An array of data rows to calculate a distribution for
   * @param groupName Optional name of the group for this plot
   * @returns A Plotly plot object specifying the bar graph to plot
   */
  getDataForRows(rows: TDataObj[], groupName?: DeepcellDatum): Partial<PlotData> {
    const count = rows.length

    // Note: the histnorm parameter from the docs below does not appear to be available in react-plotly
    // https://plotly.com/javascript/histograms/#normalized-histogram

    // So we use d3.rollup here to convert absolute counts to fractions
    const groupDensity = rollup(
      rows,
      (v) => v.length / count, // reducer
      (item) => String(item[this.valueField]) // key.  Make sure this is a string so that it displays properly
    )

    const categories = [...groupDensity.keys()].filter(
      (x) => x !== undefined
    ) as DefinedDeepcellDatum[]
    const densities: number[] = [...groupDensity.values()]

    const name = groupName?.toString()
    const markerColor = this.labelColors?.find((lc) => lc.name === name)
    const marker: Partial<PlotData> = markerColor
      ? { marker: { color: markerColor.color } }
      : {}

    return {
      ...marker,
      type: 'bar',
      x: densities,
      y: categories,
      name,
      orientation: 'h', // use horizontal bar charts to make category labels more readable
    }
  }

  getLayout(): Partial<Layout> {
    return {
      legend: { x: 0, y: -0.3, orientation: 'h' },

      // sort categories in order
      yaxis: { categoryorder: 'category descending', automargin: true, fixedrange: true },
      xaxis: { automargin: true, title: 'Density', fixedrange: true },
    }
  }
}
