{-|
== Centroid decomposition

A recursive decomposition (divide-and-conquer) of a tree into multiple subtrees.
This allows performing certain operations involving paths on the original tree effectively, by
taking every path on the tree into account exactly once when it passes through the root of a
decomposed subtree. The roots of the subtrees are chosen to be centroids so that the recursive
decomposition has logarithmic depth.

Sources:

* https://petr-mitrichev.blogspot.com/2015/03/this-week-in-competitive-programming_22.html
* https://github.com/cheran-senthil/PyRival/blob/master/pyrival/graphs/centroid_decomposition.py

-}

{-
Implementation notes:
* The decomposition is done in the usual manner by rerooting the tree at its centroid, then
  recursively decomposing its subtrees.
* Yes, centroidDecompose and centroidDecomposeL are very similar but pulling out the common parts
  makes it messy, so they remain different functions.
-}

module CentroidDecomp
    ( centroidDecompose
    , centroidDecomposeL
    ) where

import Data.Tree

import LabelledGraph ( LTree(..), lTreeToTree )
import Misc ( farthest )

-- | Performs centroid decomposition on a tree of n nodes, returning the decomposition as a tree of
-- n trees. O(n log n).
centroidDecompose :: Tree a -> Tree (Tree a)
centroidDecompose :: Tree a -> Tree (Tree a)
centroidDecompose Tree a
t = Tree a -> Tree Int -> Tree (Tree a)
forall a a. (Ord a, Num a) => Tree a -> Tree a -> Tree (Tree a)
go Tree a
t ((a -> [Tree Int] -> Tree Int) -> Tree a -> Tree Int
forall a b. (a -> [b] -> b) -> Tree a -> b
foldTree a -> [Tree Int] -> Tree Int
forall p. p -> [Tree Int] -> Tree Int
szf Tree a
t) where
    szf :: p -> [Tree Int] -> Tree Int
szf p
_ [Tree Int]
szts = let sz :: Int
sz = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree Int -> Int) -> [Tree Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Tree Int -> Int
forall a. Tree a -> a
rootLabel [Tree Int]
szts) :: Int in Int
sz Int -> Tree Int -> Tree Int
`seq` Int -> [Tree Int] -> Tree Int
forall a. a -> Forest a -> Tree a
Node Int
sz [Tree Int]
szts
    go :: Tree a -> Tree a -> Tree (Tree a)
go (Node a
r Forest a
rts) (Node a
sz Forest a
rszts) = case ((a, Forest a, Forest a) -> Maybe (a, Forest a, Forest a))
-> (a, Forest a, Forest a) -> (a, Forest a, Forest a)
forall a. (a -> Maybe a) -> a -> a
farthest (a, Forest a, Forest a) -> Maybe (a, Forest a, Forest a)
forall a. (a, [Tree a], Forest a) -> Maybe (a, [Tree a], Forest a)
step (a
r, Forest a
rts, Forest a
rszts) of
        (a
u, Forest a
uts, Forest a
uszts) -> Tree a -> Forest (Tree a) -> Tree (Tree a)
forall a. a -> Forest a -> Tree a
Node (a -> Forest a -> Tree a
forall a. a -> Forest a -> Tree a
Node a
u Forest a
uts) ((Tree a -> Tree a -> Tree (Tree a))
-> Forest a -> Forest a -> Forest (Tree a)
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tree a -> Tree a -> Tree (Tree a)
go Forest a
uts Forest a
uszts)
      where
        step :: (a, [Tree a], Forest a) -> Maybe (a, [Tree a], Forest a)
step (a
u, [Tree a]
uts, Forest a
uszts) = (Tree a, Tree a, [Tree a], Forest a) -> (a, [Tree a], Forest a)
mkv ((Tree a, Tree a, [Tree a], Forest a) -> (a, [Tree a], Forest a))
-> Maybe (Tree a, Tree a, [Tree a], Forest a)
-> Maybe (a, [Tree a], Forest a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tree a -> Bool)
-> [Tree a]
-> Forest a
-> Maybe (Tree a, Tree a, [Tree a], Forest a)
forall b a. (b -> Bool) -> [a] -> [b] -> Maybe (a, b, [a], [b])
removeOne ((a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>a
sz) (a -> Bool) -> (Tree a -> a) -> Tree a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a
forall a. Num a => a -> a -> a
*a
2) (a -> a) -> (Tree a -> a) -> Tree a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> a
forall a. Tree a -> a
rootLabel) [Tree a]
uts Forest a
uszts where
            mkv :: (Tree a, Tree a, [Tree a], Forest a) -> (a, [Tree a], Forest a)
