import { Tensor } from "@tensorflow/tfjs"
import { Vector2d } from "konva/lib/types"
import Block from "./Nodes/Block"
import { v4 as uuid } from "uuid"
import Invert from "./Nodes/Operations/Invert"
import Add from "./Nodes/Operations/Add"
import Literal from "./Nodes/Inputs/Literal"
import * as tf from "@tensorflow/tfjs"
import SymbolicInput from "./Nodes/Inputs/SymbolicInput"
import OutputBlock from "./Nodes/Outputs/OutputBlock"
import OutputShape from "./Nodes/Outputs/OutputShape"
import ActivationFn from "./Nodes/Operations/ActivationFn"
import MatMul from "./Nodes/Operations/MatMul"
import Parameter from "./Nodes/Training/Parameter"
import Trainer from "./Nodes/Training/Trainer"
import Transpose from "./Nodes/Operations/Transpose"
import Multiply from "./Nodes/Operations/Multiply"
import Predictor from "./Nodes/Outputs/Predictor"
import CustomDataset from "./Nodes/Training/CustomDataset"
import TensorDisplay from "./Nodes/Inputs/TensorDisplay"
import Ones from "./Nodes/Inputs/Ones"
import Reshape from "./Nodes/Operations/Reshape"
import { BLOCK_TYPES, InspectorProps, SaveEntry, SaveState, SerializedTestCase, TestCase, ValueStore } from "./Interfaces"
import Sum from "./Nodes/Operations/Sum"
import Loss from "./Nodes/Losses/Loss"
import SquareRoot from "./Nodes/Operations/Square Root"
import Square from "./Nodes/Operations/Square"
import Mean from "./Nodes/Operations/Mean"
import Divide from "./Nodes/Operations/Divide"
import Linear from "./Nodes/Modules/Linear"
import Argmax from "./Nodes/Operations/ArgMax"
import Conv1D from "./Nodes/Modules/Conv1D"
import { ItemInfo } from "./Exercises/ExerciseData"
import LibraryItem from "../Components/LibraryItem/LibraryItem"
import TabView from "../Components/TabView/TabView"
import Source from "./Nodes/SourceBlock"
import Konva from "konva"
import Conv2D from "./Nodes/Modules/Conv2D"
import Pooling from "./Nodes/Modules/Pooling"
import ImageViewer from "./Nodes/Outputs/ImageViewer"
import Concat from "./Nodes/Operations/Concat"
import Dropout from "./Nodes/Operations/Dropout"
import Subtract from "./Nodes/Operations/Subtract"
import ShapeReader from "./Nodes/Operations/ShapeReader"

export function tensorToString(tensor: Tensor) {
    const data = tensor.arraySync()
    if (typeof data === "number") {
        return `${round(data, 3)}`
    } else if (data instanceof Array && typeof data[0] === "number") {
        return JSON.stringify(data.map(v => round(v as number, 3)))
    }
    return JSON.stringify(data)
}

export function tensorDescription(tensor: Tensor) {
    let rawString = tensor.toString()
    rawString = rawString.replaceAll("\n    ", "\n")
    if (rawString.startsWith("Tensor")) {
        return rawString.substring(7).trim()
    }
    return rawString.trim()
}

export function distance(a: Vector2d, b: Vector2d) {
    return Math.sqrt(Math.pow(a.x - b.x, 2) + Math.pow(a.y - b.y, 2))
}

