import * as Konva from "konva"
import Block from "./Block";
import Receptor from "./Receptor";
import { distance } from "../Utils";
import Connection from "./Connection";
import { LayersModel, SymbolicTensor, Tensor } from "@tensorflow/tfjs"
import Trainer from "./Training/Trainer";
import Predictor from "./Outputs/Predictor";
import { CustomTFDataset, ReceptorType } from "../Interfaces";
import CustomDataset from "./Training/CustomDataset";
import { Line } from "konva/lib/shapes/Line";

const MAX_RECEPTOR_REUSE = 8

export default class Endpoint {
    parent: Block
    index: number
    currentValue?: Tensor | SymbolicTensor | null
    connections: Map<string, Connection> = new Map()
    name?: string
    type: ReceptorType
    
    element = new Konva.default.Group()
    ring = new Konva.default.Circle({
        radius: 6,
        visible: false,
        stroke: "#e04040",
        opacity: 0.5,
        strokeWidth: 1.2
    })

    circle = new Konva.default.Circle({
        fill: "#e04040",
        opacity: 0.6,
        radius: 4
    })

    line = new Konva.default.Line({
        stroke: "#a0a0a0",
        opacity: 0.4,
        dash: [10, 10],
        lineCap: "round"
    })

    anchor = new Konva.default.Circle({
        fill: "#e04040",
        opacity: 0,
        radius: 6,
        draggable: true,
        visible: false
    })

    mouseIn = false

    constructor(parent: Block, index: number, x: number, y: number, name?: string, type?: ReceptorType) {
        this.parent = parent
        this.index = index
        this.name = name
        this.type = type ?? "tensor"
        this.element.add(this.ring)
        this.element.add(this.circle)
        this.element.add(this.anchor)
        this.element.add(this.line)
        this.element.x(x)
        this.element.y(y)
        this.element.on("mouseenter", (e) => {
            this.ring.visible(true)
            if (e.evt.buttons == 0) {
                this.ring.opacity(0.5)
            }
            this.circle.opacity(1)
            this.parent.element?.draggable(false)
            this.anchor.visible(true)
            this.mouseIn = true
            this.parent.globalState.showTooltip!({
                text: this.name ?? `Output ${this.index + 1}`,
                position: this.element.absolutePosition()
            })
        })
        this.element.on("mousedown", () => {
            this.ring.opacity(0.75)
        })
        this.element.on("mouseup", () => {
            this.ring.opacity(0.5)
        })
        this.element.on("mouseleave", (e) => {
            this.ring.visible(false)
            this.anchor.visible(false)
            this.parent.element?.draggable(true)
            if (e.evt.buttons === 0) this.circle.opacity(0.7)
            this.mouseIn = false
            this.parent.globalState.hideTooltip!()
        })

        this.anchor.on("dragstart", () => {
            this.line.visible(true)
            this.parent.globalState.activeEndpoint = this
        })
        this.anchor.on("dragmove", (e) => {
            this.line.points([0, 0, this.anchor.x(), this.anchor.y()])
            const target = (this.parent.globalState.availableReceptors ?? []).find(r => {
                if (r.parent.id === this.parent.id) { return false }
                const dist = distance(this.anchor.absolutePosition(), r.circle.absolutePosition()) // TODO: adjust for offset
                return dist < r.ring.radius() * 1.5
            });
            (this.parent.globalState.availableReceptors ?? []).forEach(r => {
                r.ring.visible(false)
            })
            if (target) {
                target.ring.visible(true)
                target.ring.opacity(1.0)
                this.parent.globalState.activeReceptor = target
                target.showToolTip()
            } else {
                this.parent.globalState.activeReceptor = undefined
            }
            // console.log(sourceBlock.globalState.mainLayer?.getAllIntersections( {x: this.anchor.x(), y: this.anchor.y() } ))
        })

        this.anchor.on("dragend", () => {
            this.line.points([])
            this.line.visible(false)
            this.anchor.x(0)
            this.anchor.y(0)
            this.ring.visible(false)
            if (!this.mouseIn) this.circle.opacity(0.7)
            if (this.parent.globalState.activeReceptor) {
                this.parent.globalState.activeReceptor.ring.visible(false)
                this.addConnectionToReceptor(this.parent.globalState.activeReceptor)
                this.parent.globalState.activeExercise?.saveToCloud(this.parent.globalState)
            }
            this.parent.globalState.activeEndpoint = undefined
        })
    }

