-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

module MLIR.AST where

import qualified Data.ByteString as BS

import Data.Typeable
import Data.Int
import Data.Word
import Data.Coerce
import Data.Ix
import Data.Array.IArray
import Foreign.Ptr
import Foreign.Marshal.Alloc
import Foreign.Marshal.Array
import qualified Language.C.Inline as C
import qualified Data.ByteString.Unsafe as BS
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Cont
import qualified Data.Map.Strict as M

import qualified MLIR.Native as Native
import qualified MLIR.Native.FFI as Native
import qualified MLIR.AST.Dialect.Affine as Affine
import MLIR.AST.Serialize
import MLIR.AST.IStorableArray

type Name = BS.ByteString
type UInt = Word

data Signedness = Signed | Unsigned | Signless
                  deriving Signedness -> Signedness -> Bool
(Signedness -> Signedness -> Bool)
-> (Signedness -> Signedness -> Bool) -> Eq Signedness
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Signedness -> Signedness -> Bool
$c/= :: Signedness -> Signedness -> Bool
== :: Signedness -> Signedness -> Bool
$c== :: Signedness -> Signedness -> Bool
Eq
data Type =
  -- Builtin types
  -- See <https://mlir.llvm.org/docs/Dialects/Builtin/#types>
    BFloat16Type
  | Float16Type
  | Float32Type
  | Float64Type
  | Float80Type
  | Float128Type
  | ComplexType Type
  | IndexType
  | IntegerType Signedness UInt
  | TupleType [Type]
  | NoneType
  | FunctionType [Type] [Type]
  | MemRefType { Type -> [Maybe Int]
memrefTypeShape :: [Maybe Int]
               , Type -> Type
memrefTypeElement :: Type
               , Type -> Maybe Attribute
memrefTypeLayout :: Maybe Attribute
               , Type -> Maybe Attribute
memrefTypeMemorySpace :: Maybe Attribute }
  | RankedTensorType { Type -> [Maybe Int]
rankedTensorTypeShape :: [Maybe Int]
                     , Type -> Type
rankedTensorTypeElement :: Type
                     , Type -> Maybe Attribute
rankedTensorTypeEncoding :: Maybe Attribute }
  | VectorType { Type -> [Int]
vectorTypeShape :: [Int]
               , Type -> Type
vectorTypeElement :: Type }
  | UnrankedMemRefType { Type -> Type
unrankedMemrefTypeElement :: Type
                       , Type -> Attribute
unrankedMemrefTypeMemorySpace :: Attribute }
  | UnrankedTensorType { Type -> Type
unrankedTensorTypeElement :: Type }
  | OpaqueType { Type -> Name
opaqueTypeNamespace :: Name
               , Type -> Name
opaqueTypeData :: BS.ByteString }
  | forall t. (Typeable t, Eq t, FromAST t Native.Type) => DialectType t
  -- GHC cannot derive Eq due to the existential case, so we implement Eq below
  -- deriving Eq

instance Eq Type where
  Type
a == :: Type -> Type -> Bool
== Type
b = case (Type
a, Type
b) of
    (Type
BFloat16Type      , Type
BFloat16Type      ) -> Bool
True
    (Type
Float16Type       , Type
Float16Type       ) -> Bool
True
    (Type
Float32Type       , Type
Float32Type       ) -> Bool
True
    (Type
Float64Type       , Type
Float64Type       ) -> Bool
True
    (Type
Float80Type       , Type
Float80Type       ) -> Bool
True
    (Type
Float128Type      , Type
Float128Type      ) -> Bool
True
    (ComplexType Type
a1    , ComplexType Type
b1    ) -> Type
a1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
b1
    (Type
IndexType         , Type
IndexType         ) -> Bool
True
    (IntegerType Signedness
a1 UInt
a2 , IntegerType Signedness
b1 UInt
b2 ) -> (Signedness
a1, UInt
a2) (Signedness, UInt) -> (Signedness, UInt) -> Bool
forall a. Eq a => a -> a -> Bool
== (Signedness
b1, UInt
b2)
    (TupleType [Type]
a1      , TupleType [Type]
b1      ) -> [Type]
a1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
b1
    (Type
NoneType          , Type
NoneType          ) -> Bool
True
    (FunctionType [Type]
a1 [Type]
a2, FunctionType [Type]
b1 [Type]
b2) -> ([Type]
a1, [Type]
a2) ([Type], [Type]) -> ([Type], [Type]) -> Bool
forall a. Eq a => a -> a -> Bool
== ([Type]
b1, [Type]
b2)
    (MemRefType [Maybe Int]
a1 Type
a2 Maybe Attribute
a3 Maybe Attribute
a4   , MemRefType [Maybe Int]
b1 Type
b2 Maybe Attribute
b3 Maybe Attribute
b4   ) -> ([Maybe Int]
a1, Type
a2, Maybe Attribute
a3, Maybe Attribute
a4) ([Maybe Int], Type, Maybe Attribute, Maybe Attribute)
-> ([Maybe Int], Type, Maybe Attribute, Maybe Attribute) -> Bool
forall a. Eq a => a -> a -> Bool
== ([Maybe Int]
b1, Type
b2, Maybe Attribute
b3, Maybe Attribute
b4)
    (RankedTensorType [Maybe Int]
a1 Type
a2 Maybe Attribute
a3, RankedTensorType [Maybe Int]
b1 Type
b2 Maybe Attribute
b3) -> ([Maybe Int]
a1, Type
a2, Maybe Attribute
a3    ) ([Maybe Int], Type, Maybe Attribute)
-> ([Maybe Int], Type, Maybe Attribute) -> Bool
forall a. Eq a => a -> a -> Bool
== ([Maybe Int]
b1, Type
b2, Maybe Attribute
b3    )
    (VectorType [Int]
a1 Type
a2  , VectorType [Int]
b1 Type
b2  )               -> ([Int]
a1, Type
a2) ([Int], Type) -> ([Int], Type) -> Bool
forall a. Eq a => a -> a -> Bool
== ([Int]
b1, Type
b2)
    (UnrankedMemRefType Type
a1 Attribute
a2, UnrankedMemRefType Type
b1 Attribute
b2  ) -> (Type
a1, Attribute
a2) (Type, Attribute) -> (Type, Attribute) -> Bool
forall a. Eq a => a -> a -> Bool
== (Type
b1, Attribute
b2)
    (UnrankedTensorType Type
a1   , UnrankedTensorType Type
b1     ) -> (Type
a1    ) Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== (Type
b1    )
    (OpaqueType Name
a1 Name
a2  , OpaqueType Name
b1 Name
b2  )               -> (Name
a1, Name
a2) (Name, Name) -> (Name, Name) -> Bool
forall a. Eq a => a -> a -> Bool
== (Name
b1, Name
b2)
    (DialectType t
a1    , DialectType t
b1    ) -> case t -> Maybe t
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast t
a1 of
      Just t
a1' -> t
a1' t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
b1
      Maybe t
Nothing  -> Bool
False
    (Type, Type)
_ -> Bool
False

data Location =
    UnknownLocation
  | FileLocation { Location -> Name
locPath :: BS.ByteString, Location -> UInt
locLine :: UInt, Location -> UInt
locColumn :: UInt }
  | NameLocation { Location -> Name
locName :: BS.ByteString, Location -> Location
locChild :: Location }
  | FusedLocation { Location -> [Location]
locLocations :: [Location], Location -> Maybe Attribute
locMetadata :: Maybe Attribute }
  -- TODO(jpienaar): Add support C API side and implement these
  | CallSiteLocation
  | OpaqueLocation

data Binding = Bind [Name] Operation

