import * as tf from "@tensorflow/tfjs"
import Block from "../Block";
import Konva from "konva";
import Endpoint from "../Endpoint";
import Receptor from "../Receptor";
import { InspectorProps } from "../../Interfaces";
import { AddConstant, AddLayer, MatMulConstantLayer, MatMulLayer } from "../../Layers/Arithmetics";
import ArithmeticBlock from "./ArithmeticBlock";

class MatMul extends ArithmeticBlock {

    type_id = "matmul"
    layer: tf.layers.Layer = new MatMulLayer()

    get displayedName() { return "MatMul Gate" }

    constructor(id: string) {
        super(id)

        Konva.Image.fromURL("/assets/blocks/matmul.svg", img => {
            img.scale({x: 1.4, y: 1.4})
            img.x(11)
            img.y(this.triangle.height() / 2)
            img.offsetY(img.height() / 2)
            this.element?.add(img)
        })

        this.inputs = [
            new Receptor(this, 0, 0, this.triangle.height() / 3),
            new Receptor(this, 1, 0, this.triangle.height() * 2 / 3)
        ]
        this.outputs = [
            new Endpoint(this, 0, this.triangle.width(), this.triangle.height() / 2)
        ]
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided) {
            const a = this.inputs[0].currentValue!
            const b = this.inputs[1].currentValue!

            try {
                if (a instanceof tf.Tensor && b instanceof tf.Tensor || a instanceof tf.SymbolicTensor && b instanceof tf.SymbolicTensor) {
                    this.layer = new MatMulLayer()
                    this.currentValue = this.layer.apply([a, b] as (tf.Tensor[] | tf.SymbolicTensor[])) as tf.Tensor
                } else if (a instanceof tf.Tensor) {
                    this.layer = new MatMulConstantLayer(a, true)
                    this.currentValue = this.layer.apply(b as tf.SymbolicTensor) as tf.SymbolicTensor
                } else if (b instanceof tf.Tensor) {
                    this.layer = new MatMulConstantLayer(b, false)
                    this.currentValue = this.layer.apply(a as tf.SymbolicTensor) as tf.SymbolicTensor
                }
            } catch (error) {
                console.warn(error)
                this.triangle.stroke("#f01010")
                this.currentValue = undefined
                return true
            }

            if (this.currentValue instanceof Array) {
                this.currentValue = undefined
                this.triangle.stroke("#f01010")
                return true
            } else {
                this.triangle.stroke("black")
            }
        } else {
            this.currentValue = undefined
            this.triangle.stroke("black")
        }
        return this.outputs[0].propagate(this.currentValue)
    }

}

export default MatMul;