{-|
== Kruskal's algorithm

An algorithm to find the minimum spanning forest of an edge-weighted graph.

Sources:

* https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
* Joseph B. Kruskal, "On the shortest spanning subtree of a graph and the traveling salesman
  problem", 1956
  https://www.ams.org/journals/proc/1956-007-01/S0002-9939-1956-0078686-7/

-}

module Kruskal
    ( WEdge(..)
    , Weight
    , kruskal
    ) where

import Control.DeepSeq
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Graph
import Data.Ord

import DSU ( newD, unionD )
import Sort ( sortBy )

type Weight = Int
data WEdge = WEdge { WEdge -> Vertex
getU :: !Vertex, WEdge -> Vertex
getV :: !Vertex, WEdge -> Vertex
getW :: !Weight } deriving (WEdge -> WEdge -> Bool
(WEdge -> WEdge -> Bool) -> (WEdge -> WEdge -> Bool) -> Eq WEdge
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WEdge -> WEdge -> Bool
$c/= :: WEdge -> WEdge -> Bool
== :: WEdge -> WEdge -> Bool
$c== :: WEdge -> WEdge -> Bool
Eq, Vertex -> WEdge -> ShowS
[WEdge] -> ShowS
WEdge -> String
(Vertex -> WEdge -> ShowS)
-> (WEdge -> String) -> ([WEdge] -> ShowS) -> Show WEdge
forall a.
(Vertex -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WEdge] -> ShowS
$cshowList :: [WEdge] -> ShowS
show :: WEdge -> String
$cshow :: WEdge -> String
showsPrec :: Vertex -> WEdge -> ShowS
$cshowsPrec :: Vertex -> WEdge -> ShowS
Show)

-- | Runs Kruskal's algorithm on the graph represented by the given list of edges. Returns the edges that
-- are part of a minimum spanning forest. Vertices should be non-negative. O(|V| + |E|log|E|).
kruskal :: Bounds -> [WEdge] -> [WEdge]
kruskal :: Bounds -> [WEdge] -> [WEdge]
kruskal Bounds
bnds [WEdge]
es = (forall s. ST s [WEdge]) -> [WEdge]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [WEdge]) -> [WEdge])
-> (forall s. ST s [WEdge]) -> [WEdge]
forall a b. (a -> b) -> a -> b
$ do
    STUArray s Vertex Vertex
dsu <- Bounds -> ST s (STUArray s Vertex Vertex)
forall (a :: * -> * -> *) (m :: * -> *).
MArray a Vertex m =>
Bounds -> m (a Vertex Vertex)
newD Bounds
bnds :: ST s (STUArray s Int Int)
    (WEdge -> ST s Bool) -> [WEdge] -> ST s [WEdge]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (STUArray s Vertex Vertex -> Vertex -> Vertex -> ST s Bool
forall (a :: * -> * -> *) (m :: * -> *).
MArray a Vertex m =>
a Vertex Vertex -> Vertex -> Vertex -> m Bool
unionD STUArray s Vertex Vertex
dsu (Vertex -> Vertex -> ST s Bool)
-> (WEdge -> Vertex) -> WEdge -> Vertex -> ST s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WEdge -> Vertex
getU (WEdge -> Vertex -> ST s Bool)
-> (WEdge -> Vertex) -> WEdge -> ST s Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> WEdge -> Vertex
getV) ([WEdge] -> ST s [WEdge]) -> [WEdge] -> ST s [WEdge]
forall a b. (a -> b) -> a -> b
$ (WEdge -> WEdge -> Ordering) -> [WEdge] -> [WEdge]
forall e. (e -> e -> Ordering) -> [e] -> [e]
sortBy ((WEdge -> Vertex) -> WEdge -> WEdge -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing WEdge -> Vertex
getW) [WEdge]
es

--------------------------------------------------------------------------------
-- For tests

instance NFData WEdge where
    rnf :: WEdge -> ()
rnf = WEdge -> ()
forall a. a -> ()
rwhnf