import { JSX } from "react/jsx-runtime";
import LessonTemplate, { LessonTemplateProps } from "../../../../Components/LessonTemplate/LessonTemplate";
import { Fragment, useRef, useState } from "react";
import MarkdownTextView from "../../../../Components/MarkdownTextView/MarkdownTextView";
import Stack from "@mui/material/Stack"
import MultipleChoiceQuestion from "../../../../Components/MultipleChoiceQuestion/MultipleChoiceQuestion";
import { round } from "../../../Utils";
import { useBoolean } from "../../../../use-boolean";
import Check from "@mui/icons-material/CheckCircle"

const numberList1 = Array(9).fill(0).map(_ => round(Math.random(), 2))
const checkmark = <Check htmlColor="#40a845" fontSize="small" sx={{marginLeft: "2px"}} />

function OneDimensionalExample(props: { input: number[], padLeft: number, padRight: number, message?: string }) {
    const numberList = [...Array(props.padLeft).fill(0), ...props.input, ...Array(props.padRight).fill(0)]
    const containerRef = useRef<HTMLTableCellElement | null>(null)
    const sliderRef = useRef<HTMLTableElement | null>(null)
    const [startX, setStartX] = useState(0)
    const [offsetX, setOffsetX] = useState(0)
    const hasDragged = useBoolean(false)
    const isDragging = useBoolean(false)
    
    const [value1, setValue1] = useState(-1)
    const [value2, setValue2] = useState(0)
    const [value3, setValue3] = useState(1)
    const [outputBitmask, setOutputBitmask] = useState(0b1)
    const [hoveredIndex, setHoveredIndex] = useState<number | undefined>(undefined)

    const leftCell = Math.min(Math.round((offsetX + 2) / 57), numberList.length - 3)

    function getConvOutput(pos: number) {
        if ((1 << pos & outputBitmask) === 0 && hoveredIndex === undefined) {
            setOutputBitmask(outputBitmask | (1 << pos))
        }
        return value1 * numberList[pos] + value2 * numberList[pos + 1] + value3 * numberList[pos + 2]
    }

    return <div style={{overflowX: "auto"}}>
        <table style={{borderCollapse: "separate", borderSpacing: 2, gap: "10px", margin: "15px auto"}}>
            <thead>
                    <th style={{width: 75}} />
                {numberList.map((_, i) => <th style={{textAlign: "center", fontWeight: 400, fontSize: 14, lineHeight: 1.5, color: "#707070", minWidth: "50px", maxWidth: "50px"}}>{i < props.padLeft || i >= numberList.length - props.padRight ? "pad" : i - props.padLeft}</th>)}
            </thead>
            <tbody>
                <tr>
                    <td style={{fontSize: 15, textAlign: "right", paddingRight: "12px", verticalAlign: "middle"}}>Input:</td>
                    {numberList.map((_, i) => <td style={{
                        height: "52px",
                        boxSizing: "border-box",
                        border: `1px ${i < props.padLeft || i >= numberList.length - props.padRight ? "dashed" : "solid"} #e9e9e9`,
                        textAlign: "center",
                        userSelect: 'none',
                        backgroundColor: (hoveredIndex !== undefined && i >= hoveredIndex && i <= hoveredIndex + 2) ? "#fffbd6" : (i >= leftCell && i <= leftCell + 2 ? (i < props.padLeft || i >= numberList.length - props.padRight ? "#f5f5f5" : "#f0f0f0") : "transparent"),
                        fontWeight: i >= leftCell && i <= leftCell + 2 ? 600 : 500,
                        color: i >= leftCell && i <= leftCell + 2 ? "black" : "#808080",
                        cursor: "pointer"
                    }} key={i}>{numberList[i]}</td>)}
                    <td style={{width: 40}} />
                </tr>
                <tr>
                    <td />
                    {numberList.map((_, i) => <td style={{
                        height: 20, textAlign: "center",
                        userSelect: 'none',
                        fontSize: 15,
                        color: "gray",
                    }} key={i}>{i >= leftCell && i <= leftCell + 2 ? "×" : ""}</td>)}
                </tr>
                <tr>
                    <td style={{fontSize: 15, textAlign: "right", paddingRight: "12px", verticalAlign: "middle"}}>Kernel:</td>
                    <td colSpan={numberList.length} style={{backgroundColor: "", height: "50px"}} ref={containerRef}>
                        <table style={{ borderCollapse: "collapse", borderRadius: 4, transform: `translateX(${offsetX}px)`, backgroundColor: "#f7f7f7" }}
                            ref={sliderRef}
                        >
                            <tr>
                                <td style={{height: 48, verticalAlign: "middle", textAlign: "center"}}>
                                    <input type="text" defaultValue={value1} onChange={e => setValue1(Number(e.target.value))} style={{width: "50px", lineHeight: "48px", textAlign: "center", border: "1px solid #f0f0f0", backgroundColor: "#e4fff0", fontSize: 15, outlineOffset: "-5px"}} />
                                </td>
                                <td style={{height: 48, verticalAlign: "middle", textAlign: "center"}}>
                                    <input type="text" defaultValue={value2} onChange={e => setValue2(Number(e.target.value))} style={{width: "50px", lineHeight: "48px", textAlign: "center", border: "1px solid #f0f0f0", backgroundColor: "#e0eeff", fontSize: 15, outlineOffset: "-5px"}} />
                                </td>
                                <td style={{height: 48, verticalAlign: "middle", textAlign: "center"}}>
                                    <input type="text" defaultValue={value3} onChange={e => setValue3(Number(e.target.value))} style={{width: "50px", lineHeight: "48px", textAlign: "center", border: "1px solid #f0f0f0", backgroundColor: "#ece3ff", fontSize: 15, outlineOffset: "-5px"}} />
                                </td>
                            </tr>
                            <tr>
                                <td colSpan={3} style={{height: "20px", cursor: "grab", backgroundImage: hasDragged.value ? "url('/assets/drag.png')" : "", backgroundRepeat: "no-repeat", backgroundPosition: "50%", backgroundSize: "15px 15px", opacity: 0.4, textAlign: "center", verticalAlign: "middle", fontSize: 12, color: "0x303030", userSelect: "none"}}
                                onMouseDown={(e) => {
                                    isDragging.onTrue()
                                    setStartX(e.pageX - offsetX)
                                    hasDragged.onTrue()
                                }}
                                onMouseUp={isDragging.onFalse}
                                onDrag={e => {
                                    const newOffsetX = e.pageX - startX;
                                    if (!isDragging.value) { return }
                                    const boundedOffsetX = Math.max(0, Math.min(newOffsetX, containerRef.current!.clientWidth - sliderRef.current!.clientWidth));
                                    setOffsetX(boundedOffsetX)
                                }}
                                draggable
                                onDragStart={e => {
                                    var img = new Image();
                                    img.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=';
                                    e.dataTransfer.setDragImage(img, 0, 0);
                                }}>{hasDragged.value ? "" : "drag me!"}</td>
                            </tr>
                        </table>
                    </td>
                </tr>
                <tr style={{height: "80px"}}>
                    <td style={{fontSize: 15, textAlign: "right", paddingRight: "12px", verticalAlign: "middle"}}>Dot Product:</td>
                    <td colSpan={numberList.length} style={{textAlign: "center"}}>
                        <MarkdownTextView rawText={`\${\\color{#ABEBCF} ${value1}} \\times ${numberList[hoveredIndex ?? leftCell]} + {\\color{#AED5F6} ${value2}} \\times ${numberList[hoveredIndex ?? leftCell + 1]} + {\\color{#C8AEFA} ${value3}} \\times ${numberList[hoveredIndex ?? leftCell + 2]} = ${round(getConvOutput(hoveredIndex ?? leftCell), 4)}\$`}/>
                    </td>
                </tr>
                <tr>
                    <td style={{fontSize: 15, textAlign: "right", paddingRight: "12px", verticalAlign: "middle"}}>Output:</td>
                    <td />
                    {Array(numberList.length - 2).fill(0).map((_, i) => <td style={{
                        height: 50, border: `1px ${i < props.padLeft || i >= numberList.length - 2 - props.padRight ? "dashed" : "solid"} #e9e9e9`, textAlign: "center",
                        userSelect: 'none',
                        fontSize: 15,
                        backgroundColor: i === hoveredIndex ? "#fffbd6" : (i === leftCell ? "#f0f0f0" : "transparent"),
                        fontWeight: i == leftCell ? 600 : 500,
                        color: i == leftCell ? "black" : "#808080",
                        cursor: (1 << i & outputBitmask) > 0 ? "default" : "not-allowed"
                    }} key={i} onMouseEnter={() => setHoveredIndex(((1 << i) & outputBitmask) > 0 ? i : undefined)} onMouseLeave={() => setHoveredIndex(undefined)}>{(1 << i & outputBitmask) > 0 ? round(getConvOutput(i), 3) : "?"}</td>)}
                    <td style={{width: 40}} />
                    <td />
                </tr>
            </tbody>
        </table>
        {props.message && outputBitmask === (1 << (numberList.length - 2)) - 1 && <Stack direction="row" alignItems="center" columnGap={1} justifyContent="center" py={2} my={2} borderRadius={2} mx={1} sx={{backgroundColor: "#f1fcf0"}}>
            {checkmark} {props.message}  
        </Stack>}
    </div>
}

