import * as tf from "@tensorflow/tfjs"
import Konva from "konva";
import Endpoint from "../Endpoint";
import Receptor from "../Receptor";
import { SquareRootLayer, TransposeLayer } from "../../Layers/Arithmetics";
import Block from "../Block";
import { InspectorProps } from "../../Interfaces";
import { Conv1D as TFConv1D, conv1d } from "@tensorflow/tfjs-layers/dist/layers/convolutional"
import { KonvaEventObject } from "konva/lib/Node";
import SymbolicInput from "../Inputs/SymbolicInput";
import Dropdown from "rc-dropdown";

class Conv1D extends Block {

    type_id = "conv1d"
    container: Konva.Rect
    titleLabel: Konva.Text
    subtitleLabel: Konva.Text
    layer: TFConv1D
    weightLoaded = false

    get blockName() { return "Conv1D Layer" }
    get displayedName() { return "Conv1D Layer" }
    get description(): string {
        const inputChannels = this.layer.batchInputShape[this.layer.batchInputShape.length - 1]
        const outputChannels = this.layer.getConfig().filters
        const kernelSize = this.layer.getConfig().kernelSize
        return `${inputChannels} → ${outputChannels} [${kernelSize}]`
    }

    constructor(id: string) {
        super(id)
        
        this.layer = tf.layers.conv1d({ filters: 5, kernelSize: 3, inputShape: [null, 1] })
        this.element = new Konva.Group({
            draggable: true,
            offsetX: 65,
            offsetY: 35
        })
        this.container = new Konva.Rect({
            cornerRadius: 8,
            strokeWidth: 3,
            width: 130,
            height: 70,
            stroke: "#e8e063",
            fill: "#fffede"
        })

        this.titleLabel = new Konva.Text({
            text: "Conv 1D",
            fontSize: 17,
            fontStyle: "Bold",
            align: "center",
            width: this.container.width(),
            x: 0,
            y: 18
        })

        this.subtitleLabel = new Konva.Text({
            text: this.description,
            fontSize: 13,
            align: "center",
            width: this.container.width(),
            x: 0,
            y: 39
        })

        this.element.add(this.container)
        this.element.add(this.titleLabel)
        this.element.add(this.subtitleLabel)

        this.inputs = [
            new Receptor(this, 0, 0, this.container.height() / 2),
        ]
        this.outputs = [
            new Endpoint(this, 0, this.container.width(), this.container.height() / 2)
        ]
        this.weightLoaded = true
    }