    /** Create connection to this endpoint if not exists, otherwise delete existing one */
    addConnectionToReceptor(receptor: Receptor, propagate = true, dropIfExists = true) {
        const newConnection = new Connection(this, receptor)
        if (this.connections.has(newConnection.id)) {
            if (dropIfExists) {
                console.assert(this.parent.globalState.connections.has(newConnection.id))
                const oldConnection = this.parent.globalState.connections.get(newConnection.id)!
                this.connections.delete(oldConnection.id) // Unset endpoint reference
                oldConnection.end.deletedConnection = oldConnection
                oldConnection.end.connection = undefined // Unset receptor reference
                this.parent.globalState.connections.delete(newConnection.id) // Unset global reference
                oldConnection?.line.remove() // Remove from canvas
                receptor.connectionList = receptor.connectionList.filter(c => c.id !== newConnection.id)
                oldConnection.end.propagate(undefined)
                oldConnection.end.deletedConnection = undefined
            }
        } else if (receptor.allowMultiple) {
            const existingIndex = receptor.connectionList.map(conn => conn.id).indexOf(newConnection.id)
            if (existingIndex >= 0) { // connection exists 
                const [oldConnection] = receptor.connectionList.splice(existingIndex, 1)
                this.parent.globalState.connections.delete(oldConnection.id)
                oldConnection.start.connections.delete(oldConnection.id)
                oldConnection.line.remove()
            } else {
                newConnection.end.connection = newConnection
                receptor.connectionList.push(newConnection)
                if (receptor.propagate(this.currentValue)) {
                    this.connections.set(newConnection.id, newConnection) // Set endpoint reference
                    this.parent.globalState.connections.set(newConnection.id, newConnection) // Set global reference
                    this.parent.globalState.mainLayer.add(newConnection.line) // Add to canvas
                    newConnection.line.moveToBottom()
                    receptor.parent.updateReceptorsAndEndpoints()
                } else {
                    receptor.connectionList.pop()
                    console.warn(`Detected cycle, connection ${newConnection.id} not added`)
                }
            }
            this.parent.globalState.visitedReceptorCount.clear()
        } else {
            // Remove connection if exists
            if (receptor.connection) {
                this.parent.globalState.connections.delete(receptor.connection.id)
                receptor.connection.start.connections.delete(receptor.connection.id)
                receptor.connectionList = []
                receptor.connection.line.remove()
            }
            if (this.parent.globalState.visitedReceptorCount.size > 0) {
                this.parent.globalState.visitedReceptorCount.clear()
            }
            receptor.connection = newConnection
            receptor.connectionList = [newConnection]
            this.parent.globalState.visitedReceptorCount.set(receptor.parent.id + "-" + receptor.index, 1)
            if (receptor.type === "model" && receptor.propagateModel((this.parent as Trainer).model)) {
                newConnection.end.connection = newConnection
                receptor.connection = newConnection // Set receptor reference
                this.connections.set(newConnection.id, newConnection) // Set endpoint reference
                this.parent.globalState.connections.set(newConnection.id, newConnection) // Set global reference
                this.parent.globalState.mainLayer.add(newConnection.line) // Add to canvas
                newConnection.line.moveToBottom()
            } else if (receptor.type === "dataset" && receptor.propagateDataset((this.parent as CustomDataset).customTFDataset)) {
                newConnection.end.connection = newConnection
                this.connections.set(newConnection.id, newConnection) // Set endpoint reference
                receptor.connection = newConnection // Set receptor reference
                this.parent.globalState.connections.set(newConnection.id, newConnection) // Set global reference
                this.parent.globalState.mainLayer.add(newConnection.line) // Add to canvas
                newConnection.line.moveToBottom()
            } else if (propagate === false || receptor.propagate(this.currentValue)) {
                newConnection.end.connection = newConnection
                this.connections.set(newConnection.id, newConnection) // Set endpoint reference
                this.parent.globalState.connections.set(newConnection.id, newConnection) // Set global reference
                this.parent.globalState.mainLayer.add(newConnection.line) // Add to canvas
                newConnection.line.moveToBottom()
                receptor.parent.updateReceptorsAndEndpoints()
            } else {
                receptor.connection = undefined
                console.warn("Cycle detected, cannot add connection")
                console.log("Visited receptors:", Array.from(this.parent.globalState.visitedReceptorCount))
                alert("Cycle detected, cannot add connection")
            }
            this.parent.globalState.visitedReceptorCount.clear()
        }
    }

    propagate(data?: Tensor | SymbolicTensor | null, noRecursive?: boolean): boolean {
        this.currentValue = data
        if (noRecursive) { return true }

        // Check that all inputs to this endpoint's parent have up-to-date flowIDs
        // if (this.parent.inputs.some(r => r.flowId !== this.parent.globalState.flowId)) {
        //     console.log(this, 'need update', this.parent.inputs.filter((r, i) => r.flowId !== this.parent.globalState.flowId))
        //     return true
        // }

        let allSatisfy = true
        let someSatisfy = false
        let newNodeIds = new Set<string>()
        let paintedLines = new Set<Line>()
        const existingIds = this.parent.globalState.visitedReceptorCount
        this.connections.forEach(c => {
            if (existingIds.get(c.end.parent.id + "-" + c.end.index) ?? 0 < MAX_RECEPTOR_REUSE) {
                newNodeIds.add(c.end.parent.id + "-" + c.end.index)
                // if (c.line.stroke() === "#606060" && data instanceof SymbolicTensor) {
                //     c.line.stroke("#80c0f0")
                //     paintedLines.add(c.line)
                // } else {
                //     c.line.stroke("#606060")
                // }
                const propagationStatus = c.end.propagate(data)
                if (this.parent.globalState.isTraining && propagationStatus) {
                    someSatisfy = true
                    c.line.stroke("#80c8f0")
                    c.line.pointerAtBeginning(true)
                    c.line.fill("#80c8f0")
                } else {
                    c.setActive(!!data)
                    c.line.stroke("#606060")
                    c.line.fill("#606060")
                    c.line.pointerAtBeginning(false)
                }
            } else {
                allSatisfy = false
            }
        })
        if (this.parent.globalState.isTraining) {
            return someSatisfy
        } else {
            if (allSatisfy) {
                newNodeIds.forEach(id => {
                    existingIds.set(id, existingIds.get(id) ?? 0 + 1)
                })
            } else {
                console.log('stopped propagating at', this)
                paintedLines.forEach(l => {
                    l.stroke("#606060")
                    l.fill("#606060")
                })
            }
            
            return allSatisfy
        }
    }

    // null means training, undefined means no connection
    propagateModel(model?: LayersModel | null) {
        this.connections.forEach(c => {
            c.end.propagateModel(model)
        })
    }

    propagateDataset(dataset?: CustomTFDataset | null) {
        this.connections.forEach(c => {
            c.end.propagateDataset(dataset)
        })
    }

    showToolTip() {
        this.parent.globalState.showTooltip!({
            text: this.name ?? `Output ${this.index + 1}`,
            position: this.element.absolutePosition()
        })
    }
}