import { Edge, FitView } from 'reactflow'
import Dagre, { Label, graphlib } from '@dagrejs/dagre'
import _uniq from 'lodash/uniq'

import {
	EdgeData,
	BasicNodeData,
	CangoChainNode,
	NodeData,
	NodeDataWithParents,
	ListedStep,
	StepCycles,
} from './types'

const g = new Dagre.graphlib.Graph({ directed: true }).setDefaultEdgeLabel(() => ({}))

export const sortEdgesByNodes = (edges: Edge[], orderedNodes: string[]): Edge[] => {
	// Create a map to store the order of nodes
	const nodeOrder = new Map<string, number>()

	// Assign each node in orderedNodes an index
	orderedNodes.forEach((node, index) => {
		nodeOrder.set(node, index)
	})

	// Sort edges based on the order of the source and target nodes
	return edges.sort((a, b) => {
		const sourceOrderA = nodeOrder.get(a.source) ?? Infinity
		const sourceOrderB = nodeOrder.get(b.source) ?? Infinity

		// First compare by source order
		if (sourceOrderA !== sourceOrderB) {
			return sourceOrderA - sourceOrderB
		}

		// If source nodes are the same, compare by target order
		const targetOrderA = nodeOrder.get(a.target) ?? Infinity
		const targetOrderB = nodeOrder.get(b.target) ?? Infinity

		return targetOrderA - targetOrderB
	})
}

const propagateData = (
	nodes: CangoChainNode<BasicNodeData>[],
	edges: Edge<EdgeData>[],
): {
	nodes: CangoChainNode<NodeDataWithParents>[]
	edges: Edge<EdgeData>[]
	adjacencyList: Map<string, Edge<EdgeData>[]>
} => {
	const nodeMap = new Map<string, CangoChainNode<NodeDataWithParents>>(
		nodes.map((node) => [
			node.id,
			{
				...node,
				data: { ...node.data, sections: node.data.isSection ? [node.data.name] : [], parents: [] },
			},
		]),
	)

	const childToParents = new Map<string, string[]>()
	// Pre-process edges to map children to their parents and to create an adjacency list for quick look-up
	const adjacencyList = new Map<string, Edge<EdgeData>[]>()

	edges.forEach((edge) => {
		if (childToParents.has(edge.target)) {
			childToParents.get(edge.target)?.push(edge.source)
		} else {
			childToParents.set(edge.target, [edge.source])
		}
		if (adjacencyList.has(edge.source)) {
			adjacencyList.get(edge.source)?.push(edge)
		} else {
			adjacencyList.set(edge.source, [edge])
		}
	})

	const applyInheritance = ({ nodeId, parentNodeId }: { nodeId: string; parentNodeId: string }) => {
		const node = nodeMap.get(nodeId) as CangoChainNode<NodeDataWithParents> | undefined
		const parentNode = nodeMap.get(parentNodeId) as CangoChainNode<NodeDataWithParents> | undefined
		if (!node || !parentNode) {
			return
		}

		const parents = parentNode.data.parents
		const sections = parentNode.data.sections

		const newSections = node.data.isSection
			? [node.data.name]
			: _uniq([...node.data.sections, ...sections])
		const newParents = _uniq(
			parentNodeId ? [parentNodeId, ...node.data.parents, ...parents] : parents,
		)

		const newNodeData = {
			...node.data,
			sections: newSections,
			parents: newParents,
		}

		nodeMap.set(nodeId, {
			...node,
			data: newNodeData,
		})
	}

	for (const edge of edges) {
		applyInheritance({
			nodeId: edge.target,
			parentNodeId: edge.source,
		})
	}

	const indexNodeMap = new Map<number, string>()
	const nodeIndexMap = new Map(
		nodes.map((node, index) => {
			indexNodeMap.set(index, node.id)
			return [node.id, index]
		}),
	)
	const adjacencyListArray: number[][] = Array(nodes.length)
		.fill(null)
		.map(() => [])
	edges.forEach(({ source, target }) => {
		const fromIndex = nodeIndexMap.get(source) as number
		const toIndex = nodeIndexMap.get(target) as number
		adjacencyListArray[fromIndex].push(toIndex)
	})

	return { nodes: Array.from(nodeMap.values()), edges, adjacencyList }
}