    getDocumentation(): string {
        return `Convolution`
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided && this.weightLoaded) {
            try {
                let input = this.inputs[0].currentValue

                
                if (input instanceof tf.Tensor) {
                    if (input.shape.length === 2) {
                        input = input.reshape([1, ...input.shape])
                    }
                }
                this.currentValue = this.layer.apply(input!) as (tf.Tensor | tf.SymbolicTensor)
            
                this.container.stroke("#e8e063")
            } catch (error) {
                this.container.stroke("#f01010")
                console.warn(error)
            }
        } else {
            this.container.stroke("#e8e063")
            this.currentValue = undefined
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    onClickMenu(): InspectorProps {
        const oldInputChannels = this.layer.batchInputShape[this.layer.batchInputShape.length - 1]
        const oldOutputChannels = this.layer.getConfig().filters as number
        const oldKernelSize = this.layer.getConfig().kernelSize as number
        let selectionButtonRef: HTMLButtonElement | null = null


        let kernelRef: HTMLInputElement | null = null
        let inputRef: HTMLInputElement | null = null
        let outputRef: HTMLInputElement | null = null
        const kernelField = <input type="text" placeholder="Enter Integer" disabled={!this.editable} defaultValue={`${oldKernelSize}`} className="custom-textfield" style={{border: "1px solid #f0f0f0", width: 'calc(100% - 40px)', textAlign: 'right'}} id="linear-input-shape" ref={(e) => {
            kernelRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        const inputField = <input type="text" placeholder="Enter Integer" disabled={!this.editable} defaultValue={`${oldInputChannels}`} className="custom-textfield" style={{border: "1px solid #f0f0f0", width: 'calc(100% - 40px)', textAlign: 'right'}} id="linear-input-shape" ref={(e) => {
            inputRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        const outField = <input type="text" placeholder="Enter Integer" disabled={!this.editable} defaultValue={`${oldOutputChannels}`} className="custom-textfield" style={{border: "1px solid #f0f0f0", width: 'calc(100% - 40px)', textAlign: 'right'}} id="linear-output-shape" ref={(e) => {
            outputRef = e
            if (e) { e.value = e.defaultValue }
        }} />

        const menu = <div className="menu">
            {["valid", "same"].map(name => <button key={name} onClick={ (e) => {
                selectionButtonRef!.innerText = name.toLowerCase();

            }}>{name}</button>)}
        </div>

        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>Kernel Size</td>
                    <td>{kernelField}</td>
                </tr>
                <tr>
                    <td>Input Channels</td>
                    <td>{inputField}</td>
                </tr>
                <tr>
                    <td>Output Channels</td>
                    <td>{outField}</td>
                </tr>
                <tr>
                    <td>Padding</td>
                    <td>
                        <Dropdown trigger={['click']} overlay={menu} animation="slide-up">
                            <button className="menu-select" ref={e => selectionButtonRef = e}>
                                {this.layer.getConfig().padding as string ?? "Select..."}
                            </button>
                        </Dropdown>
                    </td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: this.editable ? [
                {
                    title: "Reinitialize Weights",
                    type: "normal",
                    onClick: () => {
                        try {
                            const kernelSize = Number(kernelRef!.value)
                            const inputChannels = Number(inputRef!.value)
                            const outputShape = Number(outputRef!.value)
                            if ([kernelSize, inputChannels, outputShape].some(x => Number.isNaN(x) || !Number.isInteger(x))) {
                                alert("One of the numbers is invalid. Try again.")
                                return false
                            }
                            try {
                                this.layer.dispose()
                            } catch {}
                            this.layer = tf.layers.conv1d({ filters: outputShape, kernelSize, inputShape: [null, inputChannels], padding: selectionButtonRef!.innerText as 'same' | 'valid' })
                            
                            let input = this.inputs[0].currentValue
                            if (input instanceof tf.Tensor && input.shape.length === 2) {
                                input = input.reshape([1, ...input.shape])
                            }
                            if (input instanceof tf.Tensor || input instanceof tf.SymbolicTensor) {
                                this.onInputUpdated(0)
                            } else {
                                this.currentValue = undefined
                            }
                            this.subtitleLabel.text(this.description)
                            this.globalState.visitedReceptorCount.clear()
                            this.saveWeights()
                            return true
                        } catch (error) {
                            console.warn(error)
                            this.globalState.visitedReceptorCount.clear()
                            alert("Invalid shape.")
                            kernelRef!.value = `${oldKernelSize}`
                            inputRef!.value = `${oldInputChannels}`
                            outputRef!.value = `${oldOutputChannels}`
                            return false
                        }
                    }
                }
            ] : [],
            docs: this.documentation
        }
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.container.shadowColor("#70a5fd")
        this.container.shadowBlur(8)
        this.container.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.container.shadowOpacity(0)
    }

    async loadStateDict(data: Record<string, any>): Promise<void> {
        this.subtitleLabel.text("Loading weights...")
        this.weightLoaded = false
        try {
            const { kernelSize, inputShape, filters, padding } = data
            if (kernelSize && inputShape && filters) {
                this.layer = new TFConv1D({ kernelSize, filters, batchInputShape: inputShape, padding })
            }
            this.weightLoaded = true
        //     // this.loadWeights()
        } catch (error) {
            console.warn("Failed to load state dict for constant", data, error)
        }
        this.subtitleLabel.text(this.description)
    }

    async getStateDict(): Promise<Record<string, any>> {
        return {
            kernelSize: this.layer.getConfig().kernelSize,
            inputShape: this.layer.getConfig().batchInputShape as number[],
            filters: this.layer.getConfig().filters,
            padding: this.layer.getConfig().padding
        }
    }

    async saveWeights() {
        const weightArray = this.layer.getWeights()
        const serializedWeights = await Promise.all(weightArray.map(t => t.array()))

        const req = indexedDB.open("main")

        req.onsuccess = e => {
            const db = (e.target as IDBOpenDBRequest).result
            const tx = db.transaction("weights", "readwrite")
            const store = tx.objectStore("weights")
            store.put(serializedWeights, `${this.globalState.activeExercise?.exercise_id ?? "global"}/${this.id}`)
            tx.oncomplete = () => {
                db.close()
            }
        }
    }

    async loadWeights() {
        const req = indexedDB.open("main")
        req.onsuccess = e => {
            const db = (e.target as IDBOpenDBRequest).result
            const tx = db.transaction("weights", "readonly")
            const store = tx.objectStore("weights")
            const getRequest = store.get(`${this.globalState.activeExercise?.exercise_id ?? "global"}/${this.id}`)
            getRequest.onsuccess = e => {
                const target = e.target as IDBRequest
                if (target.result) {
                    const serializedWeights = target.result as number[][]
                    try {
                        this.layer.initialWeights = serializedWeights.map(a => tf.tensor(a))
                        this.weightLoaded = true
                        this.onInputUpdated(0)
                        console.log('successfully loaded weight for', this.id)
                    } catch (error) {
                        console.warn(error)
                        console.log(this.layer)
                        this.weightLoaded = true
                    }
                } else {
                    console.log('did not find weights for', this.id)
                    this.weightLoaded = true
                    this.onInputUpdated(0)
                }
                this.subtitleLabel.text(this.description)
                db.close()
            }
        }
    }
}

export default Conv1D;