import * as tf from "@tensorflow/tfjs"
import Block from "../Block";
import Konva from "konva";
import Receptor from "../Receptor";
import { round, tensorToString } from "../../Utils";
import { KonvaEventListener, KonvaEventObject } from "konva/lib/Node";
import { InspectorProps, CustomTFDataset, LossType } from "../../Interfaces";
import Dropdown from "rc-dropdown";
import SymbolicInput from "../Inputs/SymbolicInput";
import Parameter from "./Parameter";
import CustomDataset from "./CustomDataset";
import { axiosInstance } from "../../../Utils";
import { Buffer } from "buffer"

class Trainer extends Block {

    type_id = "trainer"
    titleLabel: Konva.Text
    container: Konva.Rect
    model?: tf.LayersModel
    dataset?: CustomTFDataset
    datasetProvider?: () => Promise<CustomTFDataset | undefined>
    batchSize = 16
    learningRate = 0.001
    epochs = 3
    printInterval = 25
    lossType: LossType = "cross-entropy"

    trainButton: Konva.Group
    trainButtonBg: Konva.Rect
    trainIcon!: Konva.Image
    pauseIcon!: Konva.Image
    trainingWorker?: Worker

    constructor(id: string) {
        super(id)

        this.element = new Konva.Group({
            width: 500,
            height: 80,
            draggable: true
        })
        this.container = new Konva.Rect({
            width: this.element.width(),
            height: this.element.height(),
            fill: "#425ab5",
            cornerRadius: 5
        })
        this.element.add(this.container)

        this.titleLabel = new Konva.Text({
            width: this.element.width() - 60,
            height: this.element.height(),
            x: 50,
            align: "center",
            verticalAlign: "middle",
            fill: "white",
            fontSize: 16,
            text: "Trainer"
        })
        this.element.add(this.titleLabel)

        this.trainButton = new Konva.Group({ x: 10, y: 10 })
        this.trainButton.on("click", (e) => this.onTrainButtonClicked(e))
        this.element.add(this.trainButton)

        this.trainButtonBg = new Konva.Rect({
            fill: "#c0d0f0",
            width: 60,
            height: 60,
            cornerRadius: 5
        })
        this.trainButton.add(this.trainButtonBg)

        Konva.Image.fromURL("/assets/blocks/play.svg", img => {
            img.width(26)
            img.height(26)
            if (this.titleLabel.text() === "Ready to Train") {
                img.opacity(1)
            } else {
                img.opacity(0.5)
            }
            img.x(this.trainButtonBg.width() / 2 - 12)
            img.y(this.trainButtonBg.height() / 2 - 12)
            this.trainIcon = img
            this.trainButton.add(this.trainIcon)
        })

        Konva.Image.fromURL("/assets/blocks/pause.svg", img => {
            img.width(26)
            img.height(26)
            img.opacity(0.5)
            img.visible(false)
            img.x(this.trainButtonBg.width() / 2 - 12)
            img.y(this.trainButtonBg.height() / 2 - 12)
            this.pauseIcon = img
            this.trainButton.add(this.pauseIcon)
        })

        this.inputs = [
            // new Receptor(this, 1, this.element.width() / 6, 0, "Input"),
            new Receptor(this, 0, this.element.width() / 2, 0, "Output"),
            // new Receptor(this, 1, 0, this.element.height() / 2, "Labels", "tensor", false),
        ]

        this.outputs = [
            // new Endpoint(this, 0, this.element.width() / 2, this.element.height(), "Model", "model")
        ]

        this.onTrainButtonClicked = this.onTrainButtonClicked.bind(this);
        this.stopTraining = this.stopTraining.bind(this);
    }

    stopTraining() {
        this.globalState.isTraining = false
        this.trainIcon.visible(true)
        this.pauseIcon.visible(false)
        const symInputs = Object.values(this.globalState.everything).filter(block => block instanceof SymbolicInput) as SymbolicInput[]
        symInputs.forEach(i => {
            this.globalState.visitedReceptorCount.clear()
            i.setTraining(false)
        })
        Object.values(this.globalState.everything).filter(block => block instanceof Parameter).forEach(block => {
            (block as Parameter).propagate()
        })
        Object.values(this.globalState.everything).forEach(block => block.saveWeights())
        this.globalState.activeExercise?.saveToCloud(this.globalState)
    }

