{-# LANGUAGE DataKinds, ScopedTypeVariables, TypeFamilies #-}
{-|
== Modular arithmetic

Mod m i is an integer of type i modulo m.
m will usually be a compile-time constant. If m is not known at compile time, Mod can still be
used via GHC.TypeNats.SomeNat.
m must be >= 2 and fit in the type i.
Integer operations on i for values in [0..m-1] must not overflow. Examples to avoid are using (-)
with a word type or using (*) with a large enough mod that (m-1)^2 oveflows.

This is a very general type, for something simpler see MInt.hs.

Instances of Eq, Num, Fractional exist for Mod m i. All the usual operations take O(1) time, except
for recip which takes O(log n) time. This assumes Integral i methods take O(1) time.
An instance of Enum exists for MInt. The enum is cyclic, it wraps to 0 after m-1.
Unboxed array support is available via Unbox.

-}

module Mod
    ( Mod(..)
    , M7
    , M3
    , invMaybe
    ) where

import Control.DeepSeq
import Data.Maybe
import Data.Proxy
import Data.Ratio
import GHC.TypeNats ( KnownNat, Nat, natVal )

import Array ( Unbox(..) )
import Math ( egcd )

-- | Type for modular arithmetic modulo m with underlying type i.
newtype Mod (m :: Nat) i = Mod { Mod m i -> i
unMod :: i } deriving (Mod m i -> Mod m i -> Bool
(Mod m i -> Mod m i -> Bool)
-> (Mod m i -> Mod m i -> Bool) -> Eq (Mod m i)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (m :: Nat) i. Eq i => Mod m i -> Mod m i -> Bool
/= :: Mod m i -> Mod m i -> Bool
$c/= :: forall (m :: Nat) i. Eq i => Mod m i -> Mod m i -> Bool
== :: Mod m i -> Mod m i -> Bool
$c== :: forall (m :: Nat) i. Eq i => Mod m i -> Mod m i -> Bool
Eq, Eq (Mod m i)
Eq (Mod m i)
-> (Mod m i -> Mod m i -> Ordering)
-> (Mod m i -> Mod m i -> Bool)
-> (Mod m i -> Mod m i -> Bool)
-> (Mod m i -> Mod m i -> Bool)
-> (Mod m i -> Mod m i -> Bool)
-> (Mod m i -> Mod m i -> Mod m i)
-> (Mod m i -> Mod m i -> Mod m i)
-> Ord (Mod m i)
Mod m i -> Mod m i -> Bool
Mod m i -> Mod m i -> Ordering
Mod m i -> Mod m i -> Mod m i
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall (m :: Nat) i. Ord i => Eq (Mod m i)
forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Bool
forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Ordering
forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Mod m i
min :: Mod m i -> Mod m i -> Mod m i
$cmin :: forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Mod m i
max :: Mod m i -> Mod m i -> Mod m i
$cmax :: forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Mod m i
>= :: Mod m i -> Mod m i -> Bool
$c>= :: forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Bool
> :: Mod m i -> Mod m i -> Bool
$c> :: forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Bool
<= :: Mod m i -> Mod m i -> Bool
$c<= :: forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Bool
< :: Mod m i -> Mod m i -> Bool
$c< :: forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Bool
compare :: Mod m i -> Mod m i -> Ordering
$ccompare :: forall (m :: Nat) i. Ord i => Mod m i -> Mod m i -> Ordering
$cp1Ord :: forall (m :: Nat) i. Ord i => Eq (Mod m i)
Ord, Int -> Mod m i -> ShowS
[Mod m i] -> ShowS
Mod m i -> String
(Int -> Mod m i -> ShowS)
-> (Mod m i -> String) -> ([Mod m i] -> ShowS) -> Show (Mod m i)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (m :: Nat) i. Show i => Int -> Mod m i -> ShowS
forall (m :: Nat) i. Show i => [Mod m i] -> ShowS
forall (m :: Nat) i. Show i => Mod m i -> String
showList :: [Mod m i] -> ShowS
$cshowList :: forall (m :: Nat) i. Show i => [Mod m i] -> ShowS
show :: Mod m i -> String
$cshow :: forall (m :: Nat) i. Show i => Mod m i -> String
showsPrec :: Int -> Mod m i -> ShowS
$cshowsPrec :: forall (m :: Nat) i. Show i => Int -> Mod m i -> ShowS
Show)

-- | Commonly used modulus.
type M7 = Mod 1000000007 Int

-- | Commonly used modulus.
type M3 = Mod 998244353 Int

instance (KnownNat m, Integral i) => Num (Mod m i) where
    Mod i
a + :: Mod m i -> Mod m i -> Mod m i
+ Mod i
b = i -> Mod m i
forall (m :: Nat) i. i -> Mod m i
Mod (i -> Mod m i) -> i -> Mod m i
forall a b. (a -> b) -> a -> b
$ if i
c i -> i -> Bool
forall a. Ord a => a -> a -> Bool
>= i
m then i
c i -> i -> i
forall a. Num a => a -> a -> a
- i
m else i
c where
        c :: i
c = i
a i -> i -> i
forall a. Num a => a -> a -> a
+ i
b
        m :: i
m = Natural -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy m -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy m
forall k (t :: k). Proxy t
Proxy :: Proxy m))
    Mod i
