import * as tf from "@tensorflow/tfjs"
import { Kwargs } from "@tensorflow/tfjs-layers/dist/types"
import { computeBroadcastShape } from "./Shape Computation";

export class AddLayer extends tf.layers.Layer {
    static className = "AddLayer"
    constructor() {
        super({name: "Add"});
    }

    call(inputs: tf.Tensor<tf.Rank>[]) {
        return tf.add(inputs[0], inputs[1])
    }
}

export class AddConstant extends tf.layers.Layer {
    static className = "AddConstantLayer"
    constantTerm?: tf.Tensor | number
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "AddConstant"});

        if (config?.constantTerm) {
            try {
                this.constantTerm = tf.tensor(config.constantTerm as any)
            } catch (err) {
                console.warn(err)
            }
        }
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        const symbolicShape = inputShape as tf.Shape
        if (this.constantTerm instanceof tf.Tensor) {
            if (this.constantTerm.shape.length >= symbolicShape.length) {
                return this.constantTerm.shape
            } else {
                return symbolicShape
            }
        } else {
            return symbolicShape
        }
    }

    // Override serialization format
    getConfig(): tf.serialization.ConfigDict {
        const base = super.getConfig()
        if (typeof this.constantTerm === "number") {
            return Object.assign({constantTerm: this.constantTerm}, base)
        } else if (this.constantTerm !== undefined) {
            return Object.assign({constantTerm: this.constantTerm.arraySync()}, base)
        } else {
            return base
        }
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        if (this.constantTerm !== undefined) {
            return tf.add(inputs, this.constantTerm)
        } else {
            console.warn(`AddConstant layer ${this.name} does not have a constantTerm set!`)
            return inputs
        }
    }
}

export class MultiplyLayer extends tf.layers.Layer {
    static className = "MultiplyLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "Multiply"});
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        console.log('mul shape', inputShape)
        return super.computeOutputShape(inputShape)
    }

    call(inputs: tf.Tensor<tf.Rank>[]) {
        return tf.mul(inputs[0], inputs[1])
    }
}

export class SquareRootLayer extends tf.layers.Layer {
    static className = "SquareRootLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "SquareRoot"});
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        return tf.sqrt(inputs)
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        return inputShape
    }
}


export class SquareLayer extends tf.layers.Layer {
    static className = "SquareLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "Square"});
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        return tf.square(inputs)
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        return inputShape
    }
}

export class MultiplyConstantLayer extends tf.layers.Layer {
    static className = "MultiplyConstantLayer"
    constantTerm?: tf.Tensor | number
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "MultiplyConstant"});

        if (config?.constantTerm) {
            try {
                this.constantTerm = tf.tensor(config.constantTerm as any)
            } catch (err) {
                console.warn(err)
            }
        }
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        const constantShape = (typeof this.constantTerm === "number" ? [] : this.constantTerm?.shape ?? [])
        return computeBroadcastShape(inputShape as tf.Shape, constantShape)
    }

    // Override serialization format
    getConfig(): tf.serialization.ConfigDict {
        const base = super.getConfig()
        if (typeof this.constantTerm === "number") {
            return Object.assign({constantTerm: this.constantTerm}, base)
        } else if (this.constantTerm !== undefined) {
            return Object.assign({constantTerm: this.constantTerm.arraySync()}, base)
        } else {
            return base
        }
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        if (this.constantTerm !== undefined) {
            return tf.mul(inputs, this.constantTerm)
        } else {
            console.warn(`MultiplyConstant layer ${this.name} does not have a constantTerm set!`)
            return inputs
        }
    }
}

export class MatMulLayer extends tf.layers.Layer {
    static className = "MatMulLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "MatMul"});
    }
    
    // computeOutputShape(inputShape: tf.Shape | tf.Shape[]): {
    //     const symbolicShape = inputShape as tf.Shape[]
    //     const shapeA = symbolicShape[0]
    //     const shapeB = symbolicShape[1]
    //     if (shapeA[shapeA.length - 1] == shapeB[shapeB.length - 2]) {
    //         return tf.layers.multiply()
    //     }
    //     M.
        
    // }

    call(inputs: tf.Tensor<tf.Rank>[]) {
        return tf.matMul(inputs[0], inputs[1])
    }
}

export class MatMulConstantLayer extends tf.layers.Layer {
    constructor(public constantTerm: tf.Tensor<tf.Rank>, public constantFirst: boolean) {
        super({name: "MatMulConstant"})
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        const symbolicShape = inputShape as tf.Shape
        let outShape: tf.Shape = []
        if (this.constantFirst) {
            outShape.push(symbolicShape[symbolicShape.length - 1])
            outShape.push(this.constantTerm.shape[this.constantTerm.shape.length - 2])
        } else {
            outShape.push(this.constantTerm.shape[this.constantTerm.shape.length - 1])
            outShape.push(symbolicShape[symbolicShape.length - 2])
        }
        for (let i = 3; i <= Math.max(symbolicShape.length, this.constantTerm.shape.length); i++) {
            const symbolDimSize = i <= symbolicShape.length ? symbolicShape[symbolicShape.length - i] : undefined
            const inputDimSize = i <= this.constantTerm.shape.length ? this.constantTerm.shape[this.constantTerm.shape.length - i] : undefined
            if (symbolDimSize === undefined) {
                outShape.push(inputDimSize!)
            } else if (inputDimSize === undefined) {
                outShape.push(symbolDimSize)
            } else if (symbolDimSize === null && inputDimSize === 1) {
                outShape.push(null)
            } else if (symbolDimSize === null) {
                outShape.push(inputDimSize)
            } else if (symbolDimSize === 1 || inputDimSize === 1) {
                outShape.push(symbolDimSize * inputDimSize)
            } else if (symbolDimSize === inputDimSize) {
                outShape.push(symbolDimSize)
            } else {
                return super.computeOutputShape(inputShape)
            }
        }
        outShape.reverse()
        return outShape
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        console.log("mmc input", inputs)
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        if (this.constantFirst) {
            return tf.matMul(this.constantTerm, inputs)
        } else {
            return tf.matMul(inputs, this.constantTerm)
        }
    }
}

