-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

module MLIR.AST.Dialect.LLVM (
  -- * Types
    Type(..)
  , pattern Array
  , pattern Void
  , pattern LiteralStruct
  -- * Operations
  , module MLIR.AST.Dialect.Generated.LLVM
  ) where

import MLIR.AST.Dialect.Generated.LLVM

import Data.Typeable
import Control.Monad.IO.Class
import Control.Monad.Trans.Cont
import qualified Language.C.Inline as C

import qualified MLIR.AST           as AST
import qualified MLIR.AST.Serialize as AST
import qualified MLIR.Native        as Native
import qualified MLIR.Native.FFI    as Native

C.context $ C.baseCtx <> Native.mlirCtx
C.include "mlir-c/Dialect/LLVM.h"

data Type = ArrayType Int AST.Type
          | VoidType
          | LiteralStructType [AST.Type]
          -- TODO(apaszke): Structures, functions, vectors, etc.
          deriving Type -> Type -> Bool
(Type -> Type -> Bool) -> (Type -> Type -> Bool) -> Eq Type
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Type -> Type -> Bool
$c/= :: Type -> Type -> Bool
== :: Type -> Type -> Bool
$c== :: Type -> Type -> Bool
Eq

instance AST.FromAST Type Native.Type where
  fromAST :: Context -> ValueAndBlockMapping -> Type -> IO Type
fromAST Context
ctx ValueAndBlockMapping
env Type
ty = case Type
ty of
    ArrayType Int
size Type
t -> do
      Type
nt <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
AST.fromAST Context
ctx ValueAndBlockMapping
env Type
t
      let nsize :: CUInt
nsize = Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size
      [C.exp| MlirType { mlirLLVMArrayTypeGet($(MlirType nt), $(unsigned int nsize)) } |]
    Type
VoidType -> [C.exp| MlirType { mlirLLVMVoidTypeGet($(MlirContext ctx)) } |]
    LiteralStructType [Type]
fields -> ContT Type IO Type -> IO Type
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Type IO Type -> IO Type) -> ContT Type IO Type -> IO Type
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
numFields, Ptr Type
nativeFields) <- Context
-> ValueAndBlockMapping
-> [Type]
-> ContT Type IO (CIntPtr, Ptr Type)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
AST.packFromAST Context
ctx ValueAndBlockMapping
env [Type]
fields
      IO Type -> ContT Type IO Type
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Type -> ContT Type IO Type) -> IO Type -> ContT Type IO Type
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirType {
        mlirLLVMStructTypeLiteralGet($(MlirContext ctx), $(intptr_t numFields),
                                     $(MlirType* nativeFields), false)
      } |]


castLLVMType :: AST.Type -> Maybe Type
castLLVMType :: Type -> Maybe Type
castLLVMType Type
ty = case Type
ty of
  AST.DialectType t
dty -> t -> Maybe Type
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast t
dty
  Type
_                   -> Maybe Type
forall a. Maybe a
Nothing

pattern Array :: Int -> AST.Type -> AST.Type
pattern $bArray :: Int -> Type -> Type
$mArray :: forall r. Type -> (Int -> Type -> r) -> (Void# -> r) -> r
Array n t <- (castLLVMType -> Just (ArrayType n t))
  where Array Int
n Type
t = Type -> Type
forall t. (Typeable t, Eq t, FromAST t Type) => t -> Type
AST.DialectType (Int -> Type -> Type
ArrayType Int
n Type
t)

pattern Void :: AST.Type
pattern $bVoid :: Type
$mVoid :: forall r. Type -> (Void# -> r) -> (Void# -> r) -> r
Void <- (castLLVMType -> Just VoidType)
  where Void = Type -> Type
forall t. (Typeable t, Eq t, FromAST t Type) => t -> Type
AST.DialectType Type
VoidType

pattern LiteralStruct :: [AST.Type] -> AST.Type
pattern $bLiteralStruct :: [Type] -> Type
$mLiteralStruct :: forall r. Type -> ([Type] -> r) -> (Void# -> r) -> r
LiteralStruct fields <- (castLLVMType -> Just (LiteralStructType fields))
  where LiteralStruct [Type]
fields = Type -> Type
forall t. (Typeable t, Eq t, FromAST t Type) => t -> Type
AST.DialectType ([Type] -> Type
LiteralStructType [Type]
fields)