{-
Part of the code in this file comes from the parameterized-utils package:

Copyright (c) 2013-2022 Galois Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:

  * Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.

  * Redistributions in binary form must reproduce the above copyright
    notice, this list of conditions and the following disclaimer in
    the documentation and/or other materials provided with the
    distribution.

  * Neither the name of Galois, Inc. nor the names of its contributors
    may be used to endorse or promote products derived from this
    software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

-- |
-- Module      :   Grisette.Internal.Utils.Parameterized
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Utils.Parameterized
  ( -- * Unsafe axiom
    unsafeAxiom,

    -- * Unparameterized type
    SomeNatRepr (..),
    SomePositiveNatRepr (..),

    -- * Runtime representation of type-level natural numbers
    NatRepr,
    withKnownNat,
    natValue,
    mkNatRepr,
    mkPositiveNatRepr,
    natRepr,
    decNat,
    predNat,
    incNat,
    addNat,
    subNat,
    divNat,
    halfNat,

    -- * Proof of KnownNat
    KnownProof (..),
    hasRepr,
    withKnownProof,
    unsafeKnownProof,
    knownAdd,

    -- * Proof of CmpNat
    CmpNatProof (..),
    unsafeCmpNatProof,
    withCmpNatProof,

    -- * Proof of (<=) for type-level natural numbers
    LeqProof (..),
    withLeqProof,
    unsafeLeqProof,
    testLeq,
    leqRefl,
    leqSucc,
    leqTrans,
    leqZero,
    leqAdd2,
    leqAdd,
    leqAddPos,
  )
where

import Data.Type.Equality (type (==))
import Data.Typeable (Proxy (Proxy), type (:~:) (Refl))
import GHC.TypeNats
  ( CmpNat,
    Div,
    KnownNat,
    Nat,
    SomeNat (SomeNat),
    natVal,
    someNatVal,
    type (+),
    type (-),
    type (<=),
  )
import Numeric.Natural (Natural)
import Unsafe.Coerce (unsafeCoerce)

-- | Assert a proof of equality between two types.
-- This is unsafe if used improperly, so use this with caution!
unsafeAxiom :: forall a b. a :~: b
unsafeAxiom :: forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom = (a :~: a) -> a :~: b
forall a b. a -> b
unsafeCoerce (forall (a :: k). a :~: a
forall {k} (a :: k). a :~: a
Refl @a)
{-# INLINE unsafeAxiom #-}

-- | Construct the 'KnownNat' constraint when the runtime value is known.
withKnownNat :: forall n r. NatRepr n -> ((KnownNat n) => r) -> r
withKnownNat :: forall (n :: Natural) r. NatRepr n -> (KnownNat n => r) -> r
withKnownNat (NatRepr Natural
nVal) KnownNat n => r
v =
  case Natural -> SomeNat
someNatVal Natural
nVal of
    SomeNat (Proxy n
Proxy :: Proxy n') ->
      case n :~: n
forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom :: n :~: n' of
        n :~: n
Refl -> r
KnownNat n => r
v
{-# INLINE withKnownNat #-}

-- | A runtime representation of type-level natural numbers.
-- This can be used for performing dynamic checks on type-level natural numbers.
newtype NatRepr (n :: Nat) = NatRepr Natural

-- | The underlying runtime natural number value of a type-level natural number.
natValue :: NatRepr n -> Natural
natValue :: forall (n :: Natural). NatRepr n -> Natural
natValue (NatRepr Natural
n) = Natural
n
{-# INLINE natValue #-}

data SomeNatReprHelper where
  SomeNatReprHelper :: NatRepr n -> SomeNatReprHelper

-- | Existential wrapper for t'NatRepr'.
data SomeNatRepr where
  SomeNatRepr :: (KnownNat n) => NatRepr n -> SomeNatRepr

-- | Turn a @Natural@ into the corresponding @NatRepr@ with the KnownNat
-- constraint.
mkNatRepr :: Natural -> SomeNatRepr
mkNatRepr :: Natural -> SomeNatRepr
mkNatRepr Natural
n = case NatRepr Any -> SomeNatReprHelper
forall (n :: Natural). NatRepr n -> SomeNatReprHelper
SomeNatReprHelper (Natural -> NatRepr Any
forall (n :: Natural). Natural -> NatRepr n
NatRepr Natural
n) of
  SomeNatReprHelper NatRepr n
natRepr -> NatRepr n -> (KnownNat n => SomeNatRepr) -> SomeNatRepr
forall (n :: Natural) r. NatRepr n -> (KnownNat n => r) -> r
withKnownNat NatRepr n
natRepr ((KnownNat n => SomeNatRepr) -> SomeNatRepr)
-> (KnownNat n => SomeNatRepr) -> SomeNatRepr
forall a b. (a -> b) -> a -> b
$ NatRepr n -> SomeNatRepr
forall (n :: Natural). KnownNat n => NatRepr n -> SomeNatRepr
SomeNatRepr NatRepr n
natRepr
{-# INLINE mkNatRepr #-}

-- | Existential wrapper for t'NatRepr' with the constraint that the natural
-- number is greater than 0.
data SomePositiveNatRepr where
  SomePositiveNatRepr ::
    (KnownNat n, 1 <= n) => NatRepr n -> SomePositiveNatRepr

-- | Turn a @NatRepr@ into the corresponding @NatRepr@ with the KnownNat
-- constraint and asserts that its greater than 0.
mkPositiveNatRepr :: Natural -> SomePositiveNatRepr
mkPositiveNatRepr :: Natural -> SomePositiveNatRepr
mkPositiveNatRepr Natural
0 = [Char] -> SomePositiveNatRepr
forall a. HasCallStack => [Char] -> a
error [Char]
"mkPositiveNatRepr: 0 is not a positive number"
mkPositiveNatRepr Natural
n = case Natural -> SomeNatRepr
mkNatRepr Natural
n of
  SomeNatRepr (NatRepr n
natRepr :: NatRepr n) -> case forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof @1 @n of
    LeqProof 1 n
LeqProof -> NatRepr n -> SomePositiveNatRepr
forall (n :: Natural).
(KnownNat n, 1 <= n) =>
NatRepr n -> SomePositiveNatRepr
SomePositiveNatRepr NatRepr n
natRepr
{-# INLINE mkPositiveNatRepr #-}

-- | Construct a runtime representation of a type-level natural number when its
-- runtime value is known.
natRepr :: forall n. (KnownNat n) => NatRepr n
natRepr :: forall (n :: Natural). KnownNat n => NatRepr n
natRepr = Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Natural
natVal (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n))
{-# INLINE natRepr #-}

-- | Decrement a t'NatRepr' by 1.
decNat :: (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat :: forall (n :: Natural). (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat (NatRepr Natural
n) = Natural -> NatRepr (n - 1)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1)
{-# INLINE decNat #-}

-- | Predecessor of a t'NatRepr'
predNat :: NatRepr (n + 1) -> NatRepr n
predNat :: forall (n :: Natural). NatRepr (n + 1) -> NatRepr n
predNat (NatRepr Natural
n) = Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1)
{-# INLINE predNat #-}

-- | Increment a t'NatRepr' by 1.
incNat :: NatRepr n -> NatRepr (n + 1)
incNat :: forall (n :: Natural). NatRepr n -> NatRepr (n + 1)
incNat (NatRepr Natural
n) = Natural -> NatRepr (n + 1)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Natural
1)
{-# INLINE incNat #-}

-- | Addition of two t'NatRepr's.
addNat :: NatRepr m -> NatRepr n -> NatRepr (m + n)
addNat :: forall (m :: Natural) (n :: Natural).
NatRepr m -> NatRepr n -> NatRepr (m + n)
addNat (NatRepr Natural
m) (NatRepr Natural
n) = Natural -> NatRepr (m + n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
m Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Natural
n)
{-# INLINE addNat #-}

-- | Subtraction of two t'NatRepr's.
subNat :: (n <= m) => NatRepr m -> NatRepr n -> NatRepr (m - n)
subNat :: forall (n :: Natural) (m :: Natural).
(n <= m) =>
NatRepr m -> NatRepr n -> NatRepr (m - n)
subNat (NatRepr Natural
m) (NatRepr Natural
n) = Natural -> NatRepr (m - n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
m Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
n)
{-# INLINE subNat #-}

-- | Division of two t'NatRepr's.
divNat :: (1 <= n) => NatRepr m -> NatRepr n -> NatRepr (Div m n)
divNat :: forall (n :: Natural) (m :: Natural).
(1 <= n) =>
NatRepr m -> NatRepr n -> NatRepr (Div m n)
divNat (NatRepr Natural
m) (NatRepr Natural
n) = Natural -> NatRepr (Div m n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
m Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`div` Natural
n)
{-# INLINE divNat #-}

-- | Half of a t'NatRepr'.
halfNat :: NatRepr (n + n) -> NatRepr n
halfNat :: forall (n :: Natural). NatRepr (n + n) -> NatRepr n
halfNat (NatRepr Natural
n) = Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`div` Natural
2)
{-# INLINE halfNat #-}

-- | @'KnownProof n'@ is a type whose values are only inhabited when @n@ has
-- a known runtime value.
data KnownProof (n :: Nat) where
  KnownProof :: (KnownNat n) => KnownProof n

-- | Introduces the 'KnownNat' constraint when it's proven.
withKnownProof :: KnownProof n -> ((KnownNat n) => r) -> r
withKnownProof :: forall (n :: Natural) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof KnownProof n
p KnownNat n => r
r = case KnownProof n
p of KnownProof n
KnownProof -> r
KnownNat n => r
r
{-# INLINE withKnownProof #-}

-- | Construct a t'KnownProof' given the runtime value.
--
-- __Note:__ This function is unsafe, as it does not check that the runtime
-- representation is consistent with the type-level representation.
-- You should ensure the consistency yourself or the program can crash or
-- generate incorrect results.
unsafeKnownProof :: Natural -> KnownProof n
unsafeKnownProof :: forall (n :: Natural). Natural -> KnownProof n
unsafeKnownProof Natural
nVal = NatRepr n -> KnownProof n
forall (n :: Natural). NatRepr n -> KnownProof n
hasRepr (Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr Natural
nVal)
{-# INLINE unsafeKnownProof #-}

-- | Construct a t'KnownProof' given the runtime representation.
hasRepr :: forall n. NatRepr n -> KnownProof n
hasRepr :: forall (n :: Natural). NatRepr n -> KnownProof n
hasRepr (NatRepr Natural
nVal) =
  case Natural -> SomeNat
someNatVal Natural
nVal of
    SomeNat (Proxy n
Proxy :: Proxy n') ->
      case n :~: n
forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom :: n :~: n' of
        n :~: n
Refl -> KnownProof n
forall (n :: Natural). KnownNat n => KnownProof n
KnownProof
{-# INLINE hasRepr #-}

-- | Adding two type-level natural numbers with known runtime values gives a
-- type-level natural number with a known runtime value.
knownAdd :: forall m n. KnownProof m -> KnownProof n -> KnownProof (m + n)
knownAdd :: forall (m :: Natural) (n :: Natural).
KnownProof m -> KnownProof n -> KnownProof (m + n)
knownAdd KnownProof m
KnownProof KnownProof n
KnownProof = forall (n :: Natural). NatRepr n -> KnownProof n
hasRepr @(m + n) (Natural -> NatRepr (m + n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Proxy m -> Natural
forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Natural
natVal (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @m) Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Natural
natVal (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)))
{-# INLINE knownAdd #-}

-- | @'LeqProof m n'@ is a type whose values are only inhabited when @m <= n@.
data LeqProof (m :: Nat) (n :: Nat) where
  LeqProof :: (m <= n) => LeqProof m n

-- | Introduces the @m <= n@ constraint when it's proven.
withLeqProof :: LeqProof m n -> ((m <= n) => r) -> r
withLeqProof :: forall (m :: Natural) (n :: Natural) r.
LeqProof m n -> ((m <= n) => r) -> r
withLeqProof LeqProof m n
p (m <= n) => r
r = case LeqProof m n
p of LeqProof m n
LeqProof -> r
(m <= n) => r
r
{-# INLINE withLeqProof #-}

-- | Construct a t'LeqProof'.
--
-- __Note:__ This function is unsafe, as it does not check that the left-hand
-- side is less than or equal to the right-hand side.
-- You should ensure the consistency yourself or the program can crash or
-- generate incorrect results.
unsafeLeqProof :: forall m n. LeqProof m n
unsafeLeqProof :: forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof = LeqProof 0 0 -> LeqProof m n
forall a b. a -> b
unsafeCoerce (forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof @0 @0)
{-# INLINE unsafeLeqProof #-}

-- | Proof that the comparison of two type-level natural numbers is consistent
-- with the runtime comparison.
data CmpNatProof (m :: Nat) (n :: Nat) (o :: Ordering) where
  CmpNatProof :: ((CmpNat m n == o) ~ 'True) => CmpNatProof m n o

-- | Construct a t'CmpNatProof'.
unsafeCmpNatProof :: forall m n o. CmpNatProof m n o
unsafeCmpNatProof :: forall (m :: Natural) (n :: Natural) (o :: Ordering).
CmpNatProof m n o
unsafeCmpNatProof = CmpNatProof 0 0 'EQ -> CmpNatProof m n o
forall a b. a -> b
unsafeCoerce (forall (m :: Natural) (n :: Natural) (o :: Ordering).
((CmpNat m n == o) ~ 'True) =>
CmpNatProof m n o
CmpNatProof @0 @0 @'EQ)
{-# INLINE unsafeCmpNatProof #-}

-- | Introduces the @t'CmpNat' m n o@ constraint when it's proven.
withCmpNatProof :: CmpNatProof m n o -> (((CmpNat m n == o) ~ 'True) => r) -> r
withCmpNatProof :: forall (m :: Natural) (n :: Natural) (o :: Ordering) r.
CmpNatProof m n o -> (((CmpNat m n == o) ~ 'True) => r) -> r
withCmpNatProof CmpNatProof m n o
p ((CmpNat m n == o) ~ 'True) => r
r = case CmpNatProof m n o
p of CmpNatProof m n o
CmpNatProof -> r
((CmpNat m n == o) ~ 'True) => r
r
{-# INLINE withCmpNatProof #-}

-- | Checks if a t'NatRepr' is less than or equal to another t'NatRepr'.
testLeq :: NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq :: forall (m :: Natural) (n :: Natural).
NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq (NatRepr Natural
m) (NatRepr Natural
n) =
  case Natural -> Natural -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Natural
m Natural
n of
    Ordering
LT -> Maybe (LeqProof m n)
forall a. Maybe a
Nothing
    Ordering
EQ -> LeqProof m n -> Maybe (LeqProof m n)
forall a. a -> Maybe a
Just LeqProof m n
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
    Ordering
GT -> LeqProof m n -> Maybe (LeqProof m n)
forall a. a -> Maybe a
Just LeqProof m n
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
{-# INLINE testLeq #-}

-- | Apply reflexivity to t'LeqProof'.
leqRefl :: f n -> LeqProof n n
leqRefl :: forall (f :: Natural -> *) (n :: Natural). f n -> LeqProof n n
leqRefl f n
_ = LeqProof n n
forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof
{-# INLINE leqRefl #-}

-- | A natural number is less than or equal to its successor.
leqSucc :: f n -> LeqProof n (n + 1)
leqSucc :: forall (f :: Natural -> *) (n :: Natural).
f n -> LeqProof n (n + 1)
leqSucc f n
_ = LeqProof n (n + 1)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
{-# INLINE leqSucc #-}

-- | Apply transitivity to t'LeqProof'.
leqTrans :: LeqProof a b -> LeqProof b c -> LeqProof a c
leqTrans :: forall (a :: Natural) (b :: Natural) (c :: Natural).
LeqProof a b -> LeqProof b c -> LeqProof a c
leqTrans LeqProof a b
_ LeqProof b c
_ = LeqProof a c
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
{-# INLINE leqTrans #-}

-- | Zero is less than or equal to any natural number.
leqZero :: LeqProof 0 n
leqZero :: forall (n :: Natural). LeqProof 0 n
leqZero = LeqProof 0 n
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
{-# INLINE leqZero #-}

-- | Add both sides of two inequalities.
leqAdd2 :: LeqProof xl xh -> LeqProof yl yh -> LeqProof (xl + yl) (xh + yh)
leqAdd2 :: forall (xl :: Natural) (xh :: Natural) (yl :: Natural)
       (yh :: Natural).
LeqProof xl xh -> LeqProof yl yh -> LeqProof (xl + yl) (xh + yh)
leqAdd2 LeqProof xl xh
_ LeqProof yl yh
_ = LeqProof (xl + yl) (xh + yh)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
{-# INLINE leqAdd2 #-}

-- | Produce proof that adding a value to the larger element in an t'LeqProof'
-- is larger.
leqAdd :: LeqProof m n -> f o -> LeqProof m (n + o)
leqAdd :: forall (m :: Natural) (n :: Natural) (f :: Natural -> *)
       (o :: Natural).
LeqProof m n -> f o -> LeqProof m (n + o)
leqAdd LeqProof m n
_ f o
_ = LeqProof m (n + o)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
{-# INLINE leqAdd #-}

-- | Adding two positive natural numbers is positive.
leqAddPos :: (1 <= m, 1 <= n) => p m -> q n -> LeqProof 1 (m + n)
leqAddPos :: forall (m :: Natural) (n :: Natural) (p :: Natural -> *)
       (q :: Natural -> *).
(1 <= m, 1 <= n) =>
p m -> q n -> LeqProof 1 (m + n)
leqAddPos p m
_ q n
_ = LeqProof 1 (m + n)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
{-# INLINE leqAddPos #-}