export class NegationLayer extends tf.layers.Layer {
    static className = "NegationLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "NegationLayer"});
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            return inputs[0].neg()
        } else {
            return inputs.neg()
        }
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        return inputShape
    }
}

export class TransposeLayer extends tf.layers.Layer {
    constructor() {
        super({name: "TransposeLayer"});
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        const symbolicShape = inputShape as tf.Shape
        if (symbolicShape.length < 2) {
            return symbolicShape
        } else {
            let newShape = [...symbolicShape]
            const tmp = newShape[newShape.length - 1]
            newShape[newShape.length - 1] = newShape[newShape.length - 2]
            newShape[newShape.length - 2] = tmp
            return newShape
        }
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        if (inputs.shape.length < 2) {
            return inputs
        }
        let permutation = inputs.shape.map((_, i) => i)
        permutation[permutation.length - 1] -= 1
        permutation[permutation.length - 2] += 1
        return inputs.transpose(permutation)
    }
}

export abstract class AggregationLayer extends tf.layers.Layer {
    axis?: number[] | number
    keepDims?: boolean

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        if (this.axis === undefined) {
            return []
        }
        let reduced = Array(inputShape.length).fill(false)
        for (const a of (this.axis instanceof Array ? this.axis : [this.axis])) {
            if (a >= 0) {
                reduced[a] = true
            } else {
                reduced[reduced.length + a] = true
            }
        }
        return (inputShape as tf.Shape).filter((dim, i) => !reduced[i])
    }

    getConfig(): tf.serialization.ConfigDict {
        const base = super.getConfig()
        return Object.assign({axis: this.axis, keepDims: this.keepDims}, base)
    }
}

export class SumLayer extends AggregationLayer {
    static className = "SumLayer"

    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "SumLayer"});
        this.axis = config?.axis as number[]
        this.keepDims = config?.keepDims as boolean
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        return tf.sum(inputs, this.axis, this.keepDims)
    }
}

export class MeanLayer extends AggregationLayer {
    static className = "MeanLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "MeanLayer"});
        this.axis = config?.axis as number
        this.keepDims = config?.keepDims as boolean
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        return tf.mean(inputs, this.axis, this.keepDims)
    }
}

export class ReshapeLayer extends tf.layers.Layer {
    static className = "ReshapeLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "ReshapeLayer"});
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[], kwargs: Kwargs) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        const ret = tf.reshape(inputs, kwargs.shape)
        console.log(inputs, kwargs.shape, ret)
        
        return ret
    }
}

export class MAELayer extends tf.layers.Layer {
    static className = "MAELayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "MAELayer"});
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[], kwargs: Kwargs) {
        if (inputs instanceof Array) {
            return tf.abs(tf.sub(inputs[0], inputs[1])).mean()
        } else {
            return tf.mean(tf.abs(inputs))
        }
    }
}

export class DivideLayer extends tf.layers.Layer {
    static className = "DivideLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "Divide"});
    }

    call(inputs: tf.Tensor<tf.Rank>[]) {
        return tf.div(inputs[0], inputs[1])
    }
}

export class DivideConstantLayer extends tf.layers.Layer {
    static className = "DivideConstantLayer"
    constantTerm?: tf.Tensor | number
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "DivideConstant"});

        if (config?.constantTerm) {
            try {
                this.constantTerm = tf.tensor(config.constantTerm as any)
            } catch (err) {
                console.warn(err)
            }
        }
    }

    // Override serialization format
    getConfig(): tf.serialization.ConfigDict {
        const base = super.getConfig()
        if (typeof this.constantTerm === "number") {
            return Object.assign({constantTerm: this.constantTerm}, base)
        } else if (this.constantTerm !== undefined) {
            return Object.assign({constantTerm: this.constantTerm.arraySync()}, base)
        } else {
            return base
        }
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        if (this.constantTerm !== undefined) {
            return tf.div(inputs, this.constantTerm)
        } else {
            console.warn(`DivideConstant layer ${this.name} does not have a constantTerm set!`)
            return inputs
        }
    }
}

export class DotProductLayer extends tf.layers.Layer {
    static className = "DotProductLayer"
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "DotProduct"});
    }
    
    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        return [1]
    }

    call(inputs: tf.Tensor<tf.Rank>[]) {
        return tf.dot(inputs[0], inputs[1])
    }
}

export class ArgmaxLayer extends tf.layers.Layer {
    static className = "Argmax"
    axis?: number
    constructor(config?: tf.serialization.ConfigDict) {
        super({name: (config?.name as string) ?? "ArgmaxLayer"});
        this.axis = config?.axis as number
    }

    computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
        if (inputShape.length <= 1 || this.axis === undefined) {
            return [1]
        } else {
            return (inputShape as tf.Shape).filter((_, i) => {
                return this.axis !== i && this.axis !== i - inputShape.length
            }) as tf.Shape
        }
    }

    // Override serialization format
    getConfig(): tf.serialization.ConfigDict {
        const base = super.getConfig()
        return Object.assign({axis: this.axis}, base)
    }

    call(inputs: tf.Tensor<tf.Rank> | tf.Tensor<tf.Rank>[]) {
        if (inputs instanceof Array) {
            inputs = inputs[0]
        }
        return tf.argMax(inputs, this.axis)
    }
}