import { Box, Stack } from '@mui/material'
import { CellInfo } from 'components/cell-visualizations/tsv/types'
import _ from 'lodash'
import { useMemo } from 'react'
import { LabelColor } from 'redux/slices'
import { useCellVisualizationsSlice } from 'redux/slices/hooks/useCellVisualizationsSlice'
import { usePlotData } from '../plot/usePlotData'
import useDataFieldsInDataset from '../useDataFieldsInDataset'
import DataCategorySelect from './DataCategorySelect'
import DataFieldSelect from './DataFieldSelect'
import DataTypeCompareSelect from './DataTypeCompareSelect'
import { CategoricalDistributionPlotFactory, KDEPlotFactory } from './PlotFactory'
import useDataCategory, { CELL_GROUPS_CATEGORY_KEY } from './useDataCategory'

type GroupedCellInfo = CellInfo & {
  GROUP_NAME: string
}

const CompareGroups = (): JSX.Element => {
  const { dataFields } = useDataFieldsInDataset()
  const { cellVisualizations, getCellInfoByCellId, getMergedPinnedCellGroups } =
    useCellVisualizationsSlice()

  const {
    groupCompare,
    cellInfoGroups,
    selectedCellInfoGroupName: selectedCellInfoGroup,
  } = cellVisualizations

  const { cellInfo: cellsData } = usePlotData()

  const {
    selectedFeatures,
    selectedDataCategory,
    selectedDataFields: allSelectedDataFields,
  } = groupCompare
  const { getDataCategoryAttribute } = useDataCategory()

  // This should never be undefined
  const cellInfoKey = getDataCategoryAttribute(selectedDataCategory)
  const selectedGroups = getMergedPinnedCellGroups().filter((x) => x.isInSelectedGroup)

  const selectedDataFields = allSelectedDataFields[selectedDataCategory]

  const cellIdsByGroup = useMemo(() => {
    const ids: Record<string, Set<string>> = {}
    selectedGroups.forEach((selectedGroup) => {
      ids[selectedGroup.name] = new Set()
      selectedGroup.cells.points?.forEach(({ cellId }) => {
        const morphoData = getCellInfoByCellId(cellId ?? '')
        const cellInfoData = morphoData[cellInfoKey]

        if (cellInfoData) ids[selectedGroup.name].add(cellInfoData)
      })
    })
    return ids
  }, [cellInfoKey, selectedGroups, getCellInfoByCellId])

  /** Take each group of cells and augment the cell data with a group name
   *
   * Note that a cell may appear in the result more than once, if it's in multiple cell groups
   * and that's fine.
   */
  const rawData = useMemo(() => {
    const result: GroupedCellInfo[] = []

    if (selectedDataCategory === CELL_GROUPS_CATEGORY_KEY) {
      cellsData?.forEach((cell) =>
        _.keys(cellIdsByGroup).forEach((groupName) => {
          const cellVal = cell[cellInfoKey]
          if (cellVal && cellIdsByGroup[groupName].has(cellVal)) {
            result.push({ ...cell, GROUP_NAME: groupName })
          }
        })
      )
    } else {
      selectedDataFields?.forEach((dataField: string) => {
        result.push(
          ...(cellsData ?? []).flatMap((cell) =>
            // Make sure to convert values to string before comparing with the dataField
            String(cell[cellInfoKey]) === dataField ? { ...cell, GROUP_NAME: dataField } : []
          )
        )
      })
    }
    return result
  }, [cellIdsByGroup, cellInfoKey, cellsData, selectedDataCategory, selectedDataFields])

  const colors = cellInfoGroups[selectedCellInfoGroup]?.data.map((x) => ({
    color: x.color,
    name: x.value,
    isHidden: x.isHidden,
  })) as LabelColor[]

  const featurePlotFactories = useMemo(() => {
    return dataFields.map((feature) => {
      const params = [rawData, feature.attribute as keyof CellInfo, 'GROUP_NAME', colors] as const

      const plotFactory = feature.isContinuous
        ? new KDEPlotFactory(...params)
        : new CategoricalDistributionPlotFactory(...params)

      return { feature, plotFactory }
    })
  }, [colors, dataFields, rawData])

  const plots = selectedFeatures.flatMap((feature) => {
    const currentFeaturePlotFactory = featurePlotFactories.find(
      (x) => x.feature.attribute === feature.attribute
    )

    if (!currentFeaturePlotFactory) return []

    const { plotFactory } = currentFeaturePlotFactory

    plotFactory.setLayout({
      width: 500,
      height: 400,
      title: feature.label,
    })

    return (
      <Box key={feature.attribute} ml={-2} py={1}>
        {plotFactory.getPlot()}
      </Box>
    )
  })

  // Reverse the plots list so that as you add more plots, the new plot always shows up at the top
  // where it's immediately visible
  plots.reverse()

  return (
    <Stack spacing={2}>
      <DataCategorySelect />
      <DataTypeCompareSelect />
      <DataFieldSelect />

      <Stack>{plots}</Stack>
    </Stack>
  )
}

export default CompareGroups
