import axios from "axios"
import { MNIST_Image } from "../BlockSpace/Interfaces"

let MNIST_Train_Images: MNIST_Image[] = []
let MNIST_Test_Images: MNIST_Image[] = []

export async function mnistData(count: number, isEval = false) {
    const cachedImages = isEval ? MNIST_Test_Images : MNIST_Train_Images
    if (cachedImages && cachedImages.length >= count) {
        return cachedImages.slice(0, count)
    }

    let promise: Promise<[any, any]>
    if (isEval) {
        promise = Promise.all([
            axios(`/datasets/mnist-eval`, { responseType: "arraybuffer" }),
            axios(`/datasets/mnist-eval-labels`, { responseType: "arraybuffer" })
        ])
    } else {
        promise = Promise.all([
            axios(`/datasets/mnist-train`, { responseType: "arraybuffer", onDownloadProgress: event => {
                console.log('mnist download progress', event)
            } }),
            axios(`/datasets/mnist-train-labels`, { responseType: "arraybuffer" })
        ])
    }
    const loadedImages = await promise.then(async ([res1, res2]) => {
        const trainingData = new Uint8Array(res1.data as ArrayBuffer)
        const trainingLabels = new Uint8Array(res2.data as ArrayBuffer)
        var pixelValues: MNIST_Image[] = [];
        
        // Always load at least 10000 images, but if `count` is higher, load up to `count` images
        const loaded = cachedImages.length
        for (var image = loaded; image < Math.max(10000, Math.min(60000, count)); image++) { 
            var pixels: number[] = [];
        
            for (var y = 0; y <= 27; y++) {
                for (var x = 0; x <= 27; x++) {
                    pixels.push(trainingData.at((image * 28 * 28) + (x + (y * 28)) + 16)!);
                }
            }
        
            pixelValues.push({
                image: pixels,
                label: trainingLabels.at(image + 8)!
            });
        }
        return pixelValues
    })
    cachedImages.push(...loadedImages)
    return cachedImages!.slice(0, count)
}