    async onTrainButtonClicked(event: KonvaEventObject<MouseEvent>) {
        //@ts-ignore
        event.train = 1
        if (this.allRequiredInputsProvided) {
            if (!this.globalState.isTraining) {
                this.globalState.isTraining = true
                const symInputs = Object.values(this.globalState.everything).filter(block => block instanceof SymbolicInput && !block.isOutput) as SymbolicInput[]

                if (this.inputs[0].currentValue instanceof tf.Tensor) {
                    symInputs.forEach(i => {
                        this.globalState.visitedReceptorCount.clear()
                        i.setTraining(true, this.batchSize)
                    })
                }

                Object.values(this.globalState.everything).filter(block => block instanceof Parameter).forEach(block => {
                    (block as Parameter).propagate()
                })

                const inputTensors = symInputs.map(i => i.value as tf.SymbolicTensor)
                const outputTensor = this.inputs[0].connection?.start.currentValue as tf.SymbolicTensor
                console.log(inputTensors, outputTensor)
                try {
                    // let model = tf.sequential()
                    // model.add(tf.layers.reshape({ targetShape: [128 * 6], inputShape: [128, 6] }))
                    // model.add(tf.layers.dense({ units: 128, inputShape: [128 * 6], activation: "relu", biasInitializer: "heNormal" }))
                    // model.add(tf.layers.dense({ units: 6, inputShape: [128], biasInitializer: "heNormal" }))
                    // model.compile({
                    //     optimizer: tf.train.adam(this.learningRate),
                    //     loss: "categoricalCrossentropy",
                    //     metrics: ['accuracy']
                    //   });
                    // console.log('fake model', model)
                    // this.model = model
                    this.model = tf.model({ inputs: inputTensors, outputs: outputTensor })
                } catch (error) {
                    console.error(error)
                    console.warn('input', inputTensors)
                    console.warn('output', outputTensor)
                    this.stopTraining()
                    this.titleLabel.text("Failed to Start")
                    return
                }

                this.pauseIcon.visible(true)
                this.pauseIcon.opacity(1.0)
                this.trainIcon.visible(false)
                this.titleLabel.text("Training...")
                let dataset = this.dataset as CustomTFDataset

                
                if (!dataset && this.datasetProvider) {
                    this.titleLabel.text("Loading Dataset...")
                    await tf.nextFrame()
                    dataset = dataset ?? await this.datasetProvider() // ?? used just for suppressing type check
                }
                
                if (!dataset) {
                    alert("No dataset provided")
                    this.titleLabel.text("No Dataset")
                    return this.stopTraining()
                }

                const modelJSON = this.model.toJSON(true, false)
                //@ts-ignore
                let result: any = await this.model.save(tf.io.withSaveHandler(async modelArtifacts => {
                    return { modelArtifactsInfo: modelArtifacts }
                }));
                console.log('save result', result)
                //@ts-ignore
                // result.weightData = Buffer.from(result.weightData).toString("base64");
                // const jsonStr = JSON.stringify(result);
                // console.log('weight str', jsonStr)

                const payload = new FormData()
                payload.append("weights", new Blob([result.modelArtifactsInfo.weightData as ArrayBuffer], { type: 'application/octet-stream' }))
                payload.append("model", JSON.stringify(result.modelArtifactsInfo.modelTopology))
                payload.append("weightSpecs", JSON.stringify(result.modelArtifactsInfo.weightSpecs))
                payload.append("config", JSON.stringify(await this.getStateDict()))

                axiosInstance.postForm(`/chapters/${this.globalState.chapterId}/exercises/${this.globalState.activeExercise?.exercise_id}/train`, payload, { headers: {'Content-Type': 'multipart/form-data'} })

                await tf.nextFrame()
                const optimizer = tf.train.adam(this.learningRate)

                let inputs: any = {}
                if (dataset.trainX instanceof Array) {
                    dataset.trainX.forEach((input, i) => {
                        inputs[`input${i+1}`] = tf.data.generator(function*() {
                            for (let j = 0; j < input.shape[0]; j++) {
                                yield input.slice(j, 1).squeeze([0])
                            }
                        })
                    })
                } else {
                    inputs.input1 = tf.data.generator(function*() {
                        for (let i = 0; i < (dataset.trainX as tf.Tensor).shape[0]; i++) {
                            yield (dataset.trainX as tf.Tensor).slice(i, 1).squeeze([0])
                        }
                    })
                }

                const labels = tf.data.generator(function*() {
                    for (let i = 0; i < dataset.trainY.shape[0]; i++) {
                        yield dataset.trainY.slice(i, 1).squeeze([0])
                    }
                })
                const ds = tf.data.zip({inputs, labels}).batch(this.batchSize)
                const totalIters = Math.ceil((dataset.trainX instanceof Array ? dataset.trainX[0] : dataset.trainX).shape[0] / this.batchSize)
                this.titleLabel.text(`Training (epoch ${1} of ${this.epochs}, iter ${0} of ${totalIters})`)
                const train = async () => {
                    await tf.nextFrame()
                    for (let epoch = 1; epoch <= this.epochs; epoch++) {
                        if (!this.globalState.isTraining) return
                        let iter = 0
                        const batchIterator = await ds.iterator()
                        while (true) {
                            const {value: batchValue, done } = await batchIterator.next()
                            if (batchValue === null || done) {
                                break
                            }
                            if (!this.globalState.isTraining) { return }

                            //@ts-ignore
                            const batchInputs = batchValue.inputs as Record<string, tf.Tensor>
                            //@ts-ignore
                            let batchLabels = batchValue.labels as tf.Tensor

                            if (this.inputs[1]?.currentValue) {
                                try {
                                    for (const key of Object.keys(batchInputs)) {
                                        const dataInput = this.globalState.everything[key] as SymbolicInput
                                        dataInput.applyConcreteValue(batchInputs[key], false, false)
                                    }
                                    const labelInput = this.globalState.everything["output"] as SymbolicInput
                                    labelInput.applyConcreteValue(batchLabels)
                                    batchLabels = this.inputs[1].currentValue as tf.Tensor ?? batchLabels
                                    this.globalState.visitedReceptorCount.clear()
                                } catch (error) {
                                    console.warn(error)
                                    break
                                }
                            }
                            
                            try {
                                const l = new Promise<number>((r, _) => {
                                    try {

                                        optimizer.minimize(() => {
                                            let pred = this.model!.call(Object.values(batchInputs), {}) as tf.Tensor;
                                            if (pred instanceof Array) {
                                                pred = pred[0]
                                            }
                                            let loss: tf.Tensor
                                            if (this.lossType === "mse") {
                                                loss = tf.losses.meanSquaredError(batchLabels, pred)
                                            } else if (this.lossType === "mae") {
                                                loss = tf.losses.huberLoss(batchLabels, pred)
                                            } else if (this.lossType === "cross-entropy") {
                                                loss = tf.losses.softmaxCrossEntropy(tf.oneHot(batchLabels, dataset.totalClasses ?? 10), pred)
                                            } else {
                                                loss = tf.losses.cosineDistance(batchLabels, pred, -1)
                                            }
                                            // const loss = tf.losses.meanSquaredError(dataset.trainY, predYs);
                                            const l = round(loss.arraySync() as number, 4)
                                            r(l)
                                            return loss as tf.Scalar;
                                        });
                                    } catch (error) {
                                        console.warn(error)
                                        this.stopTraining()
                                        alert("Looks like your choice of loss function didn't match the dataset and model prediction.")
                                    }
                                })
                                iter += 1
                                if (iter % this.printInterval === 0 || iter + 1 === totalIters) {
                                    l.then(lv => {
                                        this.titleLabel.text(`Training (epoch ${epoch} of ${this.epochs}, iter ${iter} of ${totalIters}, loss=${lv.toFixed(3)})`)
                                        
                                        Object.values(this.globalState.everything).forEach(block => {
                                            if (block instanceof Parameter) {
                                                block.value = block.currentValue // For parent class
                                                block.nativeValue = block.value.dataSync() // For parent class
                                                block.updateReceptorsAndEndpoints()
                                            }
                                        })
                                    }).catch((err) => {
                                        console.warn(err)
                                        alert(err)
                                        this.stopTraining()
                                    })
                                }
                                await tf.nextFrame()
                            } catch (error) {
                                console.warn(error)
                                this.stopTraining()
                                this.titleLabel.text('Runtime Error')
                                alert("Looks like your choice of loss function didn't match the dataset and model prediction.")
                                return
                            }
                        }
                    }
                    
                    await this.model?.save(`indexeddb://models/${this.id}`, {includeOptimizer: true})
                    await this.globalState.activeExercise?.saveToCloud(this.globalState)
                }
                train().then(response => {
                    const currentText = this.titleLabel.text()
                    this.stopTraining()
                    this.titleLabel.text(currentText)

                    setTimeout(() => {
                        this.trainIcon.visible(true)
                        this.pauseIcon.visible(false)
                    })
                })
    
            } else {
                this.titleLabel.text("Ready to Train")
                this.trainIcon.visible(true)
                this.pauseIcon.visible(false)
                this.globalState.isTraining = false
                // this.trainingWorker?.postMessage({ pause: true })
            }
        } else {
            console.log("inputs", this.inputs)
        }
    }

