{-|
== Knuth-Morris-Pratt algorithm

A string matching algorithm that generates a prefix function p from an input string s, where p(i)
is the length of the longest prefix of s that ends at index i, excluding the prefix [0..i].

Sources:

* https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm
* https://cp-algorithms.com/string/prefix-function.html

-}

module KMP
    ( prefixFunc
    , prefixFuncBS
    ) where

import Control.Monad
import Data.Array.ST
import Data.Array.Unboxed
import qualified Data.ByteString.Char8 as C

-- | Constructs the prefix function. The input sequence should be 0-indexed. O(n).
prefixFunc :: Eq a => Int -> (Int -> a) -> UArray Int Int
prefixFunc :: Int -> (Int -> a) -> UArray Int Int
prefixFunc Int
n Int -> a
at = (forall s. ST s (STUArray s Int Int)) -> UArray Int Int
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
runSTUArray ((forall s. ST s (STUArray s Int Int)) -> UArray Int Int)
-> (forall s. ST s (STUArray s Int Int)) -> UArray Int Int
forall a b. (a -> b) -> a -> b
$ do
    STUArray s Int Int
p <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
0
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        let f :: Int -> f Int
f Int
j | Int -> a
at Int
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> a
at Int
j = Int -> f Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            f Int
0 = Int -> f Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0
            f Int
j = STUArray s Int Int -> Int -> f Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
p (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) f Int -> (Int -> f Int) -> f Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> f Int
f
        STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
p (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ST s Int -> (Int -> ST s Int) -> ST s Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> ST s Int
forall (f :: * -> *). MArray (STUArray s) Int f => Int -> f Int
f ST s Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
p Int
i
    STUArray s Int Int -> ST s (STUArray s Int Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure STUArray s Int Int
p

-- | prefixFunc for a ByteString. O(n).
prefixFuncBS :: C.ByteString -> UArray Int Int
prefixFuncBS :: ByteString -> UArray Int Int
prefixFuncBS ByteString
bs = Int -> (Int -> Char) -> UArray Int Int
forall a. Eq a => Int -> (Int -> a) -> UArray Int Int
prefixFunc (ByteString -> Int
C.length ByteString
bs) (ByteString -> Int -> Char
C.index ByteString
bs)