pattern Do :: Operation -> Binding
pattern $bDo :: Operation -> Binding
$mDo :: forall r. Binding -> (Operation -> r) -> (Void# -> r) -> r
Do op = Bind [] op

pattern (:=) :: Name -> Operation -> Binding
pattern $b:= :: Name -> Operation -> Binding
$m:= :: forall r. Binding -> (Name -> Operation -> r) -> (Void# -> r) -> r
(:=) name op = Bind [name] op

pattern (::=) :: [Name] -> Operation -> Binding
pattern $b::= :: [Name] -> Operation -> Binding
$m::= :: forall r.
Binding -> ([Name] -> Operation -> r) -> (Void# -> r) -> r
(::=) names op = Bind names op

data Block = Block {
    Block -> Name
blockName :: Name
  , Block -> [(Name, Type)]
blockArgs :: [(Name, Type)]
  , Block -> [Binding]
blockBody :: [Binding]
  }

data Region = Region [Block]

data Attribute =
    ArrayAttr         [Attribute]
  | DictionaryAttr    (M.Map Name Attribute)
  | FloatAttr         Type Double
  | IntegerAttr       Type Int
  | BoolAttr          Bool
  | StringAttr        BS.ByteString
  | TypeAttr          Type
  | AffineMapAttr     Affine.Map
  | UnitAttr
  | DenseArrayAttr    DenseElements
  | DenseElementsAttr Type DenseElements
  -- Represents Attribute textually represented.
  | AsmTextAttr BS.ByteString
  | forall t. (Typeable t, Eq t, Show t, FromAST t Native.Attribute) => DialectAttr t
  -- GHC cannot derive Eq due to the existential case, so we implement Eq below
  -- deriving Eq
  -- TODO(apaszke): (Flat) SymbolRef, IntegerSet, Opaque

instance Eq Attribute where
  Attribute
a == :: Attribute -> Attribute -> Bool
== Attribute
b = case (Attribute
a, Attribute
b) of
    (ArrayAttr [Attribute]
a1, ArrayAttr [Attribute]
b1) -> [Attribute]
a1 [Attribute] -> [Attribute] -> Bool
forall a. Eq a => a -> a -> Bool
== [Attribute]
b1
    (DictionaryAttr Map Name Attribute
a1, DictionaryAttr Map Name Attribute
b1) -> Map Name Attribute
a1 Map Name Attribute -> Map Name Attribute -> Bool
forall a. Eq a => a -> a -> Bool
== Map Name Attribute
b1
    (FloatAttr Type
a1 Double
a2, FloatAttr Type
b1 Double
b2) -> (Type
a1, Double
a2) (Type, Double) -> (Type, Double) -> Bool
forall a. Eq a => a -> a -> Bool
== (Type
b1, Double
b2)
    (IntegerAttr Type
a1 Int
a2, IntegerAttr Type
b1 Int
b2) -> (Type
a1, Int
a2) (Type, Int) -> (Type, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (Type
b1, Int
b2)
    (BoolAttr Bool
a1, BoolAttr Bool
b1) -> Bool
a1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
b1
    (StringAttr Name
a1, StringAttr Name
b1) -> Name
a1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
b1
    (TypeAttr Type
a1, TypeAttr Type
b1) -> Type
a1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
b1
    (AffineMapAttr Map
a1, AffineMapAttr Map
b1) -> Map
a1 Map -> Map -> Bool
forall a. Eq a => a -> a -> Bool
== Map
b1
    (Attribute
UnitAttr, Attribute
UnitAttr) -> Bool
True
    (DenseArrayAttr DenseElements
a1, DenseArrayAttr DenseElements
b1) -> DenseElements
a1 DenseElements -> DenseElements -> Bool
forall a. Eq a => a -> a -> Bool
== DenseElements
b1
    (DenseElementsAttr Type
a1 DenseElements
a2, DenseElementsAttr Type
b1 DenseElements
b2) -> (Type
a1, DenseElements
a2) (Type, DenseElements) -> (Type, DenseElements) -> Bool
forall a. Eq a => a -> a -> Bool
== (Type
b1, DenseElements
b2)
    (AsmTextAttr Name
a1, AsmTextAttr Name
b1) -> Name
a1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
b1
    (DialectAttr t
a1, DialectAttr t
b1) -> case t -> Maybe t
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast t
a1 of
      Just t
a1' -> t
a1' t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
b1
      Maybe t
Nothing  -> Bool
False
    (Attribute, Attribute)
_ -> Bool
False

data DenseElements
  = forall i. (Show i, Ix i) => DenseUInt8  (IStorableArray i Word8 )
  | forall i. (Show i, Ix i) => DenseInt8   (IStorableArray i Int8  )
  | forall i. (Show i, Ix i) => DenseUInt32 (IStorableArray i Word32)
  | forall i. (Show i, Ix i) => DenseInt32  (IStorableArray i Int32 )
  | forall i. (Show i, Ix i) => DenseUInt64 (IStorableArray i Word64)
  | forall i. (Show i, Ix i) => DenseInt64  (IStorableArray i Int64 )
  | forall i. (Show i, Ix i) => DenseFloat  (IStorableArray i Float )
  | forall i. (Show i, Ix i) => DenseDouble (IStorableArray i Double)

-- Note that we use a relaxed notion of equality, where the indices don't matter!
-- TODO: Use a faster comparison? We could really just use memcmp here.
instance Eq DenseElements where
  DenseElements
a == :: DenseElements -> DenseElements -> Bool
== DenseElements
b = case (DenseElements
a, DenseElements
b) of
    (DenseUInt8  IStorableArray i Word8
da, DenseUInt8  IStorableArray i Word8
db) -> IStorableArray i Word8 -> [Word8]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Word8
da [Word8] -> [Word8] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Word8 -> [Word8]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Word8
db
    (DenseInt8   IStorableArray i Int8
da, DenseInt8   IStorableArray i Int8
db) -> IStorableArray i Int8 -> [Int8]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Int8
da [Int8] -> [Int8] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Int8 -> [Int8]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Int8
db
    (DenseUInt32 IStorableArray i Word32
da, DenseUInt32 IStorableArray i Word32
db) -> IStorableArray i Word32 -> [Word32]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Word32
da [Word32] -> [Word32] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Word32 -> [Word32]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Word32
db
    (DenseInt32  IStorableArray i Int32
da, DenseInt32  IStorableArray i Int32
db) -> IStorableArray i Int32 -> [Int32]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Int32
da [Int32] -> [Int32] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Int32 -> [Int32]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Int32
db
    (DenseUInt64 IStorableArray i Word64
da, DenseUInt64 IStorableArray i Word64
db) -> IStorableArray i Word64 -> [Word64]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Word64
da [Word64] -> [Word64] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Word64 -> [Word64]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Word64
db
    (DenseInt64  IStorableArray i Int64
da, DenseInt64  IStorableArray i Int64
db) -> IStorableArray i Int64 -> [Int64]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Int64
da [Int64] -> [Int64] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Int64 -> [Int64]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Int64
db
    (DenseFloat  IStorableArray i Float
da, DenseFloat  IStorableArray i Float
db) -> IStorableArray i Float -> [Float]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Float
da [Float] -> [Float] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Float -> [Float]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Float
db
    (DenseDouble IStorableArray i Double
da, DenseDouble IStorableArray i Double
db) -> IStorableArray i Double -> [Double]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Double
da [Double] -> [Double] -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i Double -> [Double]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems IStorableArray i Double
db
    (DenseElements, DenseElements)
_ -> Bool
False

data ResultTypes = Explicit [Type] | Inferred

type NamedAttributes = M.Map Name Attribute

data AbstractOperation operand = Operation {
    AbstractOperation operand -> Name
opName :: Name,
    AbstractOperation operand -> Location
opLocation :: Location,
    AbstractOperation operand -> ResultTypes
opResultTypes :: ResultTypes,
    AbstractOperation operand -> [operand]
opOperands :: [operand],
    AbstractOperation operand -> [Region]
opRegions :: [Region],
    AbstractOperation operand -> [Name]
opSuccessors :: [Name],
    AbstractOperation operand -> Map Name Attribute
opAttributes :: M.Map Name Attribute
  }
type Operation = AbstractOperation Name

--------------------------------------------------------------------------------
-- Builtin operations

pattern NoAttrs :: M.Map Name Attribute
pattern $bNoAttrs :: Map Name Attribute
$mNoAttrs :: forall r. Map Name Attribute -> (Void# -> r) -> (Void# -> r) -> r
NoAttrs <- _  -- Accept any attributes
  where NoAttrs  = Map Name Attribute
forall k a. Map k a
M.empty

namedAttribute :: Name -> Attribute -> NamedAttributes
namedAttribute :: Name -> Attribute -> Map Name Attribute
namedAttribute Name
name Attribute
value = Name -> Attribute -> Map Name Attribute
forall k a. k -> a -> Map k a
M.singleton Name
name Attribute
value

pattern ModuleOp :: Block -> Operation
pattern $bModuleOp :: Block -> Operation
$mModuleOp :: forall r. Operation -> (Block -> r) -> (Void# -> r) -> r
ModuleOp body = Operation
  { opName = "builtin.module"
  , opLocation = UnknownLocation
  , opResultTypes = Explicit []
  , opOperands = []
  , opRegions = [Region [body]]
  , opSuccessors = []
  , opAttributes = NoAttrs
  }

pattern FuncAttrs :: Name -> Type -> M.Map Name Attribute
pattern $bFuncAttrs :: Name -> Type -> Map Name Attribute
$mFuncAttrs :: forall r.
Map Name Attribute -> (Name -> Type -> r) -> (Void# -> r) -> r
FuncAttrs name ty <-
  ((\d -> (M.lookup "sym_name" d, M.lookup "type" d)) ->
   (Just (StringAttr name), Just (TypeAttr ty)))
  where FuncAttrs Name
name Type
ty = [(Name, Attribute)] -> Map Name Attribute
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name
"sym_name", Name -> Attribute
StringAttr Name
name),
                                        (Name
"function_type", Type -> Attribute
TypeAttr Type
ty)]

pattern FuncOp :: Location -> Name -> Type -> Region -> Operation
pattern $bFuncOp :: Location -> Name -> Type -> Region -> Operation
$mFuncOp :: forall r.
Operation
-> (Location -> Name -> Type -> Region -> r) -> (Void# -> r) -> r
FuncOp loc name ty body = Operation
  { opName = "func.func"
  , opLocation = loc
  , opResultTypes = Explicit []
  , opOperands = []
  , opRegions = [body]
  , opSuccessors = []
  , opAttributes = FuncAttrs name ty
  }

--------------------------------------------------------------------------------
-- AST -> Native translation

C.context $ C.baseCtx <> Native.mlirCtx

C.include "<stdalign.h>"
C.include "mlir-c/IR.h"
C.include "mlir-c/BuiltinTypes.h"
C.include "mlir-c/BuiltinAttributes.h"

instance FromAST Location Native.Location where
  fromAST :: Context -> ValueAndBlockMapping -> Location -> IO Location
fromAST Context
ctx ValueAndBlockMapping
env Location
loc = case Location
loc of
    Location
UnknownLocation -> Context -> IO Location
Native.getUnknownLocation Context
ctx
    FileLocation Name
file UInt
line UInt
col -> do
      Name -> (StringRef -> IO Location) -> IO Location
forall a. Name -> (StringRef -> IO a) -> IO a
Native.withStringRef Name
file \StringRef
fileStrRef ->
        Context -> StringRef -> CUInt -> CUInt -> IO Location
Native.getFileLineColLocation Context
ctx StringRef
fileStrRef CUInt
cline CUInt
ccol
          where cline :: CUInt
cline = UInt -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral UInt
line
                ccol :: CUInt
ccol = UInt -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral UInt
col
    FusedLocation [Location]
locLocations Maybe Attribute
locMetadata -> do
      Attribute
metadata <- case Maybe Attribute
locMetadata of
        -- TODO: Consider factoring out to convenience function.
        Maybe Attribute
Nothing -> [C.exp| MlirAttribute { mlirAttributeGetNull() } |]
        Just Attribute
l -> Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Attribute
l
      ContT Location IO Location -> IO Location
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Location IO Location -> IO Location)
-> ContT Location IO Location -> IO Location
forall a b. (a -> b) -> a -> b
$ do
        (CIntPtr
numLocs, Ptr Location
locs) <- Context
-> ValueAndBlockMapping
-> [Location]
-> ContT Location IO (CIntPtr, Ptr Location)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Location]
locLocations
        IO Location -> ContT Location IO Location
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Location -> ContT Location IO Location)
-> IO Location -> ContT Location IO Location
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirLocation {
          mlirLocationFusedGet($(MlirContext ctx),
            $(intptr_t numLocs), $(MlirLocation* locs),
            $(MlirAttribute metadata))
        } |]
    NameLocation Name
name Location
childLoc -> do
      Name -> (StringRef -> IO Location) -> IO Location
forall a. Name -> (StringRef -> IO a) -> IO a
Native.withStringRef Name
name \StringRef
nameStrRef -> do
        Location
nativeChildLoc <- Context -> ValueAndBlockMapping -> Location -> IO Location
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Location
childLoc
        Context -> StringRef -> Location -> IO Location
Native.getNameLocation Context
ctx StringRef
nameStrRef Location
nativeChildLoc
    -- TODO(jpienaar): Fix
    Location
_ -> [Char] -> IO Location
forall a. HasCallStack => [Char] -> a
error [Char]
"Unimplemented Location case"

instance FromAST Type Native.Type where
  fromAST :: Context -> ValueAndBlockMapping -> Type -> IO Type
fromAST Context
ctx ValueAndBlockMapping
env Type
ty = case Type
ty of
    Type
BFloat16Type  -> [C.exp| MlirType { mlirBF16TypeGet($(MlirContext ctx)) } |]
    Type
Float16Type   -> [C.exp| MlirType { mlirF16TypeGet($(MlirContext ctx)) } |]
    Type
Float32Type   -> [C.exp| MlirType { mlirF32TypeGet($(MlirContext ctx)) } |]
    Type
Float64Type   -> [C.exp| MlirType { mlirF64TypeGet($(MlirContext ctx)) } |]
    Type
Float80Type   -> [Char] -> IO Type
forall a. HasCallStack => [Char] -> a
error [Char]
"Float80Type missing in the MLIR C API!"
    Type
Float128Type  -> [Char] -> IO Type
forall a. HasCallStack => [Char] -> a
error [Char]
"Float128Type missing in the MLIR C API!"
    ComplexType Type
e -> do
      Type
ne <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
e
      [C.exp| MlirType { mlirComplexTypeGet($(MlirType ne)) } |]
    FunctionType [Type]
args [Type]
rets -> ContT Type IO Type -> IO Type
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Type IO Type -> IO Type) -> ContT Type IO Type -> IO Type
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
numArgs, Ptr Type
nativeArgs) <- Context
-> ValueAndBlockMapping
-> [Type]
-> ContT Type IO (CIntPtr, Ptr Type)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Type]
args
      (CIntPtr
numRets, Ptr Type
nativeRets) <- Context
-> ValueAndBlockMapping
-> [Type]
-> ContT Type IO (CIntPtr, Ptr Type)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Type]
rets
      IO Type -> ContT Type IO Type
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Type -> ContT Type IO Type) -> IO Type -> ContT Type IO Type
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirType {
        mlirFunctionTypeGet($(MlirContext ctx),
                            $(intptr_t numArgs), $(MlirType* nativeArgs),
                            $(intptr_t numRets), $(MlirType* nativeRets))
      } |]
    Type
