{-# LANGUAGE BangPatterns #-}
module RerootFold
( foldReroot
) where
import Data.Tree
import Data.List
import Misc ( foldExclusive )
foldReroot :: (a -> b -> c) -> (b -> c -> b) -> b -> Tree a -> Tree (a, c)
foldReroot :: (a -> b -> c) -> (b -> c -> b) -> b -> Tree a -> Tree (a, c)
foldReroot a -> b -> c
f b -> c -> b
g b
y0 Tree a
t = Tree (a, c)
res where
res :: Tree (a, c)
res = (a, c) -> Forest (a, c) -> Tree (a, c)
forall a. a -> Forest a -> Tree a
Node (a
x,c
z) ((b -> Tree (a, [c]) -> Tree (a, c))
-> [b] -> [Tree (a, [c])] -> Forest (a, c)
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (c -> Tree (a, [c]) -> Tree (a, c)
go2 (c -> Tree (a, [c]) -> Tree (a, c))
-> (b -> c) -> b -> Tree (a, [c]) -> Tree (a, c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> c
f a
x) [b]
ys [Tree (a, [c])]
ts) where
!(c
z, Node (a
x,[c]
zs) [Tree (a, [c])]
ts) = Tree a -> (c, Tree (a, [c]))
go1 Tree a
t
ys :: [b]
ys = (b -> c -> b) -> b -> [c] -> [b]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
foldExclusive b -> c -> b
g b
y0 [c]
zs
go1 :: Tree a -> (c, Tree (a, [c]))
go1 (Node a
x Forest a
ts) = (c
z, (a, [c]) -> [Tree (a, [c])] -> Tree (a, [c])
forall a. a -> Forest a -> Tree a
Node (a
x,[c]
zs) [Tree (a, [c])]
ts') where
([c]
zs, [Tree (a, [c])]
ts') = [(c, Tree (a, [c]))] -> ([c], [Tree (a, [c])])
forall a b. [(a, b)] -> ([a], [b])
unzip ((Tree a -> (c, Tree (a, [c]))) -> Forest a -> [(c, Tree (a, [c]))]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> (c, Tree (a, [c]))
go1 Forest a
ts)
!z :: c
z = a -> b -> c
f a
x ((b -> c -> b) -> b -> [c] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' b -> c -> b
g b
y0 [c]
zs)
go2 :: c -> Tree (a, [c]) -> Tree (a, c)
go2 !c
up (Node (a
x,[c]
zs) [Tree (a, [c])]
ts) = (a, c) -> Forest (a, c) -> Tree (a, c)
forall a. a -> Forest a -> Tree a
Node (a
x,c
z) ((b -> Tree (a, [c]) -> Tree (a, c))
-> [b] -> [Tree (a, [c])] -> Forest (a, c)
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (c -> Tree (a, [c]) -> Tree (a, c)
go2 (c -> Tree (a, [c]) -> Tree (a, c))
-> (b -> c) -> b -> Tree (a, [c]) -> Tree (a, c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> c
f a
x) [b]
ys [Tree (a, [c])]
ts) where
b
y:[b]
ys = (b -> c -> b) -> b -> [c] -> [b]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
foldExclusive b -> c -> b
g b
y0 (c
upc -> [c] -> [c]
forall a. a -> [a] -> [a]
:[c]
zs)
!z :: c
z = a -> b -> c
f a
x (b -> c -> b
g b
y c
up)