a - :: Mod m i -> Mod m i -> Mod m i
- Mod i
b = i -> Mod m i
forall (m :: Nat) i. i -> Mod m i
Mod (i -> Mod m i) -> i -> Mod m i
forall a b. (a -> b) -> a -> b
$ if i
c i -> i -> Bool
forall a. Ord a => a -> a -> Bool
< i
0 then i
c i -> i -> i
forall a. Num a => a -> a -> a
+ i
m else i
c where
        c :: i
c = i
a i -> i -> i
forall a. Num a => a -> a -> a
- i
b
        m :: i
m = Natural -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy m -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy m
forall k (t :: k). Proxy t
Proxy :: Proxy m))
    Mod i
a * :: Mod m i -> Mod m i -> Mod m i
* Mod i
b = i -> Mod m i
forall (m :: Nat) i. i -> Mod m i
Mod (i -> Mod m i) -> i -> Mod m i
forall a b. (a -> b) -> a -> b
$ i
a i -> i -> i
forall a. Num a => a -> a -> a
* i
b i -> i -> i
forall a. Integral a => a -> a -> a
`mod` Natural -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy m -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy m
forall k (t :: k). Proxy t
Proxy :: Proxy m))
    abs :: Mod m i -> Mod m i
abs         = Mod m i -> Mod m i
forall a. a -> a
id
    signum :: Mod m i -> Mod m i
signum      = i -> Mod m i
forall (m :: Nat) i. i -> Mod m i
Mod (i -> Mod m i) -> (Mod m i -> i) -> Mod m i -> Mod m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> i
forall a. Num a => a -> a
signum (i -> i) -> (Mod m i -> i) -> Mod m i -> i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mod m i -> i
forall (m :: Nat) i. Mod m i -> i
unMod
    fromInteger :: Integer -> Mod m i
fromInteger = i -> Mod m i
forall (m :: Nat) i. i -> Mod m i
Mod (i -> Mod m i) -> (Integer -> i) -> Integer -> Mod m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> i
forall a. Num a => Integer -> a
fromInteger (Integer -> i) -> (Integer -> Integer) -> Integer -> i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy m -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy m
forall k (t :: k). Proxy t
Proxy :: Proxy m)))

