{-|
== Lowest common ancestor queries on a tree

Uses an Euler tour and a sparse table for range minimum queries, with an range size optimization
from 2n to n, adapted from PyRival.

Sources:

* https://en.wikipedia.org/wiki/Lowest_common_ancestor
* Michael Bender and Martin Farach-Colton, "The LCA Problem Revisited", 2000
  https://www.ics.uci.edu/~eppstein/261/BenFar-LCA-00.pdf
* https://github.com/cheran-senthil/PyRival/blob/master/pyrival/graphs/lca.py

-}

{-
Implementation notes:
* l - 1 is taken as a dummy root for a forest, converting it to a tree to make things simpler.
-}

module LCA
    ( LCA
    , buildLCA
    , queryLCA
    , build1LCA
    , query1LCA
    ) where

import Control.DeepSeq
import Data.Array.ST
import Data.Array.Unboxed
import Data.Foldable
import Data.Graph

import SparseTable ( buildSP, foldISP )

data LCA = LCA !(UArray (Int, Int) Int) !(UArray Vertex Int) !(UArray Int Vertex) deriving Int -> LCA -> ShowS
[LCA] -> ShowS
LCA -> String
(Int -> LCA -> ShowS)
-> (LCA -> String) -> ([LCA] -> ShowS) -> Show LCA
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LCA] -> ShowS
$cshowList :: [LCA] -> ShowS
show :: LCA -> String
$cshow :: LCA -> String
showsPrec :: Int -> LCA -> ShowS
$cshowsPrec :: Int -> LCA -> ShowS
Show

-- | Build a structure for LCA queries on a tree. O(n log n).
buildLCA :: Bounds -> Tree Vertex -> LCA
buildLCA :: Bounds -> Tree Int -> LCA
buildLCA (Int
l, Int
r) Tree Int
_ | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r = String -> LCA
forall a. HasCallStack => String -> a
error String
"buildLCA: empty range"
buildLCA (Int
l, Int
r) Tree Int
t = UArray Bounds Int -> UArray Int Int -> UArray Int Int -> LCA
LCA UArray Bounds Int
sp UArray Int Int
time UArray Int Int
itime where
    n :: Int
n = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    itime :: UArray Int Int
itime = Bounds -> [Int] -> UArray Int Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Int
1, Int
n) ([Int] -> UArray Int Int) -> [Int] -> UArray Int Int
forall a b. (a -> b) -> a -> b
$ Tree Int -> [Int]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Tree Int
t
    time :: UArray Int Int
time = Bounds -> [Bounds] -> UArray Int Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [(i, e)] -> a i e
array (Int
l, Int
r) [(Int
x, Int
i) | (Int
i, Int
x) <- UArray Int Int -> [Bounds]
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> [(i, e)]
assocs UArray Int Int
itime]
    euler :: [Int]
euler = Tree Int -> [Int] -> [Int]
go Tree Int
t [] where
        go :: Tree Int -> [Int] -> [Int]
go (Node Int
u Forest Int
ts) = let x :: Int
x = UArray Int Int
timeUArray Int Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Int
u in Int
x Int -> ([Int] -> [Int]) -> [Int] -> [Int]
`seq` (Tree Int -> ([Int] -> [Int]) -> [Int] -> [Int])
-> ([Int] -> [Int]) -> Forest Int -> [Int] -> [Int]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) (([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int])
-> (Tree Int -> [Int] -> [Int])
-> Tree Int
-> ([Int] -> [Int])
-> [Int]
-> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int
xInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) ([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) (([Int] -> [Int]) -> [Int] -> [Int])
-> (Tree Int -> [Int] -> [Int]) -> Tree Int -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree Int -> [Int] -> [Int]
go) [Int] -> [Int]
forall a. a -> a
id Forest Int
ts
    sp :: UArray Bounds Int
sp = if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 then (forall s. ST s (STUArray s Bounds Int)) -> UArray Bounds Int
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s Bounds Int)) -> UArray Bounds Int)
-> (forall s. ST s (STUArray s Bounds Int)) -> UArray Bounds Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int)
-> Bounds -> [Int] -> ST s (STUArray s Bounds Int)
forall (a :: * -> * -> *) e s.
MArray a e (ST s) =>
(e -> e -> e) -> Bounds -> [e] -> ST s (a Bounds e)
buildSP Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Int
1, Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [Int]
euler else (Bounds, Bounds) -> [Int] -> UArray Bounds Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray ((Int
1,Int
1),(Int
0,Int
0)) []

-- | Query the LCA of two nodes in a tree. O(1).
queryLCA :: Vertex -> Vertex -> LCA -> Vertex
queryLCA :: Int -> Int -> LCA -> Int
queryLCA Int
u Int
v (LCA UArray Bounds Int
sp UArray Int Int
time UArray Int Int
itime)
    | Int
u Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
u = String -> Int
forall a. HasCallStack => String -> a
error String
"queryLCA: invalid node"
    | Int
u Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
v         = Int
u
    | Bool
otherwise      = UArray Int Int
itime UArray Int Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! (Int -> Int -> Int) -> UArray Bounds Int -> Int -> Int -> Int
forall (a :: * -> * -> *) e.
IArray a e =>
(e -> e -> e) -> a Bounds e -> Int -> Int -> e
foldISP Int -> Int -> Int
forall a. Ord a => a -> a -> a
min UArray Bounds Int
sp (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
fu Int
fv) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
fu Int
fv Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  where
    (Int
l, Int
r) = UArray Int Int -> Bounds
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds UArray Int Int
time
    (Int
fu, Int
fv) = (UArray Int Int
timeUArray Int Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Int
u, UArray Int Int
timeUArray Int Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Int
v)

-- | Build a structure for LCA queries on a forest. O(n log n).
build1LCA :: Bounds -> [Tree Vertex] -> LCA
build1LCA :: Bounds -> Forest Int -> LCA
build1LCA (Int
l, Int
r) = Bounds -> Tree Int -> LCA
buildLCA (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
r) (Tree Int -> LCA) -> (Forest Int -> Tree Int) -> Forest Int -> LCA
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Forest Int -> Tree Int
forall a. a -> Forest a -> Tree a
Node (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- | Query the LCA of two nodes in a forest. O(1).
query1LCA :: Vertex -> Vertex -> LCA -> Maybe Vertex
query1LCA :: Int -> Int -> LCA -> Maybe Int
query1LCA Int
u Int
v lca :: LCA
lca@(LCA UArray Bounds Int
_ UArray Int Int
time UArray Int Int
_) = if Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then Maybe Int
forall a. Maybe a
Nothing else Int -> Maybe Int
forall a. a -> Maybe a
Just Int
x where
    (Int
l, Int
_) = UArray Int Int -> Bounds
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds UArray Int Int
time
    x :: Int
x = Int -> Int -> LCA -> Int
queryLCA Int
u Int
v LCA
lca

--------------------------------------------------------------------------------
-- For tests

instance NFData LCA where
    rnf :: LCA -> ()
rnf = LCA -> ()
forall a. a -> ()
rwhnf