{-|
== Segment tree

A data structure supporting point updates and range queries on a sequence of monoids.
This implementation supports large ranges that do not fit in memory, sometimes called a dynamic or
sparse segment tree.

Sources:

* https://cp-algorithms.com/data_structures/segment_tree.html

The complexities below assume mappend takes O(1) time.
Let n = r - l + 1 in all instances below.

SegTree implementable Foldable. Folding over all the elements takes O(n).
-}

{-
Implementation notes:
* The segment tree is implemented as a complete binary tree. This uses more memory when the range
  size is away from the next power of 2, and the full tree is actually constructed, but the symmetry
  allows for an easy O(log n) construction of an empty tree.
* It is possible to build an empty segment tree in O(1) instead of O(log n). This can be done with
  special nodes that indicate that they can be grown, but this complicates the implementation.
  A lazy spine may also give an equivalent result, but structures with strict spines perform better
  in general. The O(log n) implementation seems like a good compromise to me.
* fromListST can be made to run in something like O(length xs + log n) instead of O(n) by falling
  back to the empty tree strategy when the list is exhausted. This will also reduce the memory usage
  mentioned in point 1, but I don't think it is worth the hassle.
* If the range is larger than 32-bits and this is to be run on a 32-bit system, you can replace Int
  with Int64.
-}

module SegTree
    ( SegTree
    , emptyST
    , fromListST
    , adjustST
    , foldRangeST
    ) where

import Control.DeepSeq
import Control.Monad.State
import Data.Bits

import Misc ( bitLength )

data SegTree a = SegTree !(Int, Int, Int) !(SegNode a) deriving Int -> SegTree a -> ShowS
[SegTree a] -> ShowS
SegTree a -> String
(Int -> SegTree a -> ShowS)
-> (SegTree a -> String)
-> ([SegTree a] -> ShowS)
-> Show (SegTree a)
forall a. Show a => Int -> SegTree a -> ShowS
forall a. Show a => [SegTree a] -> ShowS
forall a. Show a => SegTree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegTree a] -> ShowS
$cshowList :: forall a. Show a => [SegTree a] -> ShowS
show :: SegTree a -> String
$cshow :: forall a. Show a => SegTree a -> String
showsPrec :: Int -> SegTree a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SegTree a -> ShowS
Show
data SegNode a = SLeaf !a | SBin !a !(SegNode a) !(SegNode a) deriving Int -> SegNode a -> ShowS
[SegNode a] -> ShowS
SegNode a -> String
(Int -> SegNode a -> ShowS)
-> (SegNode a -> String)
-> ([SegNode a] -> ShowS)
-> Show (SegNode a)
forall a. Show a => Int -> SegNode a -> ShowS
forall a. Show a => [SegNode a] -> ShowS
forall a. Show a => SegNode a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegNode a] -> ShowS
$cshowList :: forall a. Show a => [SegNode a] -> ShowS
show :: SegNode a -> String
$cshow :: forall a. Show a => SegNode a -> String
showsPrec :: Int -> SegNode a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SegNode a -> ShowS
Show

