{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      :   Grisette.Internal.TH.Derivation.SerializeCommon
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.Derivation.SerializeCommon
  ( serializeConfig,
    serializeWithSerialConfig,
  )
where

import Control.Monad (zipWithM)
import Data.Bytes.Serial (Serial (deserialize, serialize), Serial1, Serial2)
import qualified Data.Map as M
import Data.Maybe (mapMaybe)
import qualified Data.Set as S
import GHC.Word (Word16, Word32, Word64, Word8)
import Grisette.Internal.TH.Derivation.UnaryOpCommon
  ( UnaryOpClassConfig
      ( UnaryOpClassConfig,
        unaryOpAllowExistential,
        unaryOpConfigs,
        unaryOpContextNames,
        unaryOpExtraVars,
        unaryOpInstanceNames,
        unaryOpInstanceTypeFromConfig
      ),
    UnaryOpConfig (UnaryOpConfig),
    UnaryOpFieldConfig
      ( UnaryOpFieldConfig,
        extraLiftedPatNames,
        extraPatNames,
        fieldCombineFun,
        fieldFunExp,
        fieldResFun
      ),
    UnaryOpFunConfig (genUnaryOpFun),
    defaultFieldFunExp,
    defaultUnaryOpInstanceTypeFromConfig,
  )
import Grisette.Internal.TH.Util (integerE)
import Language.Haskell.TH
  ( Body (NormalB),
    Clause (Clause),
    Dec (FunD),
    Lit (IntegerL),
    Match (Match),
    Name,
    Pat (LitP, VarP, WildP),
    Type (VarT),
    bindS,
    caseE,
    clause,
    conE,
    conT,
    doE,
    funD,
    match,
    mkName,
    newName,
    noBindS,
    normalB,
    sigP,
    varE,
    varP,
    wildP,
  )
import Language.Haskell.TH.Datatype
  ( ConstructorInfo (constructorFields, constructorName),
    TypeSubstitution (freeVariables),
    resolveTypeSynonyms,
  )

data UnaryOpSerializeWithSerialConfig = UnaryOpSerializeWithSerialConfig

instance UnaryOpFunConfig UnaryOpSerializeWithSerialConfig where
  genUnaryOpFun :: DeriveConfig
-> UnaryOpSerializeWithSerialConfig
-> [Name]
-> Int
-> [(Type, Type)]
-> [(Type, Type)]
-> [(Type, Type)]
-> (Name -> Bool)
-> [ConstructorInfo]
-> Q Dec
genUnaryOpFun DeriveConfig
_ UnaryOpSerializeWithSerialConfig
UnaryOpSerializeWithSerialConfig [Name]
funNames Int
n [(Type, Type)]
_ [(Type, Type)]
_ [(Type, Type)]
_ Name -> Bool
_ [ConstructorInfo]
_ =
    Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD ([Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n) [[Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|serialize|]) []]

data UnaryOpDeserializeWithSerialConfig = UnaryOpDeserializeWithSerialConfig

instance UnaryOpFunConfig UnaryOpDeserializeWithSerialConfig where
  genUnaryOpFun :: DeriveConfig
-> UnaryOpDeserializeWithSerialConfig
-> [Name]
-> Int
-> [(Type, Type)]
-> [(Type, Type)]
-> [(Type, Type)]
-> (Name -> Bool)
-> [ConstructorInfo]
-> Q Dec
genUnaryOpFun DeriveConfig
_ UnaryOpDeserializeWithSerialConfig
UnaryOpDeserializeWithSerialConfig [Name]
funNames Int
n [(Type, Type)]
_ [(Type, Type)]
_ [(Type, Type)]
_ Name -> Bool
_ [ConstructorInfo]
_ =
    Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD ([Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n) [[Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|deserialize|]) []]

-- | Configuration for deserialization function, generate the function from
-- scratch.
data UnaryOpDeserializeConfig = UnaryOpDeserializeConfig

getSerializedType :: Int -> Name
getSerializedType :: Int -> Name
getSerializedType Int
numConstructors =
  if
    | Int
numConstructors Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound @Word8) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -> ''Word8
    | Int
numConstructors Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound @Word16) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -> ''Word16
    | Int
numConstructors Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound @Word32) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -> ''Word32
    | Int
numConstructors Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound @Word64) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -> ''Word64
    | Bool
otherwise -> ''Integer

instance UnaryOpFunConfig UnaryOpDeserializeConfig where
  genUnaryOpFun :: DeriveConfig
-> UnaryOpDeserializeConfig
-> [Name]
-> Int
-> [(Type, Type)]
-> [(Type, Type)]
-> [(Type, Type)]
-> (Name -> Bool)
-> [ConstructorInfo]
-> Q Dec
genUnaryOpFun DeriveConfig
_ UnaryOpDeserializeConfig
UnaryOpDeserializeConfig [Name]
funNames Int
n [(Type, Type)]
_ [(Type, Type)]
_ [(Type, Type)]
_ Name -> Bool
_ [] = do
    let instanceFunName :: Name