    onInputUpdated(index: number): boolean {
        if (index === 1 || this.globalState.isTraining) { return true }
        /*
        if (index === 1) {
            if (this.inputs[1].deletedConnection) {
                // On detach from symbolic input
                const input = this.inputs[1].deletedConnection.start?.parent as SymbolicInput
                if (input.training) {
                    console.log('disconnecting', input)
                    input.setTraining(false)
                }
            } else {
                const input = this.inputs[1]?.connection?.start?.parent as SymbolicInput
                this.globalState.visitedReceptors.clear()
                if (!input.training) {
                    input.setTraining(true, this.batchSize)
                    this.inputs[1].currentValue = input.value
                }
            }
        }
        */

        // Detach logic
        if (this.inputs[0].deletedConnection) {
            const inputs = Object.values(this.globalState.everything).filter(block => block instanceof SymbolicInput) as SymbolicInput[]
            inputs.forEach(input => {
                if (input.training) {
                    console.log('disconnecting', input)
                    this.globalState.visitedReceptorCount.clear()
                    input.setTraining(false)
                }
            })
        }
        
        if (this.allRequiredInputsProvided && this.inputs[0].currentValue) {
            // Assume first dimension is batch, remaining dimensions are data
            // const inputTensor = this.inputs[1].currentValue as tf.SymbolicTensor
            const inputs = Object.values(this.globalState.everything).filter(block => block instanceof SymbolicInput) as SymbolicInput[]
            
            
            // const inputTensors = inputs.map(i => i.value as tf.SymbolicTensor)
            // const outputTensor = this.inputs[0].connection?.start.currentValue as tf.SymbolicTensor
            if (!this.globalState.isTraining) {
                try {
                    // this.model = tf.model({ inputs: inputTensors, outputs: outputTensor })
                    this.trainIcon?.visible(true)
                    this.pauseIcon?.visible(false)
                    this.titleLabel.text("Ready to Train")
                } catch (error) {
                    console.error(error)
                    // console.warn('input', inputTensors)
                    // console.warn('output', outputTensor)
                    this.titleLabel.text("Broken Circuit")
                }
            }
            this.dataset = this.dataset ?? this.inputs[0].currentDataset ?? undefined
            this.trainIcon?.opacity(1.0)
            this.outputs[0]?.propagateModel(this.model)
        } else if (!this.model) {
            this.trainIcon?.opacity(0.5)
            if (this.inputs[0].currentDataset === null || this.dataset === null) {
                this.titleLabel.text("Loading Dataset...")
            } else if (this.inputs[0].currentDataset === undefined && this.dataset === undefined && !this.datasetProvider) {
                this.titleLabel.text("Dataset not provided")
            } else {
                this.titleLabel.text("Connect model output to start")
            }
        } else {
            this.trainIcon?.opacity(0.5)
        }

        return true
    }