IndexType -> [C.exp| MlirType { mlirIndexTypeGet($(MlirContext ctx)) } |]
    IntegerType Signedness
signedness UInt
width -> case Signedness
signedness of
      Signedness
Signless -> [C.exp| MlirType {
        mlirIntegerTypeGet($(MlirContext ctx), $(unsigned int cwidth))
      } |]
      Signedness
Signed -> [C.exp| MlirType {
        mlirIntegerTypeSignedGet($(MlirContext ctx), $(unsigned int cwidth))
      } |]
      Signedness
Unsigned -> [C.exp| MlirType {
        mlirIntegerTypeUnsignedGet($(MlirContext ctx), $(unsigned int cwidth))
      } |]
      where cwidth :: CUInt
cwidth = UInt -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral UInt
width
    MemRefType [Maybe Int]
shape Type
elTy Maybe Attribute
layout Maybe Attribute
memSpace -> ContT Type IO Type -> IO Type
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Type IO Type -> IO Type) -> ContT Type IO Type -> IO Type
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
rank, Ptr Int64
nativeShape) <- [Int64] -> ContT Type IO (CIntPtr, Ptr Int64)
forall a r. Storable a => [a] -> ContT r IO (CIntPtr, Ptr a)
packArray [Int64]
shapeI64
      IO Type -> ContT Type IO Type
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Type -> ContT Type IO Type) -> IO Type -> ContT Type IO Type
forall a b. (a -> b) -> a -> b
$ do
        Type
