{-# LANGUAGE FlexibleContexts, QuantifiedConstraints #-}
{-|
== Sparse table

Structure for fast static range fold queries. Useful when the elements do not form a group,
otherwise prefix sums can be used.

Sources:

* https://cp-algorithms.com/data_structures/sparse-table.html
* https://github.com/kth-competitive-programming/kactl/blob/main/content/data-structures/RMQ.h

Let n = r - l + 1 in all instances below.

-}

{-
Implementation notes:
* Some array elements are undefined (because of newArray_), but that is fine because those elements
  are never read.
* I would love to put Semigroup/Idempotent constraints on fromListU functions but deriving IArray
  and MArray instances for newtypes is an unfair amount of pain.
-}

module SparseTable
    ( fromListSP
    , fromListISP
    , fromListUSP
    , fromListIUSP
    , buildSP
    , foldSP
    , foldISP
    ) where

import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
import Data.Bits

import Misc ( Idempotent, bitLength )

-- | Constructs a range fold function from a list. O(n log n) to construct the structure and O(log n)
-- for each query.
fromListSP :: Semigroup e => (Int, Int) -> [e] -> Int -> Int -> e
fromListSP :: (Int, Int) -> [e] -> Int -> Int -> e
fromListSP (Int, Int)
bnds [e]
xs = (e -> e -> e) -> Array (Int, Int) e -> Int -> Int -> e
forall (a :: * -> * -> *) e.
IArray a e =>
(e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldSP e -> e -> e
forall a. Semigroup a => a -> a -> a
(<>) (Array (Int, Int) e -> Int -> Int -> e)
-> Array (Int, Int) e -> Int -> Int -> e
forall a b. (a -> b) -> a -> b
$! (forall s. ST s (STArray s (Int, Int) e)) -> Array (Int, Int) e
forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray ((forall s. ST s (STArray s (Int, Int) e)) -> Array (Int, Int) e)
-> (forall s. ST s (STArray s (Int, Int) e)) -> Array (Int, Int) e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> (Int, Int) -> [e] -> ST s (STArray s (Int, Int) e)
forall (a :: * -> * -> *) e s.
MArray a e (ST s) =>
(e -> e -> e) -> (Int, Int) -> [e] -> ST s (a (Int, Int) e)
buildSP e -> e -> e
forall a. Semigroup a => a -> a -> a
(<>) (Int, Int)
bnds [e]
xs

-- | Constructs a range fold function from a list, when the semigroup is idempotent. O(n log n) to
-- construct the structure and O(1) for each query.
fromListISP :: Idempotent e => (Int, Int) -> [e] -> Int -> Int -> e
fromListISP :: (Int, Int) -> [e] -> Int -> Int -> e
fromListISP (Int, Int)
bnds [e]
xs = (e -> e -> e) -> Array (Int, Int) e -> Int -> Int -> e
forall (a :: * -> * -> *) e.
IArray a e =>
(e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldISP e -> e -> e
forall a. Semigroup a => a -> a -> a
(<>) (Array (Int, Int) e -> Int -> Int -> e)
-> Array (Int, Int) e -> Int -> Int -> e
forall a b. (a -> b) -> a -> b
$! (forall s. ST s (STArray s (Int, Int) e)) -> Array (Int, Int) e
forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray ((forall s. ST s (STArray s (Int, Int) e)) -> Array (Int, Int) e)
-> (forall s. ST s (STArray s (Int, Int) e)) -> Array (Int, Int) e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> (Int, Int) -> [e] -> ST s (STArray s (Int, Int) e)
forall (a :: * -> * -> *) e s.
MArray a e (ST s) =>
(e -> e -> e) -> (Int, Int) -> [e] -> ST s (a (Int, Int) e)
buildSP e -> e -> e
forall a. Semigroup a => a -> a -> a
(<>) (Int, Int)
bnds [e]
xs

-- | Constructs a range fold function from a list. Uses an unboxed array. O(n log n) to construct the
-- structure and O(log n) for each query.
fromListUSP :: (IArray UArray e, forall s. MArray (STUArray s) e (ST s))
            => (e -> e -> e) -> (Int, Int) -> [e] -> Int -> Int -> e
fromListUSP :: (e -> e -> e) -> (Int, Int) -> [e] -> Int -> Int -> e
fromListUSP e -> e -> e
op (Int, Int)
bnds [e]
xs = (e -> e -> e) -> UArray (Int, Int) e -> Int -> Int -> e
forall (a :: * -> * -> *) e.
IArray a e =>
(e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldSP e -> e -> e
op (UArray (Int, Int) e -> Int -> Int -> e)
-> UArray (Int, Int) e -> Int -> Int -> e
forall a b. (a -> b) -> a -> b
$! (forall s. ST s (STUArray s (Int, Int) e)) -> UArray (Int, Int) e
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s (Int, Int) e)) -> UArray (Int, Int) e)
-> (forall s. ST s (STUArray s (Int, Int) e))
-> UArray (Int, Int) e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e)
-> (Int, Int) -> [e] -> ST s (STUArray s (Int, Int) e)
forall (a :: * -> * -> *) e s.
MArray a e (ST s) =>
(e -> e -> e) -> (Int, Int) -> [e] -> ST s (a (Int, Int) e)
buildSP e -> e -> e
op (Int, Int)
bnds [e]
xs

-- | Constructs a range fold function from a list, when the semigroup is idempotent. Uses an unboxed
-- array. O(n log n) to construct the structure and O(1) for each query.
fromListIUSP :: (IArray UArray e, forall s. MArray (STUArray s) e (ST s))
             => (e -> e -> e) -> (Int, Int) -> [e] -> Int -> Int -> e
