import Konva from "konva";
import Block from "../Block";
import * as tf from "@tensorflow/tfjs"
import { MaxPooling1D, MaxPooling2D, MaxPooling3D, AveragePooling1D, AveragePooling2D, AveragePooling3D, Pooling1D, Pooling2D, Pooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling1D, GlobalPooling2D } from "@tensorflow/tfjs-layers/dist/layers/pooling"
import Receptor from "../Receptor";
import Endpoint from "../Endpoint";
import { InspectorProps } from "../../Interfaces";
import Dropdown from "rc-dropdown";
import { KonvaEventObject } from "konva/lib/Node";
import { capitalize } from "../../../Utils";

export type PoolingType = "max" | "avg" | "global_max" | "global_avg"

class Pooling extends Block {
    type_id = "pooling"
    container: Konva.Line
    titleLabel: Konva.Text
    subtitleLabel: Konva.Text
    dim: number // 1, 2 or 3 only
    poolingType: PoolingType = "max"
    get quotaId(): string {
        let id = ""
        if (this.poolingType.startsWith("global_avg")) {
            id += "global"
        }
        if (this.poolingType.endsWith("avg")) {
            id += "avgpool"
        } else {
            id += "maxpool"
        }
        id += `${this.dim}d`
        return id
    }
    layer: Pooling1D | Pooling2D | Pooling3D | GlobalPooling1D | GlobalPooling2D
    // transposeLayer = new TransposeLayer()
    weightLoaded = false

    get blockName() { return "Pooling Layer" }
    get displayedName() {
        if (this.poolingType.endsWith("max")) {
            return `Max Pool ${this.dim}D`
        } else {
            return `Avg Pool ${this.dim}D`
        }
    }
    get description(): string {
        if (this.poolingType === "avg" || this.poolingType === "max") {
            const poolSize = this.layer.getConfig().poolSize
            return `Size: (${poolSize})`
        } else {
            return "Global"
        }
    }

    constructor(id: string, poolingType: PoolingType = "max", dim = 2) {
        super(id)
        this.dim = dim
        this.poolingType = poolingType
        if (poolingType === "max") {
            if (dim === 1) {
                this.layer = new MaxPooling1D({ poolSize: 2 })
            } else if (dim === 2) {
                this.layer = new MaxPooling2D({ poolSize: 2, dataFormat: "channelsFirst" })
            } else {
                this.dim = 3 
                this.layer = new MaxPooling3D({ poolSize: 2, dataFormat: "channelsFirst" })
            }
        } else if (poolingType === "avg") {
            if (dim === 1) {
                this.layer = new AveragePooling1D({ poolSize: 2 })
            } else if (dim === 2) {
                this.layer = new AveragePooling2D({ poolSize: 2, dataFormat: "channelsFirst" })
            } else {
                this.dim = 3
                this.layer = new AveragePooling3D({ poolSize: 2, dataFormat: "channelsFirst" })
            }
        } else if (poolingType === "global_avg") {
            if (dim === 1) {
                this.layer = new GlobalAveragePooling1D()
            } else {
                this.dim = 2
                this.layer = new GlobalAveragePooling2D({ dataFormat: "channelsFirst" })
            }
        } else {
            if (dim === 1) {
                this.layer = new GlobalMaxPooling1D({})
            } else {
                this.dim = 2
                this.layer = new GlobalMaxPooling2D({ dataFormat: "channelsFirst" })
            }
        }
        this.element = new Konva.Group({
            draggable: true,
            offsetX: 65,
            offsetY: 35
        })
        this.container = new Konva.Line({
            points: [
                0, 0,
                125, 14,
                125, 54,
                0, 68
            ],
            fill: "#fff7dc",
            stroke: "#f6e060",
            strokeWidth: 3,
            closed: true
        })

        this.titleLabel = new Konva.Text({
            text: this.displayedName,
            fontSize: 16,
            fontStyle: "Bold",
            align: "left",
            width: this.container.width() - 10,
            x: 9,
            y: 19
        })

        this.subtitleLabel = new Konva.Text({
            text: this.description,
            fontSize: 13,
            align: "left",
            width: this.container.width() - 10,
            x: 9,
            y: 39
        })

        this.element.add(this.container)
        this.element.add(this.titleLabel)
        this.element.add(this.subtitleLabel)

        this.inputs = [
            new Receptor(this, 0, 0, this.container.height() / 2),
        ]
        this.outputs = [
            new Endpoint(this, 0, this.container.width(), this.container.height() / 2)
        ]
        this.weightLoaded = true
    }