export function createBlockFromString(blockString: string, data?: SaveEntry, id?: string, completion?: (block: Block) => void) {
    let block = new Block(id ?? uuid())
    switch (blockString) {
        case "add":
            block = new Add(id ?? uuid())
            break;
        case "negate":
            block = new Invert(id ?? uuid())
            break;
        case "transpose":
            block = new Transpose(id ?? uuid())
            break;
        case "multiply":
            block = new Multiply(id ?? uuid())
            break;
        case "subtract":
            block = new Subtract(id ?? uuid())
            break;
        case "divide":
            block = new Divide(id ?? uuid())
            break;
        case "matmul":
            block = new MatMul(id ?? uuid())
            break;
        case "activation":
            block = new ActivationFn(id ?? uuid())
            break;
        case "argmax":
            block = new Argmax(id ?? uuid())
            break;
        // case "loss_mse":
        //     block = new Loss(id ?? uuid(), "mse")
        //     break;
        case "loss_mae":
            block = new Loss(id ?? uuid(), "mae")
            break;
        case "loss_mse":
            block = new Loss(id ?? uuid(), "mse")
            break;
        case "cosine_similarity":
            block = new Loss(id ?? uuid(), "cosine")
            break;
        case "reshape":
            block = new Reshape(id ?? uuid())
            break;
        case "sum": 
            block = new Sum(id ?? uuid())
            break;
        case "sqrt":
            block = new SquareRoot(id ?? uuid())
            break;
        case "square":
            block = new Square(id ?? uuid())
            break;
        case "mean":
            block = new Mean(id ?? uuid())
            break;
        case "concat":
            block = new Concat(id ?? uuid())
            break;

        // Inputs
        case "symbolic":
            block = new SymbolicInput(id ?? uuid())
            break;
        case "constant":
            block = new Literal(id ?? uuid())
            break;
        case "parameter":
            block = new Parameter(id ?? uuid())
            break;
        case "ones":
            block = new Ones(id ?? uuid(), [1, 3])
            break;
        // Outputs
        case "tensor_viewer":
            block = new OutputBlock(id ?? uuid())
            break;
        case "image_viewer":
            block = new ImageViewer(id ?? uuid())
            break;
        case "shape":
            block = new OutputShape(id ?? uuid())
            break;
        case "shape_reader":
            block = new ShapeReader(id ?? uuid())
            break;
        
        case "custom-dataset":
            block = new CustomDataset(id ?? uuid())
            break;
        case "trainer":
            block = new Trainer(id ?? uuid())
            break;
        case "predictor":
            block = new Predictor(id ?? uuid())
            break;
        
        case "linear":
            block = new Linear(id ?? uuid())
            break;
        case "conv1d":
            block = new Conv1D(id ?? uuid())
            break;
        case "conv2d":
            block = new Conv2D(id ?? uuid())
            break;
        case "maxpool1d":
            block = new Pooling(id ?? uuid(), "max", 1)
            break;
        case "avgpool1d":
            block = new Pooling(id ?? uuid(), "avg", 1)
            break;
        case "maxpool2d":
            block = new Pooling(id ?? uuid(), "max", 2)
            break;
        case "avgpool2d":
            block = new Pooling(id ?? uuid(), "avg", 2)
            break;
        case "globalmaxpool1d":
            block = new Pooling(id ?? uuid(), "global_max", 1)
            break;
        case "globalavgpool1d":
            block = new Pooling(id ?? uuid(), "global_avg", 1)
            break;
        case "globalmaxpool2d":
            block = new Pooling(id ?? uuid(), "global_max", 2)
            break;
        case "globalavgpool2d":
            block = new Pooling(id ?? uuid(), "global_avg", 2)
            break;
        case "dropout":
            block = new Dropout(id ?? uuid())
            break;
        default:
            break;
    }
    
    if (data) {
        block.deserialize(data).then(() => completion?.(block))
    }

    return block
}

/**
 * Convert a test case into a serializable format for transferring between web workers.
 * @param original Original test case in TF tensor format.
 * @returns 
 */
export function serializeTestCase(original: TestCase): SerializedTestCase {
    return {
        input: original.input.map(x => x.arraySync()),
        output: original.output.arraySync(),
    }
}

export function round(x: number, decimalPlaces: number) {
    return Math.round((x + Number.EPSILON) * 10 ** decimalPlaces) / 10 ** decimalPlaces
}

export function getLibraryViewWithQuota(quota?: Record<string, ItemInfo>) {
    let ops: JSX.Element[] = []
    let inputs: JSX.Element[] = []
    let outputs: JSX.Element[] = []
    let layers: JSX.Element[] = []
    let tabs: { title: string, view: JSX.Element }[] = []

    // Basic Operations
    for (const key in quota) {
        const { count, data, displayedTitle } = quota[key];   
        if (key in BLOCK_TYPES.basicOps) {
            ops.push(<LibraryItem data={data} title={displayedTitle ?? BLOCK_TYPES.basicOps[key]} type={key} count={count} key={key} />)
        } else if (key in BLOCK_TYPES.inputs) {
            inputs.push(<LibraryItem data={data} title={displayedTitle ?? BLOCK_TYPES.inputs[key]} type={key} count={count} key={key} />)
        } else if (key in BLOCK_TYPES.outputs) {
            outputs.push(<LibraryItem data={data} title={displayedTitle ?? BLOCK_TYPES.outputs[key]} type={key} count={count} key={key} />)
        } else if (key in BLOCK_TYPES.layers) {
            layers.push(<LibraryItem data={data} title={displayedTitle ?? BLOCK_TYPES.layers[key]} type={key} count={count} key={key} />)
        }
    }
    if (ops.length > 0) {
        tabs.push({ title: "Basic Operations", view: <div>{ops}</div> })
    }

    if (inputs.length > 0) {
        tabs.push({ title: "Inputs and Data", view: <div>{inputs}</div> })
    }

    if (outputs.length > 0) {
        tabs.push({ title: "Outputs and Visualization", view: <div>{outputs}</div> })
    } if (layers.length > 0) {
        tabs.push({ title: "Neural Network Layers", view: <div>{layers}</div> })
    }

    return <TabView tabs={tabs}/>
}

