import Konva from "konva";
import Block from "../Block";
import * as tf from "@tensorflow/tfjs"
import Receptor from "../Receptor";
import Endpoint from "../Endpoint";
import { ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config"
import { Activation } from "@tensorflow/tfjs-layers/dist/layers/core"
import { InspectorProps } from "../../Interfaces";
import Dropdown from "rc-dropdown";
import 'rc-dropdown/assets/index.css';
import { useState } from "react";
import { KonvaEventObject } from "konva/lib/Node";

class ActivationFn extends Block {

    type_id = "activation"
    ring: Konva.Circle
    layer: Activation
    type: ActivationIdentifier
    subtitleLabel: Konva.Text

    constructor(id: string, type?: ActivationIdentifier) {
        super(id)
        this.type = type ?? "sigmoid"
        this.layer = tf.layers.activation({ activation: this.type })
        this.element = new Konva.Group({
            draggable: true,
            width: 85,
            height: 85,
            offsetX: -42.5,
            offsetY: -42.5
        })

        this.ring = new Konva.RegularPolygon({
            stroke: "#a2e5e0",
            sides: 6,
            rotation: 30,
            strokeWidth: 3,
            radius: this.element.width() / 2,
            fill: "#f0fffe"
        })

        this.subtitleLabel = new Konva.Text({
            text: this.type,
            fontSize: 12,
            align: "center",
            width: this.ring.width(),
            x: -42.5,
            y: 7
        })

        this.element.add(this.ring)
        this.element.add(this.subtitleLabel)

        Konva.Image.fromURL("/assets/blocks/activation.png", img => {
            img.size({ width: 40, height: 40 })
            img.x(-img.size().width / 2)
            img.y(img.x() - 7)
            this.element?.add(img)
        })

        this.inputs = [
            new Receptor(this, 0, -this.ring.radius(), 0),
        ]
        this.outputs = [
            new Endpoint(this, 0, this.ring.radius(), 0)
        ]
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided) {
            this.currentValue = this.layer.apply(this.inputs[0].currentValue!) as tf.Tensor | tf.SymbolicTensor
        } else {
            this.currentValue = this.inputs[0].currentValue
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    onClickMenu(): InspectorProps {
        let selectionButtonRef: HTMLButtonElement | null = null
        const menu = <div className="menu">
            {["Elu", "ReLU", "Sigmoid", "SoftSign", "Tanh"].map(name => <button key={name} onClick={ (e) => {
                if (selectionButtonRef) {
                    selectionButtonRef.innerText = name.toLowerCase();
                    this.type = name.toLowerCase() as ActivationIdentifier
                    this.layer = tf.layers.activation({ activation: this.type })
                    this.subtitleLabel.text(this.type)
                    this.onInputUpdated(0)
                    this.globalState.visitedReceptorCount.clear()
                }
            }}>{name}</button>)}
        </div>
        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>Activation Function</td>
                    <td>
                        <Dropdown trigger={['click']} overlay={menu} animation="slide-up">
                            <button className="menu-select" ref={e => {
                                selectionButtonRef = e
                            }}>
                                {this.type ?? "Select..."}
                            </button>
                        </Dropdown>
                    </td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: [
                {
                    title: "Done",
                    type: "normal"
                }
            ]
        }
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.ring.shadowColor("#70a5fd")
        this.ring.shadowBlur(8)
        this.ring.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.ring.shadowOpacity(0)
    }

    async getStateDict(): Promise<Record<string, any>> {
        return {
            type: this.type
        }
    }

    async loadStateDict(data: Record<string, any>): Promise<void> {
        try {
            this.layer = tf.layers.activation({ activation: data.type })
            this.type = data.type
            this.subtitleLabel.text(this.type)
        } catch (error) {
            console.warn("Failed to load state dict for constant", data, error)
        }
    }

}

export default ActivationFn;