function Conv1dSimple(props: { input: number[], padLeft: number, padRight: number }) {
    const numberList = props.input
    const [value1, setValue1] = useState(-1)
    const [value2, setValue2] = useState(0)
    const [value3, setValue3] = useState(1)

    function getConvOutput(pos: number) {
        return value1 * numberList[pos] + value2 * numberList[pos + 1] + value3 * numberList[pos + 2]
    }

    return <table style={{borderCollapse: "collapse", borderSpacing: 2, gap: "10px", margin: "15px auto"}}>
            <thead>
                <th style={{width: 75}} />
                {numberList.map((_, i) => <th style={{textAlign: "center", fontWeight: 400, fontSize: 14, lineHeight: 1.5, color: "#707070", minWidth: "52px", maxWidth: "52px"}}>{i < props.padLeft || i >= numberList.length - props.padRight ? "pad" : i - props.padLeft}</th>)}
            </thead>
            <tbody>
                <tr>
                    <td style={{fontSize: 15, textAlign: "right", paddingRight: "12px", verticalAlign: "middle"}}>Input:</td>
                    {numberList.map((_, i) => <td style={{
                        height: "52px",
                        boxSizing: "border-box",
                        border: `1px ${i < props.padLeft || i >= numberList.length - props.padRight ? "dashed" : "solid"} #e9e9e9`,
                        textAlign: "center",
                        userSelect: 'none',
                        color: "#808080",
                        cursor: "pointer"
                    }} key={i}>{numberList[i]}</td>)}
                    <td style={{width: 40}} />
                </tr>
                <tr>
                    <td style={{fontSize: 15, textAlign: "right", paddingRight: "12px", verticalAlign: "middle"}}>Kernel:</td>
                    <td colSpan={numberList.length} style={{backgroundColor: "", height: "50px"}}>
                        <table style={{ borderCollapse: "collapse", borderRadius: 4, backgroundColor: "#f7f7f7", margin: "12px auto" }}
                        >
                            <tr>
                                <td style={{height: 48, verticalAlign: "middle", textAlign: "center"}}>
                                    <input type="text" defaultValue={value1} onChange={e => setValue1(Number(e.target.value))} style={{width: "50px", lineHeight: "48px", textAlign: "center", border: "1px solid #f0f0f0", backgroundColor: "#e4fff0", fontSize: 15, outlineOffset: "-5px"}} />
                                </td>
                                <td style={{height: 48, verticalAlign: "middle", textAlign: "center"}}>
                                    <input type="text" defaultValue={value2} onChange={e => setValue2(Number(e.target.value))} style={{width: "50px", lineHeight: "48px", textAlign: "center", border: "1px solid #f0f0f0", backgroundColor: "#e0eeff", fontSize: 15, outlineOffset: "-5px"}} />
                                </td>
                                <td style={{height: 48, verticalAlign: "middle", textAlign: "center"}}>
                                    <input type="text" defaultValue={value3} onChange={e => setValue3(Number(e.target.value))} style={{width: "50px", lineHeight: "48px", textAlign: "center", border: "1px solid #f0f0f0", backgroundColor: "#ece3ff", fontSize: 15, outlineOffset: "-5px"}} />
                                </td>
                            </tr>
                        </table>
                    </td>
                </tr>
                <tr>
                    <td colSpan={2} />
                    {Array(7).fill(0).map((_, i) => <td style={{textAlign: "center", fontWeight: 400, fontSize: 14, lineHeight: 1.5, color: "#707070", minWidth: "52px", maxWidth: "52px"}}>{i < props.padLeft || i >= numberList.length - props.padRight ? "pad" : i - props.padLeft}</td>)}
                </tr>
                <tr>
                    <td style={{fontSize: 15, textAlign: "right", paddingRight: "12px", verticalAlign: "middle"}}>Output:</td>
                    <td />
                    {Array(numberList.length - 2).fill(0).map((_, i) => <td style={{
                        height: 50, border: `1px ${i < props.padLeft || i >= numberList.length - 2 - props.padRight ? "dashed" : "solid"} #e9e9e9`, textAlign: "center",
                        userSelect: 'none',
                        fontSize: 15,
                        color: "#808080",
                    }} key={i}>{round(getConvOutput(i), 3)}</td>)}
                    <td style={{width: 40}} />
                    <td />
                </tr>
            </tbody>
        </table>
}

