import {
  Cell,
  CellClass,
  CellClassMap,
  Prediction,
  SampleType,
  SampleTypeMap,
} from '@deepcell/dc_core_proto/deepcell_schema2_pb'
import * as jspb from 'google-protobuf'
import { Message } from 'google-protobuf'

export declare type CellClassValue = CellClassMap[keyof CellClassMap]
export declare type SampleTypeValue = SampleTypeMap[keyof SampleTypeMap]

type EnumClassType = SampleTypeMap | CellClassMap

export class EnumEncoderDecoder {
  prefix: string

  enumClass: EnumClassType

  invertedMap: { [key: number]: string }

  constructor(prefix: string, enumClass: EnumClassType) {
    this.prefix = prefix
    this.enumClass = enumClass
    this.invertedMap = Object.entries(enumClass).reduce(
      (acc, [key, val]) => ({ ...acc, [val]: key }),
      {} as Record<number, string>
    )
  }

  convertFromString(key: string): number {
    const fullKey = `${this.prefix}${key}` as keyof EnumClassType
    if (fullKey in this.enumClass) {
      return this.enumClass[fullKey]
    }
    return -1
  }

  convertToString(value: number | undefined): string {
    if (value === undefined) return ''
    return this.invertedMap[value]?.substring(this.prefix.length)
  }
}

export const SampleTypeEncoderDecoder = new EnumEncoderDecoder('SAMPLE_TYPE_', SampleType)
export const CellClassEncoderDecoder = new EnumEncoderDecoder('CLASS_', CellClass)

export const base64StrToMessage = <T extends Message>(
  base64Str: string,
  messageClass: { deserializeBinary(bytes: Uint8Array): T }
): T => {
  const buf = Buffer.from(base64Str, 'base64')
  return messageClass.deserializeBinary(buf)
}

export const messageToBase64Str = (message: jspb.Message): string => {
  const buf = message.serializeBinary()
  const bufStr = buf.reduce((data, byte) => data + String.fromCharCode(byte), '')
  return btoa(bufStr)
}

export const messagesEqual = <T extends jspb.Message>(
  a: T | undefined,
  b: T | undefined
): boolean => {
  return a?.toString() === b?.toString()
}

export function getCellPrediction(cell: Cell): Prediction | undefined {
  const predictions = cell.getPredictionsList()
  if (predictions.length === 0) return undefined

  return predictions[0]
}

export function getPredictedClass(prediction: Prediction): number | undefined {
  if (prediction === undefined) return undefined

  if (prediction.hasAverage()) {
    return prediction.getAverage()
  }
  return prediction.getVoting()
}
