{-# LANGUAGE AllowAmbiguousTypes, BangPatterns, ScopedTypeVariables #-}
{-|
== Math

Sources:

* https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm

-}

module Math
    ( egcd
    , egcd2
    , mkInvFactorials
    , mkFactorials
    , mkBinom
    ) where

import Data.Array.Unboxed
import Data.List

-- | Returns (g, s) where g = gcd(a, b); as + bt = g. O(log min(a, b)). Also see note for egcd2.
egcd :: Integral i => i -> i -> (i, i)
egcd :: i -> i -> (i, i)
egcd = i -> i -> i -> i -> (i, i)
forall t. Integral t => t -> t -> t -> t -> (t, t)
go i
1 i
0 where
    go :: t -> t -> t -> t -> (t, t)
go !t
s !t
_ !t
r t
0  = (t
r, t
s)
    go t
s  t
s' t
r  t
r' = let (t
q, t
r'') = t -> t -> (t, t)
forall a. Integral a => a -> a -> (a, a)
quotRem t
r t
r' in t -> t -> t -> t -> (t, t)
go t
s' (t
s t -> t -> t
forall a. Num a => a -> a -> a
- t
q t -> t -> t
forall a. Num a => a -> a -> a
* t
s') t
r' t
r''
{-# INLINE egcd #-}

-- | Returns (g, s, t) where g = gcd(a, b); as + bt = g. O(log min(a, b)).
-- Note: If the inputs are negative the returned gcd may be negative. abs the inputs or sign flip the
-- outputs if this is undesirable. The complexity assumes operations quotRem, (-), (*), (==0) all
-- take O(1).
egcd2 :: Integral i => i -> i -> (i, i, i)
egcd2 :: i -> i -> (i, i, i)
egcd2 = i -> i -> i -> i -> i -> i -> (i, i, i)
forall t. Integral t => t -> t -> t -> t -> t -> t -> (t, t, t)
go i
1 i
0 i
0 i
1 where
    go :: t -> t -> t -> t -> t -> t -> (t, t, t)
go !t
s !t
_ !t
t !t
_ !t
r t
0  = (t
r, t
s, t
t)
    go t
s  t
s' t
t  t
t' t
r  t
r' = let (t
q, t
r'') = t -> t -> (t, t)
forall a. Integral a => a -> a -> (a, a)
quotRem t
r t
r' in t -> t -> t -> t -> t -> t -> (t, t, t)
go t
s' (t
s t -> t -> t
forall a. Num a => a -> a -> a
- t
q t -> t -> t
forall a. Num a => a -> a -> a
* t
s') t
t' (t
t t -> t -> t
forall a. Num a => a -> a -> a
- t
q t -> t -> t
forall a. Num a => a -> a -> a
* t
t') t
r' t
r''
{-# INLINE egcd2 #-}

-- | Calculate factorials for 0..n. O(n).
mkFactorials :: (IArray a e, Num e) => Int -> a Int e
mkFactorials :: Int -> a Int e
mkFactorials Int
n = (Int, Int) -> [e] -> a Int e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Int
0,Int
n) ([e] -> a Int e) -> [e] -> a Int e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> e -> [e] -> [e]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl' e -> e -> e
forall a. Num a => a -> a -> a
(*) e
1 ([e] -> [e]) -> [e] -> [e]
forall a b. (a -> b) -> a -> b
$ (Int -> e) -> [Int] -> [e]
forall a b. (a -> b) -> [a] -> [b]
map Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
1..Int
n]

-- | Given 1/(n!) calculcate inverse factorials for 0..n. O(n).
mkInvFactorials :: (IArray a e, Num e) => Int -> e -> a Int e
mkInvFactorials :: Int -> e -> a Int e
mkInvFactorials Int
n e
invfacn =
    (Int, Int) -> [(Int, e)] -> a Int e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [(i, e)] -> a i e
array (Int
0,Int
n) ([(Int, e)] -> a Int e) -> [(Int, e)] -> a Int e
forall a b. (a -> b) -> a -> b
$ [Int] -> [e] -> [(Int, e)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
n, Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 .. Int
0] ([e] -> [(Int, e)]) -> [e] -> [(Int, e)]
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> e -> [e] -> [e]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl' e -> e -> e
forall a. Num a => a -> a -> a
(*) e
invfacn ([e] -> [e]) -> [e] -> [e]
forall a b. (a -> b) -> a -> b
$ (Int -> e) -> [Int] -> [e]
forall a b. (a -> b) -> [a] -> [b]
map Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
n, Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 .. Int
1]

-- | Given the maximum value of n, calculate binomial coefficients for n, k. Apply partially for multiple
-- queries. O(maxn) to set up, O(1) per query.
mkBinom :: forall a e. (IArray a e, Fractional e) => Int -> Int -> Int -> e
mkBinom :: Int -> Int -> Int -> e
mkBinom Int
mxn = a Int e -> a Int e -> Int -> Int -> e
forall (a :: * -> * -> *) e.
(IArray a e, Num e) =>
a Int e -> a Int e -> Int -> Int -> e
mkBinom_ a Int e
fac a Int e
ifac where
    fac, ifac :: a Int e
    fac :: a Int e
fac = Int -> a Int e
forall (a :: * -> * -> *) e. (IArray a e, Num e) => Int -> a Int e
mkFactorials Int
mxn
    ifac :: a Int e
ifac = Int -> e -> a Int e
forall (a :: * -> * -> *) e.
(IArray a e, Num e) =>
Int -> e -> a Int e
mkInvFactorials Int
mxn (e -> e
forall a. Fractional a => a -> a
recip (a Int e
faca Int e -> Int -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Int
mxn))

mkBinom_ :: (IArray a e, Num e) => a Int e -> a Int e -> Int -> Int -> e
mkBinom_ :: a Int e -> a Int e -> Int -> Int -> e
mkBinom_ !a Int e
fac !a Int e
ifac = Int -> Int -> e
go where
    go :: Int -> Int -> e
go Int
n Int
k | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k = e
0
    go Int
n Int
k = a Int e
faca Int e -> Int -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Int
n e -> e -> e
forall a. Num a => a -> a -> a
* a Int e
ifaca Int e -> Int -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!Int
k e -> e -> e
forall a. Num a => a -> a -> a
* a Int e
ifaca Int e -> Int -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
!(Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k)

--------------------------------------------------------------------------------
-- For tests

{-# INLINABLE mkFactorials #-}
{-# INLINABLE mkInvFactorials #-}