export async function serializeToJSON(valueStore: ValueStore) {
    return {
        objects: Object.fromEntries(
            await Promise.all(
                Object.entries(valueStore.everything).map(async ([id, block]) => {
                    if (block.quotaId === "block") {
                        console.warn(`Encountered base block (${id}) when serializing`)
                    }
                    return [block.id, await block.serialize()]
                })
            )
        ),
        connections: Array.from(valueStore.connections).map(([id, conn]) => {
            return {
                from: {
                    blockId: conn.start.parent.id,
                    index: conn.start.index
                },
                id: id,
                to: {
                    blockId: conn.end.parent.id,
                    index: conn.end.index
                }
            }
        })
    }
}

export function addBlockAtPosition(valueStore: ValueStore, layer: Konva.Layer, block: Block, setInspectorView: (view?: InspectorProps) => void, onShowInspector: (v: boolean) => void, targetX?: number, targetY?: number, disableOffset?: boolean) {
    if (block.element) {
        if (targetX !== undefined) block.element.x(targetX - (disableOffset ? 0 : block.element.width() / 2))
        if (targetY !== undefined) block.element.y(targetY - (disableOffset ? 0 : block.element.height() / 2))
        layer.add(block.element)
        block.globalState = valueStore
        valueStore.everything[block.id] = block
        block.finishSetup((newValue) => {
            setInspectorView(newValue)
            if (newValue !== undefined) {
                onShowInspector(true)
            }
        }, setInspectorView)
    } else {
        console.warn(`Unable to add block ${block.id} because element is empty`)
    }
}

/*
    console.log("restore from", obj)
    let currentQuota = this.getInitialQuota()
    for (const [id, block] of Object.entries(obj.objects)) {
        if (block.quotaId === "block") {
            console.warn("Do not use `block` as save entry type id")
        }
        if (id in this.defaultBlocks) {
            await this.defaultBlocks[id]?.deserialize(block)
        } else {
            if (block.quotaId && block.quotaId in currentQuota && currentQuota[block.quotaId].count > 0 && block.quotaId) {
                currentQuota[block.quotaId].count -= 1
                const newBlock = createBlockFromString(block.quotaId!, block, id) // if quotaId is nonnull, so must be typeId
                this.addBlockAtPosition(newBlock, block.position?.x, block.position?.y, true)
            } else  {
                console.warn(`Skipping recovering ${block.quotaId} due to insufficient quota`)
                console.warn("Remaining quota", currentQuota)
            } 
        }
    }
    this.quota = currentQuota

    // Restore connections
    for (const connection of obj.connections) {
        const outgoing = this.valueStore.everything[connection.from.blockId]?.outputs[connection.from.index]
        const incoming = this.valueStore.everything[connection.to.blockId]?.inputs[connection.to.index]
        if (outgoing && incoming) {
            outgoing.addConnectionToReceptor(incoming, true)
        }
    }
    
    // Finally, rerun the network
    for (const block of Object.values(this.valueStore.everything)) {
        // Propagate from all sources
        (block as Source)?.propagate?.()
    }

    adjustOffset(this.valueStore)
*/
export async function restoreFromSaveState(obj: SaveState, layer: Konva.Layer, initialQuota: Record<string, ItemInfo>, valueStore: ValueStore, setInspectorView: (view?: InspectorProps) => void, onShowInspector: (v: boolean) => void, defaultBlocks?: Record<string, Block>): Promise<Record<string, ItemInfo>> {
    console.log("restore from", obj)
    let currentQuota = Object.assign({}, initialQuota)
    for (const [id, block] of Object.entries(obj.objects)) {
        if (block.quotaId === "block") {
            console.warn("Do not use `block` as save entry type id")
        }
        if (defaultBlocks && id in defaultBlocks) {
            await defaultBlocks[id]?.deserialize(block)
        } else {
            if (block.quotaId && block.quotaId in currentQuota && currentQuota[block.quotaId].count > 0 && block.quotaId) {
                currentQuota[block.quotaId].count -= 1
                const newBlock = createBlockFromString(block.quotaId!, block, id) // if quotaId is nonnull, so must be typeId
                addBlockAtPosition(valueStore, layer, newBlock, setInspectorView, onShowInspector, block.position?.x, block.position?.y, true)
            } else  {
                console.warn(`Skipping recovering ${block.quotaId} due to insufficient quota`)
                console.warn("Remaining quota", currentQuota)
            } 
        }
    }

    // Restore connections
    for (const connection of obj.connections) {
        const outgoing = valueStore.everything[connection.from.blockId]?.outputs[connection.from.index]
        const incoming = valueStore.everything[connection.to.blockId]?.inputs[connection.to.index]
        if (outgoing && incoming) {
            outgoing.addConnectionToReceptor(incoming, true)
        }
    }
    
    // Finally, rerun the network
    for (const block of Object.values(valueStore.everything)) {
        // Propagate from all sources
        (block as Source)?.propagate?.()
    }

    // setTimeout(() => adjustOffset(valueStore), 1)

    return currentQuota
}

