import * as tf from "@tensorflow/tfjs"
import ContentBlock from "../ContentBlock";
import { InspectorProps } from "../../Interfaces";
import Konva from "konva";
import Endpoint from "../Endpoint";
import { Fragment } from "react";
import { tensorToString } from "../../Utils";
import TensorDisplay from "../Inputs/TensorDisplay";
import Source from "../SourceBlock";

class Parameter extends TensorDisplay implements Source {

    type_id = "parameter"
    currentValue: tf.Variable
    
    get blockName(): string {
        return "Parameter"
    }
    get displayedName(): string {
        return this.customName ?? "Parameter"
    }

    getDocumentation(): string {
        return `Parameters are trainable constants that are updated as a neural network adjusts itself to fit the training examples. The values they hold represent the intelligence of a neural network.`
    }

    constructor(id: string) {
        super(id)

        this.currentValue = tf.variable(tf.tensor(0), true)
        this.titleLabel.text(this.displayedName)
        this.container.fill("#e0eeff")
        this.outputs = [
            new Endpoint(this, 0, this.container.width(), this.titleBar.height() + this.contentHeight / 2)
        ]
        this.outputs[0].currentValue = this.currentValue
    }

    propagate() {
        this.outputs.forEach(o => o.propagate(this.currentValue))
    }

    onClickMenu(): InspectorProps {
        let ref: HTMLTextAreaElement | null = null
        let nameRef: HTMLInputElement | null = null
        const editArea = <Fragment>
            <textarea placeholder="Enter Parameter Value" defaultValue={tensorToString(this.currentValue)} className="custom-textarea" style={{width: "100%", minHeight: "100px", resize: "vertical"}} ref={(e) => {
                ref = e
                if (e) { e.value = e.defaultValue }
            }} />
            <br />
        </Fragment>
        const nameField = <input type="text" placeholder="Name" defaultValue={this.displayedName} className="custom-textarea" style={{textAlign: 'right', maxHeight: "30px", fontFamily: "Menlo, monospace", width: "100%"}} ref={(e) => {
            nameRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>{`Name`}</td>
                    <td>{nameField}</td>
                </tr>
                <tr>
                    <td>{'Value'}</td>
                    <td>{editArea}</td>
                </tr>
            </tbody>
        </table>


        return {
            title: this.displayedName,
            settings: <Fragment>
                {table}
                {/* {editArea} */}
            </Fragment>,
            buttons: [
                {
                    title: "Cancel",
                    type: "normal",
                    onClick: () => {
                        ref!.value = tensorToString(this.currentValue)
                        nameRef!.value = this.displayedName
                        return true
                    }
                },
                {
                    title: "Save",
                    type: "normal",
                    onClick: () => {
                        try {
                            const data = tf.tensor(JSON.parse(ref!.value))
                            this.currentValue = tf.variable(data, true)
                            this.text.text(this.currentValue.shape.join(" x "))
                            this.outputs[0].propagate(this.currentValue)
                            this.value = data // For parent class
                            this.nativeValue = data.dataSync() // For parent class
                            this.updateReceptorsAndEndpoints()
                            this.customName = nameRef!.value
                            this.titleLabel.text(nameRef!.value)
                            this.globalState.visitedReceptorCount.clear()
                            return true
                        } catch (error) {
                            this.globalState.visitedReceptorCount.clear()
                            alert("Invalid data format.")
                            console.warn(error)
                            return false
                        }
                    }
                }
            ],
            docs: this.documentation
        }
    }

    setNewValue(value: tf.Tensor) {
        super.setNewValue(value)
        this.currentValue.assign(value)
    } 

    updateReceptorsAndEndpoints(): void {
        super.updateReceptorsAndEndpoints()

        this.outputs[0].element.y(this.titleBar.height() + this.contentHeight / 2)
        this.outputs[0].element.x(this.container.width())
        const { x: offsetX, y: offsetY } = this.globalState.stage!.offset()
        this.outputs[0].connections.forEach(c => {
            const [x1, y1, x2, y2] = c.line.points()
            c.line.points([
                this.outputs[0].element.absolutePosition().x + offsetX,
                this.outputs[0].element.absolutePosition().y + offsetY,
                x2,
                y2])
        })
    }

    async getStateDict() {
        return {
            value: await this.currentValue.array(),
            displayData: this.displayData
        }
    }
    
    async loadStateDict(data: Record<string, any>) {
        try {
            this.currentValue = tf.variable(tf.tensor(data.value), true)
            this.displayData = data.displayData ?? true
            this.titleLabel.text(this.displayedName)
            this.setNewValue(this.currentValue)
        } catch (error) {
            console.warn("Failed to load state dict for parameter", data, error)
        }
    }
}

export default Parameter;