nativeElTy <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
elTy
        Attribute
nativeSpace <- case Maybe Attribute
memSpace of
          Just Attribute
space -> Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Attribute
space
          Maybe Attribute
Nothing    -> Attribute -> IO Attribute
forall (m :: * -> *) a. Monad m => a -> m a
return (Attribute -> IO Attribute) -> Attribute -> IO Attribute
forall a b. (a -> b) -> a -> b
$ Ptr Any -> Attribute
coerce Ptr Any
forall a. Ptr a
nullPtr
        Attribute
nativeLayout <- case Maybe Attribute
layout of
          Just Attribute
alayout -> Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Attribute
alayout
          Maybe Attribute
Nothing      -> Attribute -> IO Attribute
forall (m :: * -> *) a. Monad m => a -> m a
return (Attribute -> IO Attribute) -> Attribute -> IO Attribute
forall a b. (a -> b) -> a -> b
$ Ptr Any -> Attribute
coerce Ptr Any
forall a. Ptr a
nullPtr
        [C.exp| MlirType {
          mlirMemRefTypeGet($(MlirType nativeElTy),
                            $(intptr_t rank), $(int64_t* nativeShape),
                            $(MlirAttribute nativeLayout), $(MlirAttribute nativeSpace))
        } |]
      where shapeI64 :: [Int64]
shapeI64 = (Maybe Int -> Int64) -> [Maybe Int] -> [Int64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int64 -> (Int -> Int64) -> Maybe Int -> Int64
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int64
1) Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral) [Maybe Int]
shape :: [Int64]
    Type
NoneType -> [C.exp| MlirType { mlirNoneTypeGet($(MlirContext ctx)) } |]
    OpaqueType Name
_ Name
_ -> IO Type
forall a. a
notImplemented
    RankedTensorType [Maybe Int]
shape Type
elTy Maybe Attribute
encoding -> ContT Type IO Type -> IO Type
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Type IO Type -> IO Type) -> ContT Type IO Type -> IO Type
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
rank, Ptr Int64
nativeShape) <- [Int64] -> ContT Type IO (CIntPtr, Ptr Int64)
forall a r. Storable a => [a] -> ContT r IO (CIntPtr, Ptr a)
packArray [Int64]
shapeI64
      IO Type -> ContT Type IO Type
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Type -> ContT Type IO Type) -> IO Type -> ContT Type IO Type
forall a b. (a -> b) -> a -> b
$ do
        Type
nativeElTy <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
elTy
        Attribute
nativeEncoding <- case Maybe Attribute
encoding of
          Just Attribute
enc -> Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Attribute
enc
          Maybe Attribute
Nothing  -> Attribute -> IO Attribute
forall (m :: * -> *) a. Monad m => a -> m a
return (Attribute -> IO Attribute) -> Attribute -> IO Attribute
forall a b. (a -> b) -> a -> b
$ Ptr Any -> Attribute
coerce Ptr Any
forall a. Ptr a
nullPtr
        [C.exp| MlirType {
          mlirRankedTensorTypeGet($(intptr_t rank), $(int64_t* nativeShape),
                                  $(MlirType nativeElTy), $(MlirAttribute nativeEncoding))
        } |]
      where shapeI64 :: [Int64]
shapeI64 = (Maybe Int -> Int64) -> [Maybe Int] -> [Int64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int64 -> (Int -> Int64) -> Maybe Int -> Int64
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int64
1) Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral) [Maybe Int]
shape :: [Int64]
    TupleType [Type]
tys -> ContT Type IO Type -> IO Type
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Type IO Type -> IO Type) -> ContT Type IO Type -> IO Type
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
numTypes, Ptr Type
nativeTypes) <- Context
-> ValueAndBlockMapping
-> [Type]
-> ContT Type IO (CIntPtr, Ptr Type)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Type]
tys
      IO Type -> ContT Type IO Type
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Type -> ContT Type IO Type) -> IO Type -> ContT Type IO Type
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirType {
        mlirTupleTypeGet($(MlirContext ctx), $(intptr_t numTypes), $(MlirType* nativeTypes))
      } |]
    UnrankedMemRefType Type
elTy Attribute
attr -> do
      Type
nativeElTy <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
elTy
      Attribute
nativeAttr <- Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Attribute
attr
      [C.exp| MlirType {
        mlirUnrankedMemRefTypeGet($(MlirType nativeElTy), $(MlirAttribute nativeAttr))
      } |]
    UnrankedTensorType Type
elTy -> do
      Type
nativeElTy <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
elTy
      [C.exp| MlirType {
        mlirUnrankedTensorTypeGet($(MlirType nativeElTy))
      } |]
    VectorType [Int]
shape Type
elTy -> ContT Type IO Type -> IO Type
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Type IO Type -> IO Type) -> ContT Type IO Type -> IO Type
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
rank, Ptr Int64
nativeShape) <- [Int64] -> ContT Type IO (CIntPtr, Ptr Int64)
forall a r. Storable a => [a] -> ContT r IO (CIntPtr, Ptr a)
packArray [Int64]
shapeI64
      IO Type -> ContT Type IO Type
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Type -> ContT Type IO Type) -> IO Type -> ContT Type IO Type
forall a b. (a -> b) -> a -> b
$ do
        Type
nativeElTy <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
elTy
        [C.exp| MlirType {
          mlirVectorTypeGet($(intptr_t rank), $(int64_t* nativeShape), $(MlirType nativeElTy))
        } |]
      where shapeI64 :: [Int64]