instanceFunName = [Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
    Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD
      Name
instanceFunName
      [ [Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
          []
          (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|error "deserializing a type without constructors"|])
          []
      ]
  genUnaryOpFun
    DeriveConfig
_
    UnaryOpDeserializeConfig
UnaryOpDeserializeConfig
    [Name]
funNames
    Int
n
    [(Type, Type)]
_
    [(Type, Type)]
_
    [(Type, Type)]
argTypes
    Name -> Bool
_
    [ConstructorInfo]
constructors = do
      allFields <-
        (Type -> Q Type) -> [Type] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Q Type
resolveTypeSynonyms ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$
          (ConstructorInfo -> [Type]) -> [ConstructorInfo] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ConstructorInfo -> [Type]
constructorFields [ConstructorInfo]
constructors
      let usedArgs = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ [Type] -> [Name]
forall a. TypeSubstitution a => a -> [Name]
freeVariables [Type]
allFields
      args <-
        traverse
          ( \(Type
ty, Type
_) -> do
              case Type
ty of
                VarT Name
nm ->
                  if Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
nm Set Name
usedArgs
                    then do
                      pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
                      return (nm, Just pname)
                    else (Name, Maybe Name) -> Q (Name, Maybe Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ('undefined, Maybe Name
forall a. Maybe a
Nothing)
                Type
_ -> (Name, Maybe Name) -> Q (Name, Maybe Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ('undefined, Maybe Name
forall a. Maybe a
Nothing)
          )
          argTypes
      let argToFunPat =
            [(Name, Name)] -> Map Name Name
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Name)] -> Map Name Name)
-> [(Name, Name)] -> Map Name Name
forall a b. (a -> b) -> a -> b
$ ((Name, Maybe Name) -> Maybe (Name, Name))
-> [(Name, Maybe Name)] -> [(Name, Name)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Name
nm, Maybe Name
mpat) -> (Name -> (Name, Name)) -> Maybe Name -> Maybe (Name, Name)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name
nm,) Maybe Name
mpat) [(Name, Maybe Name)]
args
      let funPats = ((Name, Maybe Name) -> Pat) -> [(Name, Maybe Name)] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Pat -> (Name -> Pat) -> Maybe Name -> Pat
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Pat
WildP Name -> Pat
VarP (Maybe Name -> Pat)
-> ((Name, Maybe Name) -> Maybe Name) -> (Name, Maybe Name) -> Pat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Maybe Name) -> Maybe Name
forall a b. (a, b) -> b
snd) [(Name, Maybe Name)]
args
      let genAuxFunMatch Integer
conIdx ConstructorInfo
conInfo = do
            fields <- (Type -> Q Type) -> [Type] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Q Type
resolveTypeSynonyms ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$ ConstructorInfo -> [Type]
constructorFields ConstructorInfo
conInfo
            defaultFieldFunExps <-
              traverse
                ( defaultFieldFunExp
                    funNames
                    argToFunPat
                    M.empty
                )
                fields
            let conName = ConstructorInfo -> Name
constructorName ConstructorInfo
conInfo
            exp <-
              foldl
                (\Q Exp
exp Exp
fieldFun -> [|$Q Exp
exp <*> $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
fieldFun)|])
                [|return $(conE conName)|]
                defaultFieldFunExps
            return $ Match (LitP (IntegerL conIdx)) (NormalB exp) []
      auxMatches <- zipWithM genAuxFunMatch [0 ..] constructors
      auxFallbackMatch <- match wildP (normalB [|undefined|]) []
      let instanceFunName = [Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
      -- let auxFunName = mkName "go"
      let selName = String -> Name
mkName String
"sel"
      exp <-
        doE
          [ bindS
              ( sigP
                  (varP selName)
                  (conT (getSerializedType $ length constructors))
              )
              (varE (head funNames)),
            noBindS $
              caseE (varE selName) $
                return <$> auxMatches ++ [auxFallbackMatch]
          ]
      return $
        FunD
          instanceFunName
          [ Clause
              funPats
              (NormalB exp)
              []
          ]

-- | Configuration for serialization function, generate the function from
-- scratch.
serializeConfig :: [Name] -> [Name] -> [Name] -> UnaryOpClassConfig
serializeConfig :: [Name] -> [Name] -> [Name] -> UnaryOpClassConfig
serializeConfig [Name]
instanceNames [Name]
serializeFunNames [Name]
deserializeFunNames =
  UnaryOpClassConfig
    { unaryOpConfigs :: [UnaryOpConfig]
unaryOpConfigs =
        [ UnaryOpFieldConfig -> [Name] -> UnaryOpConfig
forall config.
UnaryOpFunConfig config =>
config -> [Name] -> UnaryOpConfig
UnaryOpConfig
            UnaryOpFieldConfig
              { extraPatNames :: [String]
extraPatNames = [],
                extraLiftedPatNames :: Int -> [String]
extraLiftedPatNames = [String] -> Int -> [String]
forall a b. a -> b -> a
const [],
                fieldCombineFun :: Int
-> Int
-> ConstructorVariant
-> Name
-> [Exp]
-> [Exp]
-> Q (Exp, [Bool])
fieldCombineFun = \Int
totalConNumber Int
conIdx ConstructorVariant
_ Name
_ [] [Exp]
exp -> do
                  let ty :: Name
ty = Int -> Name
getSerializedType Int
totalConNumber
                  r <-
                    (Q Exp -> Exp -> Q Exp) -> Q Exp -> [Exp] -> Q Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                      (\Q Exp
r Exp
exp -> [|$Q Exp
r >> $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
exp)|])
                      ( [|
                          $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> Q Exp) -> Name -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Name] -> Name
forall a. HasCallStack => [a] -> a
head [Name]
serializeFunNames)
                            ($(Int -> Q Exp
forall a. Integral a => a -> Q Exp
integerE Int
conIdx) :: $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
ty))
                          |]
                      )
                      [Exp]
