import Konva from "konva";
import Block from "../Block";
import * as tf from "@tensorflow/tfjs"
import Receptor from "../Receptor";
import Endpoint from "../Endpoint";
import { Dropout as DropoutLayer } from "@tensorflow/tfjs-layers/dist/layers/core"
import { InspectorProps } from "../../Interfaces";
import Dropdown from "rc-dropdown";
import 'rc-dropdown/assets/index.css';
import { KonvaEventObject } from "konva/lib/Node";
import { ArgmaxLayer } from "../../Layers/Arithmetics";

class Dropout extends Block {

    type_id = "dropout"
    ring: Konva.Circle
    layer: DropoutLayer
    titleLabel: Konva.Text
    subtitleLabel: Konva.Text

    get description(): string {
        return `Rate: ${this.layer.getConfig().rate}`
    }

    constructor(id: string) {
        super(id)
        this.layer = tf.layers.dropout({ rate: 0.5 })
        this.element = new Konva.Group({
            draggable: true,
            width: 85,
            height: 85,
            offsetX: -42.5,
            offsetY: -42.5
        })

        this.ring = new Konva.Circle({
            stroke: "#e2a0e4",
            strokeWidth: 3,
            radius: this.element.width() / 2,
            fill: "#fef0ff"
        })

        this.titleLabel = new Konva.Text({
            text: "Dropout",
            fontSize: 16,
            fontStyle: "Bold",
            align: "center",
            width: this.ring.width(),
            x: -42.5,
            y: -13
        })

        this.subtitleLabel = new Konva.Text({
            text: this.description,
            fontSize: 13,
            align: "center",
            width: this.ring.width(),
            x: -42.5,
            y: 8
        })

        this.element.add(this.ring)
        this.element.add(this.titleLabel)
        this.element.add(this.subtitleLabel)

        this.inputs = [
            new Receptor(this, 0, -this.ring.radius(), 0, "Input")
        ]
        this.outputs = [
            new Endpoint(this, 0, this.ring.radius(), 0)
        ]
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided) {
            try {
                this.currentValue = this.layer.apply(this.inputs[0].currentValue!) as tf.Tensor | tf.SymbolicTensor
            } catch (error) {
                console.warn(error)
                this.currentValue = undefined
            }
        } else {
            this.currentValue = null
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    onClickMenu(): InspectorProps {
        let dropoutRate: HTMLInputElement | null = null
        const dimField = <input type="text" placeholder="Enter 0-1 Number" defaultValue={`${this.layer.getConfig().rate ?? 0.5}`} className="custom-textarea" style={{border: "1px solid #f0f0f0", width: '80px', textAlign: 'right', maxHeight: "30px"}} ref={(e) => {
            dropoutRate = e
            if (e) { e.value = e.defaultValue }
        }} />
        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>Dropout Rate</td>
                    <td>
                        {dimField}
                    </td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: [
                {
                    title: "Done",
                    type: "normal",
                    onClick: () => {
                        try {
                            const rate = Number(dropoutRate?.value)
                            if (rate < 0 || rate > 1) { throw new Error() }
                            this.layer = tf.layers.dropout({ rate })
                            this.subtitleLabel.text(this.description)
                            this.onInputUpdated(0)
                            this.globalState.visitedReceptorCount.clear()
                            return true
                        } catch (error) {
                            this.globalState.visitedReceptorCount.clear()
                            alert("Invalid value.")
                            return false
                        }
                    }
                }
            ]
        }
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.ring.shadowColor("#70a5fd")
        this.ring.shadowBlur(7)
        this.ring.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.ring.shadowOpacity(0)
    }

    async getStateDict(): Promise<Record<string, any>> {
        return {
            rate: this.layer.getConfig().rate
        }
    }

    async loadStateDict(data: Record<string, any>): Promise<void> {
        try {
            this.layer = tf.layers.dropout({ rate: data.rate ?? 0.5 })
            this.subtitleLabel.text(this.description)
        } catch (error) {
            console.warn("Failed to load state dict for concat", data, error)
        }
    }

}

export default Dropout;