shapeI64 = (Int -> Int64) -> [Int] -> [Int64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
shape :: [Int64]
    DialectType t
t -> Context -> ValueAndBlockMapping -> t -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env t
t


instance FromAST Region Native.Region where
  fromAST :: Context -> ValueAndBlockMapping -> Region -> IO Region
fromAST Context
ctx env :: ValueAndBlockMapping
env@(ValueMapping
valueEnv, BlockMapping
_) (Region [Block]
blocks) = do
    Region
region   <- [C.exp| MlirRegion { mlirRegionCreate() } |]
    BlockMapping
blockEnv <- (BlockMapping -> Block -> IO BlockMapping)
-> BlockMapping -> [Block] -> IO BlockMapping
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Region -> BlockMapping -> Block -> IO BlockMapping
initAppendBlock Region
region) BlockMapping
forall a. Monoid a => a
mempty [Block]
blocks
    (Block -> IO Block) -> [Block] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Context -> ValueAndBlockMapping -> Block -> IO Block
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx (ValueMapping
valueEnv, BlockMapping
blockEnv)) [Block]
blocks
    Region -> IO Region
forall (m :: * -> *) a. Monad m => a -> m a
return Region
region
    where
      initAppendBlock :: Native.Region -> BlockMapping -> Block -> IO BlockMapping
      initAppendBlock :: Region -> BlockMapping -> Block -> IO BlockMapping
initAppendBlock Region
region BlockMapping
blockEnv Block
block = do
        Block
nativeBlock <- Block -> IO Block
initBlock Block
block
        [C.exp| void {
          mlirRegionAppendOwnedBlock($(MlirRegion region), $(MlirBlock nativeBlock))
        } |]
        BlockMapping -> IO BlockMapping
forall (m :: * -> *) a. Monad m => a -> m a
return (BlockMapping -> IO BlockMapping)
-> BlockMapping -> IO BlockMapping
forall a b. (a -> b) -> a -> b
$ BlockMapping
blockEnv BlockMapping -> BlockMapping -> BlockMapping
forall a. Semigroup a => a -> a -> a
<> (Name -> Block -> BlockMapping
forall k a. k -> a -> Map k a
M.singleton (Block -> Name
blockName Block
block) Block
nativeBlock)

      initBlock :: Block -> IO Native.Block
      initBlock :: Block -> IO Block
initBlock Block{[(Name, Type)]
[Binding]
Name
blockBody :: [Binding]
blockArgs :: [(Name, Type)]
blockName :: Name
blockBody :: Block -> [Binding]
blockArgs :: Block -> [(Name, Type)]
blockName :: Block -> Name
..} = do
        -- TODO: Use proper locations
        let locations :: [Location]
locations = Int -> [Location] -> [Location]
forall a. Int -> [a] -> [a]
take ([(Name, Type)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Name, Type)]
blockArgs) (Location -> [Location]
forall a. a -> [a]
repeat Location
UnknownLocation)
        ContT Block IO Block -> IO Block
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Block IO Block -> IO Block)
-> ContT Block IO Block -> IO Block
forall a b. (a -> b) -> a -> b
$ do
          let blockArgTypes :: [Type]
blockArgTypes = (Name, Type) -> Type
forall a b. (a, b) -> b
snd ((Name, Type) -> Type) -> [(Name, Type)] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Name, Type)]
blockArgs
          (CIntPtr
numBlockArgs, Ptr Type
nativeArgTypes) <- Context
-> ValueAndBlockMapping
-> [Type]
-> ContT Block IO (CIntPtr, Ptr Type)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Type]
blockArgTypes
          (CIntPtr
_, Ptr Location
locs) <- Context
-> ValueAndBlockMapping
-> [Location]
-> ContT Block IO (CIntPtr, Ptr Location)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Location]
locations
          IO Block -> ContT Block IO Block
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Block -> ContT Block IO Block)
-> IO Block -> ContT Block IO Block
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirBlock {
            mlirBlockCreate($(intptr_t numBlockArgs), $(MlirType* nativeArgTypes), $(MlirLocation* locs))
          } |]


instance FromAST Block Native.Block where
  fromAST :: Context -> ValueAndBlockMapping -> Block -> IO Block
fromAST Context
ctx (ValueMapping
outerValueEnv, BlockMapping
blockEnv) Block{[(Name, Type)]
[Binding]
Name
blockBody :: [Binding]
blockArgs :: [(Name, Type)]
blockName :: Name
blockBody :: Block -> [Binding]
blockArgs :: Block -> [(Name, Type)]
blockName :: Block -> Name
..} = do
    let block :: Block
block = BlockMapping
blockEnv BlockMapping -> Name -> Block
forall k a. Ord k => Map k a -> k -> a
M.! Name
blockName
    [Value]
nativeBlockArgs <- Block -> IO [Value]
getBlockArgs Block
block
    let blockArgNames :: [Name]
blockArgNames = (Name, Type) -> Name
forall a b. (a, b) -> a
fst ((Name, Type) -> Name) -> [(Name, Type)] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Name, Type)]
blockArgs
    let argValueEnv :: ValueMapping
argValueEnv = [(Name, Value)] -> ValueMapping
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Value)] -> ValueMapping)
-> [(Name, Value)] -> ValueMapping
forall a b. (a -> b) -> a -> b
$ [Name] -> [Value] -> [(Name, Value)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
blockArgNames [Value]
nativeBlockArgs
    (ValueMapping -> Binding -> IO ValueMapping)
-> ValueMapping -> [Binding] -> IO ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (Block -> ValueMapping -> Binding -> IO ValueMapping
appendInstr Block
block) (ValueMapping
outerValueEnv ValueMapping -> ValueMapping -> ValueMapping
forall a. Semigroup a => a -> a -> a
<> ValueMapping
argValueEnv) [Binding]
blockBody
    Block -> IO Block
forall (m :: * -> *) a. Monad m => a -> m a
return Block
block
    where
      appendInstr :: Native.Block -> ValueMapping -> Binding -> IO ValueMapping
      appendInstr :: Block -> ValueMapping -> Binding -> IO ValueMapping
appendInstr Block
block ValueMapping
valueEnv (Bind [Name]
names Operation
operation) = do
        Operation
nativeOperation <- Context -> ValueAndBlockMapping -> Operation -> IO Operation
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx (ValueMapping
valueEnv, BlockMapping
blockEnv) Operation
operation
        [C.exp| void {
          mlirBlockAppendOwnedOperation($(MlirBlock block),
                                        $(MlirOperation nativeOperation))
        } |]
        [Value]
nativeResults <- Operation -> IO [Value]
getOperationResults Operation
nativeOperation
        ValueMapping -> IO ValueMapping
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueMapping -> IO ValueMapping)
-> ValueMapping -> IO ValueMapping
forall a b. (a -> b) -> a -> b
$ ValueMapping
valueEnv ValueMapping -> ValueMapping -> ValueMapping
forall a. Semigroup a => a -> a -> a
<> ([(Name, Value)] -> ValueMapping
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Value)] -> ValueMapping)
-> [(Name, Value)] -> ValueMapping
forall a b. (a -> b) -> a -> b
$ [Name] -> [Value] -> [(Name, Value)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names [Value]
nativeResults)

      getBlockArgs :: Native.Block -> IO [Native.Value]
      getBlockArgs :: Block -> IO [Value]
getBlockArgs Block
block = do
        CIntPtr
numArgs <- [C.exp| intptr_t { mlirBlockGetNumArguments($(MlirBlock block)) } |]
        Int -> (Ptr Value -> IO [Value]) -> IO [Value]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray (CIntPtr -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CIntPtr
numArgs) \Ptr Value
nativeArgs -> do
          [C.block| void {
            for (intptr_t i = 0; i < $(intptr_t numArgs); ++i) {
              $(MlirValue* nativeArgs)[i] = mlirBlockGetArgument($(MlirBlock block), i);
            }
          } |]
          CIntPtr -> Ptr Value -> IO [Value]
