import { Stack } from '@mui/material'
import { CellInfo } from 'components/cell-visualizations/tsv/types'
import { useMemo } from 'react'
import { CellDataField, PinnedMorphotype } from 'redux/slices'
import { useCellVisualizationsSlice } from 'redux/slices/hooks/useCellVisualizationsSlice'
import { usePlotData } from '../plot/usePlotData'
import ContinuousComparisonPlot from './ContinuousComparisonPlot'
import DataCategorySelect from './DataCategorySelect'
import DataFieldSelect from './DataFieldSelect'
import DataTypeCompareSelect from './DataTypeCompareSelect'
import HeatmapComparisonPlot from './HeatmapComparisonPlot'
import { CELL_GROUPS_CATEGORY_KEY } from './useDataCategory'

export type PlotDatum<T> = { category: string; value: T }

/**
 * @internal - only exposed for unit testing
 * This function generates plot data for comparing morphotypes based on a specific feature.
 *
 * It transforms the feature values using the provided `dataTransform` function and returns
 * an array of `PlotDatum` objects.
 *
 * @param feature - The key of the feature to compare. It should be a valid key of `CellInfo`.
 * @param cellsData - An array of `CellInfo` objects representing the cells data.
 * @param morphotypes - An array of `PinnedMorphotype` objects representing the morphotypes to compare.
 * @param dataTransform - A function that transforms the feature values. It takes a string as input and returns a transformed value of type `T`.
 *
 * @returns An array of `PlotDatum` objects, where each object contains the category (morphotype name)
 * and the transformed feature value.
 */
export function getPlotDataFromMorphotypes<T>(
  feature: keyof CellInfo,
  cellsData: CellInfo[],
  morphotypes: PinnedMorphotype[],
  dataTransform: (value: string) => T
): PlotDatum<T>[] {
  const result: PlotDatum<T>[] = []

  morphotypes.forEach((morphotype) => {
    const cellIdSet = new Set(morphotype.cells?.points?.map((cell) => cell.cellId))

    cellsData.forEach((cell) => {
      const cellId = cell.CELL_ID
      if (cellId && cellIdSet.has(cellId)) {
        const value = cell[feature]
        if (value !== undefined) {
          result.push({
            category: morphotype.name,
            value: dataTransform(value),
          })
        }
      }
    })
  })
  return result
}

/**
 * @internal - only exposed for unit testing
 * This function generates plot data for comparing cell categories based on a specific feature.
 * It transforms the feature values using the provided `dataTransform` function and returns
 * an array of `PlotDatum` objects.
 *
 * @param feature - The key of the feature to compare. It should be a valid key of `CellInfo`.
 * @param category - The key of the category to group the cells by. It should be a valid key of `CellInfo`.
 * @param selectedCategoryItems - An array of strings representing the selected categories to include in the plot.
 * @param cellsData - An array of `CellInfo` objects representing the cells data.
 * @param dataTransform - A function that transforms the feature values. It takes a string as input and returns a transformed value of type `T`.
 *
 * @returns An array of `PlotDatum` objects, where each object contains the category (cell category)
 * and the transformed feature value.
 */
export function getPlotDataFromCategoryField<T>(
  feature: keyof CellInfo,
  category: keyof CellInfo,
  selectedCategoryItems: string[],
  cellsData: CellInfo[],
  dataTransform: (value: string) => T
): PlotDatum<T>[] {
  const result: PlotDatum<T>[] = []
  const selectedCategorySet = new Set(selectedCategoryItems)
  cellsData.forEach((cell) => {
    const categoryValue = cell[category]
    const featureValue = cell[feature]
    if (
      categoryValue !== undefined &&
      featureValue !== undefined &&
      selectedCategorySet.has(categoryValue)
    ) {
      result.push({
        category: categoryValue,
        value: dataTransform(featureValue),
      })
    }
  })
  return result
}

const CompareGroups = (): JSX.Element => {
  const { cellVisualizations, getMergedPinnedMorphotypes } = useCellVisualizationsSlice()

  const { groupCompare } = cellVisualizations

  const { cellInfo: cellsData } = usePlotData()

  const {
    selectedFeatures,
    selectedDataCategory,
    selectedDataFields: allSelectedDataFields,
  } = groupCompare

  const selectedGroups = getMergedPinnedMorphotypes().filter((x) => x.isInSelectedGroup)
  const selectedDataFields = allSelectedDataFields[selectedDataCategory]

  const plots = useMemo(() => {
    if (
      selectedFeatures.length === 0 ||
      (selectedDataCategory === CELL_GROUPS_CATEGORY_KEY && selectedGroups.length === 0) ||
      (selectedDataCategory !== CELL_GROUPS_CATEGORY_KEY && selectedDataFields.length === 0)
    ) {
      return []
    }
    return selectedFeatures.map((dataField: CellDataField): JSX.Element => {
      const feature = dataField.attribute

      let title = `${dataField.label} distribution`
      if (dataField.isContinuous) {
        let data: PlotDatum<number>[] = []
        if (selectedDataCategory === CELL_GROUPS_CATEGORY_KEY) {
          data = getPlotDataFromMorphotypes(
            feature as keyof CellInfo,
            cellsData ?? [],
            selectedGroups,
            parseFloat
          )
        } else {
          data = getPlotDataFromCategoryField(
            feature as keyof CellInfo,
            selectedDataCategory as keyof CellInfo,
            selectedDataFields,
            cellsData ?? [],
            parseFloat
          )
        }
        return (
          <ContinuousComparisonPlot
            key={dataField.attribute}
            data={data}
            title={title}
            width={500}
            height={400}
          />
        )
      }

      title += ` by ${selectedDataCategory}`
      let data: PlotDatum<string>[] = []
      const transform = (value: string) => value
      if (selectedDataCategory === CELL_GROUPS_CATEGORY_KEY) {
        data = getPlotDataFromMorphotypes(
          feature as keyof CellInfo,
          cellsData ?? [],
          selectedGroups,
          transform
        )
      } else {
        data = getPlotDataFromCategoryField(
          feature as keyof CellInfo,
          selectedDataCategory as keyof CellInfo,
          selectedDataFields,
          cellsData ?? [],
          transform
        )
      }
      return (
        <HeatmapComparisonPlot
          key={feature}
          data={data}
          title={title}
          sortValuesNumerically={dataField.isNumerical}
          valueName={dataField.label}
          width={500}
          rowHeight={25}
        />
      )
    })
  }, [cellsData, selectedDataCategory, selectedDataFields, selectedFeatures, selectedGroups])

  // Reverse the plots list so that as you add more plots, the new plot always shows up at the top
  // where it's immediately visible
  plots.reverse()

  return (
    <Stack spacing={2}>
      <DataCategorySelect />
      <DataTypeCompareSelect />
      <DataFieldSelect />

      <Stack>{plots}</Stack>
    </Stack>
  )
}

export default CompareGroups