mkv (Node a
v [Tree a]
vts, Node a
vsz Forest a
vszts, [Tree a]
uts', Forest a
uszts') = (a
v, [Tree a]
vts', Forest a
vszts') where
                vts' :: [Tree a]
vts'   = a -> [Tree a] -> Tree a
forall a. a -> Forest a -> Tree a
Node a
u [Tree a]
uts' Tree a -> [Tree a] -> [Tree a]
forall a. a -> [a] -> [a]
: [Tree a]
vts
                vszts' :: Forest a
vszts' = let usz' :: a
usz' = a
sz a -> a -> a
forall a. Num a => a -> a -> a
- a
vsz in a
usz' a -> Forest a -> Forest a
`seq` a -> Forest a -> Tree a
forall a. a -> Forest a -> Tree a
Node a
usz' Forest a
uszts' Tree a -> Forest a -> Forest a
forall a. a -> [a] -> [a]
: Forest a
vszts

-- | Same as centroidDecompose, for edge-labelled graphs. O(n log n).
centroidDecomposeL :: LTree b a -> Tree (LTree b a)
centroidDecomposeL :: LTree b a -> Tree (LTree b a)
centroidDecomposeL LTree b a
t = LTree b a -> Tree Int -> Tree (LTree b a)
forall a a a.
(Ord a, Num a) =>
LTree a a -> Tree a -> Tree (LTree a a)
go LTree b a
t ((a -> [Tree Int] -> Tree Int) -> Tree a -> Tree Int
forall a b. (a -> [b] -> b) -> Tree a -> b
foldTree a -> [Tree Int] -> Tree Int
forall p. p -> [Tree Int] -> Tree Int
szf (Tree a -> Tree Int) -> Tree a -> Tree Int
forall a b. (a -> b) -> a -> b
$ LTree b a -> Tree a
forall b a. LTree b a -> Tree a
lTreeToTree LTree b a
t) where
    szf :: p -> [Tree Int] -> Tree Int
szf p
_ [Tree Int]
szts = let sz :: Int
sz = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree Int -> Int) -> [Tree Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Tree Int -> Int
forall a. Tree a -> a
rootLabel [Tree Int]
szts) :: Int in Int
sz Int -> Tree Int -> Tree Int
`seq` Int -> [Tree Int] -> Tree Int
forall a. a -> Forest a -> Tree a
Node Int
sz [Tree Int]
szts
    go :: LTree a a -> Tree a -> Tree (LTree a a)
go (LNode a
r [(a, LTree a a)]
rts) (Node a
sz Forest a
rszts) = case ((a, [(a, LTree a a)], Forest a)
 -> Maybe (a, [(a, LTree a a)], Forest a))
-> (a, [(a, LTree a a)], Forest a)
-> (a, [(a, LTree a a)], Forest a)
forall a. (a -> Maybe a) -> a -> a
farthest (a, [(a, LTree a a)], Forest a)
-> Maybe (a, [(a, LTree a a)], Forest a)
forall a a.
(a, [(a, LTree a a)], Forest a)
-> Maybe (a, [(a, LTree a a)], Forest a)
step (a
r, [(a, LTree a a)]
rts, Forest a
rszts) of
        (a
u, [(a, LTree a a)]
uts, Forest a
uszts) -> LTree a a -> Forest (LTree a a) -> Tree (LTree a a)
forall a. a -> Forest a -> Tree a
Node (a -> [(a, LTree a a)] -> LTree a a
forall b a. a -> [(b, LTree b a)] -> LTree b a
LNode a
u [(a, LTree a a)]
uts) (((a, LTree a a) -> Tree a -> Tree (LTree a a))
-> [(a, LTree a a)] -> Forest a -> Forest (LTree a a)
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (LTree a a -> Tree a -> Tree (LTree a a)
go (LTree a a -> Tree a -> Tree (LTree a a))
-> ((a, LTree a a) -> LTree a a)
-> (a, LTree a a)
-> Tree a
-> Tree (LTree a a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, LTree a a) -> LTree a a
forall a b. (a, b) -> b
snd) [(a, LTree a a)]
uts Forest a
uszts)
      where
        step :: (a, [(a, LTree a a)], Forest a)
