{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Internal.Decl.Core.Data.Class.SafeDiv
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Internal.Decl.Core.Data.Class.SafeDiv
  ( ArithException (..),
    SafeDiv (..),
    DivOr (..),
    divOrZero,
    modOrDividend,
    quotOrZero,
    remOrDividend,
    divModOrZeroDividend,
    quotRemOrZeroDividend,
  )
where

import Control.Exception (ArithException (DivideByZero, Overflow, Underflow))
import Control.Monad.Except (MonadError)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.TryMerge
  ( TryMerge,
    mrgSingle,
  )
import Grisette.Lib.Data.Functor (mrgFmap)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Control.Monad.Except
-- >>> import Control.Exception

-- | Safe division handling with default values returned on exception.
class DivOr a where
  -- | Safe 'div' with default value returned on exception.
  --
  -- >>> divOr "d" "a" "b" :: SymInteger
  -- (ite (= b 0) d (div a b))
  divOr :: a -> a -> a -> a

  -- | Safe 'mod' with default value returned on exception.
  --
  -- >>> modOr "d" "a" "b" :: SymInteger
  -- (ite (= b 0) d (mod a b))
  modOr :: a -> a -> a -> a

  -- | Safe 'divMod' with default value returned on exception.
  --
  -- >>> divModOr ("d", "m") "a" "b" :: (SymInteger, SymInteger)
  -- ((ite (= b 0) d (div a b)),(ite (= b 0) m (mod a b)))
  divModOr :: (a, a) -> a -> a -> (a, a)

  -- | Safe 'quot' with default value returned on exception.
  quotOr :: a -> a -> a -> a

  -- | Safe 'rem' with default value returned on exception.
  remOr :: a -> a -> a -> a

  -- | Safe 'quotRem' with default value returned on exception.
  quotRemOr :: (a, a) -> a -> a -> (a, a)

-- | Safe 'div' with 0 returned on exception.
divOrZero :: (DivOr a, Num a) => a -> a -> a
divOrZero :: forall a. (DivOr a, Num a) => a -> a -> a
divOrZero a
l = a -> a -> a -> a
forall a. DivOr a => a -> a -> a -> a
divOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l) a
l
{-# INLINE divOrZero #-}

-- | Safe 'mod' with dividend returned on exception.
modOrDividend :: (DivOr a, Num a) => a -> a -> a
modOrDividend :: forall a. (DivOr a, Num a) => a -> a -> a
modOrDividend a
l = a -> a -> a -> a
forall a. DivOr a => a -> a -> a -> a
modOr a
l a
l
{-# INLINE modOrDividend #-}

-- | Safe 'quot' with 0 returned on exception.
quotOrZero :: (DivOr a, Num a) => a -> a -> a
quotOrZero :: forall a. (DivOr a, Num a) => a -> a -> a
quotOrZero a
l = a -> a -> a -> a
forall a. DivOr a => a -> a -> a -> a
quotOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l) a
l
{-# INLINE quotOrZero #-}

-- | Safe 'rem' with dividend returned on exception.
remOrDividend :: (DivOr a, Num a) => a -> a -> a
remOrDividend :: forall a. (DivOr a, Num a) => a -> a -> a
remOrDividend a
l = a -> a -> a -> a
forall a. DivOr a => a -> a -> a -> a
remOr a
l a
l
{-# INLINE remOrDividend #-}

-- | Safe 'divMod' with 0 returned on exception.
divModOrZeroDividend :: (DivOr a, Num a) => a -> a -> (a, a)
divModOrZeroDividend :: forall a. (DivOr a, Num a) => a -> a -> (a, a)
divModOrZeroDividend a
l = (a, a) -> a -> a -> (a, a)
forall a. DivOr a => (a, a) -> a -> a -> (a, a)
divModOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l, a
l) a
l
{-# INLINE divModOrZeroDividend #-}

-- | Safe 'quotRem' with 0 returned on exception.
quotRemOrZeroDividend :: (DivOr a, Num a) => a -> a -> (a, a)
quotRemOrZeroDividend :: forall a. (DivOr a, Num a) => a -> a -> (a, a)
quotRemOrZeroDividend a
l = (a, a) -> a -> a -> (a, a)
forall a. DivOr a => (a, a) -> a -> a -> (a, a)
quotRemOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l, a
l) a
l
{-# INLINE quotRemOrZeroDividend #-}

-- | Safe division with monadic error handling in multi-path
-- execution. These procedures throw an exception when the
-- divisor is zero. The result should be able to handle errors with
-- `MonadError`.
class (MonadError e m, TryMerge m, Mergeable a, DivOr a) => SafeDiv e a m where
  -- | Safe 'div' with monadic error handling in multi-path execution.
  --
  -- >>> safeDiv "a" "b" :: ExceptT ArithException Union SymInteger
  -- ExceptT {If (= b 0) (Left divide by zero) (Right (div a b))}
  safeDiv :: a -> a -> m a
  safeDiv a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> a
fst (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDiv e a m => a -> a -> m (a, a)
safeDivMod a
l a
r
  {-# INLINE safeDiv #-}

  -- | Safe 'mod' with monadic error handling in multi-path execution.
  --
  -- >>> safeMod "a" "b" :: ExceptT ArithException Union SymInteger
  -- ExceptT {If (= b 0) (Left divide by zero) (Right (mod a b))}
  safeMod :: a -> a -> m a
  safeMod a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> b
snd (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDiv e a m => a -> a -> m (a, a)
safeDivMod a
l a
r
  {-# INLINE safeMod #-}

  -- | Safe 'divMod' with monadic error handling in multi-path execution.
  --
  -- >>> safeDivMod "a" "b" :: ExceptT ArithException Union (SymInteger, SymInteger)
  -- ExceptT {If (= b 0) (Left divide by zero) (Right ((div a b),(mod a b)))}
  safeDivMod :: a -> a -> m (a, a)
  safeDivMod a
l a
r = do
    d <- a -> a -> m a
forall e a (m :: * -> *). SafeDiv e a m => a -> a -> m a
safeDiv a
l a
r
    m <- safeMod l r
    mrgSingle (d, m)
  {-# INLINE safeDivMod #-}

  -- | Safe 'quot' with monadic error handling in multi-path execution.
  safeQuot :: a -> a -> m a
  safeQuot a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> a
fst (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDiv e a m => a -> a -> m (a, a)
safeQuotRem a
l a
r
  {-# INLINE safeQuot #-}

  -- | Safe 'rem' with monadic error handling in multi-path execution.
  safeRem :: a -> a -> m a
  safeRem a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> b
snd (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDiv e a m => a -> a -> m (a, a)
safeQuotRem a
l a
r
  {-# INLINE safeRem #-}

  -- | Safe 'quotRem' with monadic error handling in multi-path execution.
  safeQuotRem :: a -> a -> m (a, a)
  safeQuotRem a
l a
r = do
    q <- a -> a -> m a
forall e a (m :: * -> *). SafeDiv e a m => a -> a -> m a
safeQuot a
l a
r
    m <- safeRem l r
    mrgSingle (q, m)
  {-# INLINE safeQuotRem #-}

  {-# MINIMAL
    ((safeDiv, safeMod) | safeDivMod),
    ((safeQuot, safeRem) | safeQuotRem)
    #-}