-- | The multiplicative inverse modulo m. It exists if and only if the number is coprime to m. O(log n).
invMaybe :: forall m i. (KnownNat m, Integral i) => Mod m i -> Maybe (Mod m i)
invMaybe :: Mod m i -> Maybe (Mod m i)
invMaybe Mod m i
n = case i -> i -> (i, i)
forall i. Integral i => i -> i -> (i, i)
egcd (Mod m i -> i
forall (m :: Nat) i. Mod m i -> i
unMod Mod m i
n) i
m of
    (i
1, i
s) | i
s i -> i -> Bool
forall a. Ord a => a -> a -> Bool
>= i
0    -> Mod m i -> Maybe (Mod m i)
forall a. a -> Maybe a
Just (i -> Mod m i
forall (m :: Nat) i. i -> Mod m i
Mod i
s)
           | Bool
otherwise -> Mod m i -> Maybe (Mod m i)
forall a. a -> Maybe a
Just (i -> Mod m i
forall (m :: Nat) i. i -> Mod m i
Mod (i
s i -> i -> i
forall a. Num a => a -> a -> a
+ i
m))
    (i, i)
_ -> Maybe (Mod m i)
forall a. Maybe a
Nothing
  where
    m :: i
m = Natural -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy m -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy m
forall k (t :: k). Proxy t
Proxy :: Proxy m))

instance (KnownNat m, Integral i) => Fractional (Mod m i) where
    recip :: Mod m i -> Mod m i
recip          = Mod m i -> Maybe (Mod m i) -> Mod m i
forall a. a -> Maybe a -> a
fromMaybe (String -> Mod m i
forall a. HasCallStack => String -> a
error String
"recip: no inverse") (Maybe (Mod m i) -> Mod m i)
-> (Mod m i -> Maybe (Mod m i)) -> Mod m i -> Mod m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mod m i -> Maybe (Mod m i)
forall (m :: Nat) i.
(KnownNat m, Integral i) =>
Mod m i -> Maybe (Mod m i)
invMaybe
    fromRational :: Rational -> Mod m i
fromRational Rational
r = Integer -> Mod m i
forall a. Num a => Integer -> a
fromInteger (Rational -> Integer
forall a. Ratio a -> a
numerator Rational
r) Mod m i -> Mod m i -> Mod m i
forall a. Fractional a => a -> a -> a
/ Integer -> Mod m i
forall a. Num a => Integer -> a
fromInteger (Rational -> Integer
forall a. Ratio a -> a
denominator Rational
r)

instance (KnownNat m, Integral i) => Enum (Mod m i) where
    toEnum :: Int -> Mod m i
toEnum                 = Int -> Mod m i
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    fromEnum :: Mod m i -> Int
fromEnum               = i -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (i -> Int) -> (Mod m i -> i) -> Mod m i -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mod m i -> i
forall (m :: Nat) i. Mod m i -> i
unMod
    enumFromTo :: Mod m i -> Mod m i -> [Mod m i]
enumFromTo Mod m i
x Mod m i
y         = (Mod m i -> Bool) -> [Mod m i] -> [Mod m i]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Mod m i -> Mod m i -> Bool
forall a. Eq a => a -> a -> Bool
/= Mod m i
y Mod m i -> Mod m i -> Mod m i
forall a. Num a => a -> a -> a
+ Mod m i
1) [Mod m i
x..]
    enumFromThenTo :: Mod m i -> Mod m i -> Mod m i -> [Mod m i]
enumFromThenTo Mod m i
x1 Mod m i
x2 Mod m i
y = (Mod m i -> Bool) -> [Mod m i] -> [Mod m i]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Mod m i -> Mod m i -> Bool
forall a. Eq a => a -> a -> Bool
/= Mod m i
y Mod m i -> Mod m i -> Mod m i
forall a. Num a => a -> a -> a
+ Mod m i
1) [Mod m i
x1, Mod m i
x2 ..]

instance Unbox (Mod m i) where
    type Unboxed (Mod m i) = i

--------------------------------------------------------------------------------
-- For tests

instance NFData i => NFData (Mod m i) where
    rnf :: Mod m i -> ()
rnf = i -> ()
forall a. NFData a => a -> ()
rnf (i -> ()) -> (Mod m i -> i) -> Mod m i -> ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mod m i -> i
forall (m :: Nat) i. Mod m i -> i
unMod