import Konva from "konva";
import Block from "../Block";
import * as tf from "@tensorflow/tfjs"
import Receptor from "../Receptor";
import Endpoint from "../Endpoint";
import { InspectorProps } from "../../Interfaces";
import { tensorToString } from "../../Utils";
import { ReshapeLayer } from "../../Layers/Arithmetics";
import { KonvaEventObject } from "konva/lib/Node";

class Reshape extends Block {

    type_id = "reshape"
    titleLabel: Konva.Text
    subtitleLabel: Konva.Text
    ring: Konva.Circle
    layer = tf.layers.reshape({targetShape: [-1]})

    get description(): string {
        return JSON.stringify((this.layer.getConfig().targetShape as number[]).map(d => d ?? -1))
    }

    constructor(id: string) {
        super(id)
        this.element = new Konva.Group({
            draggable: true,
            width: 90,
            height: 90,
            offsetX: -45,
            offsetY: -45
        })

        this.ring = new Konva.Circle({
            stroke: "#e2a0e4",
            strokeWidth: 3,
            radius: this.element.width() / 2,
            fill: "#fef0ff"
        })
        this.element.add(this.ring)

        this.titleLabel = new Konva.Text({
            text: "Reshape",
            width: this.element.width(),
            height: this.element.height(),
            x: -this.element.width() / 2,
            y: -this.element.height() / 2 - 6,
            align: "center",
            verticalAlign: "middle",
            fontSize: 14,
            fontStyle: "bold"
        })
        this.element.add(this.titleLabel)

        this.subtitleLabel = new Konva.Text({
            text: this.description,
            width: this.element.width() - 4,
            x: -this.element.width() / 2 + 2,
            y: 4,
            align: "center"
        })
        this.element.add(this.subtitleLabel)

        this.inputs = [
            new Receptor(this, 0, -this.ring.radius(), 0),
        ]
        this.outputs = [
            new Endpoint(this, 0, this.ring.radius(), 0)
        ]
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided) {
            try {
                this.currentValue = this.layer.apply(this.inputs[0].currentValue!) as tf.Tensor | tf.SymbolicTensor
            } catch (error) {
                console.warn(error)
                this.currentValue = undefined
            }
        } else {
            this.currentValue = null
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    onClickMenu(): InspectorProps {
        let ref: HTMLTextAreaElement | null = null
        const editArea = <textarea placeholder="Enter Parameter Value" defaultValue={JSON.stringify((this.layer.getConfig().targetShape as number[]).map(d => d ?? -1))} className="custom-textarea" style={{width: "100%", minHeight: "100px", border: "1px solid #f0f0f0"}} ref={(e) => {
            ref = e
        }} />
        return {
            title: this.displayedName,
            settings: editArea,
            buttons: [
                {
                    title: "Cancel",
                    type: "normal"
                },
                {
                    title: "Save",
                    type: "normal",
                    onClick: () => {
                        try {
                            const newShape = JSON.parse(ref!.value)
                            if (!(newShape instanceof Array) || newShape.some(x => (typeof x) !== "number")) {
                                throw new Error("Shape must be a number array")
                            }
                            if (newShape.filter(v => v === -1).length > 1) {
                                throw new Error("Shape can only contain at most one -1.")
                            }
                            this.layer = tf.layers.reshape({ targetShape: newShape })
                            this.subtitleLabel.text(JSON.stringify((newShape as number[]).map(d => d ?? -1)))
                            this.onInputUpdated(0)
                            this.outputs[0].propagate(this.currentValue)
                            this.updateReceptorsAndEndpoints()
                            this.globalState.visitedReceptorCount.clear()
                            return true
                        } catch (error: any) {
                            this.globalState.visitedReceptorCount.clear()
                            if (error instanceof Error) {
                                alert(error.message)
                            } else {
                                alert("Invalid shape.")
                            }
                            console.warn(error)
                            return false
                        }
                    }
                }
            ]
        }
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.ring.shadowColor("#70a5fd")
        this.ring.shadowBlur(6)
        this.ring.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.ring.shadowOpacity(0)
    }

    async getStateDict() {
        return {
            shape: this.layer.getConfig().targetShape
        }
    }

    async loadStateDict(data: Record<string, any>) {
        try {
            if (data?.shape instanceof Array) {
                this.layer = tf.layers.reshape({ targetShape: data.shape })
                this.subtitleLabel.text(this.description)
            }
        } catch (error) {
            console.warn(error)
        }
    }
}

export default Reshape;