    getDocumentation(): string {
        return this.poolingType
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided && this.weightLoaded) {
            try {
                let input = this.inputs[0].currentValue

                if (input instanceof tf.Tensor) {
                    if (input.shape.length - 1 === this.dim) {
                        input = input.reshape([1, ...input.shape])
                    } else if (input.shape.length === this.dim) {
                        input = input.reshape([1, ...input.shape, 1])
                    }
                    // input = this.transposeLayer.apply(input) as tf.Tensor
                } else if (input instanceof tf.SymbolicTensor) {
                    // input = this.transposeLayer.apply(input) as tf.SymbolicTensor
                }
                this.currentValue = this.layer.apply(input!) as (tf.Tensor | tf.SymbolicTensor)
            
                this.container.stroke("#f6e060")
            } catch (error) {
                this.container.stroke("#f01010")
                console.warn(error)
            }
        } else {
            this.container.stroke("#f6e060")
            this.currentValue = this.inputs[0].currentValue
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    onClickMenu(): InspectorProps {
        const isGlobal = this.poolingType.startsWith("global")
        let oldPoolSize = 0
        if (!isGlobal) {
            oldPoolSize = (this.layer.getConfig().poolSize as number[])[0]
        }

        let poolSizeRef: HTMLInputElement | null = null
        let selectionButtonRef: HTMLButtonElement | null = null

        const poolSizeField = <input type="text" placeholder="Enter Integer" disabled={!this.editable || isGlobal} defaultValue={isGlobal ? "Global" : `${oldPoolSize}`} className="custom-textfield" style={{border: "1px solid #f0f0f0", width: 'calc(100% - 40px)', textAlign: 'right'}} id="linear-input-shape" ref={(e) => {
            poolSizeRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        
        const menu = <div className="menu">
            {["1D", "2D", ...(isGlobal ? [] : ["3D"])].map(name => <button key={name} onClick={ (e) => {
                if (selectionButtonRef) selectionButtonRef.innerText = name;
            }}>{name}</button>)}
        </div>

        const table = <table className='info-table'>
            <tbody>
                {<tr>
                    <td>Pooling Size</td>
                    <td>{poolSizeField}</td>
                </tr>}
                <tr>
                    <td>Dimension</td>
                    <td>
                        <Dropdown trigger={['click']} overlay={menu} animation="slide-up">
                            <button className="menu-select" style={{width: "70px"}} ref={e => {
                                selectionButtonRef = e
                            }}>
                                {this.dim + "D"}
                            </button>
                        </Dropdown>
                    </td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: this.editable ? [
                {
                    title: "Reinitialize Weights",
                    type: "normal",
                    onClick: () => {
                        try {
                            const dimDescription = selectionButtonRef!.innerText
                            if (this.poolingType === "max" || this.poolingType === "avg") {
                                const poolSize = Number(poolSizeRef!.value)
                                if ([poolSize].some(x => Number.isNaN(x) || !Number.isInteger(x) || poolSize < 1)) {
                                    alert("Pool size is invalid. Try again.")
                                    return false
                                }
                                if (this.layer.built) this.layer.dispose()
                                let newLayer: Pooling1D | Pooling2D | Pooling3D
                                let newDim: number
                                if (dimDescription === "1D") {
                                    newLayer = new (this.poolingType === "max" ? MaxPooling1D : AveragePooling1D)({ poolSize: poolSize })
                                    newDim = 1
                                } else if (dimDescription === "2D") {
                                    newLayer = new (this.poolingType === "max" ? MaxPooling2D : AveragePooling2D)({ poolSize: poolSize })
                                    newDim = 2
                                } else {
                                    newLayer = new (this.poolingType === "max" ? MaxPooling3D : AveragePooling3D)({ poolSize: poolSize })
                                    newDim = 3
                                }
                                
                                this.dim = newDim
                                this.layer = newLayer
                            } else if (this.poolingType === "global_max") {
                                if (this.layer.built) this.layer.dispose()
                                let newLayer: GlobalPooling1D | GlobalPooling2D
                                let newDim: number
                                if (dimDescription === "1D") {
                                    newDim = 1
                                    newLayer = new GlobalMaxPooling1D({})
                                } else {
                                    newDim = 2
                                    newLayer = new GlobalMaxPooling2D({ dataFormat: "channelsFirst" })
                                }
                                this.dim = newDim
                                this.layer = newLayer
                            } else if (this.poolingType === "global_avg") {
                                if (this.layer.built) this.layer.dispose()
                                let newLayer: GlobalPooling1D | GlobalPooling2D
                                let newDim: number
                                if (dimDescription === "1D") {
                                    newDim = 1
                                    newLayer = new GlobalAveragePooling1D({})
                                } else {
                                    newDim = 2
                                    newLayer = new GlobalAveragePooling2D({ dataFormat: "channelsFirst" })
                                }
                                this.dim = newDim
                                this.layer = newLayer
                            }
                            
                            let input = this.inputs[0].currentValue
                            if (input instanceof tf.Tensor || input instanceof tf.SymbolicTensor) {
                                this.onInputUpdated(0)
                            } else {
                                this.currentValue = undefined
                            }
                            this.titleLabel.text(this.displayedName)
                            this.subtitleLabel.text(this.description)
                            this.globalState.visitedReceptorCount.clear()
                            this.saveWeights()
                            return true
                        } catch (error) {
                            console.warn(error)
                            this.globalState.visitedReceptorCount.clear()
                            alert("Invalid shape.")
                            poolSizeRef!.value = `${oldPoolSize}`
                            return false
                        }
                    }
                }
            ] : [],
            docs: this.documentation
        }
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.container.shadowColor("#70a5fd")
        this.container.shadowBlur(8)
        this.container.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.container.shadowOpacity(0)
    }

    async loadStateDict(data: Record<string, any>): Promise<void> {
        try {
            this.dim = data.dim
            this.poolingType = data.poolingType
            const poolSize = data.poolSize ?? 2
            if (this.poolingType === "max") {
                if (this.dim === 1) {
                    this.layer = new MaxPooling1D({ poolSize: poolSize })
                } else if (this.dim === 2) {
                    this.layer = new MaxPooling2D({ poolSize: poolSize, dataFormat: "channelsFirst" })
                } else {
                    this.dim = 3 
                    this.layer = new MaxPooling3D({ poolSize: poolSize, dataFormat: "channelsFirst" })
                }
            } else if (this.poolingType === "avg") {
                if (this.dim === 1) {
                    this.layer = new AveragePooling1D({ poolSize })
                } else if (this.dim === 2) {
                    this.layer = new AveragePooling2D({ poolSize, dataFormat: "channelsFirst" })
                } else {
                    this.dim = 3
                    this.layer = new AveragePooling3D({ poolSize, dataFormat: "channelsFirst" })
                }
            } else if (this.poolingType === "global_avg") {
                if (this.dim === 1) {
                    this.layer = new GlobalAveragePooling1D()
                } else {
                    this.dim = 2
                    this.layer = new GlobalAveragePooling2D({ dataFormat: "channelsFirst" })
                }
            } else {
                if (this.dim === 1) {
                    this.layer = new GlobalMaxPooling1D({})
                } else {
                    this.dim = 2
                    this.layer = new GlobalMaxPooling2D({ dataFormat: "channelsFirst" })
                }
            }
        } catch {
            this.dim = 2
            this.poolingType = "max"
        }
        this.subtitleLabel.text(this.description)
    }

    async getStateDict(): Promise<Record<string, any>> {
        return {
            dim: this.dim,
            poolingType: this.poolingType,
            poolSize: this.layer.getConfig().poolSize
        }
    }
}

export default Pooling;