exp
                  return (r, [True]),
                fieldResFun :: ConstructorVariant
-> Name -> [Exp] -> Int -> Exp -> Exp -> Q (Exp, [Bool])
fieldResFun = \ConstructorVariant
_ Name
_ [Exp]
_ Int
_ Exp
fieldPat Exp
fieldFun -> do
                  r <- [|$(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
fieldFun) $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
fieldPat)|]
                  return (r, [True]),
                fieldFunExp :: FieldFunExp
fieldFunExp = [Name] -> FieldFunExp
defaultFieldFunExp [Name]
serializeFunNames
              }
            [Name]
serializeFunNames,
          UnaryOpDeserializeConfig -> [Name] -> UnaryOpConfig
forall config.
UnaryOpFunConfig config =>
config -> [Name] -> UnaryOpConfig
UnaryOpConfig
            UnaryOpDeserializeConfig
UnaryOpDeserializeConfig
            [Name]
deserializeFunNames
        ],
      unaryOpInstanceNames :: [Name]
unaryOpInstanceNames = [Name]
instanceNames,
      unaryOpExtraVars :: DeriveConfig -> Q [(Type, Type)]
unaryOpExtraVars = Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)]
forall a b. a -> b -> a
const (Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)])
-> Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)]
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> Q [(Type, Type)]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [],
      unaryOpInstanceTypeFromConfig :: DeriveConfig -> [(Type, Type)] -> [(Type, Type)] -> Name -> Q Type
unaryOpInstanceTypeFromConfig = DeriveConfig -> [(Type, Type)] -> [(Type, Type)] -> Name -> Q Type
defaultUnaryOpInstanceTypeFromConfig,
      unaryOpAllowExistential :: Bool
unaryOpAllowExistential = Bool
False,
      unaryOpContextNames :: Maybe [Name]
unaryOpContextNames = Maybe [Name]
forall a. Maybe a
Nothing
    }

-- | Configuration for serialization function, reuse the 'Serial' instance.
serializeWithSerialConfig :: [Name] -> [Name] -> [Name] -> UnaryOpClassConfig
serializeWithSerialConfig :: [Name] -> [Name] -> [Name] -> UnaryOpClassConfig
serializeWithSerialConfig [Name]
instanceNames [Name]
serializeFunNames [Name]
deserializeFunNames =
  UnaryOpClassConfig
    { unaryOpConfigs :: [UnaryOpConfig]
unaryOpConfigs =
        [ UnaryOpSerializeWithSerialConfig -> [Name] -> UnaryOpConfig
forall config.
UnaryOpFunConfig config =>
config -> [Name] -> UnaryOpConfig
UnaryOpConfig UnaryOpSerializeWithSerialConfig
UnaryOpSerializeWithSerialConfig [Name]
serializeFunNames,
          UnaryOpDeserializeWithSerialConfig -> [Name] -> UnaryOpConfig
forall config.
UnaryOpFunConfig config =>
config -> [Name] -> UnaryOpConfig
UnaryOpConfig UnaryOpDeserializeWithSerialConfig
UnaryOpDeserializeWithSerialConfig [Name]
deserializeFunNames
        ],
      unaryOpInstanceNames :: [Name]
unaryOpInstanceNames = [Name]
instanceNames,
      unaryOpExtraVars :: DeriveConfig -> Q [(Type, Type)]
unaryOpExtraVars = Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)]
forall a b. a -> b -> a
const (Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)])
-> Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)]
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> Q [(Type, Type)]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [],
      unaryOpInstanceTypeFromConfig :: DeriveConfig -> [(Type, Type)] -> [(Type, Type)] -> Name -> Q Type
unaryOpInstanceTypeFromConfig = DeriveConfig -> [(Type, Type)] -> [(Type, Type)] -> Name -> Q Type
defaultUnaryOpInstanceTypeFromConfig,
      unaryOpAllowExistential :: Bool
unaryOpAllowExistential = Bool
False,
      unaryOpContextNames :: Maybe [Name]
unaryOpContextNames =
        [Name] -> Maybe [Name]
forall a. a -> Maybe a
Just ([Name] -> Maybe [Name]) -> [Name] -> Maybe [Name]
forall a b. (a -> b) -> a -> b
$ Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take ([Name] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
instanceNames) [''Serial, ''Serial1, ''Serial2]
    }