{-# LANGUAGE FlexibleContexts, QuantifiedConstraints, ScopedTypeVariables #-}
{-|
== Sorting

Data.List.sort is rather inefficient when we don't care about laziness and just want to fully sort a
list. An in-place sort can have much better performance. Benchmarks show that for a list of Ints,
sort and sortU are 4x and 8x faster than Data.List.sort.

sort, sortBy, sortU, sortUBy use merge sort. countingSortUA uses counting sort. Both are stable
sorts.

Sources:

* https://en.wikipedia.org/wiki/Merge_sort
* https://en.wikipedia.org/wiki/Counting_sort

-}

module Sort
    ( sort
    , sortBy
    , sortU
    , sortUBy
    , sortUABy
    , countingSortUA
    ) where

import Control.Monad
import Control.Monad.ST
import Data.Array.Base
import Data.Array.ST

import Misc ( modifyArray )

-- | Sorts a list. O(n log n).
sort :: Ord e => [e] -> [e]
sort :: [e] -> [e]
sort = (e -> e -> Ordering) -> [e] -> [e]
forall e. (e -> e -> Ordering) -> [e] -> [e]
sortBy e -> e -> Ordering
forall a. Ord a => a -> a -> Ordering
compare

-- | Sorts a list with a comparison function. O(n log n).
sortBy :: (e -> e -> Ordering) -> [e] -> [e]
sortBy :: (e -> e -> Ordering) -> [e] -> [e]
sortBy e -> e -> Ordering
cmp [e]
xs = Array Int e -> [e]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems (Array Int e -> [e]) -> Array Int e -> [e]
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (STArray s Int e)) -> Array Int e
forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray ((forall s. ST s (STArray s Int e)) -> Array Int e)
-> (forall s. ST s (STArray s Int e)) -> Array Int e
forall a b. (a -> b) -> a -> b
$ do
    STArray s Int e
a <- (Int, Int) -> [e] -> ST s (STArray s Int e)
forall i e s. Ix i => (i, i) -> [e] -> ST s (STArray s i e)
listArrayST (Int
1, [e] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [e]
xs) [e]
xs
    STArray s Int e -> (e -> e -> Ordering) -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *).
MArray a e m =>
a Int e -> (e -> e -> Ordering) -> m ()
mergeSort STArray s Int e
a e -> e -> Ordering
cmp
    STArray s Int e -> ST s (STArray s Int e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure STArray s Int e
a

-- | Sorts a list for an element type that can be put in unboxed arrays. Faster than sort. O(n log n).
sortU :: (Ord e, forall s. MArray (STUArray s) e (ST s), IArray UArray e) => [e] -> [e]
sortU :: [e] -> [e]
sortU = (e -> e -> Ordering) -> [e] -> [e]
forall e.
(forall s. MArray (STUArray s) e (ST s), IArray UArray e) =>
(e -> e -> Ordering) -> [e] -> [e]
sortUBy e -> e -> Ordering
forall a. Ord a => a -> a -> Ordering
compare

-- | Sorts a list for an element type that can be put in unboxed arrays with a comparison function.
-- Faster than sortBy. O(n log n).
sortUBy :: (forall s. MArray (STUArray s) e (ST s), IArray UArray e)
        => (e -> e -> Ordering) -> [e] -> [e]
sortUBy :: (e -> e -> Ordering) -> [e] -> [e]
sortUBy e -> e -> Ordering
cmp [e]
xs = UArray Int e -> [e]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems (UArray Int e -> [e]) -> UArray Int e -> [e]
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (STUArray s Int e)) -> UArray Int e
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s Int e)) -> UArray Int e)
-> (forall s. ST s (STUArray s Int e)) -> UArray Int e
forall a b. (a -> b) -> a -> b
$ do
    STUArray s Int e
a <- (Int, Int) -> [e] -> ST s (STUArray s Int e)
forall s e i.
(MArray (STUArray s) e (ST s), Ix i) =>
(i, i) -> [e] -> ST s (STUArray s i e)
listUArrayST (Int
1, [e] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [e]
xs) [e]
xs
    STUArray s Int e -> (e -> e -> Ordering) -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *).
MArray a e m =>
a Int e -> (e -> e -> Ordering) -> m ()
mergeSort STUArray s Int e
a e -> e -> Ordering
cmp
    STUArray s Int e -> ST s (STUArray s Int e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure STUArray s Int e
a

-- | Sorts an unboxed array with a comparison function. O(n log n).
sortUABy :: (forall s. MArray (STUArray s) e (ST s), IArray UArray e)
         => (e -> e -> Ordering) -> UArray Int e -> UArray Int e
sortUABy :: (e -> e -> Ordering) -> UArray Int e -> UArray Int e
sortUABy e -> e -> Ordering
cmp UArray Int e
a = (forall s. ST s (STUArray s Int e)) -> UArray Int e
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s Int e)) -> UArray Int e)
-> (forall s. ST s (STUArray s Int e)) -> UArray Int e
forall a b. (a -> b) -> a -> b
$ do
    STUArray s Int e
a' <- UArray Int e -> ST s (STUArray s Int e)
forall i (a :: * -> * -> *) e (b :: * -> * -> *) (m :: * -> *).
(Ix i, IArray a e, MArray b e m) =>
a i e -> m (b i e)
thaw UArray Int e
a
    STUArray s Int e -> (e -> e -> Ordering) -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *).
