{-# LANGUAGE ScopedTypeVariables #-}
{-|
== Dijkstra's algorithm

An algorithm to find multi-source shortest paths in a graph with non-negative edges.

There are a variety of possible implementations depending on what is required, such as finding
parents, early stopping, etc. This is a basic implementation calculating only the distances,
to be modified when required.

Sources:

* Edgar W. Dijkstra, "A note on two problems in connexion with graphs", 1959
  https://www-m3.ma.tum.de/foswiki/pub/MN0506/WebHome/dijkstra.pdf
* Implementation is folklore

-}

{-
Implementation notes:
* dijkstra uses a Set as priority queue because there is no readily available priority queue
  structure in base.
* dijkstraH uses a skew heap as priority queue. Why a skew heap? Because it is pretty fast and easy
  to implement. An in-place binary heap would likely perform better, but hasn't been tested.
-}

module Dijkstra
    ( dijkstra
    , dijkstraH
    , Weight
    ) where

import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
import Data.Graph
import Data.List
import qualified Data.Set as S

import LabelledGraph ( LGraph )

type Weight = Int

-- | Runs Dijkstra's algorithm on the given graph. Unreachable vertices have distance maxBound.
-- O((V + E) log V).
dijkstra :: LGraph Weight -> [Vertex] -> UArray Vertex Weight
dijkstra :: LGraph Weight -> [Weight] -> UArray Weight Weight
dijkstra LGraph Weight
g [Weight]
srcs = (forall s. ST s (STUArray s Weight Weight)) -> UArray Weight Weight
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s Weight Weight))
 -> UArray Weight Weight)
-> (forall s. ST s (STUArray s Weight Weight))
-> UArray Weight Weight
forall a b. (a -> b) -> a -> b
$ do
    let bnds :: (Weight, Weight)
bnds = LGraph Weight -> (Weight, Weight)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds LGraph Weight
g
    STUArray s Weight Weight
d :: STUArray s Vertex Weight <- (Weight, Weight) -> Weight -> ST s (STUArray s Weight Weight)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Weight, Weight)
bnds Weight
forall a. Bounded a => a
maxBound
    let go :: Set (Weight, Weight) -> ST s ()
go = ST s ()
-> (((Weight, Weight), Set (Weight, Weight)) -> ST s ())
-> Maybe ((Weight, Weight), Set (Weight, Weight))
-> ST s ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) ((Weight, Weight), Set (Weight, Weight)) -> ST s ()
go' (Maybe ((Weight, Weight), Set (Weight, Weight)) -> ST s ())
-> (Set (Weight, Weight)
    -> Maybe ((Weight, Weight), Set (Weight, Weight)))
-> Set (Weight, Weight)
-> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (Weight, Weight)
-> Maybe ((Weight, Weight), Set (Weight, Weight))
forall a. Set a -> Maybe (a, Set a)
S.minView
        go' :: ((Weight, Weight), Set (Weight, Weight)) -> ST s ()
go' ((Weight
du, Weight
u), Set (Weight, Weight)
q) = (Set (Weight, Weight)
 -> (Weight, Weight) -> ST s (Set (Weight, Weight)))
-> Set (Weight, Weight)
-> [(Weight, Weight)]
-> ST s (Set (Weight, Weight))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Set (Weight, Weight)
-> (Weight, Weight) -> ST s (Set (Weight, Weight))
f Set (Weight, Weight)
q (LGraph Weight
gLGraph Weight -> Weight -> [(Weight, Weight)]
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Weight
u) ST s (Set (Weight, Weight))
-> (Set (Weight, Weight) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Set (Weight, Weight) -> ST s ()
go where
            f :: Set (Weight, Weight)
-> (Weight, Weight) -> ST s (Set (Weight, Weight))
f Set (Weight, Weight)
q' (Weight
w, Weight
v) = do
                Weight
dv <- STUArray s Weight Weight -> Weight -> ST s Weight
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Weight Weight
d Weight
v
                let dv' :: Weight
dv' = Weight
du Weight -> Weight -> Weight
forall a. Num a => a -> a -> a
+ Weight
w
                if Weight
dv Weight -> Weight -> Bool
forall a. Ord a => a -> a -> Bool
<= Weight
dv' then Set (Weight, Weight) -> ST s (Set (Weight, Weight))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Set (Weight, Weight)
q' else do
                    STUArray s Weight Weight -> Weight -> Weight -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Weight Weight
d Weight
v Weight
dv' :: ST s ()
                    Set (Weight, Weight) -> ST s (Set (Weight, Weight))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Set (Weight, Weight) -> ST s (Set (Weight, Weight)))
-> Set (Weight, Weight) -> ST s (Set (Weight, Weight))
forall a b. (a -> b) -> a -> b
$ (Weight, Weight) -> Set (Weight, Weight) -> Set (Weight, Weight)
forall a. Ord a => a -> Set a -> Set a
S.insert (Weight
dv', Weight
v) ((Weight, Weight) -> Set (Weight, Weight) -> Set (Weight, Weight)
forall a. Ord a => a -> Set a -> Set a
S.delete (Weight
dv, Weight
v) Set (Weight, Weight)
q')
    [Weight] -> (Weight -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Weight]
srcs ((Weight -> ST s ()) -> ST s ()) -> (Weight -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Weight
v -> STUArray s Weight Weight -> Weight -> Weight -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Weight Weight
d Weight
v Weight
0
    Set (Weight, Weight) -> ST s ()
go ([(Weight, Weight)] -> Set (Weight, Weight)
forall a. Ord a => [a] -> Set a
S.fromList ((,) Weight
0 (Weight -> (Weight, Weight)) -> [Weight] -> [(Weight, Weight)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Weight]
srcs))
    STUArray s Weight Weight -> ST s (STUArray s Weight Weight)
forall (f :: * -> *) a. Applicative f => a -> f a
pure STUArray s Weight Weight
d

-- | Runs Dijkstra's algorithm on the given graph. srcs should not have duplicates. Unreachable vertices
-- have distance maxBound. Faster than dijkstra, especially for large sparse graphs. O((V + E) log E).
dijkstraH :: LGraph Weight -> [Vertex] -> UArray Vertex Weight
dijkstraH :: LGraph Weight -> [Weight] -> UArray Weight Weight
dijkstraH LGraph Weight
g [Weight]
srcs = (forall s. ST s (STUArray s Weight Weight)) -> UArray Weight Weight
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s Weight Weight))
 -> UArray Weight Weight)
