import { JSX } from "react/jsx-runtime";
import LessonTemplate, { LessonTemplateProps } from "../../../../Components/LessonTemplate/LessonTemplate";
import { Fragment, useEffect, useRef, useState } from "react";
import MarkdownTextView from "../../../../Components/MarkdownTextView/MarkdownTextView";
import Stack from "@mui/material/Stack"
import MultipleChoiceQuestion from "../../../../Components/MultipleChoiceQuestion/MultipleChoiceQuestion";
import { round } from "../../../Utils";
import { useBoolean } from "../../../../use-boolean";
import Check from "@mui/icons-material/CheckCircle"
import { mnistData } from "../../../../Datasets/MNIST_Dataset";
import { MNIST_Image } from "../../../Interfaces";
import BinaryMatrixImage from "../../Chapter3/BinaryMatrixImage";
import "./Convolutions_2D.css"
import { colorInterpolate } from "../../../../Utils";

const checkmark = <Check htmlColor="#40a845" fontSize="small" sx={{marginLeft: "2px"}} />

const availableFilters = [
    [[1,0,-1],[2,0,-2],[1,0,-1]],
    [[1/9, 1/9, 1/9], [1/9, 1/9, 1/9], [1/9, 1/9, 1/9]]
]

const numberList1 = Array(16).fill(0).map(_ => round(Math.random(), 2))

function SimpleConv2D1() {
    const [counter, setCounter] = useState(0)
    const [hoverIndex, setHoverIndex] = useState<number | undefined>(undefined)
    useEffect(() => {
        setTimeout(() => setCounter((counter + 1) % 4), 2000)
    }, [counter])

    let top = Math.floor(counter / 2)
    let left = counter % 2
    let hoverTop = hoverIndex && Math.floor((hoverIndex) / 2)
    let hoverLeft = hoverIndex && (hoverIndex) % 2

    // Formula
    let outputValues = [[0, 0], [0, 0]]
    let outputFormulas = [["", ""], ["", ""]]
    for (let i = 0; i < 2; i++) {
        for (let j = 0; j < 2; j++) {
            let terms: string[] = []
            outputValues[i][j] = round(
                Array(9).fill(0).reduce((prev, _, current) => {
                    const row = Math.floor(current / 3)
                    const column = current % 3
                    let formulaString = `${current === 0 ? '\\phantom{+}' : '+'} ${numberList1[(i + row) * 4 + j + column].toFixed(2)} \\times {\\color{#a056ef} ${availableFilters[0][row][column]}}`
                    if (column === 2) {
                        formulaString += "\\\\"
                    } else if (column === 0) {
                        formulaString = "&" + formulaString
                    }
                    terms.push(formulaString)
                    return prev + numberList1[(i + row) * 4 + j + column] * availableFilters[0][row][column]
                }, 0)
            , 2)
            outputFormulas[i][j] = terms.join(" ") + ` &= ${outputValues[i][j]}`
        }
    }

    return <div>

        <div style={{display: "flex", justifyContent: "space-evenly", alignItems: "center", margin: "10px 0px"}}>
            <table style={{borderCollapse: "collapse", borderSpacing: 2, gap: "10px", margin: "15px auto"}}>
                <tbody>
                    {[0, 1, 2, 3].map(i => {
                        return <tr key={i}>
                            {numberList1.slice(i * 4, (i + 1) * 4).map((v, j) => {
                                return <td key={j} style={{fontSize: 15, verticalAlign: "middle", textAlign: "center", width: 50, height: 50, border: "1px solid #e9e9e9", backgroundColor: (hoverIndex !== undefined && i >= hoverTop! && i < hoverTop!+3 && j >= hoverLeft! && j < hoverLeft!+3) ? "#fffbd6" : (i >= top && i < top+3 && j >= left && j < left+3 ? "#f0f0f0" : "")}}>{v}</td>
                            })}
                        </tr>
                    })}
                </tbody>
            </table>

            {/* kernel */}
            <table style={{borderCollapse: "collapse", borderSpacing: 2, gap: "10px", margin: "15px auto"}}>
                <tbody>
                    {availableFilters[0].map((row,i) => <tr key={i}>
                        {row.map((v,j) => <td key={j} style={{fontSize: 15, verticalAlign: "middle", textAlign: "center", width: 50, height: 50, border: "1px solid #e9e9e9", backgroundColor: "#f5f3ff"}}>{round(v, 2)}</td>)}
                    </tr>)}
                </tbody>
            </table>

            <table style={{borderCollapse: "collapse", borderSpacing: 2, gap: "10px", margin: "15px auto"}}>
                <tbody>
                    {[0, 1].map(i => <tr key={i}>
                        {[0, 1].map(j => <td key={j} style={{fontSize: 15, verticalAlign: "middle", textAlign: "center", width: 50, height: 50, border: "1px solid #e9e9e9",
                        backgroundColor: hoverIndex === i * 2 + j ? "#fffbd6" : (counter === i * 2 + j ? "#f0f0f0" : ""),
                    }} onMouseEnter={e => setHoverIndex(i*2+j)} onMouseLeave={e => setHoverIndex(undefined)}>{outputValues[i][j]}</td>)}
                    </tr>)}
                </tbody>
            </table>
        </div>
        <div style={{fontSize: 15}}>
            <h3>Formula for 2D Convolution:</h3>
            <MarkdownTextView rawText={`\$\$\n\\begin{aligned}\n${outputFormulas[hoverTop ?? top][hoverLeft ?? left]}\n\\end{aligned}\n\$\$`}/>
        </div>
    </div>
}

