{-|
== Fenwick tree, or binary indexed tree

A data structure supporting point updates and range queries, or the opposite.
Large ranges, beyond typical memory limits, are supported.
See FenwickMut.hs for a mutable (and more commonly seen) version.

Sources:

* https://en.wikipedia.org/wiki/Fenwick_tree
* Peter M. Fenwick, "A New Data Structure for Cumulative Frequency Tables", 1994
  https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.14.8917
* https://hackage.haskell.org/package/binary-indexed-tree

     4
    / \
   /   \
  2     6
 / \   / \
1   3 5   7

Let n = r - l + 1 where (l, r) is the range of the Fenwick tree.
The complexities assume (<>) takes O(1) time.

-}

{-
Implementation notes:
* The implementation here is literally a tree, unlike the usual implementation with an array.
  The responsibilies of the indices remain the same.
* It is a complete binary tree where each node stores the accumulation of values in its left
  subtree and itself.
-}

module Fenwick
    ( FTree
    , emptyF
    , fromListF
    , mappendF
    , foldPrefixF
    , foldRangeF
    , mappendRangeF
    , binSearchF
    , toScanl1F
    ) where

import Control.Applicative
import Control.DeepSeq
import Control.Monad.State
import Data.Bits

import Misc ( Commutative, Group(..), bitLength )

data FTree a = FTree !(Int, Int, Int) !(FNode a) deriving Int -> FTree a -> ShowS
[FTree a] -> ShowS
FTree a -> String
(Int -> FTree a -> ShowS)
-> (FTree a -> String) -> ([FTree a] -> ShowS) -> Show (FTree a)
forall a. Show a => Int -> FTree a -> ShowS
forall a. Show a => [FTree a] -> ShowS
forall a. Show a => FTree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FTree a] -> ShowS
$cshowList :: forall a. Show a => [FTree a] -> ShowS
show :: FTree a -> String
$cshow :: forall a. Show a => FTree a -> String
showsPrec :: Int -> FTree a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> FTree a -> ShowS
Show
data FNode a = FTip | FBin !a !(FNode a) !(FNode a) deriving Int -> FNode a -> ShowS
[FNode a] -> ShowS
FNode a -> String
(Int -> FNode a -> ShowS)
-> (FNode a -> String) -> ([FNode a] -> ShowS) -> Show (FNode a)
forall a. Show a => Int -> FNode a -> ShowS
forall a. Show a => [FNode a] -> ShowS
forall a. Show a => FNode a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FNode a] -> ShowS
$cshowList :: forall a. Show a => [FNode a] -> ShowS
show :: FNode a -> String
$cshow :: forall a. Show a => FNode a -> String
showsPrec :: Int -> FNode a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> FNode a -> ShowS
Show

buildF :: Monoid a => (Int, Int) -> (Int -> FNode a) -> FTree a
buildF :: (Int, Int) -> (Int -> FNode a) -> FTree a
buildF (Int
l, Int
r) Int -> FNode a
_ | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 = String -> FTree a
forall a. HasCallStack => String -> a
error String
"buildF: invalid range"
buildF (Int
l, Int
r) Int -> FNode a
f = (Int, Int, Int) -> FNode a -> FTree a
forall a. (Int, Int, Int) -> FNode a -> FTree a
FTree (Int
l, Int
r, Int
ht) (Int -> FNode a
f Int
ht)
  where
    n :: Int
n = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    ht :: Int
ht = Int -> Int
forall b. FiniteBits b => b -> Int
bitLength Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

-- | Builds a Fenwick tree on range (l, r) where each element is mempty. O(log n).
emptyF :: Monoid a => (Int, Int) -> FTree a
emptyF :: (Int, Int) -> FTree a
emptyF (Int, Int)
bnds = (Int, Int) -> (Int -> FNode a) -> FTree a
forall a. Monoid a => (Int, Int) -> (Int -> FNode a) -> FTree a
buildF (Int, Int)
bnds Int -> FNode a
forall a a. (Eq a, Monoid a, Num a) => a -> FNode a
go where
    go :: a -> FNode a
