{-# LANGUAGE BangPatterns, ScopedTypeVariables #-}
{-|
== Heavy-light decomposition

Decomposition of a tree of size n into multiple paths, such that there are O(log n) paths between
any pair of vertices.

HLD is useful for problems that have queries or updates on the path between two vertices.

This implementation maps each node to an integer position in [1..n] such that nodes in a path have
consecutive positions. A path between two vertices decomposes into O(log n) ranges of consecutive
integers, possibly simplifying the problem.

Sources:

* https://en.wikipedia.org/wiki/Heavy_path_decomposition
* https://cp-algorithms.com/graph/hld.html

-}

{-
Implementation notes:
* The size array need not be stored if subtree queries are not required
-}

module HLD
    ( HLD(..)
    , buildHLD
    , posHLD
    , pathHLD
    , edgePathHLD
    , subtreeHLD
    , lcaHLD
    ) where

import Control.DeepSeq
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
import Data.Array.Unsafe
import Data.Graph
import Data.Ord
import Data.Tree
import GHC.Exts

import Misc ( maximumByMaybe )

data HLD = HLD
    { HLD -> UArray Vertex Vertex
par_ :: !(UArray Vertex Vertex)
    , HLD -> UArray Vertex Vertex
dep_ :: !(UArray Vertex Int)
    , HLD -> UArray Vertex Vertex
hed_ :: !(UArray Vertex Vertex)
    , HLD -> UArray Vertex Vertex
pos_ :: !(UArray Vertex Int)
    , HLD -> UArray Vertex Vertex
siz_ :: !(UArray Vertex Int)
    }

-- | Builds the HLD structure from a tree. O(n).
buildHLD :: Bounds -> Tree Vertex -> HLD
buildHLD :: Bounds -> Tree Vertex -> HLD
buildHLD Bounds
bnds Tree Vertex
t = UArray Vertex Vertex
-> UArray Vertex Vertex
-> UArray Vertex Vertex
-> UArray Vertex Vertex
-> UArray Vertex Vertex
-> HLD
HLD UArray Vertex Vertex
par UArray Vertex Vertex
dep UArray Vertex Vertex
hed UArray Vertex Vertex
pos UArray Vertex Vertex
sz where
    [UArray Vertex Vertex
par, UArray Vertex Vertex
dep, UArray Vertex Vertex
sz] = (forall s. ST s [UArray Vertex Vertex]) -> [UArray Vertex Vertex]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [UArray Vertex Vertex]) -> [UArray Vertex Vertex])
-> (forall s. ST s [UArray Vertex Vertex])
-> [UArray Vertex Vertex]
forall a b. (a -> b) -> a -> b
$ do
        [STUArray s Vertex Vertex
pa, STUArray s Vertex Vertex
da, STUArray s Vertex Vertex
sa :: STUArray s Vertex Int] <- Vertex
-> ST s (STUArray s Vertex Vertex)
-> ST s [STUArray s Vertex Vertex]
forall (m :: * -> *) a. Applicative m => Vertex -> m a -> m [a]
replicateM Vertex
3 (ST s (STUArray s Vertex Vertex)
 -> ST s [STUArray s Vertex Vertex])
-> ST s (STUArray s Vertex Vertex)
-> ST s [STUArray s Vertex Vertex]
forall a b. (a -> b) -> a -> b
$ Bounds -> ST s (STUArray s Vertex Vertex)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ Bounds
bnds
        let go :: Tree Vertex -> Vertex -> Vertex -> ST s Vertex
go (Node Vertex
u Forest Vertex
ts) Vertex
p Vertex
d = do
                STUArray s Vertex Vertex -> Vertex -> Vertex -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Vertex Vertex
pa Vertex
u Vertex
p
                STUArray s Vertex Vertex -> Vertex -> Vertex -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Vertex Vertex
da Vertex
u Vertex
d :: ST s ()
                Vertex