export const detectCycles = (edges: Edge[]) => {
	const g = new graphlib.Graph({ directed: true })
	edges.forEach((edge) => g.setEdge(edge.source, edge.target))
	return graphlib.alg.findCycles(g)
}

const findFirstEdge = (edges: Edge[]): Edge | null => {
	// Step 1: Calculate in-degree of each node
	const inDegreeMap = new Map<string, number>()

	edges.forEach(({ source, target }) => {
		inDegreeMap.set(target, (inDegreeMap.get(target) || 0) + 1)
		if (!inDegreeMap.has(source)) {
			inDegreeMap.set(source, 0)
		}
	})

	// Step 2: Find a node with zero in-degree
	const startNode = Array.from(inDegreeMap.entries()).find(([, inDegree]) => inDegree === 0)?.[0]
	if (!startNode) {
		return null // No starting point found (or graph might still have cycles)
	}

	// Step 3: Find an edge that starts from this start node
	const firstEdge = edges.find((edge) => edge.source === startNode)
	return firstEdge || null
}

export const getLayoutedElements = (
	steps: ListedStep[],
	props: { fitView: FitView },
	initialNodeIds: string[],
): {
	nodes: CangoChainNode<NodeDataWithParents>[]
	edges: Edge<EdgeData>[]
	adjacencyList: Map<string, Edge<EdgeData>[]>
	cycles: StepCycles
} => {
	g.setGraph({ rankdir: 'TB', nodesep: 200, edgesep: 300 })
	const { edges, nodes } = steps.reduce(
		(
			acc: {
				edges: Edge[]
				nodes: Label[]
			},
			_step,
		) => {
			const stepChildren = _step.descendants
			const descendantsMap = new Map(stepChildren.map((_desc) => [_desc.step, _desc]))

			const nodeData: BasicNodeData = {
				..._step,
				fitView: props.fitView,
				options: _step.options,
			}

			acc.nodes.push({
				id: _step._id,
				data: nodeData,
				type: _step.chain_reference ? 'chainReference' : 'standard',
				width: 200,
				height: 300,
			})

			const newEdges = stepChildren.reduce((_acc: Edge[], _childStep) => {
				if (_acc.find((_edge) => _edge.source === _step._id && _edge.target === _childStep.step)) {
					return _acc
				}

				const childMeta = descendantsMap.get(_childStep.step)

				if (!childMeta) {
					return _acc
				}

				const edgeData: EdgeData = {
					isMenu: _step.isMenu,
					child: childMeta,
					options: _step.options,
				}

				_acc.push({
					id: `e${_step._id}--${_childStep.step}`,
					source: _step._id,
					target: _childStep.step,
					type: 'chainOptionsEdge',
					data: edgeData,
				})

				return _acc
			}, [])
			acc.edges.push(...newEdges)
			return acc
		},
		{
			edges: [],
			nodes: [],
		},
	)

	const cycles = detectCycles(edges)
	edges.forEach((edge) => {
		g.setEdge(edge.source, edge.target)
	})
	nodes.forEach((node) => g.setNode(node.id, node))

	const edgeRoutes: string[][] = initialNodeIds.reduce((_acc: string[][], startingId) => {
		_acc.push([startingId])
		return _acc
	}, [])

	let edgesToLoop = [...edges]

	const removedCycles: StepCycles = new Map()
	const fulfilledCycles: {
		[index: number]: string[]
	} = []

	while (edgesToLoop.length) {
		const edgeToHandle = edgesToLoop.find((_edge) =>
			edgeRoutes.some((_route) => _route[_route.length - 1] === _edge.source),
		)

		if (!edgeToHandle) {
			const firstEdge = findFirstEdge(edgesToLoop)
			if (firstEdge) {
				edgeRoutes.push([firstEdge.source, firstEdge.target])
				edgesToLoop = edgesToLoop.filter((_edge) => _edge.id !== firstEdge.id)
				continue
			}
			break
		}

		const indexesOfCycles = cycles.reduce((_acc: number[], _cycle, index) => {
			if (_cycle.includes(edgeToHandle.source) && _cycle.includes(edgeToHandle.target)) {
				_acc.push(index)
				return _acc
			}
			_acc.push(-1)
			return _acc
		}, [])

		indexesOfCycles.forEach((_index) => {
			if (_index === -1) {
				return
			}
			const cycle = cycles[_index]
			if (cycle.length - 1 === fulfilledCycles[_index]?.length) {
				removedCycles.set(edgeToHandle.id, edgeToHandle)
			} else {
				fulfilledCycles[_index] = [...(fulfilledCycles[_index] ?? []), edgeToHandle.source]
			}
		})

		const routeIndex = edgeRoutes.findIndex(
			(_route) => _route[_route.length - 1] === edgeToHandle.source,
		)

		const currentRoute = [...edgeRoutes[routeIndex]]

		const edgesFromSource = edgesToLoop.filter((_edge) => _edge.source === edgeToHandle.source)

		if (edgesFromSource.length > 1) {
			// For each target, except the first one (handled by the current route), create a new route
			edgesFromSource.forEach((branchEdge) => {
				if (branchEdge.id === edgeToHandle.id) {
					// Handle the first target by continuing the current route
					edgeRoutes[routeIndex].push(branchEdge.target)
				} else {
					// For other targets, clone the current route and append the new target
					const newRoute = [...currentRoute, branchEdge.target]
					edgeRoutes.push(newRoute)
				}
			})
			edgesToLoop = edgesToLoop.filter(
				(_edge) => !edgesFromSource.some((_edgeFromSource) => _edgeFromSource.id === _edge.id),
			)
		} else {
			// If no branching, just continue the current route
			edgeRoutes[routeIndex].push(edgeToHandle.target)
		}

		// Remove the handled edges from the edgesToLoop list
		edgesToLoop = edgesToLoop.filter((_edge) => !edgesFromSource.includes(_edge))
	}

	edges.forEach((edge) => {
		if (removedCycles.has(edge.id)) {
			return
		}
		edgeRoutes.forEach((_route) => {
			if (!_route.includes(edge.source) || !_route.includes(edge.target)) {
				return false
			}
			const sourceIndex = _route.indexOf(edge.source)
			const targetIndex = _route.indexOf(edge.target)
			if (sourceIndex > targetIndex) {
				removedCycles.set(edge.id, edge)
			}
		})
	})

	const filteredEdges = edges.reduce((_acc: Edge<EdgeData>[], _edge) => {
		if (!removedCycles.has(_edge.id)) {
			_acc.push(_edge)
			return _acc
		}
		g.removeEdge(_edge.source, _edge.target)
		return _acc
	}, [])

	Dagre.layout(g)

	const sortedNodes = graphlib.alg.topsort(g)
	const sortedEdges = sortEdgesByNodes(filteredEdges, sortedNodes)

	const prePropagatedNodes = nodes.map((node) => {
		const { x, y } = g.node(node.id)
		return { ...node, position: { x, y } }
	}) as CangoChainNode<NodeData>[]

	const {
		nodes: propagatedNodes,
		edges: propagatedEdges,
		adjacencyList,
	} = propagateData(prePropagatedNodes, sortedEdges)

	return {
		nodes: propagatedNodes.map((node) => ({
			...node,
			data: { ...node.data, nodes: propagatedNodes },
		})),
		edges: propagatedEdges,
		adjacencyList,
		cycles: removedCycles,
	}
}
