{-# LANGUAGE FlexibleContexts #-}
{-|
== Disjoint set union

Data structure to maintain disjoint sets of Ints, supporting find and union.
Uses union by size and path halving.

Sources:

* https://en.wikipedia.org/wiki/Disjoint-set_data_structure
* https://cp-algorithms.com/data_structures/disjoint_set_union.html
* https://github.com/kth-competitive-programming/kactl/blob/main/content/data-structures/UnionFind.h
* Robert E. Tarjan and Jan van Leeuwen, "Worst-Case Analysis of Set Union Algorithms", 1984
  https://dl.acm.org/doi/10.1145/62.2160

Use unboxed arrays (IOUArray/STUArray) for best performance!
n = r - l + 1 in all instances below.
α is the inverse Ackermann function.

-}

{-
Implementation notes:
* KACTL's optimization is used where a single array is used for both size and parent, the size
  stored as negative.
* There is no way to make this structure functional without making the complexity worse :(
-}

module DSU
    ( newD
    , sameSetD
    , unionD
    ) where

import Control.Monad
import Data.Array.MArray

-- | Creates a new DSU structure with elements in the range (l, r), each in its own set. O(n).
newD :: MArray a Int m => (Int, Int) -> m (a Int Int)
newD :: (Int, Int) -> m (a Int Int)
newD (Int
l, Int
_) | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> m (a Int Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"negative range"
newD (Int, Int)
bnds = (Int, Int) -> Int -> m (a Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int, Int)
bnds (-Int
1)

findD :: MArray a Int m => a Int Int -> Int -> m Int
findD :: a Int Int -> Int -> m Int
findD a Int Int
d = Int -> m Int
forall (m :: * -> *). MArray a Int m => Int -> m Int
go where
    go :: Int -> m Int
go Int
i = a Int Int -> Int -> m Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray a Int Int
d Int
i m Int -> (Int -> m Int) -> m Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
j ->
        if Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 then Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i else a Int Int -> Int -> m Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray a Int Int
d Int
j m Int -> (Int -> m Int) -> m Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
k ->
            if Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 then Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
j else a Int Int -> Int -> Int -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray a Int Int
d Int
i Int
k m () -> m Int -> m Int
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m Int
go Int
k

-- | Returns whether the two elements are in the same set. Amortized O(α(n)).
sameSetD :: MArray a Int m => a Int Int -> Int -> Int -> m Bool
sameSetD :: a Int Int -> Int -> Int -> m Bool
sameSetD a Int Int
d Int
i Int
j = Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Int -> Int -> Bool) -> m Int -> m (Int -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a Int Int -> Int -> m Int
forall (a :: * -> * -> *) (m :: * -> *).
MArray a Int m =>
a Int Int -> Int -> m Int
findD a Int Int
d Int
i m (Int -> Bool) -> m Int -> m Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a Int Int -> Int -> m Int
forall (a :: * -> * -> *) (m :: * -> *).
MArray a Int m =>
a Int Int -> Int -> m Int
findD a Int Int
d Int
j

-- | Unites the sets containing the two elements. If they are already in the same set, returns False,
-- otherwise performs the union and returns True. Amortized O(α(n)).
unionD :: MArray a Int m => a Int Int -> Int -> Int -> m Bool
unionD :: a Int Int -> Int -> Int -> m Bool
unionD a Int Int
d Int
i Int
j = m (m Bool) -> m Bool
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Int -> Int -> m Bool
forall (f :: * -> *). MArray a Int f => Int -> Int -> f Bool
go (Int -> Int -> m Bool) -> m Int -> m (Int -> m Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a Int Int -> Int -> m Int
forall (a :: * -> * -> *) (m :: * -> *).
MArray a Int m =>
a Int Int -> Int -> m Int
findD a Int Int
d Int
i m (Int -> m Bool) -> m Int -> m (m Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a Int Int -> Int -> m Int
forall (a :: * -> * -> *) (m :: * -> *).
MArray a Int m =>
a Int Int -> Int -> m Int
findD a Int Int
d Int
j) where
    go :: Int -> Int -> f Bool
go Int
i Int
j
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j    = Bool -> f Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
        | Bool
otherwise = Bool
True Bool -> f () -> f Bool
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f (f ()) -> f ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Int -> Int -> Int -> Int -> f ()
forall (m :: * -> *).
MArray a Int m =>
Int -> Int -> Int -> Int -> m ()
upd Int
i Int
j (Int -> Int -> f ()) -> f Int -> f (Int -> f ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a Int Int -> Int -> f Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray a Int Int
d Int
i f (Int -> f ()) -> f Int -> f (f ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a Int Int -> Int -> f Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray a Int Int
d Int
j)
    upd :: Int -> Int -> Int -> Int -> m ()
upd Int
i Int
j Int
si Int
sj
        | Int
si Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
sj   = Int -> Int -> Int -> Int -> m ()
upd Int
j Int
i Int
sj Int
si
        | Bool
otherwise = a Int Int -> Int -> Int -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray a Int Int
d Int
i (Int
si Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sj) m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a Int Int -> Int -> Int -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray a Int Int
d Int
j Int
i

--------------------------------------------------------------------------------
-- For tests

-- Allows specialization across modules
{-# INLINABLE findD #-}
{-# INLINABLE sameSetD #-}
{-# INLINABLE unionD #-}