fromListIUSP :: (e -> e -> e) -> (Int, Int) -> [e] -> Int -> Int -> e
fromListIUSP e -> e -> e
op (Int, Int)
bnds [e]
xs = (e -> e -> e) -> UArray (Int, Int) e -> Int -> Int -> e
forall (a :: * -> * -> *) e.
IArray a e =>
(e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldISP e -> e -> e
op (UArray (Int, Int) e -> Int -> Int -> e)
-> UArray (Int, Int) e -> Int -> Int -> e
forall a b. (a -> b) -> a -> b
$! (forall s. ST s (STUArray s (Int, Int) e)) -> UArray (Int, Int) e
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s (Int, Int) e)) -> UArray (Int, Int) e)
-> (forall s. ST s (STUArray s (Int, Int) e))
-> UArray (Int, Int) e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e)
-> (Int, Int) -> [e] -> ST s (STUArray s (Int, Int) e)
forall (a :: * -> * -> *) e s.
MArray a e (ST s) =>
(e -> e -> e) -> (Int, Int) -> [e] -> ST s (a (Int, Int) e)
buildSP e -> e -> e
op (Int, Int)
bnds [e]
xs

-- | Builds a sparse table. O(n log n). Prefer the fromList functions.
buildSP :: MArray a e (ST s) => (e -> e -> e) -> (Int, Int) -> [e] -> ST s (a (Int, Int) e)
buildSP :: (e -> e -> e) -> (Int, Int) -> [e] -> ST s (a (Int, Int) e)
buildSP e -> e -> e
_  (Int
l, Int
r) [e]
_ | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r = [Char] -> ST s (a (Int, Int) e)
forall a. HasCallStack => [Char] -> a
error [Char]
"buildSP: empty range"
buildSP e -> e -> e
op (Int
l, Int
r) [e]
xs = do
    let h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
bitLength (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    a (Int, Int) e
t <- ((Int, Int), (Int, Int)) -> ST s (a (Int, Int) e)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ ((Int
0, Int
l), (Int
h, Int
r))
    [(Int, e)] -> ((Int, e) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [e] -> [(Int, e)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
l..Int
r] [e]
xs) (((Int, e) -> ST s ()) -> ST s ())
-> ((Int, e) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, e
x) -> a (Int, Int) e -> (Int, Int) -> e -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray a (Int, Int) e
t (Int
0, Int
i) e
x
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
h] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
        let d :: Int
d = Int -> Int
forall a. Bits a => Int -> a
bit (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
l..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
            e -> e -> e
op (e -> e -> e) -> ST s e -> ST s (e -> e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a (Int, Int) e -> (Int, Int) -> ST s e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray a (Int, Int) e
t (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Int
i) ST s (e -> e) -> ST s e -> ST s e
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a (Int, Int) e -> (Int, Int) -> ST s e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray a (Int, Int) e
t (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
d) ST s e -> (e -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a (Int, Int) e -> (Int, Int) -> e -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray a (Int, Int) e
t (Int
j, Int
i) (e -> ST s ()) -> e -> ST s ()
forall a b. (a -> b) -> a -> b
$!)
    a (Int, Int) e -> ST s (a (Int, Int) e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure a (Int, Int) e
t
{-# INLINE buildSP #-}

-- | Folds a range on a sparse table. O(log n). Prefer the fromList functions.
foldSP :: IArray a e => (e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldSP :: (e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldSP e -> e -> e
op a (Int, Int) e
t = Int -> Int -> e
qry where
    qry :: Int -> Int -> e
qry Int
l Int
r | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r = [Char] -> e
forall a. HasCallStack => [Char] -> a
error [Char]
"foldSP: empty range"
    qry Int
l Int
r = Int -> e -> e
go (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall a. Bits a => Int -> a
bit Int
j) (a (Int, Int) e
ta (Int, Int) e -> (Int, Int) -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!(Int
j, Int
l)) where
        j :: Int
j = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        go :: Int -> e -> e
go Int
l' e
acc | Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r = e
acc
        go Int
l' e
acc = Int -> e -> e
go (Int
l' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall a. Bits a => Int -> a
bit Int
j') (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$! e -> e -> e
op e
acc (a (Int, Int) e
ta (Int, Int) e -> (Int, Int) -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!(Int
j', Int
l')) where
            j' :: Int
j' = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE foldSP #-}

-- | Folds a range on a sparse table, when the semigroup is idempotent. O(1). Prefer the fromList
-- functions.
foldISP :: IArray a e => (e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldISP :: (e -> e -> e) -> a (Int, Int) e -> Int -> Int -> e
foldISP e -> e -> e
op a (Int, Int) e
t = Int -> Int -> e
qry where
    qry :: Int -> Int -> e
qry Int
l Int
r | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
r = [Char] -> e
forall a. HasCallStack => [Char] -> a
error [Char]
"foldISP: empty range"
    qry Int
l Int
r = e -> e -> e
op (a (Int, Int) e
ta (Int, Int) e -> (Int, Int) -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!(Int
j, Int
l)) (a (Int, Int) e
ta (Int, Int) e -> (Int, Int) -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!(Int
j, Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int
forall a. Bits a => Int -> a
bit Int
j)) where
        j :: Int
j = Int -> Int
forall b. FiniteBits b => b -> Int
bitLength (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
{-# INLINE foldISP #-}

--------------------------------------------------------------------------------
-- For tests

-- Allows specialization across modules
{-# INLINABLE fromListSP #-}
{-# INLINABLE fromListISP #-}
{-# INLINABLE fromListUSP #-}
{-# INLINABLE fromListIUSP #-}