-> Maybe (a, [(a, LTree a a)], Forest a)
step (a
u, [(a, LTree a a)]
uts, Forest a
uszts) = ((a, LTree a a), Tree a, [(a, LTree a a)], Forest a)
-> (a, [(a, LTree a a)], Forest a)
forall a.
((a, LTree a a), Tree a, [(a, LTree a a)], Forest a)
-> (a, [(a, LTree a a)], Forest a)
mkv (((a, LTree a a), Tree a, [(a, LTree a a)], Forest a)
 -> (a, [(a, LTree a a)], Forest a))
-> Maybe ((a, LTree a a), Tree a, [(a, LTree a a)], Forest a)
-> Maybe (a, [(a, LTree a a)], Forest a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tree a -> Bool)
-> [(a, LTree a a)]
-> Forest a
-> Maybe ((a, LTree a a), Tree a, [(a, LTree a a)], Forest a)
forall b a. (b -> Bool) -> [a] -> [b] -> Maybe (a, b, [a], [b])
removeOne ((a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>a
sz) (a -> Bool) -> (Tree a -> a) -> Tree a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a
forall a. Num a => a -> a -> a
*a
2) (a -> a) -> (Tree a -> a) -> Tree a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> a
forall a. Tree a -> a
rootLabel) [(a, LTree a a)]
uts Forest a
uszts where
            mkv :: ((a, LTree a a), Tree a, [(a, LTree a a)], Forest a)
-> (a, [(a, LTree a a)], Forest a)
mkv ((a
l, LNode a
v [(a, LTree a a)]
vts), Node a
vsz Forest a
vszts, [(a, LTree a a)]
uts', Forest a
uszts') = (a
v, [(a, LTree a a)]
vts', Forest a
vszts') where
                vts' :: [(a, LTree a a)]
vts'   = (a
l, a -> [(a, LTree a a)] -> LTree a a
forall b a. a -> [(b, LTree b a)] -> LTree b a
LNode a
u [(a, LTree a a)]
uts') (a, LTree a a) -> [(a, LTree a a)] -> [(a, LTree a a)]
forall a. a -> [a] -> [a]
: [(a, LTree a a)]
vts
                vszts' :: Forest a
vszts' = let usz' :: a
usz' = a
sz a -> a -> a
forall a. Num a => a -> a -> a
- a
vsz in a
usz' a -> Forest a -> Forest a
`seq` a -> Forest a -> Tree a
forall a. a -> Forest a -> Tree a
Node a
usz' Forest a
uszts' Tree a -> Forest a -> Forest a
forall a. a -> [a] -> [a]
: Forest a
vszts

removeOne :: (b -> Bool) -> [a] -> [b] -> Maybe (a, b, [a], [b])
removeOne :: (b -> Bool) -> [a] -> [b] -> Maybe (a, b, [a], [b])
removeOne b -> Bool
p = [a] -> [b] -> Maybe (a, b, [a], [b])
forall a. [a] -> [b] -> Maybe (a, b, [a], [b])
go where
    go :: [a] -> [b] -> Maybe (a, b, [a], [b])
go [] [] = Maybe (a, b, [a], [b])
forall a. Maybe a
Nothing
    go (a
a:[a]
as) (b
b:[b]
bs)
        | b -> Bool
p b
b       = (a, b, [a], [b]) -> Maybe (a, b, [a], [b])
forall a. a -> Maybe a
Just (a
a, b
b, [a]
as, [b]
bs)
        | Bool
otherwise = (\(a
a', b
b', [a]
as', [b]
bs') -> (a
a', b
b', a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
as', b
bb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
bs')) ((a, b, [a], [b]) -> (a, b, [a], [b]))
-> Maybe (a, b, [a], [b]) -> Maybe (a, b, [a], [b])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [b] -> Maybe (a, b, [a], [b])
go [a]
as [b]
bs
    go [a]
_ [b]
_ = [Char] -> Maybe (a, b, [a], [b])
forall a. HasCallStack => [Char] -> a
error [Char]
"bad input"