import Graph, { MultiGraph } from 'graphology'

import dfsFromNode from '../utils/dfsFromNode'
import {
  ClainTransactionLayoutEdge,
  ClainTransactionLayoutNode,
  ClainTransactionLayoutNodeAttributes,
  ClainTransactionLayoutOptions,
  ClainTransactionLayoutReturn,
} from './customClainTransactionLayout.types'

const WIDTH_UNIT = 350
const HEIGHT_UNIT = -200

const createGraph = (
  nodes: Array<ClainTransactionLayoutNode>,
  edges: Array<ClainTransactionLayoutEdge>
): Graph<ClainTransactionLayoutNodeAttributes> => {
  const graph = new MultiGraph<ClainTransactionLayoutNodeAttributes>()

  nodes.forEach((key) => {
    if (!graph.hasNode(key)) {
      graph.addNode(key, { rank: 0, level: 0 })
    }
  })

  edges.forEach(({ from, to }) => {
    const edgeKey = `${from}_${to}`
    if (!graph.hasEdge(edgeKey)) {
      graph.addEdgeWithKey(edgeKey, from, to)
    }
  })

  return graph
}

const assignRanks = (
  graph: Graph<ClainTransactionLayoutNodeAttributes>,
  rootKey: string
) => {
  const ranks = new Map<string, number>([[rootKey, 0]])
  let g = new Set([rootKey])

  const getRanksRecursiveOut = (nodeKey: string, rootKey: string) => {
    const path = new Set<string>()

    dfsFromNode(
      graph,
      nodeKey,
      (key, _, depth) => {
        path.add(key)

        const rootRank = ranks.get(rootKey)
        const currentRank = depth + rootRank + 1

        if (ranks.has(key)) {
          const savedRank = ranks.get(key)
          ranks.set(key, Math.max(savedRank, currentRank))
        } else {
          ranks.set(key, currentRank)
        }
      },
      { mode: 'out' }
    )

    g = new Set([...Array.from(g), ...Array.from(path)])

    path.forEach((key) => {
      graph.inEdges(key).forEach((inEdgeKey) => {
        const nodeKey = graph.source(inEdgeKey)
        if (nodeKey !== rootKey && !g.has(nodeKey)) {
          getRanksRecursiveIn(nodeKey, key)
        }
      })
    })
  }

  const getRanksRecursiveIn = (nodeKey: string, rootKey: string) => {
    const path = new Set<string>()

    dfsFromNode(
      graph,
      nodeKey,
      (key, _, depth) => {
        path.add(key)

        const rootRank = ranks.get(rootKey)
        const currentRank = rootRank - depth - 1

        if (ranks.has(key)) {
          const savedRank = ranks.get(key)
          ranks.set(key, Math.min(savedRank, currentRank))
        } else {
          ranks.set(key, currentRank)
        }
      },
      { mode: 'in' }
    )

    g = new Set([...Array.from(g), ...Array.from(path)])

    path.forEach((key) => {
      graph.outEdges(key).forEach((outEdgeKey) => {
        const nodeKey = graph.target(outEdgeKey)

        if (nodeKey !== rootKey && !g.has(nodeKey)) {
          getRanksRecursiveOut(nodeKey, key)
        }
      })
    })
  }

  graph.outEdges(rootKey).forEach((outEdge) => {
    getRanksRecursiveOut(graph.target(outEdge), rootKey)
  })

  graph.inEdges(rootKey).forEach((inEdge) => {
    getRanksRecursiveIn(graph.source(inEdge), rootKey)
  })

  ranks.forEach((rank, key) => {
    graph.setNodeAttribute(key, 'rank', rank)
  })
}