MArray a e m =>
a Int e -> (e -> e -> Ordering) -> m ()
mergeSort STUArray s Int e
a' e -> e -> Ordering
cmp
    STUArray s Int e -> ST s (STUArray s Int e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure STUArray s Int e
a'

mergeSort :: forall a e m. (MArray a e m) => a Int e -> (e -> e -> Ordering) -> m ()
mergeSort :: a Int e -> (e -> e -> Ordering) -> m ()
mergeSort a Int e
a e -> e -> Ordering
cmp = do
    Int
n <- a Int e -> m Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m Int
getNumElements a Int e
a
    a Int e
b :: a Int e <- (Int, Int) -> m (a Int e)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ (Int
1, Int
n)
    let merge :: Int -> Int -> Int -> m ()
merge Int
l Int
m Int
r = ((Int, Int) -> Int -> m (Int, Int)) -> (Int, Int) -> [Int] -> m ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (Int, Int) -> Int -> m (Int, Int)
forall (f :: * -> *).
MArray a e f =>
(Int, Int) -> Int -> f (Int, Int)
f (Int
l, Int
m) [Int
l .. Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] where
            f :: (Int, Int) -> Int -> f (Int, Int)
f (Int
i, Int
j) Int
k
                | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
m = f (Int, Int)
takej
                | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
r = f (Int, Int)
takei
                | Bool
otherwise = do
                    Ordering
o <- e -> e -> Ordering
cmp (e -> e -> Ordering) -> f e -> f (e -> Ordering)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a Int e -> Int -> f e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead a Int e
a Int
i f (e -> Ordering) -> f e -> f Ordering
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a Int e -> Int -> f e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead a Int e
a Int
j
                    if Ordering
o Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
/= Ordering
GT then f (Int, Int)
takei else f (Int, Int)
takej
              where
                takei :: f (Int, Int)
takei = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
j) (Int, Int) -> f () -> f (Int, Int)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ (a Int e -> Int -> e -> f ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite a Int e
b Int
k (e -> f ()) -> f e -> f ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< a Int e -> Int -> f e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead a Int e
a Int
i)
                takej :: f (Int, Int)
takej = (Int
i, Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int, Int) -> f () -> f (Int, Int)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ (a Int e -> Int -> e -> f ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite a Int e
b Int
k (e -> f ()) -> f e -> f ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< a Int e -> Int -> f e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead a Int e
a Int
j)
    [Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
n) ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> Int -> [Int]
forall a. (a -> a) -> a -> [a]
iterate (Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
2) Int
1) ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
w -> do
        [Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0, Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
w .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int -> Int -> Int -> m ()
forall (m :: * -> *). MArray a e m => Int -> Int -> Int -> m ()
merge Int
i ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
w) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
n) ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
w) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
n)
        [Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> a Int e -> Int -> m e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead a Int e
b Int
i m e -> (e -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a Int e -> Int -> e -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite a Int e
a Int
i
{-# INLINE mergeSort #-}

-- | Sorts an unboxed array using counting sort. f should be a function that maps every element to an Int
-- in [0..b-1]. O(n + b).
countingSortUA :: (IArray UArray e, forall s. MArray (STUArray s) e (ST s))
               => Int -> (e -> Int) -> UArray Int e -> UArray Int e
countingSortUA :: Int -> (e -> Int) -> UArray Int e -> UArray Int e
countingSortUA Int
b e -> Int
f UArray Int e
a = (forall s. ST s (STUArray s Int e)) -> UArray Int e
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s Int e)) -> UArray Int e)
-> (forall s. ST s (STUArray s Int e)) -> UArray Int e
forall a b. (a -> b) -> a -> b
$ do
    STUArray s Int Int
cnt <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
b) Int
0 :: ST s (STUArray s Int Int)
    [e] -> (e -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (UArray Int e -> [e]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems UArray Int e
a) ((e -> ST s ()) -> ST s ()) -> (e -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \e
x -> STUArray s Int Int -> Int -> (Int -> Int) -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> (e -> e) -> m ()
modifyArray STUArray s Int Int
cnt (e -> Int
f e
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
cnt Int
0 ((Int, Int) -> Int
forall a b. (a, b) -> a
fst (UArray Int e -> (Int, Int)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds UArray Int e
a))
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1 .. Int
bInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
cnt (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ST s Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STUArray s Int Int -> Int -> (Int -> Int) -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> (e -> e) -> m ()
modifyArray STUArray s Int Int
cnt Int
i ((Int -> Int) -> ST s ()) -> (Int -> Int -> Int) -> Int -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)
    STUArray s Int e
a' <- (Int, Int) -> ST s (STUArray s Int e)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ (UArray Int e -> (Int, Int)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds UArray Int e
a)
    [e] -> (e -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (UArray Int e -> [e]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems UArray Int e
a) ((e -> ST s ()) -> ST s ()) -> (e -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \e
x -> do
        let y :: Int
y = e -> Int
f e
x
        Int
i <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
cnt Int
y
        STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
cnt Int
y (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        STUArray s Int e -> Int -> e -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int e
a' Int
i e
x
    STUArray s Int e -> ST s (STUArray s Int e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure STUArray s Int e
a'

--------------------------------------------------------------------------------
-- For tests

-- Allows specialization across modules
{-# INLINABLE sortU #-}
{-# INLINABLE sortUBy #-}
{-# INLINABLE sortUABy #-}
{-# INLINABLE countingSortUA #-}