import * as tf from "@tensorflow/tfjs"
import Block from "../Block";
import Konva from "konva";
import Endpoint from "../Endpoint";
import Receptor from "../Receptor";
import { InspectorProps, LossType, SaveEntry } from "../../Interfaces";
import { AddConstant, AddLayer } from "../../Layers/Arithmetics";
import { KonvaEventObject } from "konva/lib/Node";
import ArithmeticBlock from "../Operations/ArithmeticBlock";
import { LossLayer } from "../../Layers/Loss Functions";

class Loss extends ArithmeticBlock {

    layer: LossLayer
    textLabel: Konva.Text
    octagon: Konva.RegularPolygon
    value?: tf.Tensor | tf.SymbolicTensor | null
    type_id = "loss"

    get quotaId(): string {
        return "loss_" + this.layer.type
    }

    get blockName() { return "Loss" }
    get displayedName() { return "Loss Block" }

    constructor(id: string, type?: LossType) {
        super(id)

        this.layer = new LossLayer()
        this.layer.type = type ?? "mae"
        this.element = new Konva.Group({
            draggable: true,
            width: Math.sqrt(3) * 40,
            height: Math.sqrt(3) * 40,
            offsetX: -(Math.sqrt(3) / 2 * 40),
            offsetY: -(Math.sqrt(3) / 2 * 40)
        })

        this.octagon = new Konva.RegularPolygon({
            sides: 6,
            radius: 40,
            strokeWidth: 3,
            stroke: "#50dbaf"
        })
        this.element.add(this.octagon)

        this.textLabel = new Konva.Text({
            zPosition: 2,
            text: this.layer.type,
            fontSize: 15,
            fill: "black",
            align: "center",
            x: -(Math.sqrt(3) / 2 * 40),
            y: -(Math.sqrt(3) / 2 * 40),
            height: Math.sqrt(3) * 40,
            width: Math.sqrt(3) * 40,
            verticalAlign: "middle"
        })
        this.element.add(this.textLabel)

        this.inputs = [
            new Receptor(this, 0, -(Math.sqrt(3) / 2 * 40), -20),
            new Receptor(this, 1, -(Math.sqrt(3) / 2 * 40), 20)
        ]
        this.outputs = [
            new Endpoint(this, 0, Math.sqrt(3) / 2 * 40, 0)
        ]
    }

    getDocumentation(): string {
        return `Quantifies how far two tensors are. Takes average over elementwise differences.
$$`
    }

    onInputUpdated(index: number): boolean {
        if (this.allRequiredInputsProvided) {
            const a = this.inputs[0].currentValue!
            const b = this.inputs[1].currentValue!

            try {
                if (a instanceof tf.Tensor && b instanceof tf.Tensor) {
                    this.currentValue = this.layer.apply([a, b]) as tf.Tensor
                } else if (a instanceof tf.Tensor && b instanceof tf.SymbolicTensor) {
                    this.currentValue = tf.tensor(0)
                } else if (a instanceof tf.SymbolicTensor && b instanceof tf.Tensor) {
                    this.currentValue = tf.tensor(0)
                } else {
                    this.currentValue = this.layer.apply([a, b] as tf.SymbolicTensor[]) as tf.SymbolicTensor
                }
            } catch (error) {
                console.warn(error)
                this.triangle.stroke("#f01010")
                this.currentValue = undefined
            }

            if (this.currentValue instanceof Array) {
                this.currentValue = undefined
                console.warn("currentValue is array", this.currentValue)
                this.triangle.stroke("#f01010")
            } else {
                this.triangle.stroke("black")
            }
        } else {
            this.currentValue = undefined
            this.triangle.stroke("black")
        }
        return this.outputs[0].propagate(this.currentValue)
    }

    select(e: KonvaEventObject<MouseEvent>): void {
        super.select(e)

        this.octagon.shadowColor("#70a5fd")
        this.octagon.shadowBlur(6)
        this.octagon.shadowOpacity(0.9)
    }

    unselect(): void {
        super.unselect()

        this.octagon.shadowOpacity(0)
    }

    async getStateDict(): Promise<Record<string, any>> {
        return {
            type: this.layer.type
        }
    }

    async loadStateDict(data: Record<string, any>) {
        if (data?.type) {
            this.layer.type = data.type
        }
    }

    override async serialize(): Promise<SaveEntry> {
        return {
            position: this.element!.position(),
            typeId: this.type_id,
            quotaId: "loss_" + this.layer.type,
            value: await this.getStateDict(),
            customName: this.customName
        }
    }
}

export default Loss;