import * as tf from "@tensorflow/tfjs"
import { Kwargs } from "@tensorflow/tfjs-layers/dist/types"
import { LossType } from "../Interfaces";
import { LayerArgs } from "@tensorflow/tfjs-layers/dist/engine/topology";

export class LossLayer extends tf.layers.Layer {

    labels?: tf.Tensor
    type: LossType = "mae"
    static className = "LossLayer"

    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "LossLayer"});
        this.type = (config?.type as LossType) ?? "mae"
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            this.labels = inputs[1]
            inputs = inputs[0]
        } else {
            console.warn("not array", inputs)
        }
        if (this.labels === undefined) {
            return tf.tensor(0)
        }
        switch (this.type) {
            case "mse":
                return tf.losses.meanSquaredError(this.labels, inputs)
            case "mae":
                return tf.losses.absoluteDifference(this.labels, inputs, undefined, tf.Reduction.MEAN)
            case "cross-entropy":
                return tf.losses.softmaxCrossEntropy(this.labels, inputs)
            case "cosine":
                return tf.losses.cosineDistance(this.labels, inputs, -1)
            default:
                console.error("Unexpected type:", this.type)
                return tf.tensor(0)
        }
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]) {
        return []
    }
}