/** Expand the canvas size if necessary so that all elements fit on the screen. */
export function adjustOffset(valueStore: ValueStore) {
    const minX = Math.min(...Object.values(valueStore.everything).map(b => b.element ? (b.element.x() - b.element.offsetX()) : Infinity))
    const maxX = Math.max(...Object.values(valueStore.everything).map(b => b.element ? b.element.width() - b.element.offsetX() + b.element.x() : 0))

    const minY = Math.min(...Object.values(valueStore.everything).map(b => b.element ? (b.element.y() - b.element.offsetY()) : Infinity))
    const maxY = Math.max(...Object.values(valueStore.everything).map(b => b.element ? (b.element.y() - b.element.offsetY() + b.element.height()) : 0))

    const offsetX = Math.min(valueStore.stage!.offsetX(), minX - (valueStore.autoOffsetPadding ?? 20))
    const offsetY = Math.min(valueStore.stage!.offsetY(), minY - (valueStore.autoOffsetPadding ?? 20))

    // const blockspace = document.querySelector(".blockspace-canvas") as HTMLDivElement
    valueStore.stage?.offsetX(offsetX)
    valueStore.stage?.offsetY(offsetY)
    valueStore.stage?.width(Math.max(valueStore.viewSize?.width ?? 1000, valueStore.stage.width(), maxX + (valueStore.autoOffsetPadding ?? 100) - offsetX))
    valueStore.stage?.height(Math.max(valueStore.viewSize?.height ?? 1000, valueStore.stage.height(), maxY + (valueStore.autoOffsetPadding ?? 100) - offsetY))

    // blockspace.scrollBy({ left: Math.max(0, oldOffsetX - offsetX) })
}

export async function sleep(seconds: number) {
    await new Promise<void>((r, _) => setTimeout(r, seconds * 1000))
}

export function getNativeValue(tensor: tf.Tensor) {
    if (tensor.dtype === "bool") {
        const raw = tensor.arraySync()
        if (typeof raw === "number") {
            return raw > 0
        } else if (tensor.size === 0) {
            return raw
        } else if (typeof raw[0] === "number") {
            return raw.map(v => (v as number) > 0)
        } else if (typeof raw[0][0] === "number") {
            return raw.map(row => (row as number[]).map(v => (v as number) > 0))
        } else if (typeof raw[0][0][0] === "number") {
            return raw.map(row => (row as number[][]).map(col => col.map(v => v > 0)))
        } else if (typeof raw[0][0][0][0] === "number") {
            return raw.map(dim0 => (dim0 as number[][][]).map(dim1 => dim1.map(dim2 => dim2.map(v => v > 0))))
        } else if (typeof raw[0][0][0][0][0] === "number") {
            return raw.map(dim0 => (dim0 as number[][][][]).map(dim1 => dim1.map(dim2 => dim2.map(dim3 => dim3.map(v => v > 0)))))
        }
        return raw
    } else {
        return tensor.arraySync()
    }
}