module SegTree
( SegTree
, emptyST
, fromListST
, adjustST
, foldRangeST
) where
import Control.DeepSeq
import Control.Monad.State
import Data.Bits
import Misc ( bitLength )
data SegTree a = SegTree !(Int, Int, Int) !(SegNode a) deriving Int -> SegTree a -> ShowS
[SegTree a] -> ShowS
SegTree a -> String
(Int -> SegTree a -> ShowS)
-> (SegTree a -> String)
-> ([SegTree a] -> ShowS)
-> Show (SegTree a)
forall a. Show a => Int -> SegTree a -> ShowS
forall a. Show a => [SegTree a] -> ShowS
forall a. Show a => SegTree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegTree a] -> ShowS
$cshowList :: forall a. Show a => [SegTree a] -> ShowS
show :: SegTree a -> String
$cshow :: forall a. Show a => SegTree a -> String
showsPrec :: Int -> SegTree a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SegTree a -> ShowS
Show
data SegNode a = SLeaf !a | SBin !a !(SegNode a) !(SegNode a) deriving Int -> SegNode a -> ShowS
[SegNode a] -> ShowS
SegNode a -> String
(Int -> SegNode a -> ShowS)
-> (SegNode a -> String)
-> ([SegNode a] -> ShowS)
-> Show (SegNode a)
forall a. Show a => Int -> SegNode a -> ShowS
forall a. Show a => [SegNode a] -> ShowS
forall a. Show a => SegNode a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegNode a] -> ShowS
$cshowList :: forall a. Show a => [SegNode a] -> ShowS
show :: SegNode a -> String
$cshow :: forall a. Show a => SegNode a -> String
showsPrec :: Int -> SegNode a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SegNode a -> ShowS
Show
buildST :: Monoid a => (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST :: (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST (Int
l, Int
r) Int -> SegNode a
f
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< -Int
1 = String -> SegTree a
forall a. HasCallStack => String -> a
error String
"invalid range"
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1 = (Int, Int, Int) -> SegNode a -> SegTree a
forall a. (Int, Int, Int) -> SegNode a -> SegTree a
SegTree (Int
l, Int
r, Int
0) (a -> SegNode a
forall a. a -> SegNode a
SLeaf a
forall a. Monoid a => a
mempty)
| Bool
otherwise = (Int, Int, Int) -> SegNode a -> SegTree a
forall a. (Int, Int, Int) -> SegNode a -> SegTree a
SegTree (Int
l, Int
r, Int -> Int
forall a. Bits a => Int -> a
bit Int
ht) (Int -> SegNode a
f Int
ht)
where
n :: Int
n = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
ht :: Int
ht = Int -> Int
forall b. FiniteBits b => b -> Int
bitLength Int
n
emptyST :: Monoid a => (Int, Int) -> SegTree a
emptyST :: (Int, Int) -> SegTree a
emptyST (Int, Int)
bnds = (Int, Int) -> (Int -> SegNode a) -> SegTree a
forall a. Monoid a => (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST (Int, Int)
bnds Int -> SegNode a
forall a a. (Eq a, Monoid a, Num a) => a -> SegNode a
go where
go :: a -> SegNode a
go a
0 = a -> SegNode a
forall a. a -> SegNode a
SLeaf a
forall a. Monoid a => a
mempty
go a
j = a -> SegNode a -> SegNode a -> SegNode a
forall a. a -> SegNode a -> SegNode a -> SegNode a
SBin a
forall a. Monoid a => a
mempty SegNode a
lr SegNode a
lr where lr :: SegNode a
lr = a -> SegNode a
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
makeSN :: Monoid a => SegNode a -> SegNode a -> SegNode a
makeSN :: SegNode a -> SegNode a -> SegNode a
makeSN SegNode a
lt SegNode a
rt = a -> SegNode a -> SegNode a -> SegNode a
forall a. a -> SegNode a -> SegNode a -> SegNode a
SBin (SegNode a -> a
forall p. SegNode p -> p
getx SegNode a
lt a -> a -> a
forall a. Semigroup a => a -> a -> a
<> SegNode a -> a
forall p. SegNode p -> p
getx SegNode a
rt) SegNode a
lt SegNode a
rt where
getx :: SegNode p -> p
getx (SLeaf p
x) = p
x
getx (SBin p
x SegNode p
_ SegNode p
_) = p
x
fromListST :: Monoid a => (Int, Int) -> [a] -> SegTree a
fromListST :: (Int, Int) -> [a] -> SegTree a
fromListST (Int, Int)
bnds [a]
xs = (Int, Int) -> (Int -> SegNode a) -> SegTree a
forall a. Monoid a => (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST (Int, Int)
bnds ((State [a] (SegNode a) -> [a] -> SegNode a)
-> [a] -> State [a] (SegNode a) -> SegNode a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State [a] (SegNode a) -> [a] -> SegNode a
forall s a. State s a -> s -> a
evalState [a]
xs (State [a] (SegNode a) -> SegNode a)
-> (Int -> State [a] (SegNode a)) -> Int -> SegNode a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> State [a] (SegNode a)
forall a. (Eq a, Num a) => a -> State [a] (SegNode a)
go) where
pop :: StateT [a] Identity a
pop = ([a] -> (a, [a])) -> StateT [a] Identity a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state [a] -> (a, [a])
forall a. Monoid a => [a] -> (a, [a])
go where
go :: [a] -> (a, [a])
go [] = (a
forall a. Monoid a => a
mempty, [])
go (a
x:[a]
xs) = (a
x, [a]
xs)
go :: a -> State [a] (SegNode a)
go a
0 = a -> SegNode a
forall a. a -> SegNode a
SLeaf (a -> SegNode a) -> StateT [a] Identity a -> State [a] (SegNode a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT [a] Identity a
pop
go a
j = SegNode a -> SegNode a -> SegNode a
forall a. Monoid a => SegNode a -> SegNode a -> SegNode a
makeSN (SegNode a -> SegNode a -> SegNode a)
-> State [a] (SegNode a)
-> StateT [a] Identity (SegNode a -> SegNode a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> State [a] (SegNode a)
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1) StateT [a] Identity (SegNode a -> SegNode a)
-> State [a] (SegNode a) -> State [a] (SegNode a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> State [a] (SegNode a)
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
adjustST :: Monoid a => (a -> a) -> Int -> SegTree a -> SegTree a
adjustST :: (a -> a) -> Int -> SegTree a -> SegTree a
adjustST a -> a
f Int
i (SegTree lrp :: (Int, Int, Int)
lrp@(Int
l0,Int
r0,Int
p) SegNode a
root)
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l0 Bool -> Bool -> Bool
|| Int
r0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = String -> SegTree a
forall a. HasCallStack => String -> a
error String
"adjustST: outside range"
| Bool
otherwise = (Int, Int, Int) -> SegNode a -> SegTree a
forall a. (Int, Int, Int) -> SegNode a -> SegTree a
SegTree (Int, Int, Int)
lrp (SegNode a -> Int -> Int -> SegNode a
go SegNode a
root Int
l0 (Int
l0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
where
go :: SegNode a -> Int -> Int -> SegNode a
go SegNode a
n Int
l Int
r | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = SegNode a
n
go (SLeaf a
x) Int
_ Int
_ = a -> SegNode a
forall a. a -> SegNode a
SLeaf (a -> a
f a
x)
go (SBin a
_ SegNode a
lt SegNode a
rt) Int
l Int
r = SegNode a -> SegNode a -> SegNode a
forall a. Monoid a => SegNode a -> SegNode a -> SegNode a
makeSN (SegNode a -> Int -> Int -> SegNode a
go SegNode a
lt Int
l Int
m) (SegNode a -> Int -> Int -> SegNode a
go SegNode a
rt (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
r) where m :: Int
m = (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
foldRangeST :: Monoid a => Int -> Int -> SegTree a -> a
foldRangeST :: Int -> Int -> SegTree a -> a
foldRangeST Int
ql Int
qr (SegTree (Int
l0,Int
_,Int
p) SegNode a
root)
| Int
ql Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
qr Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 = String -> a
forall a. HasCallStack => String -> a
error String
"foldRangeST: bad range"
| Bool
otherwise = SegNode a -> Int -> Int -> a -> a
forall a. Semigroup a => SegNode a -> Int -> Int -> a -> a
go SegNode a
root Int
l0 (Int
l0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) a
forall a. Monoid a => a
mempty
where
go :: SegNode a -> Int -> Int -> a -> a
go SegNode a
_ Int
l Int
r a
acc | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
ql Bool -> Bool -> Bool
|| Int
qr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l = a
acc
go (SLeaf a
x) Int
_ Int
_ a
acc = a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
go (SBin a
x SegNode a
lt SegNode a
rt) Int
l Int
r a
acc
| Int
ql Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l Bool -> Bool -> Bool
&& Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
qr = a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
| Bool
otherwise = SegNode a -> Int -> Int -> a -> a
go SegNode a
rt (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
r (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$! SegNode a -> Int -> Int -> a -> a
go SegNode a
lt Int
l Int
m a
acc
where m :: Int
m = (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
instance Foldable SegTree where
foldr :: (a -> b -> b) -> b -> SegTree a -> b
foldr a -> b -> b
f b
z (SegTree (Int
l0,Int
r0,Int
p) SegNode a
root) = SegNode a -> Int -> Int -> b -> b
go SegNode a
root Int
l0 (Int
l0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) b
z where
go :: SegNode a -> Int -> Int -> b -> b
go SegNode a
_ Int
l Int
_ | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r0 = b -> b
forall a. a -> a
id
go (SLeaf a
x) Int
_ Int
_ = a -> b -> b
f a
x
go (SBin a
_ SegNode a
lt SegNode a
rt) Int
l Int
r = SegNode a -> Int -> Int -> b -> b
go SegNode a
lt Int
l Int
m (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegNode a -> Int -> Int -> b -> b
go SegNode a
rt (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
r where m :: Int
m = (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
{-# INLINABLE fromListST #-}
{-# INLINABLE adjustST #-}
{-# INLINABLE foldRangeST #-}
instance NFData a => NFData (SegTree a) where
rnf :: SegTree a -> ()
rnf (SegTree (Int, Int, Int)
lrp SegNode a
n) = (Int, Int, Int) -> ()
forall a. NFData a => a -> ()
rnf (Int, Int, Int)
lrp () -> () -> ()
`seq` SegNode a -> ()
forall a. NFData a => a -> ()
rnf SegNode a
n
instance NFData a => NFData (SegNode a) where
rnf :: SegNode a -> ()
rnf (SLeaf a
x) = a -> ()
forall a. NFData a => a -> ()
rnf a
x
rnf (SBin a
x SegNode a
lt SegNode a
rt) = a -> ()
forall a. NFData a => a -> ()
rnf a
x () -> () -> ()
`seq` SegNode a -> ()
forall a. NFData a => a -> ()
rnf SegNode a
lt () -> () -> ()
`seq` SegNode a -> ()
forall a. NFData a => a -> ()
rnf SegNode a
rt