sm <- (Vertex
1Vertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+) (Vertex -> Vertex) -> ST s Vertex -> ST s Vertex
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tree Vertex -> (Vertex -> ST s Vertex) -> Vertex -> ST s Vertex)
-> (Vertex -> ST s Vertex)
-> Forest Vertex
-> Vertex
-> ST s Vertex
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Tree Vertex
t1 Vertex -> ST s Vertex
k !Vertex
acc -> Tree Vertex -> Vertex -> Vertex -> ST s Vertex
go Tree Vertex
t1 Vertex
u (Vertex
dVertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+Vertex
1) ST s Vertex -> (Vertex -> ST s Vertex) -> ST s Vertex
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Vertex -> ST s Vertex
k (Vertex -> ST s Vertex)
-> (Vertex -> Vertex) -> Vertex -> ST s Vertex
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vertex
accVertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+)) Vertex -> ST s Vertex
forall (f :: * -> *) a. Applicative f => a -> f a
pure Forest Vertex
ts Vertex
0
                Vertex
sm Vertex -> ST s () -> ST s Vertex
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STUArray s Vertex Vertex -> Vertex -> Vertex -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Vertex Vertex
sa Vertex
u Vertex
sm
        Vertex
_ <- Tree Vertex -> Vertex -> Vertex -> ST s Vertex
go Tree Vertex
t (Tree Vertex -> Vertex
forall a. Tree a -> a
rootLabel Tree Vertex
t) Vertex
0
        (STUArray s Vertex Vertex -> ST s (UArray Vertex Vertex))
-> [STUArray s Vertex Vertex] -> ST s [UArray Vertex Vertex]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM STUArray s Vertex Vertex -> ST s (UArray Vertex Vertex)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
unsafeFreeze [STUArray s Vertex Vertex
pa, STUArray s Vertex Vertex
da, STUArray s Vertex Vertex
sa]
    [UArray Vertex Vertex
hed, UArray Vertex Vertex
pos] = (forall s. ST s [UArray Vertex Vertex]) -> [UArray Vertex Vertex]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [UArray Vertex Vertex]) -> [UArray Vertex Vertex])
-> (forall s. ST s [UArray Vertex Vertex])
-> [UArray Vertex Vertex]
forall a b. (a -> b) -> a -> b
$ do
        [STUArray s Vertex Vertex
ha, STUArray s Vertex Vertex
xa :: STUArray s Vertex Int] <- Vertex
-> ST s (STUArray s Vertex Vertex)
-> ST s [STUArray s Vertex Vertex]
forall (m :: * -> *) a. Applicative m => Vertex -> m a -> m [a]
replicateM Vertex
2 (ST s (STUArray s Vertex Vertex)
 -> ST s [STUArray s Vertex Vertex])
-> ST s (STUArray s Vertex Vertex)
-> ST s [STUArray s Vertex Vertex]
forall a b. (a -> b) -> a -> b
$ Bounds -> ST s (STUArray s Vertex Vertex)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ Bounds
bnds
        let go :: Tree Vertex -> Vertex -> Vertex -> ST s Vertex
go (Node Vertex
u Forest Vertex
ts) Vertex
h Vertex
x = do
                STUArray s Vertex Vertex -> Vertex -> Vertex -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Vertex Vertex
ha Vertex
u Vertex
h
                STUArray s Vertex Vertex -> Vertex -> Vertex -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Vertex Vertex
xa Vertex
u (Vertex
xVertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+Vertex
1) :: ST s ()
                case (Tree Vertex -> Tree Vertex -> Ordering)
-> Forest Vertex -> Maybe (Tree Vertex)
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> Ordering) -> f a -> Maybe a
maximumByMaybe ((Tree Vertex -> Vertex) -> Tree Vertex -> Tree Vertex -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing ((UArray Vertex Vertex
szUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!) (Vertex -> Vertex)
-> (Tree Vertex -> Vertex) -> Tree Vertex -> Vertex
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree Vertex -> Vertex
forall a. Tree a -> a
rootLabel)) Forest Vertex
ts of
                    Maybe (Tree Vertex)
Nothing -> Vertex -> ST s Vertex
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vertex
xVertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+Vertex
1)
                    Just Tree Vertex