class C4_Convolutions extends LessonTemplate {
    constructor(props: LessonTemplateProps) {
        super(props, 3, "Introduction to Convolutions")
    }

    getPageData(index: number): JSX.Element {
        if (index === 0) {
            return <Fragment>
                <MarkdownTextView rawText={"### What is Convolution?\n*Convolution* is perhaps the single most powerful algorithm in computer vision, powering most of the AI technology that deals with images, from face recognition to AI art. Don't be intimidated by the word though. Convolution is just a fancy word for *sliding window*. In the following examples, we will explore how it works."} />
                
                <MarkdownTextView rawText={`### 1D Convolution\n\nConvolutions work in any number of dimensions, but let's start with 1D to keep things simple. Suppose we have a sequence of numbers representing a student's historical grades in a class:

$$
\\begin{bmatrix}
${numberList1.join(" & ")}
\\end{bmatrix}
$$

Convolution is like ironing clothes. To *convolve* the input data, we need to slide over it using a **kernel**, just like how we run the iron over the clothes. Try sliding the kernel below over the input data:`}/>

                <OneDimensionalExample input={numberList1} padLeft={0} padRight={0} message={"Congrats for making your first 1D convolution!"} />

                <MarkdownTextView rawText={"### Understanding the Output\nLet's try to understand what's going on. The kernel is a special vector that we align with the input at every position. For each alignment, we calculate the kernel's dot product with the corresponding segment of the input (shaded). Then we output all the dot products as an ordered sequence. Sometimes, the output sequence is called *activations* because each output value indicates how much a kernel has *activated* at that position. \n\nYou might wonder what is the point of this. It turns out that the activations tell us how much the kernel's pattern is manifested at each location. For example, the kernel $[-1, 0, 1]$ can be used to search for steady increases in the input (i.e. the student's grade). The output value with maximal activation corresponds to the three consecutive assignments when the student has demonstrated the most growth. "} />
                
                <MultipleChoiceQuestion prompt="Which kernel is most likely to reveal sharp dips in the student's performance?" options={["$[1, -1, 1]$", "$[1, 0, -1]$", "$[0.33, 0.33, 0.33]$", `$[-1, 1, -1]$`]} correctIndex={0} explanation="$[1, -1, 1]$ is able to pickup dips in the student's performance because it is the only kernel whose middle value is the smallest, thus resembling a dip the most." />
                
            </Fragment>
        } else if (index === 1) {
            return <Fragment>
                <MarkdownTextView rawText={"### Output Shape\nIn the previous example, you may have noticed that the output length could be smaller than the input length, depending on the kernel size. In particular, after the convolution, the length shrank by 2."} />

                <Conv1dSimple input={numberList1} padLeft={0} padRight={0} />
    
                <MultipleChoiceQuestion prompt="Suppose the kernel length is now 5. Using the method above, how long should the output be?" correctIndex={0} options={["5", "6", "7", "8"]} explanation="If you observe the convolution carefully, you will realize that the length of the output is $n - k + 1$, where $n$ is the input length and $k$ is the kernel length. Since the input length is 9, 9 - 5 + 1 = 5." />

                <MarkdownTextView rawText={"### Padding\nA problem with convolutions is that it's hard to keep track of how long the output is, especially if there are multiple layers. To address this problem, we could pad additional zeros around the input so that the output length remains the same."} />

                <OneDimensionalExample input={numberList1.slice()} padLeft={1} padRight={1} message="You completed a padded 1D convolution." />

                <MarkdownTextView rawText="In AI Playground, these two padding approaches are called `valid` and `same`. `valid` means no padding, `same` means pad the input so that the output length is the same." />
            </Fragment>
        } else {
            return <Fragment>
                <MarkdownTextView rawText={"#### Human Activity Recognition Using Smartphones\nThe first problem we will work on is a [dataset](https://archive.ics.uci.edu/dataset/240/human+activity+recognition+using+smartphones) of sensor data (acceleratometer + gyroscope) collected from humans performing 6 activities. Each sample consists of 128 readings spanned over 2-3 seconds. Each reading has 6 values: the XYZ values of the acceleratometer, and the XYZ values of the gyroscope."} />
                <a href="/chapters/4/exercises/1" style={{textDecoration: "none"}}>
                    <button className="next-button">Open Exercise</button>
                </a>
            </Fragment>
        }

        return <Fragment />
    }
}

export default C4_Convolutions;