import ContentBlock from "../ContentBlock";
import * as tf from "@tensorflow/tfjs"
import Konva from "konva"
import Receptor from "../Receptor";
import Endpoint from "../Endpoint";
import { CustomTFDataset } from "../../Interfaces";

class CustomDataset extends ContentBlock {

    type_id = "custom_dataset"
    text: Konva.Text
    customTFDataset?: CustomTFDataset

    constructor(id: string) {
        super(id, 150, 200)
        
        this.titleLabel.text("Custom Dataset")
        this.text = new Konva.Text({
            width: this.container.width(),
            height: this.container.height() - this.titleBar.height(),
            verticalAlign: "middle",
            align: "center",
            fontSize: 16,
            padding: 5,
            y: this.titleBar.height(),
            fill: "#101010",
            text: "Missing Input(s)"
        })
        this.element.add(this.text)

        this.inputs = [
            new Receptor(this, 0, 0, this.titleBar.height() + this.contentHeight / 5, "Training Set Input"),
            new Receptor(this, 1, 0, this.titleBar.height() + this.contentHeight * 2 / 5, "Training Set Labels"),
            new Receptor(this, 2, 0, this.titleBar.height() + this.contentHeight * 3 / 5, "Validation Set Inputs"),
            new Receptor(this, 3, 0, this.titleBar.height() + this.contentHeight * 4 / 5, "Validation Set Labels")
        ]
        this.outputs = [
            new Endpoint(this, 0, this.container.width(), this.titleBar.height() + this.contentHeight / 2, "Dataset", "dataset")
        ]
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided) {
            const trainX = this.inputs[0].currentValue as tf.Tensor
            const trainY = this.inputs[1].currentValue as tf.Tensor
            const evalX = this.inputs[2].currentValue as tf.Tensor
            const evalY = this.inputs[3].currentValue as tf.Tensor
            if (trainX.shape !== evalX.shape) {
                this.text.text(`Shape mismatch between training input (${trainX.shape.join(" x ")}) and eval input (${evalX.shape.join(" x ")})`)
            } else if (trainY.shape !== evalY.shape) {
                this.text.text(`Shape mismatch between training labels (${trainY.shape.join(" x ")}) and eval labels (${evalY.shape.join(" x ")})`)   
            } else if (trainX.shape[0] !== trainY.shape[0]) {
                this.text.text(`Training input and labels differ in length`)
            } else if (evalX.shape[0] !== evalY.shape[0]) {
                this.text.text(`Validation input and labels differ in length`)
            } else {
                this.text.text(`Received ${trainX.shape[0]} training samples, ${evalX.shape[0]} eval samples`)
                this.customTFDataset = { trainX, trainY, evalX, evalY }
                this.outputs[0].propagateDataset(this.customTFDataset)
            }
            return true
        } else {
            console.log(this.inputs.map(i => i.currentValue))
        }
        return true
    }
}

export default CustomDataset