import * as tf from "@tensorflow/tfjs"
import Block from "../Block";
import Konva from "konva";
import Endpoint from "../Endpoint";
import Receptor from "../Receptor";
import { InspectorProps } from "../../Interfaces";
import { MultiplyConstantLayer, MultiplyLayer } from "../../Layers/Arithmetics";
import ArithmeticBlock from "./ArithmeticBlock";

class Multiply extends ArithmeticBlock {

    type_id = "multiply"
    layer: tf.layers.Layer = new MultiplyLayer()

    get blockName(): string {
        return "Multiplication Block"
    }
    get displayedName() { return "Multiplication Block" }

    constructor(id: string) {
        super(id)
        
        Konva.Image.fromURL("/assets/blocks/multiply.svg", img => {
            img.x(13)
            img.y(this.triangle.height() / 2)
            img.width(30)
            img.height(30)
            img.offsetY(img.height() / 2)
            this.element?.add(img)
        })

        this.inputs = [
            new Receptor(this, 0, 0, this.triangle.height() / 3),
            new Receptor(this, 1, 0, this.triangle.height() * 2 / 3)
        ]
        this.outputs = [
            new Endpoint(this, 0, this.triangle.width(), this.triangle.height() / 2)
        ]
    }

    getDocumentation(): string {
        return `Multiplies two tensors \`A\` and \`B\` elementwise, broadcasting if necessary.

$$
f(A, B) = A \\times B
$$`
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided) {
            const a = this.inputs[0].currentValue!
            const b = this.inputs[1].currentValue!

            try {
                if (a instanceof tf.Tensor && b instanceof tf.Tensor) {
                    this.layer = new MultiplyLayer()
                    this.currentValue = this.layer.apply([a, b]) as tf.Tensor
                } else if (a instanceof tf.Tensor && b instanceof tf.SymbolicTensor) {
                    this.layer = new MultiplyConstantLayer();
                    (this.layer as MultiplyConstantLayer).constantTerm = a
                    this.layer.name = this.id
                    this.currentValue = this.layer.apply(b) as tf.SymbolicTensor
                } else if (a instanceof tf.SymbolicTensor && b instanceof tf.Tensor) {
                    this.layer = new MultiplyConstantLayer();
                    (this.layer as MultiplyConstantLayer).constantTerm = b
                    this.layer.name = this.id
                    this.currentValue = this.layer.apply(a) as tf.SymbolicTensor
                } else {
                    this.layer = tf.layers.multiply()
                    this.currentValue = this.layer.apply([a, b] as tf.SymbolicTensor[]) as tf.SymbolicTensor
                }
            } catch (error) {
                this.triangle.stroke("#f01010")
                this.currentValue = undefined
            }

            if (this.currentValue instanceof Array) {
                this.currentValue = undefined
                this.triangle.stroke("#f01010")
            } else {
                this.triangle.stroke("black")
            }
        } else {
            this.currentValue = undefined
            this.triangle.stroke("black")
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    async getStateDict() {
        return {}
    }

    async loadStateDict(data: Record<string, any>): Promise<void> {
            
    }
}

export default Multiply;