const assignLevels = (
  graph: Graph<ClainTransactionLayoutNodeAttributes>,
  rootKey: string
) => {
  const levels = new Map<string, number>()
  const levelsCapacity = new LevelsCapacity()
  const path = new Set<string>()

  const getLevelsRecursiveOut = (nodeKey: string, rootLevel = 0) => {
    dfsFromNode(
      graph,
      nodeKey,
      (key, { rank }) => {
        if (path.has(key)) return true

        const inEdges = graph.inEdges(key)
        const prevNodeLevel = !inEdges.length
          ? rootLevel
          : Math.max(
              ...inEdges.map(
                (inEdge) => levels.get(graph.source(inEdge)) ?? rootLevel
              )
            )

        const currentTopRank =
          levelsCapacity.getTop(rank) !== undefined
            ? levelsCapacity.getTop(rank) + 1
            : prevNodeLevel
        const level = Math.max(prevNodeLevel, currentTopRank)

        levelsCapacity.setTop(rank, level)
        if (levelsCapacity.getBottom(rank) === undefined) {
          levelsCapacity.setBottom(rank, level)
        }
        levels.set(key, level)
        path.add(key)

        adjustBranch(key, { mode: 'out' })
      },
      { mode: 'out' }
    )

    const localPath = new Set<string>()

    dfsFromNode(
      graph,
      nodeKey,
      (key) => {
        if (localPath.has(key)) return true

        const inEdges = graph.inEdges(key)
        inEdges.sort().forEach((inEdge) => {
          const source = graph.source(inEdge)

          if (!path.has(source)) {
            getLevelsRecursiveIn(source, levels.get(key))
          }
        })

        localPath.add(key)
      },
      { mode: 'out' }
    )
  }

  const getLevelsRecursiveIn = (nodeKey: string, rootLevel = 0) => {
    dfsFromNode(
      graph,
      nodeKey,
      (key, { rank }) => {
        if (path.has(key)) return true

        const outEdges = graph.outEdges(key)
        const prevNodeLevel = !outEdges.length
          ? rootLevel
          : Math.min(
              ...outEdges.map(
                (outEdge) => levels.get(graph.target(outEdge)) ?? rootLevel
              )
            )

        const currentBottomRank =
          levelsCapacity.getBottom(rank) !== undefined
            ? levelsCapacity.getBottom(rank) - 1
            : prevNodeLevel
        const level = Math.min(prevNodeLevel, currentBottomRank)

        levelsCapacity.setBottom(rank, level)
        if (levelsCapacity.getTop(rank) === undefined) {
          levelsCapacity.setTop(rank, level)
        }
        levels.set(key, level)
        path.add(key)

        adjustBranch(key, { mode: 'in' })
      },
      { mode: 'in' }
    )

    const localPath = new Set<string>()

    dfsFromNode(
      graph,
      nodeKey,
      (key) => {
        if (localPath.has(key)) return true

        const outEdges = graph.outEdges(key)
        outEdges.sort().forEach((outEdge) => {
          const target = graph.target(outEdge)

          if (!path.has(target)) {
            getLevelsRecursiveOut(target, levels.get(key))
          }
        })

        localPath.add(key)
      },
      { mode: 'in' }
    )
  }

  const adjustBranch = (nodeKey: string, options?: { mode: 'in' | 'out' }) => {
    const mode = options?.mode || 'out'
    const requiredLevel = levels.get(nodeKey)

    if (mode === 'out') {
      const adjustBranchRecursive = (nodeKey: string) => {
        const rank = graph.getNodeAttribute(nodeKey, 'rank')
        levelsCapacity.setTop(rank, requiredLevel)
        levels.set(nodeKey, requiredLevel)

        const inEdges = graph.inEdges(nodeKey)
        if (inEdges.length) {
          inEdges.forEach((inEdge) => {
            const source = graph.source(inEdge)
            const sourceOutEdges = graph.outEdges(source)
            if (
              !sourceOutEdges.some((sourceOutEdge) => {
                const sourceOutEdgeTarget = graph.target(sourceOutEdge)
                return levels.get(source) === levels.get(sourceOutEdgeTarget)
              })
            ) {
              adjustBranchRecursive(source)
            }
          })
        }
      }

      const inEdges = graph.inEdges(nodeKey)
      if (inEdges.length) {
        inEdges.forEach((inEdge) => {
          const source = graph.source(inEdge)
          const sourceOutEdges = graph.outEdges(source)
          if (
            !sourceOutEdges.some((sourceOutEdge) => {
              const sourceOutEdgeTarget = graph.target(sourceOutEdge)
              return levels.get(source) === levels.get(sourceOutEdgeTarget)
            })
          ) {
            adjustBranchRecursive(source)
          }
        })
      }
    }

    if (mode === 'in') {
      const adjustBranchRecursive = (nodeKey: string) => {
        const rank = graph.getNodeAttribute(nodeKey, 'rank')
        levelsCapacity.setBottom(rank, requiredLevel)
        levels.set(nodeKey, requiredLevel)

        const outEdges = graph.outEdges(nodeKey)
        if (outEdges.length) {
          outEdges.forEach((outEdge) => {
            const target = graph.target(outEdge)
            const targetInEdges = graph.inEdges(target)
            if (
              !targetInEdges.some((targetInEdge) => {
                const targetInEdgeSource = graph.source(targetInEdge)
                return levels.get(target) === levels.get(targetInEdgeSource)
              })
            ) {
              adjustBranchRecursive(target)
            }
          })
        }
      }

      const outEdges = graph.outEdges(nodeKey)
      if (outEdges.length) {
        outEdges.forEach((outEdge) => {
          const target = graph.target(outEdge)
          const targetInEdges = graph.inEdges(target)
          if (
            !targetInEdges.some((targetInEdge) => {
              const targetInEdgeSource = graph.source(targetInEdge)
              return levels.get(target) === levels.get(targetInEdgeSource)
            })
          ) {
            adjustBranchRecursive(target)
          }
        })
      }
    }
  }

  getLevelsRecursiveOut(rootKey)

  levels.forEach((level, key) => {
    graph.setNodeAttribute(key, 'level', level)
  })
}

class LevelsCapacity {
  private capacityTop: Map<number, number> = new Map()
  private capacityBottom: Map<number, number> = new Map()

  public getTop(rank: number) {
    return this.capacityTop.get(rank)
  }

  public getBottom(rank: number) {
    return this.capacityBottom.get(rank)
  }

  public setTop(rank: number, level: number) {
    this.capacityTop.set(rank, level)
  }

  public setBottom(rank: number, level: number) {
    this.capacityBottom.set(rank, level)
  }
}

export const customClainTransactionLayout = ({
  nodes,
  edges,
  rootKey,
}: ClainTransactionLayoutOptions): ClainTransactionLayoutReturn => {
  const graph = createGraph(nodes, edges)
  assignRanks(graph, rootKey)
  assignLevels(graph, rootKey)

  const positions = {}

  graph.forEachNode((key, { rank, level }) => {
    positions[key] = { x: rank * WIDTH_UNIT, y: level * HEIGHT_UNIT }
  })

  const { x, y } = positions[rootKey]

  graph.forEachNode((key) => {
    positions[key].x = positions[key].x - x
    positions[key].y = positions[key].y - y
  })

  graph.clear()

  return { positions }
}
