import * as tf from "@tensorflow/tfjs"

function mergedSize(a: number | null, b: number | null) {
    if (a === 1) {
        return b
    } else if (b === 1) {
        return a
    } else if (a === null || b === null) {
        return null
    } else if (a === b) {
        return a
    } else {
        throw new Error("Shape mismatch")
    }
}

export function computeBroadcastShape(shapeA: (number | null)[], shapeB: (number | null)[]) {
    const shapeAReversed = [...shapeA].reverse()
    const shapeBReversed = [...shapeB].reverse()
    let outputShape: (number | null)[] = [] // reversed
    for (let i = 0; i < Math.max(shapeA.length, shapeB.length); i++) {
        if (i < shapeA.length && i < shapeB.length) {
            outputShape.push(mergedSize(shapeAReversed[i], shapeBReversed[i]))
        } else {
            outputShape.push(shapeAReversed[i] !== undefined ? shapeAReversed[i] : shapeBReversed[i])
        }
    }
    outputShape.reverse()
    return outputShape
}