import { Layer } from "konva/lib/Layer";
import TabView from "../../../Components/TabView/TabView";
import { ValueStore, InspectorProps, TestResult, TestCaseResult, SaveState, TestCase } from "../../Interfaces";
import LibraryItem from "../../../Components/LibraryItem/LibraryItem";
import ExerciseData, { ItemInfo } from "../ExerciseData";
import { createBlockFromString } from "../../Utils";
import * as tf from "@tensorflow/tfjs"
import MarkdownTextView from "../../../Components/MarkdownTextView/MarkdownTextView";
import SymbolicInput from "../../Nodes/Inputs/SymbolicInput";
import OutputBlock from "../../Nodes/Outputs/OutputBlock";
import { LossLayer } from "../../Layers/Loss Functions";
import axios from "axios";
import { mnistData } from "../../../Datasets/MNIST_Dataset";
import Trainer from "../../Nodes/Training/Trainer";
import { ArgmaxLayer } from "../../Layers/Arithmetics";

// Computes distances to origin
class Chapter3_MNIST2 extends ExerciseData {

    defaultBlocks: {
        input1?: SymbolicInput
        output?: OutputBlock
        trainer?: Trainer
    } = {}
    title = "Recognizing Handwritten Digits (Attempt 2)"
    exercise_id = "mnist2"

    constructor() {
        super()
        this.instructions = <div className="instructions">
            <h3>{this.title}</h3>
            <MarkdownTextView rawText={`Build another model that recognizes digits. This time, you have access to activation blocks, which can be found in *Basic Operations*. Let's aim for 90% accuracy.`} />
        </div>

        this.onTestCasesUpdated = this.onTestCasesUpdated.bind(this);
    }

    async generateTestCases(count?: number): Promise<TestCase[]> {
        const mnistTrainImages = await mnistData(count ?? 5)
        this.trainingCases = mnistTrainImages.map(x => {
            return {
                input: [tf.tensor(x.image)],
                output: tf.tensor(x.label),
                inputLabels: ["Image"],
                outputLabel: "Digit",
                inputTypes: ["grayscale_image"]
            }
        })
        return this.trainingCases
    }

    getInitialQuota(): Record<string, ItemInfo> {
        return {
            "add": {count: 5},
            "sum": {count: 5},
            "negate": { count: 2},
            "mean": { count: 5 },
            "multiply": {count: 2},
            "tensor_viewer": {count: 2},
            "image_viewer": {count: 1},
            "divide": {count: 2},
            "linear": { count: 5},
            "activation": {count: 5},
            "argmax": {count: 2},
            "reshape": {count: 5}
        }
    }

    setup(layer: Layer, store: ValueStore, setShowInspector: (value: boolean) => void, setInspectorView: (view?: InspectorProps | undefined) => void): void {
        super.setup(layer, store, setShowInspector, setInspectorView)

        this.defaultBlocks.output = createBlockFromString("tensor_viewer", { customName: "Output" }, "output") as OutputBlock
        this.defaultBlocks.output.editable = false
        this.addBlockAtPosition(this.defaultBlocks.output, 700, 300)

        this.defaultBlocks.input1 = createBlockFromString("symbolic", { value: { shape: [784]}, customName: "Image" }, "input1") as SymbolicInput
        this.defaultBlocks.input1.addAltInput()
        this.defaultBlocks.input1.editable = false
        this.addBlockAtPosition(this.defaultBlocks.input1, 200, 300)

        this.defaultBlocks.trainer = createBlockFromString("trainer", undefined, "trainer") as Trainer
        this.defaultBlocks.trainer.editable = false
        this.addBlockAtPosition(this.defaultBlocks.trainer, 400, 500)

        new Promise<void>(async (r, _) => {
            const trainData = await mnistData(10000)
            const testData = await mnistData(10, true)
            this.defaultBlocks.trainer!.inputs[0].currentDataset = {
                trainX: tf.tensor(trainData.map(d => d.image)),
                trainY: tf.tensor(trainData.map(d => d.label), undefined, "int32"),
                evalX: tf.tensor(testData.map(d => d.image)),
                evalY: tf.tensor(testData.map(d => d.label), undefined, "int32"),
            }
            this.defaultBlocks.trainer?.onInputUpdated(0)
            r()
        })
    }

    onTestCasesUpdated(): void {
        if (this.defaultBlocks.input1 === undefined || this.trainingCases === undefined) {
            return
        }
        if (this.activeExampleIndices.length === 0) {
            this.defaultBlocks.input1!.applyConcreteValue(tf.zeros([0, 784]))
        } else {
            const filteredInputs1 = this.activeExampleIndices.map(i => this.trainingCases![i].input[0])
            const combined1 = tf.stack(filteredInputs1, 0)
            this.defaultBlocks.input1!.applyConcreteValue(combined1)
        }
    }

    async assess(onProgressUpdated: (result: TestResult) => void) {
        if (!this.defaultBlocks.input1 || !this.defaultBlocks.output) {
            console.warn("Exercise blocks not initialized")
            return
        }

        // Get model by passing in symbolic input
        this.defaultBlocks.input1.applyConcreteValue(undefined)
        const symbolicOutput = this.defaultBlocks.output.inputs[0].currentValue as tf.SymbolicTensor
        try {
            let model = tf.model({ inputs: [this.defaultBlocks.input1.value], outputs: symbolicOutput })
            if (this.defaultBlocks.trainer?.model) {
                // let trained = await this.defaultBlocks.trainer.model.save(`indexeddb://models/${this.exercise_id}`)
                // console.log(trained)
                // tf.
                // const sourceLayers = this.defaultBlocks.trainer.model.layers;
                // const targetLayers = model.layers;

                // sourceLayers.forEach((sourceLayer, index) => {
                //     const targetLayer = targetLayers[index];
                //     if (sourceLayer.getWeights().length > 0) {
                //         const weights = sourceLayer.getWeights();
                //         targetLayer.setWeights(weights);
                //     }
                // });
            }
            this.onTestCasesUpdated()
            const saveData = await model.save(`indexeddb://models/${this.exercise_id}`)
            console.log('saved model', model)
        } catch (error) {
            console.warn(error)
            alert("Output is not connected to inputs")
            return
        }
        

        this.worker = new Worker(new URL("./MNIST Grader.ts", import.meta.url))
        this.worker.onmessage = e => {
            onProgressUpdated(e.data)
        }
        this.worker.postMessage({ save_path: `indexeddb://models/${this.exercise_id}` })
    }
}

export default Chapter3_MNIST2;