forall a. Storable a => CIntPtr -> Ptr a -> IO [a]
unpackArray CIntPtr
numArgs Ptr Value
nativeArgs

      getOperationResults :: Native.Operation -> IO [Native.Value]
      getOperationResults :: Operation -> IO [Value]
getOperationResults Operation
op = do
        CIntPtr
numResults <- [C.exp| intptr_t { mlirOperationGetNumResults($(MlirOperation op)) } |]
        Int -> (Ptr Value -> IO [Value]) -> IO [Value]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray (CIntPtr -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CIntPtr
numResults) \Ptr Value
nativeResults -> do
          [C.block| void {
            for (intptr_t i = 0; i < $(intptr_t numResults); ++i) {
              $(MlirValue* nativeResults)[i] = mlirOperationGetResult($(MlirOperation op), i);
            }
          } |]
          CIntPtr -> Ptr Value -> IO [Value]
forall a. Storable a => CIntPtr -> Ptr a -> IO [a]
unpackArray CIntPtr
numResults Ptr Value
nativeResults


instance FromAST Attribute Native.Attribute where
  fromAST :: Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
fromAST Context
ctx ValueAndBlockMapping
env Attribute
attr = case Attribute
attr of
    ArrayAttr [Attribute]
attrs -> ContT Attribute IO Attribute -> IO Attribute
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Attribute IO Attribute -> IO Attribute)
-> ContT Attribute IO Attribute -> IO Attribute
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
numAttrs, Ptr Attribute
nativeAttrs) <- Context
-> ValueAndBlockMapping
-> [Attribute]
-> ContT Attribute IO (CIntPtr, Ptr Attribute)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Attribute]
attrs
      IO Attribute -> ContT Attribute IO Attribute
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Attribute -> ContT Attribute IO Attribute)
-> IO Attribute -> ContT Attribute IO Attribute
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirAttribute {
        mlirArrayAttrGet($(MlirContext ctx), $(intptr_t numAttrs), $(MlirAttribute* nativeAttrs))
      } |]
    DictionaryAttr Map Name Attribute
dict -> ContT Attribute IO Attribute -> IO Attribute
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Attribute IO Attribute -> IO Attribute)
-> ContT Attribute IO Attribute -> IO Attribute
forall a b. (a -> b) -> a -> b
$ do
      (CIntPtr
numAttrs, Ptr NamedAttribute
nativeAttrs) <- Context
-> ValueAndBlockMapping
-> Map Name Attribute
-> ContT Attribute IO (CIntPtr, Ptr NamedAttribute)
forall r.
Context
-> ValueAndBlockMapping
-> Map Name Attribute
-> ContT r IO (CIntPtr, Ptr NamedAttribute)
packNamedAttrs Context
ctx ValueAndBlockMapping
env Map Name Attribute
dict
      IO Attribute -> ContT Attribute IO Attribute
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Attribute -> ContT Attribute IO Attribute)
-> IO Attribute -> ContT Attribute IO Attribute
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirAttribute {
        mlirDictionaryAttrGet($(MlirContext ctx), $(intptr_t numAttrs), $(MlirNamedAttribute* nativeAttrs))
      } |]
    DialectAttr t
at -> Context -> ValueAndBlockMapping -> t -> IO Attribute
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env t
at
    FloatAttr Type
ty Double
value -> do
      Type
nativeType <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
ty
      let nativeValue :: CDouble
nativeValue = Double -> CDouble
coerce Double
value
      [C.exp| MlirAttribute {
        mlirFloatAttrDoubleGet($(MlirContext ctx), $(MlirType nativeType), $(double nativeValue))
      } |]
    IntegerAttr Type
ty Int
value -> do
      Type
nativeType <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
ty
      let nativeValue :: Int64
nativeValue = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
value
      [C.exp| MlirAttribute {
        mlirIntegerAttrGet($(MlirType nativeType), $(int64_t nativeValue))
      } |]
    BoolAttr Bool
value -> do
      let nativeValue :: CInt
nativeValue = if Bool
value then CInt
1 else CInt
0
      [C.exp| MlirAttribute {
        mlirBoolAttrGet($(MlirContext ctx), $(int nativeValue))
      } |]
    StringAttr Name
value -> do
      Name -> (StringRef -> IO Attribute) -> IO Attribute
forall a. Name -> (StringRef -> IO a) -> IO a
Native.withStringRef Name
value \(Native.StringRef Ptr CChar
ptr CSize
len) ->
        [C.exp| MlirAttribute {
          mlirStringAttrGet($(MlirContext ctx), (MlirStringRef){$(char* ptr), $(size_t len)})
        } |]
    AsmTextAttr Name
value ->
      Name -> (StringRef -> IO Attribute) -> IO Attribute
forall a. Name -> (StringRef -> IO a) -> IO a
Native.withStringRef Name
value \(Native.StringRef Ptr CChar
ptr CSize
len) ->
        [C.exp| MlirAttribute {
          mlirAttributeParseGet($(MlirContext ctx), (MlirStringRef){$(char* ptr), $(size_t len)})
        } |]
    TypeAttr Type
ty -> do
      Type
nativeType <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
ty
      [C.exp| MlirAttribute { mlirTypeAttrGet($(MlirType nativeType)) } |]
    AffineMapAttr Map
afMap -> do
      AffineMap
nativeMap <- Context -> ValueAndBlockMapping -> Map -> IO AffineMap
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Map
afMap
      [C.exp| MlirAttribute { mlirAffineMapAttrGet($(MlirAffineMap nativeMap)) } |]
    Attribute
UnitAttr -> [C.exp| MlirAttribute { mlirUnitAttrGet($(MlirContext ctx)) } |]
    DenseArrayAttr DenseElements
storage -> do
      case DenseElements
storage of
        DenseInt8 IStorableArray i Int8
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Int8 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Int8
arr
          IStorableArray i Int8 -> (Ptr Int8 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Int8
arr \Ptr Int8
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseI8ArrayGet($(MlirContext ctx), $(intptr_t size),
                                           $(const int8_t* valuesPtr))
            } |]
        DenseInt32 IStorableArray i Int32
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Int32 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Int32
arr
          IStorableArray i Int32
-> (Ptr Int32 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Int32
arr \Ptr Int32
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseI32ArrayGet($(MlirContext ctx), $(intptr_t size),
                                            $(const int32_t* valuesPtr))
            } |]
        DenseInt64 IStorableArray i Int64
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Int64 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Int64
arr
          IStorableArray i Int64
-> (Ptr Int64 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Int64
arr \Ptr Int64
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseI64ArrayGet($(MlirContext ctx), $(intptr_t size),
                                            $(const int64_t* valuesPtr))
            } |]
        DenseFloat IStorableArray i Float
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Float -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Float
arr
          IStorableArray i Float
