import * as tf from "@tensorflow/tfjs"
import Konva from "konva";
import Endpoint from "../Endpoint";
import Receptor from "../Receptor";
import { SquareRootLayer } from "../../Layers/Arithmetics";
import Block from "../Block";
import { InspectorProps } from "../../Interfaces";
import { Dense } from "@tensorflow/tfjs-layers/dist/layers/core"
import { KonvaEventObject } from "konva/lib/Node";
import SymbolicInput from "../Inputs/SymbolicInput";

class Linear extends Block {

    type_id = "linear"
    container: Konva.Rect
    titleLabel: Konva.Text
    subtitleLabel: Konva.Text
    layer: Dense
    weightLoaded = true

    get blockName() { return "Linear Layer" }
    get displayedName() { return "Linear Layer" }
    get description() { return `${this.layer.batchInputShape[1]} x ${this.layer.getConfig().units}` }

    constructor(id: string) {
        super(id)

        this.layer = tf.layers.dense({ inputShape: [2], units: 5 })
        this.element = new Konva.Group({
            draggable: true,
            // offsetX: 65,
            // offsetY: 35
            width: 130,
            height: 70
        })
        this.container = new Konva.Rect({
            cornerRadius: 8,
            strokeWidth: 3,
            width: 130,
            height: 70,
            stroke: "#90c76c",
            fill: "#edffe0"
        })

        this.titleLabel = new Konva.Text({
            text: "Linear",
            fontSize: 18,
            fontStyle: "Bold",
            align: "center",
            width: this.container.width(),
            x: 0,
            y: 16
        })

        this.subtitleLabel = new Konva.Text({
            text: this.description,
            fontSize: 13,
            align: "center",
            width: this.container.width(),
            x: 0,
            y: 40
        })

        this.element.add(this.container)
        this.element.add(this.titleLabel)
        this.element.add(this.subtitleLabel)

        this.inputs = [
            new Receptor(this, 0, 0, this.container.height() / 2),
        ]
        this.outputs = [
            new Endpoint(this, 0, this.container.width(), this.container.height() / 2)
        ]
    }

    getDocumentation(): string {
        return `Linear`
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided && this.weightLoaded) {
            try {
                this.currentValue = this.layer.apply(this.inputs[0].currentValue!) as tf.SymbolicTensor
                this.container.stroke("#90c76c")
            } catch (error) {
                this.container.stroke("#f01010")
                console.warn(error)
            }
        } else {
            this.container.stroke("#90c76c")
            this.currentValue = null // this.inputs[0].currentValue
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    onClickMenu(): InspectorProps {
        let inputRef: HTMLInputElement | null = null
        let outputRef: HTMLInputElement | null = null
        const inField = <input type="text" placeholder="Enter Integer" disabled={!this.editable} defaultValue={`${this.layer.batchInputShape[1]}`} className="custom-textfield" style={{border: "1px solid #f0f0f0", width: 'calc(100% - 40px)', textAlign: 'right'}} id="linear-input-shape" ref={(e) => {
            inputRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        const outField = <input type="text" placeholder="Enter Integer" disabled={!this.editable} defaultValue={`${this.layer.getConfig().units}`} className="custom-textfield" style={{border: "1px solid #f0f0f0", width: 'calc(100% - 40px)', textAlign: 'right'}} id="linear-output-shape" ref={(e) => {
            outputRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>Input Size</td>
                    <td>{inField}</td>
                </tr>
                <tr>
                    <td>Output Size</td>
                    <td>{outField}</td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: this.editable ? [
                {
                    title: "Reinitialize Weights",
                    type: "normal",
                    onClick: () => {
                        const oldInputShape = this.layer.batchInputShape[1]
                        const oldOutputShape = this.layer.getConfig().units
                        try {
                            const inputShape = Number(inputRef!.value)
                            const outputShape = Number(outputRef!.value)
                            if ([inputShape, outputShape].some(x => Number.isNaN(x) || !Number.isInteger(x))) {
                                alert("One of the numbers is invalid. Try again.")
                                return false
                            }
                            this.layer = tf.layers.dense({ units: outputShape, inputShape: [inputShape] })
                            
                            let input = this.inputs[0].currentValue
                            if (input instanceof tf.Tensor || input instanceof tf.SymbolicTensor) {
                                this.inputs[0].currentValue = input
                                this.onInputUpdated(0)
                            } else {
                                this.currentValue = undefined
                            }
                            this.subtitleLabel.text(this.description)
                            this.globalState.visitedReceptorCount.clear()
                            this.saveWeights()
                            return true
                        } catch (error) {
                            console.warn(error)
                            this.globalState.visitedReceptorCount.clear()
                            alert("Invalid shape.")
                            inputRef!.value = `${oldInputShape}`
                            outputRef!.value = `${oldOutputShape}`
                            return false
                        }
                    }
                }
            ] : [],
            docs: this.documentation
        }
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.container.shadowColor("#70a5fd")
        this.container.shadowBlur(8)
        this.container.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.container.shadowOpacity(0)
    }

    async loadStateDict(data: Record<string, any>): Promise<void> {
        this.subtitleLabel.text("loading weights...")
        try {
            const { inputShape, outputShape } = data
            if (inputShape && outputShape) {
                this.layer = tf.layers.dense({ units: outputShape, inputDim: inputShape })
            }
            this.weightLoaded = false
            this.loadWeights()
        } catch (error) {
            console.warn("Failed to load state dict for linear layer", data, error)
        }
        this.subtitleLabel.text(this.description)
    }

    async getStateDict(): Promise<Record<string, any>> {
        return { inputShape: this.layer.batchInputShape[1], outputShape: this.layer.getConfig().units }
    }

    async saveWeights() {
        const weightArray = this.layer.getWeights()
        const serializedWeights = await Promise.all(weightArray.map(t => t.array()))

        const req = indexedDB.open("main")

        req.onsuccess = e => {
            const db = (e.target as IDBOpenDBRequest).result
            const tx = db.transaction("weights", "readwrite")
            const store = tx.objectStore("weights")
            store.put(serializedWeights, `${this.globalState.activeExercise?.exercise_id ?? "global"}/${this.id}`)
            tx.oncomplete = () => {
                db.close()
            }
        }
    }

    async loadWeights() {
        const req = indexedDB.open("main")
        req.onsuccess = e => {
            const db = (e.target as IDBOpenDBRequest).result
            const tx = db.transaction("weights", "readonly")
            const store = tx.objectStore("weights")
            const getRequest = store.get(`${this.globalState.activeExercise?.exercise_id ?? "global"}/${this.id}`)
            getRequest.onsuccess = e => {
                const target = e.target as IDBRequest
                if (target.result) {
                    const serializedWeights = target.result as number[][]
                    try {
                        if (serializedWeights.length > 0) {
                            this.layer.initialWeights = serializedWeights.map(a => tf.tensor(a))
                        }
                        this.weightLoaded = true
                        this.onInputUpdated(0)
                    } catch (error) {
                        console.warn(error)
                        console.log(this.layer)
                        this.weightLoaded = true
                    }
                } else {
                    console.log('did not find weights for', this.id)
                    this.weightLoaded = true
                    this.onInputUpdated(0)
                }
                this.subtitleLabel.text(this.description)
                db.close()
            }
        }
    }
}

export default Linear;