import { Container } from "@mui/material";
import { round, randomInt, number } from "mathjs";
import React, { Fragment, LegacyRef, RefObject, useEffect, useRef, useState } from "react"

import MultipleChoiceQuestion from "../../../../Components/MultipleChoiceQuestion/MultipleChoiceQuestion";
import LessonTemplate, { LessonTemplateProps } from "../../../../Components/LessonTemplate/LessonTemplate";
import { JSX } from "react/jsx-runtime";
import MarkdownTextView from "../../../../Components/MarkdownTextView/MarkdownTextView";
import Stack from "@mui/material/Stack"
import BinaryMatrixImage from "../BinaryMatrixImage";
import { MNIST_Image } from "../../../Interfaces";
import { mnistData } from "../../../../Datasets/MNIST_Dataset";
import Check from "@mui/icons-material/CheckCircle";
import Cross from "@mui/icons-material/Cancel"

interface State {
    currentPage: number
}

const checkmark = <Check htmlColor="#40a845" fontSize="small" sx={{marginLeft: "2px"}} />
const cross = <Cross htmlColor="#c0605a" fontSize="small" sx={{marginLeft: "2px"}} />

const numberList = Array(10).fill(0).map(_ => round(Math.random() * 2 - 1, 2))

function ArgmaxTable() {
    const [picked, setPicked] = useState<number | undefined>()
    const max = Math.max(...numberList)
    return <table style={{borderCollapse: "collapse", margin: "15px auto"}}>
        <tbody>
            <tr>
                <td></td>
                {Array(10).fill(0).map((_, i) => <td style={{textAlign: "center", fontWeight: 400, fontSize: 14, lineHeight: 1.5, color: "#707070"}}>{i}</td>)}
            </tr>
            <tr>
                <td style={{width: 60, fontSize: 15}}>Score:</td>
                {Array(10).fill(0).map((_, i) => <td style={{
                    width: 50, height: 50, border: '1px solid #e9e9e9', textAlign: "center",
                    userSelect: 'none',
                    backgroundColor: picked === i ? (numberList[i] === max ? "#ebffeb" : "#ffd5d5") : "transparent" ,
                    cursor: "pointer"
                }} onClick={() => setPicked(i)}>{numberList[i]}</td>)}
                <td style={{width: 40}} />
            </tr>
            <tr>
                {picked !== undefined && <td colSpan={12} style={{textAlign: "center", paddingTop: "5px"}}>
                    <div className="mcq" style={{display: "flex", flexDirection: "row", columnGap: "4px"}}>
                        {numberList[picked] === max ? checkmark : cross}
                        {numberList[picked] === max ? `You got it! In this case, the model predicted ${numberList.indexOf(max)} because it has the highest score (${max}).` : "Try again!" }
                    </div>
                </td>}
            </tr>
        </tbody>
    </table>
}

class C3_MNIST extends LessonTemplate {
    
    examples: MNIST_Image[] = []
    currentError = Number.POSITIVE_INFINITY

    constructor(props: LessonTemplateProps) {
        super(props, 3, "Handwritten Digits")       
        this.loadImages() 
    }

    loadImages() {
        mnistData(1).then(images => {
            this.examples = images
            this.forceUpdate()
        })
    }

