{-# LANGUAGE BangPatterns #-}
{-|
== Tree reroot fold

Folds of a tree with every node as root.
Known as "tree rerooting DP", among other names.

Recall that Data.Tree has foldTree :: (a -> [c] -> c) -> Tree a -> c. foldReroot is similar but
requires (b -> c -> b) and b to perform strict folds over [c].

g :: b -> c -> b must be commutative, in the sense that
(b `g` c1) `g` c2 = (b `g` c2) `g` c1

Sources:

* pajenegod, "The Ultimate Reroot Template"
  https://codeforces.com/blog/entry/124286

-}

{-
Implementation notes:
* Thanks to Haskell's laziness it is possible to write a more concise implementation with a single
  dfs, but this seems to be slower in practice. Here is a possible implementation:
  https://gist.github.com/meooow25/6460e45327355106cedbcf3bd5166cd6
* Reroot problems can often be solved in O(n) time, if there is a way to *take away* the
  contribution of a node from an accumulated value. Avoiding this requirement, it takes O(n log n)
  but is far simpler.
-}

module RerootFold
    ( foldReroot
    ) where

import Data.Tree
import Data.List

import Misc ( foldExclusive )

-- | Returns the same tree with each vertex accompanied by the fold of the tree if that vertex is made
-- the root of the tree. f is called O(n) times. g is called O(n log n) times.
foldReroot :: (a -> b -> c) -> (b -> c -> b) -> b -> Tree a -> Tree (a, c)
foldReroot :: (a -> b -> c) -> (b -> c -> b) -> b -> Tree a -> Tree (a, c)
foldReroot a -> b -> c
f b -> c -> b
g b
y0 Tree a
t = Tree (a, c)
res where
    res :: Tree (a, c)
res = (a, c) -> Forest (a, c) -> Tree (a, c)
forall a. a -> Forest a -> Tree a
Node (a
x,c
z) ((b -> Tree (a, [c]) -> Tree (a, c))
-> [b] -> [Tree (a, [c])] -> Forest (a, c)
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (c -> Tree (a, [c]) -> Tree (a, c)
go2 (c -> Tree (a, [c]) -> Tree (a, c))
-> (b -> c) -> b -> Tree (a, [c]) -> Tree (a, c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> c
f a
x) [b]
ys [Tree (a, [c])]
ts) where
        !(c
z, Node (a
x,[c]
zs) [Tree (a, [c])]
ts) = Tree a -> (c, Tree (a, [c]))
go1 Tree a
t
        ys :: [b]
ys = (b -> c -> b) -> b -> [c] -> [b]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
foldExclusive b -> c -> b
g b
y0 [c]
zs
    go1 :: Tree a -> (c, Tree (a, [c]))
go1 (Node a
x Forest a
ts) = (c
z, (a, [c]) -> [Tree (a, [c])] -> Tree (a, [c])
forall a. a -> Forest a -> Tree a
Node (a
x,[c]
zs) [Tree (a, [c])]
ts') where
        ([c]
zs, [Tree (a, [c])]
ts') = [(c, Tree (a, [c]))] -> ([c], [Tree (a, [c])])
forall a b. [(a, b)] -> ([a], [b])
unzip ((Tree a -> (c, Tree (a, [c]))) -> Forest a -> [(c, Tree (a, [c]))]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> (c, Tree (a, [c]))
go1 Forest a
ts)
        !z :: c
z = a -> b -> c
f a
x ((b -> c -> b) -> b -> [c] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' b -> c -> b
g b
y0 [c]
zs)
    go2 :: c -> Tree (a, [c]) -> Tree (a, c)
go2 !c
up (Node (a
x,[c]
zs) [Tree (a, [c])]
ts) = (a, c) -> Forest (a, c) -> Tree (a, c)
forall a. a -> Forest a -> Tree a
Node (a
x,c
z) ((b -> Tree (a, [c]) -> Tree (a, c))
-> [b] -> [Tree (a, [c])] -> Forest (a, c)
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (c -> Tree (a, [c]) -> Tree (a, c)
go2 (c -> Tree (a, [c]) -> Tree (a, c))
-> (b -> c) -> b -> Tree (a, [c]) -> Tree (a, c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> c
f a
x) [b]
ys [Tree (a, [c])]
ts) where
        b
y:[b]
ys = (b -> c -> b) -> b -> [c] -> [b]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
foldExclusive b -> c -> b
g b
y0 (c
upc -> [c] -> [c]
forall a. a -> [a] -> [a]
:[c]
zs)
        !z :: c
z = a -> b -> c
f a
x (b -> c -> b
g b
y c
up)