{-# LANGUAGE BangPatterns #-}

{-|
== Aho-Corasick algorithm

The Aho-Corasick algorithm builds an automaton from a set of pattern strings, and then uses it to
find positions in a search string where each of the pattern strings occur.

This implementation only works on ByteStrings, to keep things fast. If required it can be adapted
to work on other sequence types.

A TrieAC a can be constructed from pattern strings with associated values a, which can be then be
turned into an ACRoot a. An ACRoot a can then be run on a search string to find matches.
Construction and matching are both lazy.

Sources:

* Alfred V. Aho and Margaret J. Corasick, "Efficient string matching: An aid to bibliographic
  search", 1975
  https://dl.acm.org/doi/10.1145/360825.360855
* Stanford CS166 Aho-Corasick lecture slides
  https://web.stanford.edu/class/archive/cs/cs166/cs166.1166/lectures/04/Slides04.pdf

For complexities below, k is the alphabet range (max 256).

-}

{-
Implementation notes:
* We have to be lazy in the (Maybe (ACNode a)) and the [a] in fromTrieAC because we build the tree
  depth-first and strictly (due to IntMap.Strict). If we could build it breadth-first, then we
  could be strict in these, but I don't see an easy way to do that.
-}

module AhoCorasick
    ( TrieAC
    , emptyTAC
    , insertTAC
    , fromListTAC
    , ACRoot
    , fromTrieAC
    , matchAC
    ) where

import Control.Applicative
import Control.DeepSeq
import Data.List
import Data.Maybe
import qualified Data.ByteString as B
import qualified Data.IntMap.Strict as IM

data ACRoot a = ACRoot !(IM.IntMap (ACNode a)) [a]
data ACNode a = ACNode !(IM.IntMap (ACNode a)) (Maybe (ACNode a)) [a]

-- | Builds an Aho-Corasick automaton from a trie. O(n), where n is the number of nodes in the trie.
-- This is not more than the total length of strings the trie was constructed with.
fromTrieAC :: TrieAC a -> ACRoot a
fromTrieAC :: TrieAC a -> ACRoot a
fromTrieAC (TrieAC IntMap (TrieAC a)
tm [a]
routs) = IntMap (ACNode a) -> [a] -> ACRoot a
forall a. IntMap (ACNode a) -> [a] -> ACRoot a
ACRoot IntMap (ACNode a)
rmp [a]
routs where
    rmp :: IntMap (ACNode a)
rmp = (TrieAC a -> ACNode a) -> IntMap (TrieAC a) -> IntMap (ACNode a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map TrieAC a -> ACNode a
go1 IntMap (TrieAC a)
tm
    go1 :: TrieAC a -> ACNode a
go1 (TrieAC IntMap (TrieAC a)
m [a]
vs) = IntMap (ACNode a) -> Maybe (ACNode a) -> [a] -> ACNode a
forall a. IntMap (ACNode a) -> Maybe (ACNode a) -> [a] -> ACNode a
ACNode ((Key -> TrieAC a -> ACNode a)
-> IntMap (TrieAC a) -> IntMap (ACNode a)
forall a b. (Key -> a -> b) -> IntMap a -> IntMap b
IM.mapWithKey (Maybe (ACNode a) -> Key -> TrieAC a -> ACNode a
go Maybe (ACNode a)
forall a. Maybe a
Nothing) IntMap (TrieAC a)
m) Maybe (ACNode a)
forall a. Maybe a
Nothing ([a]
vs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
routs)
    go :: Maybe (ACNode a) -> Key -> TrieAC a -> ACNode a
go Maybe (ACNode a)
psuf !Key
c (TrieAC IntMap (TrieAC a)
m [a]
vs) = IntMap (ACNode a) -> Maybe (ACNode a) -> [a] -> ACNode a
forall a. IntMap (ACNode a) -> Maybe (ACNode a) -> [a] -> ACNode a
ACNode ((Key -> TrieAC a -> ACNode a)
-> IntMap (TrieAC a) -> IntMap (ACNode a)
forall a b. (Key -> a -> b) -> IntMap a -> IntMap b
IM.mapWithKey (Maybe (ACNode a) -> Key -> TrieAC a -> ACNode a
go Maybe (ACNode a)
suf) IntMap (TrieAC a)
m) Maybe (ACNode a)
suf [a]
outs where
        suf :: Maybe (ACNode a)
suf = Maybe (ACNode a) -> Maybe (ACNode a)
getSuf Maybe (ACNode a)
psuf
        getSuf :: Maybe (ACNode a) -> Maybe (ACNode a)
getSuf Maybe (ACNode a)
Nothing                    = Key -> IntMap (ACNode a) -> Maybe (ACNode a)
forall a. Key -> IntMap a -> Maybe a
IM.lookup Key
c IntMap (ACNode a)
rmp
        getSuf (Just (ACNode IntMap (ACNode a)
mp' Maybe (ACNode a)
suf' [a]
_)) = Key -> IntMap (ACNode a) -> Maybe (ACNode a)
forall a. Key -> IntMap a -> Maybe a
IM.lookup Key
c IntMap (ACNode a)
mp' Maybe (ACNode a) -> Maybe (ACNode a) -> Maybe (ACNode a)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Maybe (ACNode a) -> Maybe (ACNode a)
getSuf Maybe (ACNode a)
suf'
        outs :: [a]
outs = [a]
vs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a] -> (ACNode a -> [a]) -> Maybe (ACNode a) -> [a]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [a]
routs (\(ACNode IntMap (ACNode a)
_ Maybe (ACNode a)
_ [a]
outs') -> [a]
outs') Maybe (ACNode a)
suf

-- | Returns a list of length (m + 1) where m is the length of the search string. This list contains a
-- list of pattern matches for every position in the string, including before the first character. A
-- match at a position is present as the associated value of the pattern string found to be ending at
-- that position.
-- O(m log k + z), where m is the length of the string and z is the total number of matches.
matchAC :: ACRoot a -> B.ByteString -> [[a]]
matchAC :: ACRoot a -> ByteString -> [[a]]
matchAC (ACRoot IntMap (ACNode a)
rmp [a]
routs) !ByteString
s0 = [a]
routs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: ByteString -> [[a]]
gor ByteString
s0 where
    gor :: ByteString -> [[a]]
gor ByteString
s = case ByteString -> Maybe (Word8, ByteString)
B.uncons ByteString
s of
        Maybe (Word8, ByteString)
Nothing -> []
        Just (Word8
c,ByteString
s') -> case Key -> IntMap (ACNode a) -> Maybe (ACNode a)
forall a. Key -> IntMap a -> Maybe a
IM.lookup (Word8 -> Key
forall a. Enum a => a -> Key
fromEnum Word8
c) IntMap (ACNode a)
rmp of
            Maybe (ACNode a)
Nothing -> [a]
routs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: ByteString -> [[a]]
gor ByteString
s'
            Just (ACNode IntMap (ACNode a)
mp Maybe (ACNode a)
suf [a]
outs) -> [a]
outs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: IntMap (ACNode a) -> Maybe (ACNode a) -> ByteString -> [[a]]
go IntMap (ACNode a)
mp Maybe (ACNode a)
suf ByteString
s'
    go :: IntMap (ACNode a) -> Maybe (ACNode a) -> ByteString -> [[a]]
go IntMap (ACNode a)
mp Maybe (ACNode a)
suf ByteString
s = case ByteString -> Maybe (Word8, ByteString)
B.uncons ByteString
s of
        Maybe (Word8, ByteString)
Nothing -> []
        Just (Word8
c, ByteString
s') -> case Key -> IntMap (ACNode a) -> Maybe (ACNode a)
forall a. Key -> IntMap a -> Maybe a
IM.lookup (Word8 -> Key
forall a. Enum a => a -> Key
fromEnum Word8
c) IntMap (ACNode a)
mp of
            Maybe (ACNode a)
Nothing -> (ByteString -> [[a]])
-> (ACNode a -> ByteString -> [[a]])
-> Maybe (ACNode a)
-> ByteString
-> [[a]]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString -> [[a]]
gor (\(ACNode IntMap (ACNode a)
mp' Maybe (ACNode a)
suf' [a]
_) -> IntMap (ACNode a) -> Maybe (ACNode a) -> ByteString -> [[a]]
go IntMap (ACNode a)
mp' Maybe (ACNode a)
suf') Maybe (ACNode a)
suf ByteString
s
            Just (ACNode IntMap (ACNode a)
mp' Maybe (ACNode a)
suf' [a]
outs) -> [a]
outs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: IntMap (ACNode a) -> Maybe (ACNode a) -> ByteString -> [[a]]
go IntMap (ACNode a)
mp' Maybe (ACNode a)
suf' ByteString
s'

data TrieAC a = TrieAC !(IM.IntMap (TrieAC a)) ![a] deriving Key -> TrieAC a -> ShowS
[TrieAC a] -> ShowS
TrieAC a -> String
(Key -> TrieAC a -> ShowS)
-> (TrieAC a -> String) -> ([TrieAC a] -> ShowS) -> Show (TrieAC a)
forall a. Show a => Key -> TrieAC a -> ShowS
forall a. Show a => [TrieAC a] -> ShowS
forall a. Show a => TrieAC a -> String
forall a.
(Key -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TrieAC a] -> ShowS
$cshowList :: forall a. Show a => [TrieAC a] -> ShowS
show :: TrieAC a -> String
$cshow :: forall a. Show a => TrieAC a -> String
showsPrec :: Key -> TrieAC a -> ShowS
$cshowsPrec :: forall a. Show a => Key -> TrieAC a -> ShowS
Show

-- | An empty trie.
emptyTAC :: TrieAC a
emptyTAC :: TrieAC a
emptyTAC = IntMap (TrieAC a) -> [a] -> TrieAC a
forall a. IntMap (TrieAC a) -> [a] -> TrieAC a
TrieAC IntMap (TrieAC a)
forall a. IntMap a
IM.empty []

-- | Inserts a string with an associated value into a trie. O(n log k) where n is the length of the
-- string.
insertTAC :: B.ByteString -> a -> TrieAC a -> TrieAC a
insertTAC :: ByteString -> a -> TrieAC a -> TrieAC a
insertTAC ByteString
s a
v = ByteString -> TrieAC a -> TrieAC a
go ByteString
s where
    go :: ByteString -> TrieAC a -> TrieAC a
go ByteString
cs (TrieAC IntMap (TrieAC a)
m [a]
vs) = case ByteString -> Maybe (Word8, ByteString)
B.uncons ByteString
cs of
        Maybe (Word8, ByteString)
Nothing       -> IntMap (TrieAC a) -> [a] -> TrieAC a
forall a. IntMap (TrieAC a) -> [a] -> TrieAC a
TrieAC IntMap (TrieAC a)
m (a
va -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
vs)
        Just (Word8
c, ByteString
cs') -> IntMap (TrieAC a) -> [a] -> TrieAC a
forall a. IntMap (TrieAC a) -> [a] -> TrieAC a
TrieAC IntMap (TrieAC a)
m' [a]
vs where
            m' :: IntMap (TrieAC a)
m' = (Maybe (TrieAC a) -> Maybe (TrieAC a))
-> Key -> IntMap (TrieAC a) -> IntMap (TrieAC a)
forall a. (Maybe a -> Maybe a) -> Key -> IntMap a -> IntMap a
IM.alter ((TrieAC a -> Maybe (TrieAC a)
forall a. a -> Maybe a
Just (TrieAC a -> Maybe (TrieAC a)) -> TrieAC a -> Maybe (TrieAC a)
forall a b. (a -> b) -> a -> b
$!) (TrieAC a -> Maybe (TrieAC a))
-> (Maybe (TrieAC a) -> TrieAC a)
-> Maybe (TrieAC a)
-> Maybe (TrieAC a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> TrieAC a -> TrieAC a
go ByteString
cs' (TrieAC a -> TrieAC a)
-> (Maybe (TrieAC a) -> TrieAC a) -> Maybe (TrieAC a) -> TrieAC a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TrieAC a -> Maybe (TrieAC a) -> TrieAC a
forall a. a -> Maybe a -> a
fromMaybe TrieAC a
forall a. TrieAC a
emptyTAC) (Word8 -> Key
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
c) IntMap (TrieAC a)
m

-- | Builds a trie from a list of strings and associated values. O(n log k) where n is total length of
-- the strings.
fromListTAC :: [(B.ByteString, a)] -> TrieAC a
fromListTAC :: [(ByteString, a)] -> TrieAC a
fromListTAC = (TrieAC a -> (ByteString, a) -> TrieAC a)
-> TrieAC a -> [(ByteString, a)] -> TrieAC a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\TrieAC a
t (ByteString
s, a
v) -> ByteString -> a -> TrieAC a -> TrieAC a
forall a. ByteString -> a -> TrieAC a -> TrieAC a
insertTAC ByteString
s a
v TrieAC a
t) TrieAC a
forall a. TrieAC a
emptyTAC

--------------------------------------------------------------------------------
-- For tests

instance NFData a => NFData (ACNode a) where
    rnf :: ACNode a -> ()
rnf (ACNode IntMap (ACNode a)
mp Maybe (ACNode a)
_outs [a]
suf) = [a]
suf [a] -> () -> ()
`seq` IntMap (ACNode a) -> ()
forall a. NFData a => a -> ()
rnf IntMap (ACNode a)
mp
-- outs of nodes share structure, so it is not forced
-- the suf link is forced only to WHNF, otherwise it would be reevaluating various parts of the tree

instance NFData a => NFData (ACRoot a) where
    rnf :: ACRoot a -> ()
rnf (ACRoot IntMap (ACNode a)
mp [a]
_outs) = IntMap (ACNode a) -> ()
forall a. NFData a => a -> ()
rnf IntMap (ACNode a)
mp