    onClickMenu(e: KonvaEventObject<MouseEvent>): InspectorProps {
        let batchSizeRef: HTMLInputElement | null = null
        let learningRateRef: HTMLInputElement | null = null
        let epochsRef: HTMLInputElement | null = null
        let printEveryRef: HTMLInputElement | null = null
        const batchSizeField = <input type="number" placeholder="Batch Size" defaultValue={this.batchSize} className="custom-textarea" style={{width: '80px', textAlign: 'right', maxHeight: "30px"}} ref={(e) => {
            batchSizeRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        const learningRateField = <input type="number" placeholder="Learning Rate" defaultValue={this.learningRate} step={0.001} className="custom-textarea" style={{width: '80px', textAlign: 'right', maxHeight: "30px"}} ref={(e) => {
            learningRateRef = e
            if (e) { e.value = e.defaultValue }
        }} />
        const epochsField = <input type="number" placeholder="Epochs" defaultValue={this.epochs} className="custom-textarea" style={{width: '80px', textAlign: 'right', maxHeight: "30px"}} ref={(e) => {
            epochsRef = e
            if (e) { e.value = e.defaultValue }
        }} />

        const printEvery = <input type="number" placeholder="# Iterations" defaultValue={this.printInterval} className="custom-textarea" style={{width: '80px', textAlign: 'right', maxHeight: "30px"}} ref={(e) => {
            printEveryRef = e
            if (e) { e.value = e.defaultValue }
        }} />

        let selectionButtonRef: HTMLButtonElement | null = null
        const menu = <div className="menu">
            {["mse", "mae", "cross-entropy", "cosine"].map(name => <button key={name} onClick={ (e) => {
                selectionButtonRef!.innerText = name.toLowerCase();
                this.lossType = name.toLowerCase() as LossType
                this.onInputUpdated(0)
            }}>{name}</button>)}
        </div>

        const table = <table className='info-table'>
            <tbody>
                <tr>
                    <td>{`Batch Size`}</td>
                    <td>{batchSizeField}</td>
                </tr>
                <tr>
                    <td>{`Learning Rate`}</td>
                    <td>{learningRateField}</td>
                </tr>
                <tr>
                    <td>{`Epochs`}</td>
                    <td>{epochsField}</td>
                </tr>
                <tr>
                    <td>Loss Function</td>
                    <td>
                        <Dropdown trigger={['click']} overlay={menu} animation="slide-up">
                            <button className="menu-select" ref={e => selectionButtonRef = e}>
                                {this.lossType ?? "Select..."}
                            </button>
                        </Dropdown>
                    </td>
                </tr>
                <tr>
                    <td>{`Print Interval`}</td>
                    <td>{printEvery}</td>
                </tr>
            </tbody>
        </table>
        return {
            title: this.displayedName,
            settings: table,
            buttons: [
                {
                    title: "Done",
                    type: "normal",
                    onClick: () => {
                        try {
                            this.learningRate = Number.parseFloat(learningRateRef!.value)
                        } catch {
                            alert("Learning rate must be a float.")
                            return false
                        }
                        try {
                            this.batchSize = Number.parseInt(batchSizeRef!.value)
                            if (this.batchSize <= 0) {
                                throw new Error()
                            }
                        } catch {
                            alert("Batch size must be a positive int.")
                            return false
                        }
                        try {
                            this.epochs = Number.parseInt(epochsRef!.value)
                        } catch {
                            alert("Epochs must be an int.")
                            return false
                        }
                        try {
                            this.printInterval = Number(parseInt(printEveryRef!.value))
                        } catch {
                            alert("Print interval must be an int.")
                            return false
                        }
                        this.onInputUpdated(0)
                        return true
                    }
                },
                ...(!this.model ? [] : [{
                    title: "Download Model",
                    type: "normal" as "normal",
                    onClick: () => {
                        this.model?.save("localstorage://tmp_model").then(async () => {
                            let namedTensorMap: Record<string, any> = {};
                            this.model!.weights.forEach((weight, i) => {
                                namedTensorMap[weight.name] = weight.read();
                            });
                            const encodedWeights = await tf.io.encodeWeights(namedTensorMap)
                            console.log('encoded', encodedWeights)

                            // const weight_data = localStorage.getItem("tensorflowjs_models/tmp-model/weight_data")!
                            // const decodedWeights = tf.io.decodeWeights(decode(weight_data), JSON.parse(localStorage.getItem("tensorflowjs_models/tmp-model/weight_specs")!))
                            // console.log('decoded', decodedWeights)

                            // const manualDecoded = tf.io.decodeWeights(encodedWeights.data, encodedWeights.specs)
                            // console.log('decoded from encoded', manualDecoded)
                            // this.model?.loadWeights(manualDecoded)
                            const handler = tf.io.fromMemory({}, encodedWeights.specs, encodedWeights.data)
                            const loaded = tf.loadLayersModel(handler)
                            console.log('loaded', loaded)
                            console.log('actual', this.model)
                        })
                        this.model?.save("downloads://tmp-model")
                        // const blob = new Blob([response.data], { type: 'application/octet-stream' });

                        // // Download
                        // let tmp = document.createElement('a');
                        // tmp.href = URL.createObjectURL(blob)
                        // tmp.target = '_blank';
                        // tmp.download = `Transcript-${page1Data.student_id}.pdf`;
                        // tmp.click();
                        return true
                    }
                }])
            ]
        }
    }
    
    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.container.shadowColor("#70a5fd")
        this.container.shadowBlur(8)
        this.container.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.container.shadowOpacity(0)
    }

    async loadStateDict(data: Record<string, any>) {
        try {
            if (data.lossType) this.lossType = data.lossType
            if (data.batchSize) this.batchSize = data.batchSize
            if (data.epochs) this.epochs = data.epochs
            if (data.learningRate) this.learningRate = data.learningRate
            if (data.printInterval) this.printInterval = data.printInterval
        } catch {
            console.warn("Trainer failed to deserialize from saved state")
        }
    }

    async getStateDict(): Promise<Record<string, any>> {
        return {
            lossType: this.lossType,
            batchSize: this.batchSize,
            epochs: this.epochs,
            learningRate: this.learningRate,
            printInterval: this.printInterval
        }
    }
}

export default Trainer;