package de.geomobile.common.utils

import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable


@Serializable
data class Tree<T>(
    val rootNodes: List<Node<T>>
) {

    @Serializable
    data class Node<T>(
        val value: T,
        val children: List<Node<T>> = emptyList()
    ) {
        val subTree: Tree<T> get() = Tree(children)
    }

    fun any(predicate: (T) -> Boolean): Boolean = rootNodes.any {
        predicate(it.value) || it.subTree.any(predicate)
    }

    fun all(predicate: (T) -> Boolean): Boolean = rootNodes.all {
        predicate(it.value) && it.subTree.all(predicate)
    }

    fun <R> mapByBranch(transform: (branch: List<Node<T>>, selected: Node<T>) -> R): Tree<R> = Tree(
        rootNodes.map {
            val localBranch = listOf(it)
            Node(
                value = transform(localBranch, it),
                children = it.subTree.mapByBranch { branch, selected ->
                    transform(localBranch + branch, selected)
                }.rootNodes
            )
        }
    )

    fun <R> map(transform: (parent: Node<T>?, selected: Node<T>) -> R): Tree<R> = Tree(
        rootNodes.map {
            Node(
                value = transform(null, it),
                children = it.subTree.map { _, selected -> transform(it, selected) }.rootNodes
            )
        }
    )

    fun <R> mapValues(transform: (T) -> R): Tree<R> = Tree(
        rootNodes.map {
            Node(
                value = transform(it.value),
                children = it.subTree.mapValues(transform).rootNodes
            )
        }
    )

    fun findNode(predicate: (T) -> Boolean): Node<T>? =
        rootNodes.firstOrNull { predicate(it.value) }
            ?: rootNodes.asSequence().mapNotNull { it.subTree.findNode(predicate) }.firstOrNull()

    fun findNodes(predicate: (T) -> Boolean): List<Node<T>> =
        rootNodes.filter { predicate(it.value) } + rootNodes.flatMap { it.subTree.findNodes(predicate) }

    fun findBranch(predicate: (T) -> Boolean): List<T> {
        val match = rootNodes.firstOrNull { predicate(it.value) }
        if (match != null) return listOf(match.value)

        for (node in rootNodes) {
            val subMatch = node.subTree.findBranch(predicate)
            if (subMatch.isNotEmpty())
                return listOf(node.value) + subMatch
        }
        return emptyList()
    }

    fun findParent(predicate: (T) -> Boolean): T? = findParent(null, predicate)

    private fun findParent(parent: T?, predicate: (T) -> Boolean): T? {
        val values = rootNodes.map { it.value }
        if (values.any(predicate)) return parent

        return rootNodes.asSequence().map { it.subTree.findParent(it.value, predicate) }.filter { it != null }
            .firstOrNull()
    }

    fun findSiblings(predicate: (T) -> Boolean): List<T>? {
        val values = rootNodes.map { it.value }
        if (values.any(predicate)) return values

        return rootNodes.asSequence().mapNotNull { it.subTree.findSiblings(predicate) }.firstOrNull()
    }

    fun filter(predicate: (T) -> Boolean): Tree<T> {
        return Tree(
            rootNodes = rootNodes
                .filter { predicate(it.value) }
                .map { it.copy(children = it.subTree.filter(predicate).rootNodes) }
        )
    }

    fun transformSiblings(
        select: (T) -> Boolean,
        transform: (siblings: List<Node<T>>) -> List<Node<T>>
    ): Tree<T> {
        val newRootNodes =
            if (rootNodes.map { it.value }.any(select)) transform(rootNodes)
            else rootNodes

        return Tree(
            rootNodes = newRootNodes.map {
                it.copy(children = it.subTree.transformSiblings(select, transform).rootNodes)
            }
        )
    }

    fun plusChild(selectParent: (T?) -> Boolean, newNode: Node<T>): Tree<T> {
        return plusChild(null, selectParent, newNode)
    }

    private fun plusChild(parent: T?, selectParent: (T?) -> Boolean, newNode: Node<T>): Tree<T> {
        return if (selectParent(parent)) Tree(rootNodes.plus(newNode))
        else Tree(
            rootNodes.map {
                it.copy(children = it.subTree.plusChild(it.value, selectParent, newNode).rootNodes)
            }
        )
    }

    fun plusSibling(select: (T) -> Boolean, newNode: Node<T>): Tree<T> {
        val index = rootNodes.indexOfFirst { select(it.value) }
        val newRootNodes =
            if (index >= 0)
                rootNodes.toMutableList().apply {
                    add(index + 1, newNode)
                }
            else rootNodes

        return Tree(
            rootNodes = when {
                index < 0 -> rootNodes.map {
                    it.copy(children = it.subTree.plusSibling(select, newNode).rootNodes)
                }
                else -> newRootNodes
            }
        )
    }
}

fun <T> emptyTree(): Tree<T> = Tree(emptyList())

val <T : Any> KSerializer<T>.tree: KSerializer<Tree<T>>
    get() = Tree.serializer(this)