    getPageData(index: number): JSX.Element {
        if (index === 0) {
            return <Fragment>
                <MarkdownTextView rawText={"We are now finally ready to move on to images! To start off, we will work with the [MNIST Handwritten Digit Database](http://yann.lecun.com/exdb/mnist/) to build our first model for handwritten digit recognition. Given an image, it will predict what digit it contains.\n\n#### How are images represented on a computer?\nFor this task, we will use black & white images. Each image contains 28x28 pixels whose values range between 0-255, 0 means completely black, 255 means completely white.\nHere is an image from the MNIST dataset, with all the pixels labeled."} />
                <Stack direction="row" justifyContent="center" my={3}>
                    <BinaryMatrixImage imageWidth={28} pixelSize={21} matrix={this.examples[0]?.image} showPixelValue />
                </Stack>
                <MultipleChoiceQuestion prompt="Can you tell what number is in this image?" correctIndex={5} options={["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]} explanation="Correct!" />
            </Fragment>
        } else if (index === 1) {
            return <Fragment>
                <MarkdownTextView rawText={"We, as humans, are very good at recognizing symbols. But it is challenging for a computer to do the same. For the computer, each image is just a sequence (or vector) of $28 \\times 28 = 784$ numbers, and it needs to use only this information to classify it into one of the 10 categories."} />

                <MarkdownTextView rawText={"But mathematically, how does a computer do classification? In the case of handwritten digit recognition, how do we force the model to output one of the 10 possible digits 0-9?\n\nIn machine learning, a model performs classification by calculating a score for each possible category the input data can be. Then the category with the *highest* score is treated as the model's prediction. It's OK for some categories' scores to be negative, because we just care about locating the biggest value.\n\nSuppose a handwritten digit recognition model outputs the following scores. Which digit is the model predicting? **Click on the score to check your answer**."} />
                <ArgmaxTable />

                <MarkdownTextView rawText="In general, if we want to build a model that can do classification for $N$ classes, we need to make sure that it outputs **exactly** $N$ numbers, representing the raw scores of each class."/>
                
            </Fragment>
        } else if (index === 2) {
            return <Fragment>
                {/* TODO: higher resolution */}
                <MarkdownTextView rawText={"### New Blocks for Neural Networks\nWhen it comes to neural networks, we use a new set of blocks. Here are the most important three blocks to get you started."} />
                <MarkdownTextView rawText={"#### Linear Block\nLinear blocks are your best friend moving forward. They are highly versatile operations that can take inputs of any length and transform it to another length of your choice. A Linear block is described by $(n \\times m)$ where $n$ is input length and $m$ is output length. For example, a Linear block of shape $(3 \\times 4)$ can help us transform a tensor of shape $(n \\times 3)$ to $(n \\times 4)$ for an arbitrary $n$."} />
                <img src="/assets/chapter3/linear.png" width={550} className="centered" />
                <MarkdownTextView rawText="In textbooks and other places, Linear blocks are also called *Dense*, *Fully Connected*, and *FeedForward* layers. **They are the same**. The fact that Linear blocks has so many nicknames shows its central position in the kingdom of machine learning." />
                <MarkdownTextView rawText={"#### Argmax Block\nGiven a list of $N$ numbers representing the scores of the $N$ classes, how can we find the most likely class? We obtain the position with the maximum value using the *argmax* operation. The axis tells us which dimension the argmax is calculated for. -1 just means the last dimension."}/>
                <img src="/assets/chapter3/argmax.png" width={500} className="centered" />

                <MarkdownTextView rawText="In the above case, -1 means the second dimension, which is the horizontal axis. Hence, the output $[1,2]$ means that the maximal value in the first row is at index 1 (second from left, namely 2), and the maximal value in the second row is at index 2 (third from left, namely 5)." />
                
                <MarkdownTextView rawText={"#### Trainer Block\nThe trainer is what makes training happen. To train your model, point the output to the trainer and click the run button. The training configuration can be customized in the sidebar."} />
                <img src="/assets/chapter3/trainer.png" width={500} className="centered" />
                <MarkdownTextView rawText={"That's it! Are you ready to build your first neural network?"} />

                <a href="/chapters/3/exercises/1" style={{textDecoration: "none"}}>
                    <button className="next-button">Open Exercise</button>
                </a>
            </Fragment>
        } else if (index === 3) {
            return <Fragment>
            </Fragment>
        } else if (index === 4) {
            return <Fragment>
                <MarkdownTextView rawText="So a handwritten digit recognition model reads a length-784 vector containig values $I_1, \dots, I_{784}$, and produces a length-10 “score” vector $S$. Those scores are called *logits*. There are many ways to achieve this. For this exercise, we will start with one of the easiest yet powerful methods to build a machine learning model -- dot products." />
                    
                <MarkdownTextView rawText={`As a review, when two vectors of the same length are *dotted*, you pair up the numbers by position, multiply them, and sum them up into one number. We could imagine that given an image vector $I$ of length 784, we use 10 vectors $\\vec w_0, \\dots, \\vec w_9$ (of length 784) to calculate the logit value for each digit, resulting in 10 numbers $S = \\left[S_0, \\dots, S_9 \\right]$. The vectors $\\vec w_i$ are called **weight vectors** because they are to be learned from data by the model. So our 10 weight vectors contain a total of $10 \\times 784 = 7840$ *weights*.\n\nTo perform this calculation efficiently, instead of doing the dot product 10 times, once for each $\\vec w_i$, we stack the weight vectors up into a matrix and multiply the image by the weight matrix.

$$
\\begin{aligned}
\\left[
    \\begin{array}{cccc}
        I_1 & I_2 & \\cdots & I_{784}
    \\end{array}
\\right]
\\cdot
\\underbrace{\\left[
    \\begin{array}{cccc}
        \\uparrow & \\uparrow & \\cdots & \\uparrow \\\\
        \\vec{w}_0 & \\vec w_1 & \\cdots & \\vec w_9 \\\\
        \\downarrow & \\downarrow & \\cdots & \\downarrow
    \\end{array}
\\right]}_{W}

&= \\left[
    \\begin{array}{cccc}
        I \\cdot \\vec w_0 & I \\cdot \\vec w_1 & \\cdots & I \\cdot \\vec w_9
    \\end{array}
\\right] \\\\
&= 
\\left[
    \\begin{array}{cccc}
        S_0 & S_1 & \\cdots & S_9
    \\end{array}
\\right]
\\end{aligned}
$$
                `} />

                <MarkdownTextView rawText="In the playground, we use Linear blocks to represent matrix multiplication. Are you ready to build your first handwritten digit recognition model?" />

                <img src="/assets/chapter3/linear.png" width={180} className="centered" />

                <a href="/chapters/3/exercises/1" style={{textDecoration: "none"}}>
                    <button className="next-button">Open Exercise</button>
                </a>
            </Fragment>
        } else {
            return <Fragment />
        }
    }
}

export default C3_MNIST;