{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Eta reduce" #-}

-- |
-- Module      :   Grisette.Internal.Unified.Util
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Unified.Util
  ( DecideEvalMode (..),
    withMode,
    EvalModeConvertible (..),
  )
where

import Data.Typeable (type (:~:) (Refl))
import Grisette.Internal.Unified.EvalModeTag (EvalModeTag (C, S))
import Grisette.Internal.Utils.Parameterized (unsafeAxiom)

-- | A class that provides the mode tag at runtime.
class DecideEvalMode (mode :: EvalModeTag) where
  decideEvalMode :: EvalModeTag

instance DecideEvalMode 'C where
  decideEvalMode :: EvalModeTag
decideEvalMode = EvalModeTag
C
  {-# INLINE decideEvalMode #-}

instance DecideEvalMode 'S where
  decideEvalMode :: EvalModeTag
decideEvalMode = EvalModeTag
S
  {-# INLINE decideEvalMode #-}

-- | Case analysis on the mode.
withMode ::
  forall mode r.
  (DecideEvalMode mode) =>
  ((mode ~ 'C) => r) ->
  ((mode ~ 'S) => r) ->
  r
withMode :: forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode (mode ~ 'C) => r
con (mode ~ 'S) => r
sym =
  case forall (mode :: EvalModeTag). DecideEvalMode mode => EvalModeTag
decideEvalMode @mode of
    EvalModeTag
C -> case forall {k} (a :: k) (b :: k). a :~: b
forall (a :: EvalModeTag) (b :: EvalModeTag). a :~: b
unsafeAxiom @mode @'C of
      mode :~: 'C
Refl -> r
(mode ~ 'C) => r
con
    EvalModeTag
S -> case forall {k} (a :: k) (b :: k). a :~: b
forall (a :: EvalModeTag) (b :: EvalModeTag). a :~: b
unsafeAxiom @mode @'S of
      mode :~: 'S
Refl -> r
(mode ~ 'S) => r
sym
{-# INLINE withMode #-}

-- | A class saying that we can convert a value with one mode to another mode.
--
-- Allowed conversions:
--
-- - 'C' <-> 'C'
-- - 'S' <-> 'S'
-- - 'C' <-> 'S'
--
-- Conversion from left to right uses 'Grisette.ToSym' class, and conversion
-- from right to left uses 'Grisette.ToCon' class.
class
  (DecideEvalMode c, DecideEvalMode s) =>
  EvalModeConvertible (c :: EvalModeTag) (s :: EvalModeTag)
  where
  withModeConvertible ::
    ((c ~ 'C) => r) ->
    ((s ~ 'S) => r) ->
    r
  withModeConvertible' ::
    ((c ~ 'C, s ~ 'C) => r) ->
    ((c ~ 'C, s ~ 'S) => r) ->
    ((c ~ 'S, s ~ 'S) => r) ->
    r

instance {-# INCOHERENT #-} (DecideEvalMode c) => EvalModeConvertible c c where
  withModeConvertible :: forall r. ((c ~ 'C) => r) -> ((c ~ 'S) => r) -> r
withModeConvertible (c ~ 'C) => r
con (c ~ 'S) => r
sym = forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @c r
(c ~ 'C) => r
con r
(c ~ 'S) => r
sym
  {-# INLINE withModeConvertible #-}
  withModeConvertible' :: forall r.
((c ~ 'C, c ~ 'C) => r)
-> ((c ~ 'C, c ~ 'S) => r) -> ((c ~ 'S, c ~ 'S) => r) -> r
withModeConvertible' (c ~ 'C, c ~ 'C) => r
con (c ~ 'C, c ~ 'S) => r
_ (c ~ 'S, c ~ 'S) => r
sym = forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @c r
(c ~ 'C) => r
(c ~ 'C, c ~ 'C) => r
con r
(c ~ 'S) => r
(c ~ 'S, c ~ 'S) => r
sym
  {-# INLINE withModeConvertible' #-}

instance {-# INCOHERENT #-} (DecideEvalMode s) => EvalModeConvertible 'C s where
  withModeConvertible :: forall r. (('C ~ 'C) => r) -> ((s ~ 'S) => r) -> r
withModeConvertible ('C ~ 'C) => r
con (s ~ 'S) => r
_ = r
('C ~ 'C) => r
con
  {-# INLINE withModeConvertible #-}
  withModeConvertible' :: forall r.
(('C ~ 'C, s ~ 'C) => r)
-> (('C ~ 'C, s ~ 'S) => r) -> (('C ~ 'S, s ~ 'S) => r) -> r
withModeConvertible' ('C ~ 'C, s ~ 'C) => r
con0 ('C ~ 'C, s ~ 'S) => r
con1 ('C ~ 'S, s ~ 'S) => r
_ = forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @s r
(s ~ 'C) => r
('C ~ 'C, s ~ 'C) => r
con0 r
(s ~ 'S) => r
('C ~ 'C, s ~ 'S) => r
con1
  {-# INLINE withModeConvertible' #-}

instance {-# INCOHERENT #-} (DecideEvalMode c) => EvalModeConvertible c 'S where
  withModeConvertible :: forall r. ((c ~ 'C) => r) -> (('S ~ 'S) => r) -> r
withModeConvertible (c ~ 'C) => r
_ ('S ~ 'S) => r
sym = r
('S ~ 'S) => r
sym
  {-# INLINE withModeConvertible #-}
  withModeConvertible' :: forall r.
((c ~ 'C, 'S ~ 'C) => r)
-> ((c ~ 'C, 'S ~ 'S) => r) -> ((c ~ 'S, 'S ~ 'S) => r) -> r
withModeConvertible' (c ~ 'C, 'S ~ 'C) => r
_ (c ~ 'C, 'S ~ 'S) => r
sym0 (c ~ 'S, 'S ~ 'S) => r
sym1 = forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @c r
(c ~ 'C) => r
(c ~ 'C, 'S ~ 'S) => r
sym0 r
(c ~ 'S) => r
(c ~ 'S, 'S ~ 'S) => r
sym1
  {-# INLINE withModeConvertible' #-}