import axios from "axios"
import { ClassificationImage } from "../BlockSpace/Interfaces"

let CIFAR10_Train_Images: ClassificationImage[] = []
let CIFAR10_Test_Images: ClassificationImage[] = []

export const CIFAR10_LABELS = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

export async function cifar10Data(count: number, isEval = false) {
    const cachedImages = isEval ? CIFAR10_Test_Images : CIFAR10_Train_Images
    if (cachedImages.length >= count) {
        return cachedImages.slice(0, count)
    }

    const cachedCount = cachedImages.length

    let promise: Promise<any[]>
    if (isEval) {
        promise = Promise.all([
            axios(`/datasets/cifar10-test.bin`, { responseType: "arraybuffer" }),
        ])
    } else {
        let filesToLoad = [
            axios(`/datasets/cifar10-train1.bin`, { responseType: "arraybuffer" }),
        ]
        while (filesToLoad.length * 10000 < count && filesToLoad.length < 5) {
            filesToLoad.push(
                axios(`/datasets/cifar10-train${filesToLoad.length}.bin`, { responseType: "arraybuffer" })
            )
        }
        promise = Promise.all(filesToLoad)
    }
    const loadedImages = await promise.then(async results => {
        var pixelValues: ClassificationImage[] = [];
                
        for (let group = Math.floor(cachedCount / 10000); group < results.length; group++) {
            const rawArray = results[group].data as ArrayBuffer;
            
            const startingIndex = cachedCount > group * 10000 && cachedCount <= (group + 1) * 10000 ? cachedCount % 10000 : 0 
            for (let image = startingIndex; image < 10000; image++) {
                const byteArray = new Uint8Array(rawArray.slice(image * 3073, (image + 1) * 3073));
                const label = byteArray.at(0)!
                var pixels: number[][][] = [];

                for (let rgb = 0; rgb < 3; rgb++) {
                    pixels.push([])
                    for (let y = 0; y < 32; y++) {
                        pixels[rgb].push([])
                        for (let x = 0; x < 32; x++) {
                            pixels[rgb][y].push(byteArray.at(rgb * 32 * 32 + y * 32 + x)!);
                        }
                    }
                }

                pixelValues.push({ image: pixels, label })
            }
        }
        return pixelValues
    })
    
    cachedImages.push(...loadedImages)
    return cachedImages!.slice(0, count)
}