heavy -> do
                        let lights :: Forest Vertex
lights = (Tree Vertex -> Bool) -> Forest Vertex -> Forest Vertex
forall a. (a -> Bool) -> [a] -> [a]
filter ((Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
/= Tree Vertex -> Vertex
forall a. Tree a -> a
rootLabel Tree Vertex
heavy) (Vertex -> Bool) -> (Tree Vertex -> Vertex) -> Tree Vertex -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree Vertex -> Vertex
forall a. Tree a -> a
rootLabel) Forest Vertex
ts
                        Vertex
x' <- Tree Vertex -> Vertex -> Vertex -> ST s Vertex
go Tree Vertex
heavy Vertex
h (Vertex
xVertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+Vertex
1)
                        (Vertex -> Tree Vertex -> ST s Vertex)
-> Vertex -> Forest Vertex -> ST s Vertex
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Vertex
x1 Tree Vertex
t1 -> Tree Vertex -> Vertex -> Vertex -> ST s Vertex
go Tree Vertex
t1 (Tree Vertex -> Vertex
forall a. Tree a -> a
rootLabel Tree Vertex
t1) Vertex
x1) Vertex
x' Forest Vertex
lights
        Vertex
_ <- Tree Vertex -> Vertex -> Vertex -> ST s Vertex
go Tree Vertex
t (Tree Vertex -> Vertex
forall a. Tree a -> a
rootLabel Tree Vertex
t) Vertex
0
        (STUArray s Vertex Vertex -> ST s (UArray Vertex Vertex))
-> [STUArray s Vertex Vertex] -> ST s [UArray Vertex Vertex]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM STUArray s Vertex Vertex -> ST s (UArray Vertex Vertex)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
unsafeFreeze [STUArray s Vertex Vertex
ha, STUArray s Vertex Vertex
xa]

-- | The position for the given vertex. O(1).
posHLD :: HLD -> Vertex -> Int
posHLD :: HLD -> Vertex -> Vertex
posHLD (HLD UArray Vertex Vertex
_ UArray Vertex Vertex
_ UArray Vertex Vertex
_ UArray Vertex Vertex
pos UArray Vertex Vertex
_) = (UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!)

pathHLD_ :: Bool -> HLD -> Vertex -> Vertex -> [(Int, Int)]
pathHLD_ :: Bool -> HLD -> Vertex -> Vertex -> [Bounds]
pathHLD_ Bool
keepLca (HLD UArray Vertex Vertex
par UArray Vertex Vertex
dep UArray Vertex Vertex
hed UArray Vertex Vertex
pos UArray Vertex Vertex
_) Vertex
u0 Vertex
v0 = (forall b. (Bounds -> b -> b) -> b -> b) -> [Bounds]
forall a. (forall b. (a -> b -> b) -> b -> b) -> [a]
build ((forall b. (Bounds -> b -> b) -> b -> b) -> [Bounds])
-> (forall b. (Bounds -> b -> b) -> b -> b) -> [Bounds]
forall a b. (a -> b) -> a -> b
$ \Bounds -> b -> b
c b
n ->
    let go :: Vertex -> Vertex -> b
go Vertex
u Vertex
v
            | UArray Vertex Vertex
depUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
hu Vertex -> Vertex -> Bool
forall a. Ord a => a -> a -> Bool
> UArray Vertex Vertex
depUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
hv = Vertex -> Vertex -> b
go Vertex
v Vertex
u
            | Vertex
hu Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
/= Vertex
hv =
                let !xhv :: Vertex
xhv = UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
hv
                    !xv :: Vertex
xv = UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
v
                in (Vertex
xhv, Vertex
xv) Bounds -> b -> b
`c` Vertex -> Vertex -> b
go Vertex
u (UArray Vertex Vertex
parUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
hv)
            | Bool
otherwise =
                let minmax :: b -> b -> (b, b)
