import * as tf from "@tensorflow/tfjs"
import Block from "../Block";
import Receptor from "../Receptor";
import ContentBlock from "../ContentBlock";
import Konva from "konva";
import { InspectorProps } from "../../Interfaces";
import { tensorToString } from "../../Utils";

class Predictor extends ContentBlock {

    type_id = "predictor"
    testSet?: tf.Tensor | null
    textLabel: Konva.Text
    batchSize = 8
    model?: tf.LayersModel | null

    constructor(id: string) {
        super(id, 160, 140, {
            backgroundColor: "#faf7d0"
        })
        this.titleLabel.text("Predictor")
        this.textLabel = new Konva.Text({
            width: this.container.width(),
            height: this.container.height() - this.titleBar.height(),
            verticalAlign: "middle",
            align: "center",
            fontSize: 16,
            y: this.titleBar.height(),
            fill: "#101010",
            text: "No Value",
            padding: 5
        })
        this.element.add(this.textLabel)

        this.inputs = [
            new Receptor(this, 0, 0, this.titleBar.height() + this.contentHeight / 3, "Model", "model"),
            new Receptor(this, 0, 1, this.titleBar.height() + this.contentHeight * 2 / 3, "Test Set")
        ]
    }

    onClickMenu(): InspectorProps {
        let batchSizeRef: HTMLInputElement | null = null
        const editArea = <input type="text" placeholder="Batch Size" defaultValue={this.batchSize} className="custom-textarea" style={{border: "1px solid #f0f0f0", width: '50px', textAlign: 'right', maxHeight: "30px"}} ref={(e) => {
            batchSizeRef = e
        }} />
        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>{`Batch Size`}</td>
                    <td>{editArea}</td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: [
                {
                    title: "Done",
                    type: "normal"
                },
                {
                    title: "Delete",
                    type: "destructive",
                    onClick: this.destroy
                }
            ]
        }
    }

    onInputUpdated(index: number): boolean {
        this.model = this.inputs[0].currentModel
        this.testSet = this.inputs[1].currentValue as tf.Tensor
        if (this.allRequiredInputsProvided) {
            this.predict()
        } else if (this.model === undefined) {
            this.textLabel.text("No Model")
        } else if (this.testSet === undefined) { 
            this.textLabel.text("No Test Set")
        } else if (this.model === null || this.testSet === null) {
            this.textLabel.text("Waiting...")
        }
        this.updateReceptorsAndEndpoints()
        return true
    }

    predict() {
        const p = this.model!.predict(this.inputs[1].currentValue as tf.Tensor, { batchSize: 8 })
        if (p instanceof Array) {
            this.textLabel.text(p.map(tensorToString).join("\n"))
        } else {
            this.textLabel.text(tensorToString(p))
        }
    }

    updateReceptorsAndEndpoints(): void {
        //@ts-ignore
        this.textLabel.height(null)
        this.textLabel.height(Math.min(350, Math.max(115, this.textLabel.height())))
        this.container.height(this.titleBar.height() + this.textLabel.height())
        this.inputs[0].element.y(this.titleBar.height() + this.contentHeight / 3)
        this.inputs[1].element.y(this.titleBar.height() + this.contentHeight * 2 / 3)
        if (this.inputs[0].connection) { // doesn't exist when creating new
            const [x1, y1, x2, y2] = this.inputs[0].connection.line.points()
            this.inputs[0].connection.line.points([x1, y1, x2, this.inputs[0].element.absolutePosition().y])
            this.inputs[0].element.fire("mouseleave", {evt: { buttons: 0 }})
        }

        this.inputs.forEach((receptor, i) => {
            if (receptor.connection) {
                const [x1, y1, x2, y2] = receptor.connection.line.points()
                receptor.connection.line.points([x1, y1, x2, receptor.element.absolutePosition().y])
                receptor.element.fire("mouseleave", {evt: { buttons: 0 }})
            }
        })
    }
}

export default Predictor;