import { Layer } from "konva/lib/Layer";
import TabView from "../../../Components/TabView/TabView";
import { ValueStore, InspectorProps, TestResult, TestCaseResult, SaveState, TestCase } from "../../Interfaces";
import ExerciseData, { ItemInfo } from "../ExerciseData";
import * as tf from "@tensorflow/tfjs"
import MarkdownTextView from "../../../Components/MarkdownTextView/MarkdownTextView";
import SymbolicInput from "../../Nodes/Inputs/SymbolicInput";
import { random, randomInt, round } from "mathjs";
import OutputBlock from "../../Nodes/Outputs/OutputBlock";
import { LossLayer } from "../../Layers/Loss Functions";
import axios from "axios";
import Trainer from "../../Nodes/Training/Trainer";
import { ArgmaxLayer } from "../../Layers/Arithmetics";
import { createBlockFromString } from "../../Utils";
import { cifar10Data } from "../../../Datasets/CIFAR10_Dataset";
import { activityData } from "../../../Datasets/HumanActivity";

// Computes distances to origin
class Chapter4_PhysicalActivities extends ExerciseData {

    defaultBlocks: {
        input1?: SymbolicInput
        output?: OutputBlock
        trainer?: Trainer
    } = {}
    title = "Identifying Physical Activities"
    exercise_id = "physical_activities"

    constructor() {
        super()
        this.instructions = <div className="instructions">
            <h3>{this.title}</h3>
            <MarkdownTextView rawText={`Given the accelerometer data, build a model that detects the physical activity the person is performing.`} />
        </div>

        this.onTestCasesUpdated = this.onTestCasesUpdated.bind(this);
    }

    async generateTestCases(count?: number): Promise<TestCase[]> {
        const trainData = await activityData(count ?? 4)
        this.trainingCases = trainData.map(x => {
            return {
                input: [tf.tensor(x.input).transpose()],
                output: tf.tensor(x.label),
                inputLabels: ["Input"],
                outputLabel: "Activity Type",
                inputTypes: ["default"] as "default"[]
            }
        })
        return this.trainingCases
    }

    getInitialQuota(): Record<string, ItemInfo> {
        return {
            "add": {count: 5},
            "sum": {count: 5},
            "negate": { count: 5},
            "mean": { count: 5 },
            "sqrt": { count: 5},
            "square": { count: 5},
            "multiply": {count: 5},
            "tensor_viewer": {count: 2},
            "image_viewer": {count: 2},
            "divide": {count: 5},
            "parameter": { count: 5 },
            "linear": { count: 5},
            "activation": {count: 10},
            "argmax": {count: 5},
            "conv1d": {count: 5},
            "maxpool1d": {count: 5},
            "avgpool1d": {count: 5},
            "globalavgpool1d": {count: 2},
            "globalmaxpool1d": {count: 2},
            "reshape": {count: 8},
            "concat": {count: 5},
            "dropout": {count: 3}
        }
    }

    setup(layer: Layer, store: ValueStore, setShowInspector: (value: boolean) => void, setInspectorView: (view?: InspectorProps | undefined) => void): void {
        super.setup(layer, store, setShowInspector, setInspectorView)

        this.defaultBlocks.output = createBlockFromString("tensor_viewer", { customName: "Output" }, "output") as OutputBlock
        this.defaultBlocks.output.editable = false
        this.addBlockAtPosition(this.defaultBlocks.output, 800, 300)

        this.defaultBlocks.input1 = createBlockFromString("symbolic", { value: { shape: [128, 6]}, customName: "Input" }, "input1") as SymbolicInput
        this.defaultBlocks.input1.addAltInput()
        this.defaultBlocks.input1.editable = false
        this.addBlockAtPosition(this.defaultBlocks.input1, 300, 200)

        this.defaultBlocks.trainer = createBlockFromString("trainer", undefined, "trainer") as Trainer
        this.defaultBlocks.trainer.editable = false
        this.addBlockAtPosition(this.defaultBlocks.trainer, 500, 600)

        new Promise<void>(async (r, _) => {
            const trainData = await activityData(8000)
            const evalData = await activityData(1, true)
            
            this.defaultBlocks.trainer!.inputs[0].currentDataset = {
                trainX: [tf.tensor(trainData.map(d => d.input)).transpose([0, 2, 1])],
                trainY: tf.tensor(trainData.map(d => d.label), undefined, "int32"),
                evalX: [tf.tensor(evalData.map(d => d.input)).transpose()],
                evalY: tf.tensor(evalData.map(d => d.label), undefined, "int32"),
                totalClasses: 6
            }

            this.defaultBlocks.trainer?.onInputUpdated(0)
            r()
        })
    }

    onTestCasesUpdated(): void {
        if (this.defaultBlocks.input1 === undefined || this.trainingCases === undefined) {
            return
        }
        if (this.activeExampleIndices.length === 0) {
            this.defaultBlocks[`input1`].applyConcreteValue(tf.zeros([0, 128, 6]))
        } else {
            const filteredInputs1 = this.activeExampleIndices.map(i => this.trainingCases![i].input[0])
            this.defaultBlocks.input1!.applyConcreteValue(tf.stack(filteredInputs1, 0))
        }
    }

    async assess(onProgressUpdated: (result: TestResult) => void) {
        if (!this.defaultBlocks.input1 || !this.defaultBlocks.output) {
            console.warn("Exercise blocks not initialized")
            return
        }

        // Get model by passing in symbolic input
        this.defaultBlocks.input1.applyConcreteValue(undefined)
        const symbolicOutput = this.defaultBlocks.output.inputs[0].currentValue as tf.SymbolicTensor
        try {
            // let model = tf.model({ inputs: [this.defaultBlocks.input6.value], outputs: symbolicOutput })

            let model = tf.model({ inputs: [this.defaultBlocks.input1.value], outputs: symbolicOutput })
            this.onTestCasesUpdated()
            const saveData = await model.save(`indexeddb://models/${this.exercise_id}`)
            console.log('saved model', model)
        } catch (error) {
            console.warn(error)
            alert("Output is not connected to inputs")
            return
        }
        

        this.worker = new Worker(new URL("./Human Activities Grader.ts", import.meta.url))
        this.worker.onmessage = e => {
            onProgressUpdated(e.data)
        }
        this.worker.postMessage({ save_path: `indexeddb://models/${this.exercise_id}` })
    }
}

export default Chapter4_PhysicalActivities;