minmax b
x b
y = if b
x b -> b -> Bool
forall a. Ord a => a -> a -> Bool
<= b
y then (b
x,b
y) else (b
y,b
x)
                    !(!Vertex
xu, !Vertex
xv) = Vertex -> Vertex -> Bounds
forall b. Ord b => b -> b -> (b, b)
minmax (UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
u) (UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
v)
                in case () of
                    ()
_ | Bool
keepLca   -> (Vertex
xu, Vertex
xv) Bounds -> b -> b
`c` b
n
                      | Vertex
xu Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
== Vertex
xv  -> b
n
                      | Bool
otherwise -> (Vertex
xu Vertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+ Vertex
1, Vertex
xv) Bounds -> b -> b
`c` b
n
          where
            !hu :: Vertex
hu = UArray Vertex Vertex
hedUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
u
            !hv :: Vertex
hv = UArray Vertex Vertex
hedUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
v
    in Vertex -> Vertex -> b
go Vertex
u0 Vertex
v0
{-# INLINE pathHLD_ #-}

-- | A list of position ranges which make up the path from u to v. O(log n).
pathHLD :: HLD -> Vertex -> Vertex -> [(Int, Int)]
pathHLD :: HLD -> Vertex -> Vertex -> [Bounds]
pathHLD = Bool -> HLD -> Vertex -> Vertex -> [Bounds]
pathHLD_ Bool
True
{-# INLINE pathHLD #-}

-- | pathsHLD but excludes the LCA. Useful when working with edges, each edge can be mapped to the node
-- it leads down into. O(log n).
edgePathHLD :: HLD -> Vertex -> Vertex -> [(Int, Int)]
edgePathHLD :: HLD -> Vertex -> Vertex -> [Bounds]
edgePathHLD = Bool -> HLD -> Vertex -> Vertex -> [Bounds]
pathHLD_ Bool
False
{-# INLINE edgePathHLD #-}

-- | A position range covering the subtree of the given node. O(1).
subtreeHLD :: HLD -> Vertex -> (Int, Int)
subtreeHLD :: HLD -> Vertex -> Bounds
subtreeHLD (HLD UArray Vertex Vertex
_ UArray Vertex Vertex
_ UArray Vertex Vertex
_ UArray Vertex Vertex
pos UArray Vertex Vertex
sz) Vertex
u = (UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
u, UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
u Vertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+ UArray Vertex Vertex
szUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
u Vertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
- Vertex
1)

-- | The lowest common ancestor of two vertices. O(log n).
lcaHLD :: HLD -> Vertex -> Vertex -> Vertex
lcaHLD :: HLD -> Vertex -> Vertex -> Vertex
lcaHLD (HLD UArray Vertex Vertex
par UArray Vertex Vertex
dep UArray Vertex Vertex
hed UArray Vertex Vertex
pos UArray Vertex Vertex
_) = Vertex -> Vertex -> Vertex
go where
    go :: Vertex -> Vertex -> Vertex
go Vertex
u Vertex
v
        | UArray Vertex Vertex
depUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
hu Vertex -> Vertex -> Bool
forall a. Ord a => a -> a -> Bool
> UArray Vertex Vertex
depUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
hv = Vertex -> Vertex -> Vertex
go Vertex
v Vertex
u
        | Vertex
hu Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
/= Vertex
hv        = Vertex -> Vertex -> Vertex
go Vertex
u (UArray Vertex Vertex
parUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
hv)
        | Bool
otherwise       = let !xu :: Vertex
xu = UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
u
                                !xv :: Vertex
xv = UArray Vertex Vertex
posUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
v
                            in if Vertex
xu Vertex -> Vertex -> Bool
forall a. Ord a => a -> a -> Bool
< Vertex
xv then Vertex
u else Vertex
v
      where
        !hu :: Vertex
hu = UArray Vertex Vertex
hedUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
u
        !hv :: Vertex
hv = UArray Vertex Vertex
hedUArray Vertex Vertex -> Vertex -> Vertex
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Vertex
v

--------------------------------------------------------------------------------
-- For tests

instance NFData HLD where
    rnf :: HLD -> ()
rnf = HLD -> ()
forall a. a -> ()
rwhnf