-> (Ptr Float -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Float
arr \Ptr Float
valuesPtrHs -> do
            let valuesPtr :: Ptr CFloat
valuesPtr = Ptr Float -> Ptr CFloat
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
valuesPtrHs
            [C.exp| MlirAttribute {
              mlirDenseF32ArrayGet($(MlirContext ctx), $(intptr_t size),
                                            $(const float* valuesPtr))
            } |]
        DenseDouble IStorableArray i Double
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Double -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Double
arr
          IStorableArray i Double
-> (Ptr Double -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Double
arr \Ptr Double
valuesPtrHs -> do
            let valuesPtr :: Ptr CDouble
valuesPtr = Ptr Double -> Ptr CDouble
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
valuesPtrHs
            [C.exp| MlirAttribute {
              mlirDenseF64ArrayGet($(MlirContext ctx), $(intptr_t size),
                                             $(const double* valuesPtr))
            } |]
        DenseElements
_ -> [Char] -> IO Attribute
forall a. HasCallStack => [Char] -> a
error [Char]
"Found aDenseArray datatype unsupported in the MLIR API"
    DenseElementsAttr Type
ty DenseElements
storage -> do
      Type
nativeType <- Context -> ValueAndBlockMapping -> Type -> IO Type
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Type
ty
      case DenseElements
storage of
        DenseUInt8  IStorableArray i Word8
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Word8 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Word8
arr
          IStorableArray i Word8
-> (Ptr Word8 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Word8
arr \Ptr Word8
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrUInt8Get($(MlirType nativeType), $(intptr_t size),
                                            $(const uint8_t* valuesPtr))
            } |]
        DenseInt8   IStorableArray i Int8
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Int8 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Int8
arr
          IStorableArray i Int8 -> (Ptr Int8 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Int8
arr \Ptr Int8
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrInt8Get($(MlirType nativeType), $(intptr_t size),
                                           $(const int8_t* valuesPtr))
            } |]
        DenseUInt32 IStorableArray i Word32
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Word32 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Word32
arr
          IStorableArray i Word32
-> (Ptr Word32 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Word32
arr \Ptr Word32
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrUInt32Get($(MlirType nativeType), $(intptr_t size),
                                             $(const uint32_t* valuesPtr))
            } |]
        DenseInt32  IStorableArray i Int32
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Int32 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Int32
arr
          IStorableArray i Int32
-> (Ptr Int32 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Int32
arr \Ptr Int32
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrInt32Get($(MlirType nativeType), $(intptr_t size),
                                            $(const int32_t* valuesPtr))
            } |]
        DenseUInt64 IStorableArray i Word64
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Word64 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Word64
arr
          IStorableArray i Word64
-> (Ptr Word64 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Word64
arr \Ptr Word64
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrUInt64Get($(MlirType nativeType), $(intptr_t size),
                                             $(const uint64_t* valuesPtr))
            } |]
        DenseInt64  IStorableArray i Int64
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Int64 -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Int64
arr
          IStorableArray i Int64
-> (Ptr Int64 -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Int64
arr \Ptr Int64
valuesPtr ->
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrInt64Get($(MlirType nativeType), $(intptr_t size),
                                            $(const int64_t* valuesPtr))
            } |]
        DenseFloat IStorableArray i Float
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Float -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Float
arr
          IStorableArray i Float
-> (Ptr Float -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Float
arr \Ptr Float
valuesPtrHs -> do
            let valuesPtr :: Ptr CFloat
valuesPtr = Ptr Float -> Ptr CFloat
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
valuesPtrHs
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrFloatGet($(MlirType nativeType), $(intptr_t size),
                                            $(const float* valuesPtr))
            } |]
        DenseDouble IStorableArray i Double
arr -> do
          let size :: CIntPtr
size = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CIntPtr) -> Int -> CIntPtr
forall a b. (a -> b) -> a -> b
$ (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int) -> (i, i) -> Int
forall a b. (a -> b) -> a -> b
$ IStorableArray i Double -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i Double
arr
          IStorableArray i Double
-> (Ptr Double -> IO Attribute) -> IO Attribute
forall i e c. IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray IStorableArray i Double
arr \Ptr Double
valuesPtrHs -> do
            let valuesPtr :: Ptr CDouble
valuesPtr = Ptr Double -> Ptr CDouble
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
valuesPtrHs
            [C.exp| MlirAttribute {
              mlirDenseElementsAttrDoubleGet($(MlirType nativeType), $(intptr_t size),
                                             $(const double* valuesPtr))
            } |]


instance FromAST Operation Native.Operation where
  fromAST :: Context -> ValueAndBlockMapping -> Operation -> IO Operation
fromAST Context
ctx env :: ValueAndBlockMapping
env@(ValueMapping
valueEnv, BlockMapping
blockEnv) Operation{[Name]
[Region]
Name
Map Name Attribute
ResultTypes
Location
opAttributes :: Map Name Attribute
opSuccessors :: [Name]
opRegions :: [Region]
opOperands :: [Name]
opResultTypes :: ResultTypes
opLocation :: Location
opName :: Name
opAttributes :: forall operand. AbstractOperation operand -> Map Name Attribute
opSuccessors :: forall operand. AbstractOperation operand -> [Name]
opRegions :: forall operand. AbstractOperation operand -> [Region]
opOperands :: forall operand. AbstractOperation operand -> [operand]
opResultTypes :: forall operand. AbstractOperation operand -> ResultTypes
opLocation :: forall operand. AbstractOperation operand -> Location
opName :: forall operand. AbstractOperation operand -> Name
..} = ContT Operation IO Operation -> IO Operation
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT Operation IO Operation -> IO Operation)
-> ContT Operation IO Operation -> IO Operation
forall a b. (a -> b) -> a -> b
$ do
    (Ptr CChar
namePtr, Int
nameLen) <- (((Ptr CChar, Int) -> IO Operation) -> IO Operation)
-> ContT Operation IO (Ptr CChar, Int)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT ((((Ptr CChar, Int) -> IO Operation) -> IO Operation)
 -> ContT Operation IO (Ptr CChar, Int))
-> (((Ptr CChar, Int) -> IO Operation) -> IO Operation)
-> ContT Operation IO (Ptr CChar, Int)
forall a b. (a -> b) -> a -> b
$ Name -> ((Ptr CChar, Int) -> IO Operation) -> IO Operation
forall a. Name -> ((Ptr CChar, Int) -> IO a) -> IO a
BS.unsafeUseAsCStringLen Name
opName
    let nameLenSizeT :: CSize
nameLenSizeT = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nameLen
    (CBool
infersResults, (CIntPtr
numResultTypes, Ptr Type
nativeResultTypes)) <- case ResultTypes
opResultTypes of
      ResultTypes
Inferred -> (CBool, (CIntPtr, Ptr Type))
-> ContT Operation IO (CBool, (CIntPtr, Ptr Type))
forall (m :: * -> *) a. Monad m => a -> m a
return (CBool
CTrue, (CIntPtr
0, Ptr Type
forall a. Ptr a
nullPtr))
      Explicit [Type]
types -> (CBool
CFalse,) ((CIntPtr, Ptr Type) -> (CBool, (CIntPtr, Ptr Type)))
-> ContT Operation IO (CIntPtr, Ptr Type)
-> ContT Operation IO (CBool, (CIntPtr, Ptr Type))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context
-> ValueAndBlockMapping
-> [Type]
-> ContT Operation IO (CIntPtr, Ptr Type)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Type]
types
    Location
