import { Layer } from "konva/lib/Layer"
import Endpoint from "./Nodes/Endpoint"
import Receptor from "./Nodes/Receptor"
import Connection from "./Nodes/Connection"
import React from "react"
import { Tensor } from "@tensorflow/tfjs"
import { Vector2d } from "konva/lib/types"
import Block from "./Nodes/Block"
import { Stage } from "konva/lib/Stage"
import * as tf from "@tensorflow/tfjs"
import ExerciseData, { ItemInfo } from "./Exercises/ExerciseData"
import { createBlockFromString } from "./Utils"

/*
const outputLibrary = [
    <LibraryItem title='Tensor Reader' type='output'/>,
    <LibraryItem title='Tensor Info' type='shape'/>,
    <LibraryItem title='Trainer' type="trainer" />,
    <LibraryItem title='Predictor' type="predictor" />
]
*/
export const BLOCK_TYPES: Record<"basicOps" | "inputs" | "outputs" | "layers", Record<string, string>> = {
    basicOps: {
        add: "Add",
        multiply: "Multiply",
        negate: "Negate",
        transpose: "Transpose",
        matmul: "Matrix Multiply",
        activation: "Activation Function",
        reshape: "Reshape",
        sum: "Sum",
        loss_mae: "Mean Absolute Error",
        loss_mse: "Mean Squared Error",
        cosine_similarity: "Cosine Similarity",
        sqrt: "Square Root",
        square: "Square",
        mean: "Mean",
        divide: "Divide",
        argmax: "ArgMax",
        concat: "Concatenation",
        shape_reader: "Shape Reader"
    },
    inputs: {
        constant: "Constant",
        symbolic: "Input",
        parameter: "Parameter",
        custom_dataset: "Custom Dataset",
        ones: "Ones"
    },
    outputs: {
        tensor_viewer: "Tensor Viewer",
        tensor_info: "Tensor Info",
        trainer: "Trainer",
        predictor: "Predictor",
        image_viewer: "Image Viewer",
    },
    layers: {
        linear: "Linear Layer",
        conv1d: "Convolution 1D",
        conv2d: "Convolution 2D",
        maxpool1d: "Max Pool 1D",
        avgpool1d: "Average Pool 1D",
        maxpool2d: "Max Pool 2D",
        avgpool2d: "Average Pool 2D",
        globalavgpool1d: "Global Average Pool 1D",
        globalmaxpool1d: "Global Max Pool 1D",
        globalavgpool2d: "Global Average Pool 2D",
        globalmaxpool2d: "Global Max Pool 2D",
        dropout: "Dropout"
    }
}

export const samplePlaygroundQuota: Record<string, ItemInfo> = {
    "constant": { count: Infinity },
    "tensor_viewer": { count: Infinity },
    "add": { count: Infinity },
    "multiply": {count: Infinity},
    "sum": {count: Infinity},
    "square" :{count: Infinity},
    "sqrt": {count: Infinity},
    "trainer": {count: 1},
    "parameter": {count: Infinity},
    "linear": { count: Infinity },
    "conv1d": {count: Infinity },
    "conv2d": {count: Infinity },
    "maxpool1d": {count: Infinity},
    "avgpool1d": {count: Infinity},
    "maxpool2d": {count: Infinity},
    "avgpool2d": {count: Infinity},
    "globalavgpool2d": {count: Infinity},
    "globalmaxpool2d": {count: Infinity},
    "argmax": {count: Infinity },
    "reshape": {count: Infinity},
    "concat": {count: Infinity},
    "negate": {count: Infinity},
    "image_viewer": {count: Infinity},
    "activation": {count: Infinity},
    "dropout": {count: Infinity}
}

export interface ValueStore {
    mainLayer: Layer
    stage?: Stage
    activeEndpoint?: Endpoint
    activeReceptor?: Receptor
    availableReceptors: Receptor[]
    availableEndpoints: Endpoint[]
    connections: Map<string, Connection>
    showTooltip?: (tooltip: TooltipProps) => void
    hideTooltip?: () => void

    /** Used for checking for cycles. When a connection is changed, add all used downstream receptors are added to the set. If a duplicate receptor is found, then there exists a cycle in the computation graph. */
    visitedReceptorCount: Map<string, number>

    everything: Record<string, Block>
    selection: Record<string, Arrangement>
    activeExercise?: ExerciseData
    isTraining?: boolean
    viewSize?: { width: number, height: number }
    autoOffsetPadding?: number
    chapterId?: string
}

export interface Arrangement {
    position: { x: number, y: number }
    state: Record<string, any>
}

export interface InspectorProps {
    title: string
    settings?: React.ReactNode
    buttons?: {
        title: string
        onClick?: () => boolean
        disabled?: boolean
        type: "normal" | "default" | "default-hollow" | "destructive"
    }[]
    docs?: React.ReactNode
}

export type LossType = "mse" | "mae" | "cross-entropy" | "cosine"

export interface TooltipProps {
    text: string
    position: Vector2d
}

export interface CustomTFDataset {
    trainX: Tensor | Tensor[]
    trainY: Tensor
    evalX: Tensor | Tensor[]
    evalY: Tensor
    totalClasses?: number // For classification tasks
}

export type ReceptorType = "tensor" | "model" | "dataset"
export type InputDataType = "tensor" | "grayscale_image" | "rgb_image" | "default"

export interface TestCase {
    inputLabels?: string[]
    inputTypes?: InputDataType[]
    outputTypes?: InputDataType[]
    input: tf.Tensor[]
    output: tf.Tensor
    outputLabel?: string
}

export type ArrayTensor = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][]

export interface SerializedTestCase {
    input: ArrayTensor[]
    inputNames?: string[]
    inputTypes?: TestCase['inputTypes']
    output: ArrayTensor
    outputName?: string
}

export interface TestCaseResult {
    case: SerializedTestCase
    status: "waiting" | "running" | "correct" | "incorrect"
    pred?: {
        value: ArrayTensor
    }
}

export interface TestResult {
    /** The test case number */
    current: number
    status: "ready" | "waiting" | "running" | "done"
    result: TestCaseResult[]
    classLabels?: string[]
    passed?: number
}

/*
objects: Object.fromEntries(
        Object.entries(valueStore.everything).map(([key, block]) => {
            return [
                key,
                {
                    position: block.element?.position(),
                    value: block.serialize(),
                    customName: block.customName
                }
            ]
        })
    )
*/

interface StoredConnection {
    from: {
        blockId: string
        index: number
    }
    to: {
        blockId: string
        index: number
    },
    id: string
}

export interface SaveState {
    objects: SaveEntry[]
    connections: StoredConnection[]
}

export interface SaveEntry {
    /** Optional for default blocks. Required for additional blocks. */
    position?: Vector2d,
    value?: any,
    customName?: string

    /** Optional for default blocks. Required for additional blocks. */
    typeId?: string

    /** If id for quota purposes is different from the type (e.g. loss blocks), define custom quotaId. For default blocks, quotaId stores the value of the actual id. */
    quotaId?: string
}

export interface MNIST_Image {
    label: number
    image: number[] // binary numbers of length 784
}

export interface ClassificationImage {
    label: number
    image: number[][][] // height, width, channel
}