import Konva from "konva"
import Endpoint from "../Endpoint"
import Block from '../Block';
import ContentBlock from '../ContentBlock'
import * as tf from "@tensorflow/tfjs"
import { round, tensorToString } from '../../Utils';
import { InspectorProps } from "../../Interfaces";
import { KonvaEventObject } from "konva/lib/Node";
import Source from "../SourceBlock";
import { randomUUID } from "crypto";
import { v4 } from "uuid";

class SymbolicInput extends Block implements Source {
    value: tf.SymbolicTensor
    concreteValue?: tf.Tensor
    arrow: Konva.Line
    text: Konva.Text
    customName?: string
    type_id = "symbolic"
    training = false
    isOutput = false
    
    get blockName() { return "Input" }

    get displayedName(): string {
        return this.customName ?? "Input"
    }

    get description(): string {
        if (this.training) {
            return `Training Mode\n(Shape: ${(this.value.shape.map(x => x ?? "B")).join(" x ")})`
        } else if (this.concreteValue && this.concreteValue.size === 1) {
            return `${this.displayedName} (${JSON.stringify(round(this.concreteValue.squeeze([0]).arraySync() as number, 2))})`
        } else {
            return `${this.displayedName}\n(Shape: ${(this.concreteValue?.shape ?? (this.value.shape).map(x => x ?? "B")).join(" x ")})`
        }
    }

    getDocumentation(): string {
        return `Input blocks are the sources of information for an AI model. They are the ways in which AI perceives the outside world. For example, if you are building a model on image classification, then the Input to your model would be an RGB image.`
    }

    constructor(id: string) {
        super(id)
        this.value = tf.input({ shape: [], name: id })

        this.element = new Konva.Group({
            draggable: true,
            width: 185,
            height: 60,
        })
        this.arrow = new Konva.Line({
            points: [
                0, 0,
                150, 0,
                150, -15,
                185, 30,
                150, 75,
                150, 60,
                0, 60
            ],
            fill: "#e4e2f8",
            closed: true
        })
        this.element.add(this.arrow)

        // this.titleLabel.text("Symbolic Input")
        this.text = new Konva.Text({
            width: 160,
            height: 60,
            verticalAlign: "middle",
            fontSize: 15,
            text: this.description,
            padding: 5
        })
        this.element.add(this.text)

        this.outputs = [
            new Endpoint(this, 0, this.arrow.width(), 30)
        ]
        this.outputs.forEach(output => output.currentValue = this.value)
    }

    addAltInput() {
        const output2 = new Endpoint(this, 1, (this.arrow.width() - 20) / 2, 60)
        output2.currentValue = this.value
        this.outputs.push(output2)
    }

    setTraining(training = true, batchSize?: number) {
        const oldStatus = this.training
        this.training = training
        this.value = tf.input({ batchShape: [batchSize ?? null, ...this.value.shape.slice(1)] })
        this.text.text(this.description)

        // this.value = tf.input({ name: this.id, batchShape: batchedInputShape })
        if (oldStatus !== training) {
            try {
                this.propagate()
            } catch (error) {
                console.warn(error)
            }
        }
    }

    applyConcreteValue(tensor?: tf.Tensor, dontPropagate?: boolean, force = false) {
        if (this.training && !force) return
        this.concreteValue = tensor
        this.globalState.visitedReceptorCount.clear()
        this.text.text(this.description)
        const shapeInput = document.querySelector("#symbolic-input-shape") as HTMLInputElement
        if (shapeInput) {
            shapeInput.value = JSON.stringify(this.concreteValue?.shape ?? this.value.shape.slice(1))
        }
        this.outputs[0].propagate(this.concreteValue ?? this.value, dontPropagate)
        this.outputs[1]?.propagate(this.value, true)
        this.globalState.visitedReceptorCount.clear()
    }

    propagate() {
        if (this.training) {
            this.outputs.forEach(o => o.propagate(this.value))
        } else {
            this.outputs.forEach(o => o.propagate(this.concreteValue ?? this.value))
        }
    }

    onClickMenu(): InspectorProps {
        let ref: HTMLInputElement | null = null
        const editArea = <input type="text" placeholder="Shape as 1D Array" disabled={!this.editable} defaultValue={JSON.stringify(this.concreteValue?.shape ?? this.value.shape.slice(1))} className="custom-textarea" style={{ width: 'calc(100% - 40px)', textAlign: 'right'}} id="symbolic-input-shape" ref={(e) => {
            ref = e
            if (e) { e.value = e.defaultValue }
        }} />
        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>{`Shape`}</td>
                    <td>{editArea}</td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: this.editable ? [
                {
                    title: "Save",
                    type: "normal",
                    onClick: () => {
                        try {
                            const shape = JSON.parse(ref!.value)
                            this.value = tf.input({ batchShape: shape })
                            this.text.text(this.description)
                            this.outputs[0].propagate(this.value)
                            this.globalState.visitedReceptorCount.clear()
                            return true
                        } catch (error) {
                            this.globalState.visitedReceptorCount.clear()
                            alert("Invalid shape.")
                            return false
                        }
                    }
                }
            ] : [],
            docs: this.documentation
        }
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.arrow.shadowColor("#70a5fd")
        this.arrow.shadowBlur(8)
        this.arrow.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.arrow.shadowOpacity(0)
    }

    async getStateDict() {
        return {
            shape: this.value.shape.slice(1),
            training: this.training
        }
    }

    async loadStateDict(data: Record<string, any>) {
        try {
            this.value = tf.input({ shape: data.shape })
            if (data.customName) {
                this.customName = data.customName
            }
            if (data.isOutput) {
                this.arrow.fill("#d2e8ff")
                this.isOutput = true
            }
            // this.training = data.training ?? false
            this.text.text(this.description)
        } catch (error) {
            console.warn("Failed to load state dict for symbolic input", data, error)
        }
    }
}

export default SymbolicInput;