nativeLocation <- IO Location -> ContT Operation IO Location
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Location -> ContT Operation IO Location)
-> IO Location -> ContT Operation IO Location
forall a b. (a -> b) -> a -> b
$ Context -> ValueAndBlockMapping -> Location -> IO Location
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Location
opLocation
    (CIntPtr
numOperands, Ptr Value
nativeOperands) <- [Value] -> ContT Operation IO (CIntPtr, Ptr Value)
forall a r. Storable a => [a] -> ContT r IO (CIntPtr, Ptr a)
packArray ([Value] -> ContT Operation IO (CIntPtr, Ptr Value))
-> [Value] -> ContT Operation IO (CIntPtr, Ptr Value)
forall a b. (a -> b) -> a -> b
$ (Name -> Value) -> [Name] -> [Value]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ValueMapping
valueEnv ValueMapping -> Name -> Value
forall k a. Ord k => Map k a -> k -> a
M.!) [Name]
opOperands
    (CIntPtr
numRegions, Ptr Region
nativeRegions) <- Context
-> ValueAndBlockMapping
-> [Region]
-> ContT Operation IO (CIntPtr, Ptr Region)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Region]
opRegions
    (CIntPtr
numSuccessors, Ptr Block
nativeSuccessors) <- [Block] -> ContT Operation IO (CIntPtr, Ptr Block)
forall a r. Storable a => [a] -> ContT r IO (CIntPtr, Ptr a)
packArray ([Block] -> ContT Operation IO (CIntPtr, Ptr Block))
-> [Block] -> ContT Operation IO (CIntPtr, Ptr Block)
forall a b. (a -> b) -> a -> b
$ (Name -> Block) -> [Name] -> [Block]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BlockMapping
blockEnv BlockMapping -> Name -> Block
forall k a. Ord k => Map k a -> k -> a
M.!) [Name]
opSuccessors
    (CIntPtr
numAttributes, Ptr NamedAttribute
nativeAttributes) <- Context
-> ValueAndBlockMapping
-> Map Name Attribute
-> ContT Operation IO (CIntPtr, Ptr NamedAttribute)
forall r.
Context
-> ValueAndBlockMapping
-> Map Name Attribute
-> ContT r IO (CIntPtr, Ptr NamedAttribute)
packNamedAttrs Context
ctx ValueAndBlockMapping
env Map Name Attribute
opAttributes
    -- NB: This is nullable when result type inference is enabled
    Maybe Operation
maybeOperation <- IO (Maybe Operation) -> ContT Operation IO (Maybe Operation)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Operation) -> ContT Operation IO (Maybe Operation))
-> IO (Maybe Operation) -> ContT Operation IO (Maybe Operation)
forall a b. (a -> b) -> a -> b
$ Operation -> Maybe Operation
forall a. Coercible a (Ptr ()) => a -> Maybe a
Native.nullable (Operation -> Maybe Operation)
-> IO Operation -> IO (Maybe Operation)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.block| MlirOperation {
      MlirOperationState state = mlirOperationStateGet(
        (MlirStringRef){$(char* namePtr), $(size_t nameLenSizeT)},
        $(MlirLocation nativeLocation));
      if ($(bool infersResults)) {
        mlirOperationStateEnableResultTypeInference(&state);
      } else {
        mlirOperationStateAddResults(
            &state, $(intptr_t numResultTypes), $(MlirType* nativeResultTypes));
      }
      mlirOperationStateAddOperands(
          &state, $(intptr_t numOperands), $(MlirValue* nativeOperands));
      mlirOperationStateAddOwnedRegions(
          &state, $(intptr_t numRegions), $(MlirRegion* nativeRegions));
      mlirOperationStateAddSuccessors(
          &state, $(intptr_t numSuccessors), $(MlirBlock* nativeSuccessors));
      mlirOperationStateAddAttributes(
          &state, $(intptr_t numAttributes), $(MlirNamedAttribute* nativeAttributes));
      return mlirOperationCreate(&state);
    } |]
    case Maybe Operation
maybeOperation of
      Just Operation
operation -> Operation -> ContT Operation IO Operation
forall (m :: * -> *) a. Monad m => a -> m a
return Operation
operation
      Maybe Operation
Nothing -> [Char] -> ContT Operation IO Operation
forall a. HasCallStack => [Char] -> a
error ([Char] -> ContT Operation IO Operation)
-> [Char] -> ContT Operation IO Operation
forall a b. (a -> b) -> a -> b
$ [Char]
"Type inference failed for operation " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Show a => a -> [Char]
show Name
opName

--------------------------------------------------------------------------------
-- Utilities for AST -> Native translation

packNamedAttrs :: Native.Context -> ValueAndBlockMapping
               -> M.Map Name Attribute -> ContT r IO (C.CIntPtr, Ptr Native.NamedAttribute)
packNamedAttrs :: Context
-> ValueAndBlockMapping
-> Map Name Attribute
-> ContT r IO (CIntPtr, Ptr NamedAttribute)
packNamedAttrs Context
ctx ValueAndBlockMapping
env Map Name Attribute
attrDict = do
  let arrSize :: Int
arrSize = Map Name Attribute -> Int
forall k a. Map k a -> Int
M.size Map Name Attribute
attrDict
  Int
elemSize  <- IO Int -> ContT r IO Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> ContT r IO Int) -> IO Int -> ContT r IO Int
forall a b. (a -> b) -> a -> b
$ CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| size_t { sizeof(MlirNamedAttribute) } |]
  Int
elemAlign <- IO Int -> ContT r IO Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> ContT r IO Int) -> IO Int -> ContT r IO Int
forall a b. (a -> b) -> a -> b
$ CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| size_t { alignof(MlirNamedAttribute) } |]
  Ptr NamedAttribute
ptr <- ((Ptr NamedAttribute -> IO r) -> IO r)
-> ContT r IO (Ptr NamedAttribute)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr NamedAttribute -> IO r) -> IO r)
 -> ContT r IO (Ptr NamedAttribute))
-> ((Ptr NamedAttribute -> IO r) -> IO r)
-> ContT r IO (Ptr NamedAttribute)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> (Ptr NamedAttribute -> IO r) -> IO r
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Int
arrSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
elemSize) (Int
elemAlign)
  (((CInt, (Name, Attribute)) -> ContT r IO ())
 -> [(CInt, (Name, Attribute))] -> ContT r IO ())
-> [(CInt, (Name, Attribute))]
-> ((CInt, (Name, Attribute)) -> ContT r IO ())
-> ContT r IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((CInt, (Name, Attribute)) -> ContT r IO ())
-> [(CInt, (Name, Attribute))] -> ContT r IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([CInt] -> [(Name, Attribute)] -> [(CInt, (Name, Attribute))]
forall a b. [a] -> [b] -> [(a, b)]
zip [CInt
0..] ([(Name, Attribute)] -> [(CInt, (Name, Attribute))])
-> [(Name, Attribute)] -> [(CInt, (Name, Attribute))]
forall a b. (a -> b) -> a -> b
$ Map Name Attribute -> [(Name, Attribute)]
forall k a. Map k a -> [(k, a)]
M.toList Map Name Attribute
attrDict) \(CInt
i, (Name
name, Attribute
attr)) -> do
    StringRef
nameRef <- ((StringRef -> IO r) -> IO r) -> ContT r IO StringRef
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((StringRef -> IO r) -> IO r) -> ContT r IO StringRef)
-> ((StringRef -> IO r) -> IO r) -> ContT r IO StringRef
forall a b. (a -> b) -> a -> b
$ Name -> (StringRef -> IO r) -> IO r
forall a. Name -> (StringRef -> IO a) -> IO a
Native.withStringRef Name
name
    IO () -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT r IO ()) -> IO () -> ContT r IO ()
forall a b. (a -> b) -> a -> b
$ do
      Attribute
nativeAttr <- Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Attribute
attr
      Identifier
ident <- Context -> StringRef -> IO Identifier
Native.createIdentifier Context
ctx StringRef
nameRef
      [C.exp| void {
        $(MlirNamedAttribute* ptr)[$(int i)] =
          mlirNamedAttributeGet($(MlirIdentifier ident), $(MlirAttribute nativeAttr));
      } |]
  (CIntPtr, Ptr NamedAttribute)
-> ContT r IO (CIntPtr, Ptr NamedAttribute)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
arrSize, Ptr NamedAttribute
ptr)

pattern CTrue :: C.CBool
pattern $bCTrue :: CBool
$mCTrue :: forall r. CBool -> (Void# -> r) -> (Void# -> r) -> r
CTrue = C.CBool 1

pattern CFalse :: C.CBool
pattern $bCFalse :: CBool
$mCFalse :: forall r. CBool -> (Void# -> r) -> (Void# -> r) -> r
CFalse = C.CBool 0

notImplemented :: forall a. a
notImplemented :: a
notImplemented = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented"