-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

module MLIR.AST.Serialize (
  ValueMapping,
  BlockMapping,
  ValueAndBlockMapping,
  FromAST(..),
  packFromAST, packArray, unpackArray) where

import Foreign.Ptr
import Foreign.Storable
import Foreign.Marshal.Array
import Control.Monad.IO.Class
import Control.Monad.Trans.Cont
import qualified Language.C.Inline as C
import qualified Data.ByteString as BS
import qualified Data.Map.Strict as M

import qualified MLIR.Native     as Native
import qualified MLIR.Native.FFI as Native

type Name = BS.ByteString

type ValueMapping = M.Map Name Native.Value
type BlockMapping = M.Map Name Native.Block
type ValueAndBlockMapping = (ValueMapping, BlockMapping)

class FromAST ast native | ast -> native where
  fromAST :: Native.Context -> ValueAndBlockMapping -> ast -> IO native

packFromAST :: (FromAST ast native, Storable native)
            => Native.Context -> ValueAndBlockMapping
            -> [ast] -> ContT r IO (C.CIntPtr, Ptr native)
packFromAST :: Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [ast]
asts = [native] -> ContT r IO (CIntPtr, Ptr native)
forall a r. Storable a => [a] -> ContT r IO (CIntPtr, Ptr a)
packArray ([native] -> ContT r IO (CIntPtr, Ptr native))
-> ContT r IO [native] -> ContT r IO (CIntPtr, Ptr native)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO [native] -> ContT r IO [native]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO ((ast -> IO native) -> [ast] -> IO [native]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Context -> ValueAndBlockMapping -> ast -> IO native
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env) [ast]
asts)

-- TODO(apaszke): Unify this with packing utilities from ExecutionEngine?
packArray :: Storable a => [a] -> ContT r IO (C.CIntPtr, Ptr a)
packArray :: [a] -> ContT r IO (CIntPtr, Ptr a)
packArray [a]
xs = do
  let arrSize :: Int
arrSize = ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs)
  Ptr a
ptr <- ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a))
-> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr a -> IO r) -> IO r
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
arrSize
  IO () -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT r IO ()) -> IO () -> ContT r IO ()
forall a b. (a -> b) -> a -> b
$ ((Int, a) -> IO ()) -> [(Int, a)] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Int -> a -> IO ()) -> (Int, a) -> IO ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Int -> a -> IO ()) -> (Int, a) -> IO ())
-> (Int -> a -> IO ()) -> (Int, a) -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
ptr) ([(Int, a)] -> IO ()) -> [(Int, a)] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [a]
xs
  (CIntPtr, Ptr a) -> ContT r IO (CIntPtr, Ptr a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
arrSize, Ptr a
ptr)

unpackArray :: Storable a => C.CIntPtr -> Ptr a -> IO [a]
unpackArray :: CIntPtr -> Ptr a -> IO [a]
unpackArray CIntPtr
size Ptr a
arrPtr = (Int -> IO a) -> [Int] -> IO [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
arrPtr) [Int
0..CIntPtr -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CIntPtr
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]