go (-1) = FNode a
forall a. FNode a
FTip
    go a
j    = a -> FNode a -> FNode a -> FNode a
forall a. a -> FNode a -> FNode a -> FNode a
FBin a
forall a. Monoid a => a
mempty FNode a
lr FNode a
lr where lr :: FNode a
lr = a -> FNode a
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1)

-- | Builds a Fenwick tree on (l, r) where the elements are taken from a list. If the list is shorter
-- than the range, the remaining elements are mempty. O(n).
fromListF :: Monoid a => (Int, Int) -> [a] -> FTree a
fromListF :: (Int, Int) -> [a] -> FTree a
fromListF (Int, Int)
bnds [a]
xs = (Int, Int) -> (Int -> FNode a) -> FTree a
forall a. Monoid a => (Int, Int) -> (Int -> FNode a) -> FTree a
buildF (Int, Int)
bnds ((FNode a, a) -> FNode a
forall a b. (a, b) -> a
fst ((FNode a, a) -> FNode a)
-> (Int -> (FNode a, a)) -> Int -> FNode a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State [a] (FNode a, a) -> [a] -> (FNode a, a))
-> [a] -> State [a] (FNode a, a) -> (FNode a, a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State [a] (FNode a, a) -> [a] -> (FNode a, a)
forall s a. State s a -> s -> a
evalState [a]
xs (State [a] (FNode a, a) -> (FNode a, a))
-> (Int -> State [a] (FNode a, a)) -> Int -> (FNode a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> State [a] (FNode a, a)
forall a. (Eq a, Num a) => a -> State [a] (FNode a, a)
go) where
    pop :: StateT [a] Identity a
pop = ([a] -> (a, [a])) -> StateT [a] Identity a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state [a] -> (a, [a])
forall a. Monoid a => [a] -> (a, [a])
go where
        go :: [a] -> (a, [a])
go []     = (a
forall a. Monoid a => a
mempty, [])
        go (a
x:[a]
xs) = (a
x,      [a]
xs)
    go :: a -> State [a] (FNode a, a)
go (-1) = (FNode a, a) -> State [a] (FNode a, a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FNode a
forall a. FNode a
FTip, a
forall a. Monoid a => a
mempty)
    go a
j = do
        (FNode a
lt, a
lx) <- a -> State [a] (FNode a, a)
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
        a
x <- StateT [a] Identity a
pop
        (FNode a
rt, a
rx) <- a -> State [a] (FNode a, a)
go (a
j a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
        let x' :: a
x'  = a
lx a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
            x'' :: a
x'' = a
x' a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
rx
            n :: FNode a
n   = a -> FNode a -> FNode a -> FNode a
forall a. a -> FNode a -> FNode a -> FNode a
FBin a
x' FNode a
lt FNode a
rt
        a
x'' a -> State [a] (FNode a, a) -> State [a] (FNode a, a)
`seq` FNode a
n FNode a -> State [a] (FNode a, a) -> State [a] (FNode a, a)
`seq` (FNode a, a) -> State [a] (FNode a, a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FNode a
n, a
x'')

-- | mappends to the element at an index. O(log n).
mappendF :: Commutative a => Int -> a -> FTree a -> FTree a
mappendF :: Int -> a -> FTree a -> FTree a
mappendF Int
i a
y (FTree lrh :: (Int, Int, Int)
lrh@(Int
l, Int
r, Int
ht) FNode a
rt)
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = String -> FTree a
forall a. HasCallStack => String -> a
error String
"mappendF: outside range"
    | Bool
otherwise      = (Int, Int, Int) -> FNode a -> FTree a
forall a. (Int, Int, Int) -> FNode a -> FTree a
FTree (Int, Int, Int)
lrh (FNode a -> Int -> FNode a
go FNode a
rt Int
ht)
  where
    i' :: Int
i' = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    h' :: Int
h' = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros Int
i'
    go :: FNode a -> Int -> FNode a
go (FBin a
x FNode a
l FNode a
r) Int
h
        | Int
h Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h'      = a -> FNode a -> FNode a -> FNode a
forall a. a -> FNode a -> FNode a -> FNode a
FBin (a
x a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
y) FNode a
l FNode a
r
        | Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
i' Int
h = a -> FNode a -> FNode a -> FNode a
forall a. a -> FNode a -> FNode a -> FNode a
FBin a
x FNode a
l (FNode a -> Int -> FNode a
go FNode a
r (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
        | Bool
otherwise    = a -> FNode a -> FNode a -> FNode a
forall a. a -> FNode a -> FNode a -> FNode a
FBin (a
x a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
y) (FNode a -> Int -> FNode a
go FNode a
l (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) FNode a
r
    go FNode a
FTip Int
_ = String -> FNode a
forall a. HasCallStack => String -> a
error String
"unexpected"

-- | The result of folding the prefix upto the given index. Indices outside the tree range are allowed,
-- it is assumed elements there are mempty. O(log n).
foldPrefixF :: Monoid a => Int -> FTree a -> a
foldPrefixF :: Int -> FTree a -> a
foldPrefixF Int
i (FTree (Int
l, Int
r, Int
ht) FNode a
rt) = if Int
i' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then a
forall a. Monoid a => a
mempty else FNode a -> Int -> a -> a
forall a. Semigroup a => FNode a -> Int -> a -> a
go FNode a
rt Int
ht a
forall a. Monoid a => a
mempty where
    i' :: Int
i' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
r Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    h' :: Int
h' = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros Int
i'
    go :: FNode a -> Int -> a -> a
go (FBin a
x FNode a
l FNode a
r) Int
h a
acc
        | Int
h Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h'      = a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
        | Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
i' Int
h = FNode a -> Int -> a -> a
go FNode a
r (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$! a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
        | Bool
otherwise    = FNode a -> Int -> a -> a
go FNode a
l (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a
acc
    go FNode a
FTip Int
_ a
_ = String -> a
forall a. HasCallStack => String -> a
error String
"unexpected"

-- | Folds the elements in the range (l, r). O(log n).
foldRangeF :: (Commutative a, Group a) => Int -> Int -> FTree a -> a
foldRangeF :: Int -> Int -> FTree a -> a
foldRangeF Int
l Int
r FTree a
ft = Int -> FTree a -> a
forall a. Monoid a => Int -> FTree a -> a
foldPrefixF Int
r FTree a
ft a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a -> a
forall a. Group a => a -> a
invert (Int -> FTree a -> a
forall a. Monoid a => Int -> FTree a -> a
foldPrefixF (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) FTree a
ft)

-- | mappends to all elements in the range (l, r). Can be used with foldPrefixF for point queries.
-- O(log n).
mappendRangeF :: (Commutative a, Group a) => Int -> Int -> a -> FTree a -> FTree a
mappendRangeF :: Int -> Int -> a -> FTree a -> FTree a
mappendRangeF Int
l Int
r a
y ft :: FTree a
ft@(FTree (Int
_, Int
r', Int
_) FNode a
_) = FTree a
ft'' where
    ft' :: FTree a
ft' = Int -> a -> FTree a -> FTree a
forall a. Commutative a => Int -> a -> FTree a -> FTree a
mappendF Int
l a
y FTree a
ft
    ft'' :: FTree a
ft'' = if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r' then FTree a
ft' else Int -> a -> FTree a -> FTree a
forall a. Commutative a => Int -> a -> FTree a -> FTree a
mappendF (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (a -> a
forall a. Group a => a -> a
invert a
y) FTree a
ft'

-- | Binary searches for the shortest prefix such that the fold of all values in it satisfies the given
-- monotonic predicate. Returns the end index and the fold of the found prefix. O(log n).
binSearchF :: Monoid a => (a -> Bool) -> FTree a -> Maybe (Int, a)
binSearchF :: (a -> Bool) -> FTree a -> Maybe (Int, a)
binSearchF a -> Bool
f (FTree (Int
l, Int
_, Int
ht) FNode a
rt) = FNode a -> Int -> Int -> a -> Maybe (Int, a)
forall t.
(Num t, Bits t) =>
FNode a -> Int -> t -> a -> Maybe (t, a)
go FNode a
rt Int
ht (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a
forall a. Monoid a => a
mempty where
    go :: FNode a -> Int -> t -> a -> Maybe (t, a)
go FNode a
FTip Int
_ t
_ a
_ = Maybe (t, a)
forall a. Maybe a
Nothing
    go (FBin a
x FNode a
l FNode a
r) Int
h t
i a
acc
        | a -> Bool
f a
acc'    = t
i' t -> Maybe (t, a) -> Maybe (t, a)
`seq` FNode a -> Int -> t -> a -> Maybe (t, a)
go FNode a
l (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) t
i a
acc Maybe (t, a) -> Maybe (t, a) -> Maybe (t, a)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (t, a) -> Maybe (t, a)
forall a. a -> Maybe a
Just (t
i', a
acc')
        | Bool
otherwise = t
i' t -> Maybe (t, a) -> Maybe (t, a)
`seq` FNode a -> Int -> t -> a -> Maybe (t, a)
go FNode a
r (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) t
i' a
acc'
      where
        acc' :: a
acc' = a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x
        i' :: t
i' = t
i t -> t -> t
forall a. Num a => a -> a -> a
+ Int -> t
forall a. Bits a => Int -> a
bit Int
h

-- | Converts to a list of prefix accumulated values. O(n).
toScanl1F :: Monoid a => FTree a -> [a]
toScanl1F :: FTree a -> [a]
toScanl1F (FTree (Int
l, Int
r, Int
_) FNode a
rt) = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ FNode a -> a -> [a] -> [a]
forall t. Semigroup t => FNode t -> t -> [t] -> [t]
go FNode a
rt a
forall a. Monoid a => a
mempty [] where
    go :: FNode t -> t -> [t] -> [t]
go FNode t
FTip         t
_   = [t] -> [t]
forall a. a -> a
id
    go (FBin t
x FNode t
l FNode t
r) t
acc = FNode t -> t -> [t] -> [t]
go FNode t
l t
acc ([t] -> [t]) -> ([t] -> [t]) -> [t] -> [t]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t
acc't -> [t] -> [t]
forall a. a -> [a] -> [a]
:) ([t] -> [t]) -> ([t] -> [t]) -> [t] -> [t]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FNode t -> t -> [t] -> [t]
go FNode t
r t
acc' where acc' :: t
acc' = t
acc t -> t -> t
forall a. Semigroup a => a -> a -> a
<> t
x

--------------------------------------------------------------------------------
-- For tests

-- Allows specialization across modules
{-# INLINABLE fromListF #-}
{-# INLINABLE mappendF #-}
{-# INLINABLE foldPrefixF #-}
{-# INLINABLE binSearchF #-}

instance NFData a => NFData (FTree a) where
    rnf :: FTree a -> ()
rnf (FTree (Int, Int, Int)
lrh FNode a
rt) = (Int, Int, Int) -> ()
forall a. NFData a => a -> ()
rnf (Int, Int, Int)
lrh () -> () -> ()
`seq` FNode a -> ()
forall a. NFData a => a -> ()
rnf FNode a
rt

instance NFData a => NFData (FNode a) where
    rnf :: FNode a -> ()
rnf FNode a
FTip = ()
    rnf (FBin a
x FNode a
l FNode a
r) = a -> ()
forall a. NFData a => a -> ()
rnf a
x () -> () -> ()
`seq` FNode a -> ()
forall a. NFData a => a -> ()
rnf FNode a
l () -> () -> ()
`seq` FNode a -> ()
forall a. NFData a => a -> ()
rnf FNode a
r