buildST :: Monoid a => (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST :: (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST (Int
l, Int
r) Int -> SegNode a
f
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< -Int
1    = String -> SegTree a
forall a. HasCallStack => String -> a
error String
"invalid range"
    | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1   = (Int, Int, Int) -> SegNode a -> SegTree a
forall a. (Int, Int, Int) -> SegNode a -> SegTree a
SegTree (Int
l, Int
r, Int
0) (a -> SegNode a
forall a. a -> SegNode a
SLeaf a
forall a. Monoid a => a
mempty)
    | Bool
otherwise = (Int, Int, Int) -> SegNode a -> SegTree a
forall a. (Int, Int, Int) -> SegNode a -> SegTree a
SegTree (Int
l, Int
r, Int -> Int
forall a. Bits a => Int -> a
bit Int
ht) (Int -> SegNode a
f Int
ht)
  where
    n :: Int
n = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
    ht :: Int
ht = Int -> Int
forall b. FiniteBits b => b -> Int
bitLength Int
n

-- | Builds a segment tree on range (l, r) where each element is mempty. O(log n).
emptyST :: Monoid a => (Int, Int) -> SegTree a
emptyST :: (Int, Int) -> SegTree a
emptyST (Int, Int)
bnds = (Int, Int) -> (Int -> SegNode a) -> SegTree a
forall a. Monoid a => (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST (Int, Int)
bnds Int -> SegNode a
forall a a. (Eq a, Monoid a, Num a) => a -> SegNode a
go where
    go :: a -> SegNode a
go a
0 = a -> SegNode a
forall a. a -> SegNode a
SLeaf a
forall a. Monoid a => a
mempty
    go a
j = a -> SegNode a -> SegNode a -> SegNode a
forall a. a -> SegNode a -> SegNode a -> SegNode a
SBin a
forall a. Monoid a => a
mempty SegNode a
lr SegNode a
lr where lr :: SegNode a
lr = a -> SegNode a
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1)

makeSN :: Monoid a => SegNode a -> SegNode a -> SegNode a
makeSN :: SegNode a -> SegNode a -> SegNode a
makeSN SegNode a
lt SegNode a
rt = a -> SegNode a -> SegNode a -> SegNode a
forall a. a -> SegNode a -> SegNode a -> SegNode a
SBin (SegNode a -> a
forall p. SegNode p -> p
getx SegNode a
lt a -> a -> a
forall a. Semigroup a => a -> a -> a
<> SegNode a -> a
forall p. SegNode p -> p
getx SegNode a
rt) SegNode a
lt SegNode a
rt where
    getx :: SegNode p -> p
getx (SLeaf p
x)    = p
x
    getx (SBin p
x SegNode p
_ SegNode p
_) = p
x

-- | Builds a segment tree on (l, r) where the elements are taken from a list. If the list is shorter
-- than the range, the remaining elements are mempty. O(n).
fromListST :: Monoid a => (Int, Int) -> [a] -> SegTree a
fromListST :: (Int, Int) -> [a] -> SegTree a
fromListST (Int, Int)
bnds [a]
xs = (Int, Int) -> (Int -> SegNode a) -> SegTree a
forall a. Monoid a => (Int, Int) -> (Int -> SegNode a) -> SegTree a
buildST (Int, Int)
bnds ((State [a] (SegNode a) -> [a] -> SegNode a)
-> [a] -> State [a] (SegNode a) -> SegNode a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State [a] (SegNode a) -> [a] -> SegNode a
forall s a. State s a -> s -> a
evalState [a]
xs (State [a] (SegNode a) -> SegNode a)
-> (Int -> State [a] (SegNode a)) -> Int -> SegNode a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> State [a] (SegNode a)
forall a. (Eq a, Num a) => a -> State [a] (SegNode a)
go) where
    pop :: StateT [a] Identity a
pop = ([a] -> (a, [a])) -> StateT [a] Identity a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state [a] -> (a, [a])
forall a. Monoid a => [a] -> (a, [a])
go where
        go :: [a] -> (a, [a])
go []     = (a
forall a. Monoid a => a
mempty, [])
        go (a
x:[a]
xs) = (a
x,      [a]
xs)
    go :: a -> State [a] (SegNode a)
go a
0 = a -> SegNode a
forall a. a -> SegNode a
SLeaf (a -> SegNode a) -> StateT [a] Identity a -> State [a] (SegNode a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT [a] Identity a
pop
    go a
j = SegNode a -> SegNode a -> SegNode a
forall a. Monoid a => SegNode a -> SegNode a -> SegNode a
makeSN (SegNode a -> SegNode a -> SegNode a)
-> State [a] (SegNode a)
-> StateT [a] Identity (SegNode a -> SegNode a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> State [a] (SegNode a)
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1) StateT [a] Identity (SegNode a -> SegNode a)
-> State [a] (SegNode a) -> State [a] (SegNode a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> State [a] (SegNode a)
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1)

-- | Adjusts the element at index i. O(log n).
adjustST :: Monoid a => (a -> a) -> Int -> SegTree a -> SegTree a
adjustST :: (a -> a) -> Int -> SegTree a -> SegTree a
adjustST a -> a
f Int
i (SegTree lrp :: (Int, Int, Int)
lrp@(Int
l0,Int
r0,Int
p) SegNode a
root)
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l0 Bool -> Bool -> Bool
|| Int
r0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = String -> SegTree a
forall a. HasCallStack => String -> a
error String
"adjustST: outside range"
    | Bool
otherwise        = (Int, Int, Int) -> SegNode a -> SegTree a
forall a. (Int, Int, Int) -> SegNode a -> SegTree a
SegTree (Int, Int, Int)
lrp (SegNode a -> Int -> Int -> SegNode a
go SegNode a
root Int
l0 (Int
l0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
  where
    go :: SegNode a -> Int -> Int -> SegNode a
go SegNode a
n Int
l Int
r | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = SegNode a
n
    go (SLeaf a
x)      Int
_ Int
_ = a -> SegNode a
forall a. a -> SegNode a
SLeaf (a -> a
f a
x)
    go (SBin a
_ SegNode a
lt SegNode a
rt) Int
l Int
r = SegNode a -> SegNode a -> SegNode a
forall a. Monoid a => SegNode a -> SegNode a -> SegNode a
makeSN (SegNode a -> Int -> Int -> SegNode a
go SegNode a
lt Int
l Int
m) (SegNode a -> Int -> Int -> SegNode a
go SegNode a
rt (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
r) where m :: Int
m = (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

-- | Folds the elements in the range (ql, qr). Elements outside (l, r) are considered to be mempty.
-- O(log n).
foldRangeST :: Monoid a => Int -> Int -> SegTree a -> a
foldRangeST :: Int -> Int -> SegTree a -> a
foldRangeST Int
ql Int
qr (SegTree (Int
l0,Int
_,Int
p) SegNode a
root)
    | Int
ql Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
qr Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 = String -> a
forall a. HasCallStack => String -> a
error String
"foldRangeST: bad range"
    | Bool
otherwise   = SegNode a -> Int -> Int -> a -> a
forall a. Semigroup a => SegNode a -> Int -> Int -> a -> a
go SegNode a
root Int
l0 (Int
l0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) a
forall a. Monoid a => a
mempty
  where
    go :: SegNode a -> Int -> Int -> a -> a
go SegNode a
_ Int
l Int
r a
acc | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
ql Bool -> Bool -> Bool
|| Int
qr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l = a
acc
    go (SLeaf a
x) Int
_ Int
_ a
acc = a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
    go (SBin a
x SegNode a
lt SegNode a
rt) Int
l Int
r a
acc
        | Int
ql Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l Bool -> Bool -> Bool
&& Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
qr = a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
        | Bool
otherwise          = SegNode a -> Int -> Int -> a -> a
go SegNode a
rt (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
r (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$! SegNode a -> Int -> Int -> a -> a
go SegNode a
lt Int
l Int
m a
acc
        where m :: Int
m = (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

instance Foldable SegTree where
    foldr :: (a -> b -> b) -> b -> SegTree a -> b
foldr a -> b -> b
f b
z (SegTree (Int
l0,Int
r0,Int
p) SegNode a
root) = SegNode a -> Int -> Int -> b -> b
go SegNode a
root Int
l0 (Int
l0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) b
z where
        go :: SegNode a -> Int -> Int -> b -> b
go SegNode a
_ Int
l Int
_ | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r0 = b -> b
forall a. a -> a
id
        go (SLeaf a
x)      Int
_ Int
_ = a -> b -> b
f a
x
        go (SBin a
_ SegNode a
lt SegNode a
rt) Int
l Int
r = SegNode a -> Int -> Int -> b -> b
go SegNode a
lt Int
l Int
m (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegNode a -> Int -> Int -> b -> b
go SegNode a
rt (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
r where m :: Int
m = (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

--------------------------------------------------------------------------------
-- For tests

-- Allows specialization across modules
{-# INLINABLE fromListST #-}
{-# INLINABLE adjustST #-}
{-# INLINABLE foldRangeST #-}

instance NFData a => NFData (SegTree a) where
    rnf :: SegTree a -> ()
rnf (SegTree (Int, Int, Int)
lrp SegNode a
n) = (Int, Int, Int) -> ()
forall a. NFData a => a -> ()
rnf (Int, Int, Int)
lrp () -> () -> ()
`seq` SegNode a -> ()
forall a. NFData a => a -> ()
rnf SegNode a
n

instance NFData a => NFData (SegNode a) where
    rnf :: SegNode a -> ()
rnf (SLeaf a
x)      = a -> ()
forall a. NFData a => a -> ()
rnf a
x
    rnf (SBin a
x SegNode a
lt SegNode a
rt) = a -> ()
forall a. NFData a => a -> ()
rnf a
x () -> () -> ()
`seq` SegNode a -> ()
forall a. NFData a => a -> ()
rnf SegNode a
lt () -> () -> ()
`seq` SegNode a -> ()
forall a. NFData a => a -> ()
rnf SegNode a
rt