-> (forall s. ST s (STUArray s Weight Weight))
-> UArray Weight Weight
forall a b. (a -> b) -> a -> b
$ do
    let bnds :: (Weight, Weight)
bnds = LGraph Weight -> (Weight, Weight)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds LGraph Weight
g
    STUArray s Weight Weight
d :: STUArray s Vertex Weight <- (Weight, Weight) -> Weight -> ST s (STUArray s Weight Weight)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Weight, Weight)
bnds Weight
forall a. Bounded a => a
maxBound
    let go :: DHeap -> ST s ()
go DHeap
Tip = () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        go (Bin Weight
du Weight
u DHeap
ql DHeap
qr) = do
            Weight
du' <- STUArray s Weight Weight -> Weight -> ST s Weight
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Weight Weight
d Weight
u
            if Weight
du Weight -> Weight -> Bool
forall a. Eq a => a -> a -> Bool
== Weight
du' then (DHeap -> (Weight, Weight) -> ST s DHeap)
-> DHeap -> [(Weight, Weight)] -> ST s DHeap
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM DHeap -> (Weight, Weight) -> ST s DHeap
f DHeap
qlr (LGraph Weight
gLGraph Weight -> Weight -> [(Weight, Weight)]
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Weight
u) ST s DHeap -> (DHeap -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DHeap -> ST s ()
go else DHeap -> ST s ()
go DHeap
qlr
          where
            qlr :: DHeap
qlr = DHeap -> DHeap -> DHeap
unionH DHeap
ql DHeap
qr
            f :: DHeap -> (Weight, Weight) -> ST s DHeap
f DHeap
q (Weight
w, Weight
v) = do
                Weight
dv <- STUArray s Weight Weight -> Weight -> ST s Weight
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Weight Weight
d Weight
v
                let dv' :: Weight
dv' = Weight
du Weight -> Weight -> Weight
forall a. Num a => a -> a -> a
+ Weight
w
                if Weight
dv Weight -> Weight -> Bool
forall a. Ord a => a -> a -> Bool
<= Weight
dv' then DHeap -> ST s DHeap
forall (f :: * -> *) a. Applicative f => a -> f a
pure DHeap
q else do
                    STUArray s Weight Weight -> Weight -> Weight -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Weight Weight
d Weight
v Weight
dv' :: ST s ()
                    DHeap -> ST s DHeap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DHeap -> ST s DHeap) -> DHeap -> ST s DHeap
forall a b. (a -> b) -> a -> b
$ DHeap -> DHeap -> DHeap
unionH DHeap
q (Weight -> Weight -> DHeap -> DHeap -> DHeap
Bin Weight
dv' Weight
v DHeap
Tip DHeap
Tip)
    [Weight] -> (Weight -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Weight]
srcs ((Weight -> ST s ()) -> ST s ()) -> (Weight -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Weight
v -> STUArray s Weight Weight -> Weight -> Weight -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Weight Weight
d Weight
v Weight
0
    DHeap -> ST s ()
go ((DHeap -> DHeap -> DHeap) -> DHeap -> [DHeap] -> DHeap
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' DHeap -> DHeap -> DHeap
unionH DHeap
Tip [Weight -> Weight -> DHeap -> DHeap -> DHeap
Bin Weight
0 Weight
v DHeap
Tip DHeap
Tip | Weight
v <- [Weight]
srcs])
    STUArray s Weight Weight -> ST s (STUArray s Weight Weight)
forall (f :: * -> *) a. Applicative f => a -> f a
pure STUArray s Weight Weight
d

data DHeap = Tip | Bin !Weight !Vertex !DHeap !DHeap

unionH :: DHeap -> DHeap -> DHeap
unionH :: DHeap -> DHeap -> DHeap
unionH DHeap
Tip DHeap
h2 = DHeap
h2
unionH DHeap
h1 DHeap
Tip = DHeap
h1
unionH h1 :: DHeap
h1@(Bin Weight
w1 Weight
v1 DHeap
l1 DHeap
r1) h2 :: DHeap
h2@(Bin Weight
w2 Weight
v2 DHeap
l2 DHeap
r2)
   | Weight
w1 Weight -> Weight -> Bool
forall a. Ord a => a -> a -> Bool
<= Weight
w2  = Weight -> Weight -> DHeap -> DHeap -> DHeap
Bin Weight
w1 Weight
v1 (DHeap -> DHeap -> DHeap
unionH DHeap
r1 DHeap
h2) DHeap
l1
   | Bool
otherwise = Weight -> Weight -> DHeap -> DHeap -> DHeap
Bin Weight
w2 Weight
v2 (DHeap -> DHeap -> DHeap
unionH DHeap
r2 DHeap
h1) DHeap
l2