function Conv2dVisualization(props: { }) {
    const [imageData, setImageData] = useState<MNIST_Image | undefined>()
    const [filter, setFilter] = useState<number[][] | undefined>(undefined)
    const [convolvedValue, setConvolvedValue] = useState<(number | undefined)[]>(Array(784).fill(undefined))
    const [dragOffset, setDragOffset] = useState<{x: number, y: number} | undefined>()
    
    useEffect(() => {
        mnistData(1).then(value => setImageData(value[0]))
    }, [])

    function recalculatePixel(_x: number, _y: number) {
        for (let x = _x - Math.round((filter![0].length-1) / 2); x < Math.min(28, _x - Math.round((filter![0].length-1) / 2) + filter![0].length); x++) {
            for (let y = _y - Math.round((filter!.length-1) / 2); y < Math.min(28, _y - Math.round((filter!.length-1) / 2) + filter!.length); y++) {
                const index = y * 28 + x
                if (filter) {
                    const lookupIndex = index + Math.floor((filter[0].length - 1) / 2) + 28 * Math.floor((filter.length - 1) / 2)
                    if (Math.abs(lookupIndex % 28 - x) >= filter[0].length) { return } // wrapped around the edge
                    if (convolvedValue[lookupIndex] === undefined) {
                        let sum = 0
                        for (let row = 0; row < filter.length; row++) {
                            for (let col = 0; col < filter[0].length; col++) {
                                if (x + col >= 0 && x + col < 28 && y + row >= 0 && y + row < 28) {
                                    sum += (imageData!.image[index + row * 28 + col] ?? 0) / 255 * filter[row][col]
                                }
                            }
                        }
                        convolvedValue[lookupIndex] = sum
                    }
                }
            }
        }
        setConvolvedValue(convolvedValue.slice())
    }
    
    return <>
        <BinaryMatrixImage imageWidth={28} pixelSize={24} matrix={imageData?.image} showPixelValue convolvedValues={convolvedValue} onCellConvolved={(x,y) => {
            recalculatePixel(x, y)
        }} offset={dragOffset} activeFilter={filter} />
        <div style={{height: 10}} />
        <div style={{margin: "12px", fontSize: 15, display: "flex", justifyContent: "space-between"}}>
            Select a kernel and drag it into the image:
            <button onClick={() => setConvolvedValue(Array(784).fill(undefined))}>Reset Image</button>
        </div>
        <div style={{backgroundColor: "#f0f0f0", borderRadius: 8, padding: 15, display: "flex", columnGap: 10, alignItems: "center", margin: "5px 12px", justifyContent: "center"}}>
            {availableFilters.map((filter, i) => {
                return <table key={i} className="filter-table generic-table" draggable onDragStart={e => {
                    const rect = e.currentTarget.getBoundingClientRect()
                    setDragOffset({
                        x: (e.clientX - rect.left),
                        y: (e.clientY - rect.top)
                    })
                    setFilter(filter)
                    // const canvas = document.createElement('canvas');
                    // const ctx = canvas.getContext('2d')!;
                
                    // canvas.height = 24 * filter.length * window.devicePixelRatio + 2;
                    // canvas.width = 24 * filter[0].length * window.devicePixelRatio + 2;

                    // let formattedArray: number[] = []
                    // for (let row = 0; row < filter.length; row++) {
                    //     for (let col = 0; col < filter[0].length; col++) {
                    //         const squashed = 1 / (1 + Math.exp(-filter[row][col]))
                    //         const {r,g,b} = colorInterpolate([5, 18, 38], [215, 231, 255], squashed)
                    //         ctx.fillStyle = `rgb(${r}, ${g}, ${b})`
                    //         ctx.fillRect(col * 24 * devicePixelRatio, row * 24 * devicePixelRatio, 24 * devicePixelRatio, 24 * devicePixelRatio)
                    //         ctx.strokeStyle = "red"
                    //         ctx.strokeRect(0, 0, canvas.width, canvas.height)
                    //     }
                    // }
                
                    // const imgData = ctx.createImageData(canvas.width, canvas.height);
                    // ctx.scale(24, 24)
                    // imgData.data.set(formattedArray);
                    // ctx.putImageData(imgData, 0, 0);
                    // img.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=';
                    // const image = new Image()
                    // image.style.opacity = "0.8"
                    // image.src = canvas.toDataURL()
                    // console.log(image.src)
                    // e.dataTransfer.setDragImage(image, -36, 36);
                }} onDragEnd={e => {
                    setFilter(undefined)
                }}
                >
                    <tbody>
                    {filter.map((row, rowIndex) => (
                        <tr style={{textAlign: "center"}} key={rowIndex}>
                            {row.map((value, colIndex) => <td key={colIndex}>{round(value, 2)}</td>)}
                        </tr>
                    ))}
                    </tbody>
                </table>
            })}
        </div>
    </>
}

class C4_Convolutions_2D extends LessonTemplate {


    constructor(props: LessonTemplateProps) {
        super(props, 3, "2D Convolutions")
    }

    getPageData(index: number): JSX.Element {
        if (index === 0) {
            return <Fragment>
                <MarkdownTextView rawText={"### 2D Convolution\nConvolutions are useful for sequential data like sensor readings, but what really make them powerful is that they are extremely good at 2D data as well, such as images."} />
                <SimpleConv2D1 />
                <Conv2dVisualization />
            </Fragment>
        } else if (index === 1) {
            return <Fragment>
                
            </Fragment>
        } else {
            return <Fragment>
                
            </Fragment>
        }

        return <Fragment />
    }
}

export default C4_Convolutions_2D;