module MLIR.Native (
Context,
createContext,
destroyContext,
withContext,
HasContext(..),
registerAllDialects,
getNumLoadedDialects,
Type,
Location,
getFileLineColLocation,
getNameLocation,
getUnknownLocation,
Operation,
getOperationName,
showOperation,
showOperationWithLocation,
verifyOperation,
Region,
getOperationRegions,
getRegionBlocks,
Block,
showBlock,
getBlockOperations,
Module,
createEmptyModule,
parseModule,
destroyModule,
getModuleBody,
moduleAsOperation,
moduleFromOperation,
showModule,
StringRef(..),
withStringRef,
Identifier,
createIdentifier,
identifierString,
LogicalResult,
pattern Failure,
pattern Success,
setDebugMode,
HasDump(..),
) where
import qualified Data.ByteString as BS
import Foreign.Ptr
import Foreign.Storable
import Foreign.Marshal.Alloc
import Foreign.Marshal.Array
import qualified Language.C.Inline as C
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Cont
import Control.Exception (bracket)
import MLIR.Native.FFI
C.context $ C.baseCtx <> mlirCtx
C.include "mlir-c/Support.h"
C.include "mlir-c/Debug.h"
C.include "mlir-c/IR.h"
C.include "mlir-c/Pass.h"
C.include "mlir-c/Conversion.h"
C.include "mlir-c/RegisterEverything.h"
C.verbatim stringCallbackDecl
createContext :: IO Context
createContext :: IO Context
createContext = [C.exp| MlirContext { mlirContextCreate() } |]
destroyContext :: Context -> IO ()
destroyContext :: Context -> IO ()
destroyContext Context
ctx = [C.exp| void { mlirContextDestroy($(MlirContext ctx)) } |]
withContext :: (Context -> IO a) -> IO a
withContext :: (Context -> IO a) -> IO a
withContext = IO Context -> (Context -> IO ()) -> (Context -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Context
createContext Context -> IO ()
destroyContext
class HasContext a where
getContext :: a -> IO Context
registerAllDialects :: Context -> IO ()
registerAllDialects :: Context -> IO ()
registerAllDialects Context
ctx = [C.block| void {
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterAllDialects(registry);
mlirContextAppendDialectRegistry($(MlirContext ctx), registry);
mlirDialectRegistryDestroy(registry);
mlirContextLoadAllAvailableDialects($(MlirContext ctx));
} |]
getNumLoadedDialects :: Context -> IO Int
getNumLoadedDialects :: Context -> IO Int
getNumLoadedDialects Context
ctx = CIntPtr -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CIntPtr -> Int) -> IO CIntPtr -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
[C.exp| intptr_t { mlirContextGetNumLoadedDialects($(MlirContext ctx)) } |]
getUnknownLocation :: Context -> IO Location
getUnknownLocation :: Context -> IO Location
getUnknownLocation Context
ctx =
[C.exp| MlirLocation { mlirLocationUnknownGet($(MlirContext ctx)) } |]
getFileLineColLocation :: Context -> StringRef -> C.CUInt -> C.CUInt -> IO Location
getFileLineColLocation :: Context -> StringRef -> CUInt -> CUInt -> IO Location
getFileLineColLocation Context
ctx (StringRef Ptr CChar
sPtr CSize
len) CUInt
line CUInt
col =
[C.exp| MlirLocation {
mlirLocationFileLineColGet(
$(MlirContext ctx),
(MlirStringRef){$(char* sPtr), $(size_t len)},
$(unsigned int line),
$(unsigned int col)) } |]
getNameLocation :: Context -> StringRef -> Location -> IO Location
getNameLocation :: Context -> StringRef -> Location -> IO Location
getNameLocation Context
ctx (StringRef Ptr CChar
sPtr CSize
len) Location
childLoc =
[C.exp| MlirLocation {
mlirLocationNameGet(
$(MlirContext ctx),
(MlirStringRef){$(char* sPtr), $(size_t len)},
$(MlirLocation childLoc)) } |]
getOperationName :: Operation -> IO Identifier
getOperationName :: Operation -> IO Identifier
getOperationName Operation
op =
[C.exp| MlirIdentifier { mlirOperationGetName($(MlirOperation op)) } |]
showOperation :: Operation -> IO BS.ByteString
showOperation :: Operation -> IO ByteString
showOperation Operation
op = (Ptr () -> IO ()) -> IO ByteString
showSomething \Ptr ()
ctx ->
[C.block| void {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOperationPrintWithFlags($(MlirOperation op), flags,
HaskellMlirStringCallback, $(void* ctx));
mlirOpPrintingFlagsDestroy(flags);
} |]
showOperationWithLocation :: Operation -> IO BS.ByteString
showOperationWithLocation :: Operation -> IO ByteString
showOperationWithLocation Operation
op = (Ptr () -> IO ()) -> IO ByteString
showSomething \Ptr ()
ctx ->
[C.block| void {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, /*prettyForm=*/false);
mlirOperationPrintWithFlags($(MlirOperation op), flags,
HaskellMlirStringCallback, $(void* ctx));
mlirOpPrintingFlagsDestroy(flags);
} |]
verifyOperation :: Operation -> IO Bool
verifyOperation :: Operation -> IO Bool
verifyOperation Operation
op =
(CBool
1CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
==) (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| bool { mlirOperationVerify($(MlirOperation op)) } |]
getOperationFirstRegion :: Operation -> IO (Maybe Region)
getOperationFirstRegion :: Operation -> IO (Maybe Region)
getOperationFirstRegion Operation
op = Region -> Maybe Region
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Region -> Maybe Region) -> IO Region -> IO (Maybe Region)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| MlirRegion {
mlirOperationGetFirstRegion($(MlirOperation op))
} |]
getOperationNextRegion :: Region -> IO (Maybe Region)
getOperationNextRegion :: Region -> IO (Maybe Region)
getOperationNextRegion Region
region = Region -> Maybe Region
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Region -> Maybe Region) -> IO Region -> IO (Maybe Region)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| MlirRegion {
mlirRegionGetNextInOperation($(MlirRegion region))
} |]
getOperationRegions :: Operation -> IO [Region]
getOperationRegions :: Operation -> IO [Region]
getOperationRegions Operation
op = (Region -> IO (Maybe Region)) -> IO (Maybe Region) -> IO [Region]
forall a. (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a]
unrollIOMaybe Region -> IO (Maybe Region)
getOperationNextRegion (Operation -> IO (Maybe Region)
getOperationFirstRegion Operation
op)
getRegionFirstBlock :: Region -> IO (Maybe Block)
getRegionFirstBlock :: Region -> IO (Maybe Block)
getRegionFirstBlock Region
region = Block -> Maybe Block
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Block -> Maybe Block) -> IO Block -> IO (Maybe Block)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| MlirBlock {
mlirRegionGetFirstBlock($(MlirRegion region))
} |]
getRegionNextBlock :: Block -> IO (Maybe Block)
getRegionNextBlock :: Block -> IO (Maybe Block)
getRegionNextBlock Block
block = Block -> Maybe Block
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Block -> Maybe Block) -> IO Block -> IO (Maybe Block)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| MlirBlock {
mlirBlockGetNextInRegion($(MlirBlock block))
} |]
getRegionBlocks :: Region -> IO [Block]
getRegionBlocks :: Region -> IO [Block]
getRegionBlocks Region
region = (Block -> IO (Maybe Block)) -> IO (Maybe Block) -> IO [Block]
forall a. (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a]
unrollIOMaybe Block -> IO (Maybe Block)
getRegionNextBlock (Region -> IO (Maybe Block)
getRegionFirstBlock Region
region)
showBlock :: Block -> IO BS.ByteString
showBlock :: Block -> IO ByteString
showBlock Block
block = (Ptr () -> IO ()) -> IO ByteString
showSomething \Ptr ()
ctx -> [C.exp| void {
mlirBlockPrint($(MlirBlock block), HaskellMlirStringCallback, $(void* ctx))
} |]
getFirstOperationBlock :: Block -> IO (Maybe Operation)
getFirstOperationBlock :: Block -> IO (Maybe Operation)
getFirstOperationBlock Block
block = Operation -> Maybe Operation
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Operation -> Maybe Operation)
-> IO Operation -> IO (Maybe Operation)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
[C.exp| MlirOperation { mlirBlockGetFirstOperation($(MlirBlock block)) } |]
getNextOperationBlock :: Operation -> IO (Maybe Operation)
getNextOperationBlock :: Operation -> IO (Maybe Operation)
getNextOperationBlock Operation
childOp = Operation -> Maybe Operation
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Operation -> Maybe Operation)
-> IO Operation -> IO (Maybe Operation)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| MlirOperation {
mlirOperationGetNextInBlock($(MlirOperation childOp)) } |]
getBlockOperations :: Block -> IO [Operation]
getBlockOperations :: Block -> IO [Operation]
getBlockOperations Block
block = (Operation -> IO (Maybe Operation))
-> IO (Maybe Operation) -> IO [Operation]
forall a. (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a]
unrollIOMaybe Operation -> IO (Maybe Operation)
getNextOperationBlock (Block -> IO (Maybe Operation)
getFirstOperationBlock Block
block)
instance HasContext Module where
getContext :: Module -> IO Context
getContext Module
m = [C.exp| MlirContext { mlirModuleGetContext($(MlirModule m)) } |]
createEmptyModule :: Location -> IO Module
createEmptyModule :: Location -> IO Module
createEmptyModule Location
loc =
[C.exp| MlirModule { mlirModuleCreateEmpty($(MlirLocation loc)) } |]
parseModule :: Context -> StringRef -> IO (Maybe Module)
parseModule :: Context -> StringRef -> IO (Maybe Module)
parseModule Context
ctx (StringRef Ptr CChar
sPtr CSize
len) = Module -> Maybe Module
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Module -> Maybe Module) -> IO Module -> IO (Maybe Module)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
[C.exp| MlirModule {
mlirModuleCreateParse($(MlirContext ctx),
(MlirStringRef){$(char* sPtr), $(size_t len)})
} |]
destroyModule :: Module -> IO ()
destroyModule :: Module -> IO ()
destroyModule Module
m =
[C.exp| void { mlirModuleDestroy($(MlirModule m)) } |]
getModuleBody :: Module -> IO Block
getModuleBody :: Module -> IO Block
getModuleBody Module
m = [C.exp| MlirBlock { mlirModuleGetBody($(MlirModule m)) } |]
moduleAsOperation :: Module -> IO Operation
moduleAsOperation :: Module -> IO Operation
moduleAsOperation Module
m =
[C.exp| MlirOperation { mlirModuleGetOperation($(MlirModule m)) } |]
moduleFromOperation :: Operation -> IO (Maybe Module)
moduleFromOperation :: Operation -> IO (Maybe Module)
moduleFromOperation Operation
op =
Module -> Maybe Module
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (Module -> Maybe Module) -> IO Module -> IO (Maybe Module)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.exp| MlirModule { mlirModuleFromOperation($(MlirOperation op)) } |]
showModule :: Module -> IO BS.ByteString
showModule :: Module -> IO ByteString
showModule = Module -> IO Operation
moduleAsOperation (Module -> IO Operation)
-> (Operation -> IO ByteString) -> Module -> IO ByteString
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Operation -> IO ByteString
showOperation
data StringRef = StringRef (Ptr C.CChar) C.CSize
withStringRef :: BS.ByteString -> (StringRef -> IO a) -> IO a
withStringRef :: ByteString -> (StringRef -> IO a) -> IO a
withStringRef ByteString
s StringRef -> IO a
f = ByteString -> (Ptr CChar -> IO a) -> IO a
forall a. ByteString -> (Ptr CChar -> IO a) -> IO a
BS.useAsCString ByteString
s \Ptr CChar
ptr -> StringRef -> IO a
f (StringRef -> IO a) -> StringRef -> IO a
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> CSize -> StringRef
StringRef Ptr CChar
ptr (CSize -> StringRef) -> CSize -> StringRef
forall a b. (a -> b) -> a -> b
$ Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
s
peekStringRef :: StringRef -> IO BS.ByteString
peekStringRef :: StringRef -> IO ByteString
peekStringRef (StringRef Ptr CChar
ref CSize
size) = CStringLen -> IO ByteString
BS.packCStringLen (Ptr CChar
ref, CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size)
identifierString :: Identifier -> IO StringRef
identifierString :: Identifier -> IO StringRef
identifierString Identifier
ident = ContT StringRef IO StringRef -> IO StringRef
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT StringRef IO StringRef -> IO StringRef)
-> ContT StringRef IO StringRef -> IO StringRef
forall a b. (a -> b) -> a -> b
$ do
Ptr (Ptr CChar)
namePtrPtr <- ((Ptr (Ptr CChar) -> IO StringRef) -> IO StringRef)
-> ContT StringRef IO (Ptr (Ptr CChar))
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (Ptr (Ptr CChar) -> IO StringRef) -> IO StringRef
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca
Ptr CSize
sizePtr <- ((Ptr CSize -> IO StringRef) -> IO StringRef)
-> ContT StringRef IO (Ptr CSize)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (Ptr CSize -> IO StringRef) -> IO StringRef
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca
IO StringRef -> ContT StringRef IO StringRef
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO StringRef -> ContT StringRef IO StringRef)
-> IO StringRef -> ContT StringRef IO StringRef
forall a b. (a -> b) -> a -> b
$ do
[C.block| void {
MlirStringRef identStr = mlirIdentifierStr($(MlirIdentifier ident));
*$(const char** namePtrPtr) = identStr.data;
*$(size_t* sizePtr) = identStr.length;
} |]
Ptr CChar -> CSize -> StringRef
StringRef (Ptr CChar -> CSize -> StringRef)
-> IO (Ptr CChar) -> IO (CSize -> StringRef)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr (Ptr CChar) -> IO (Ptr CChar)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr CChar)
namePtrPtr IO (CSize -> StringRef) -> IO CSize -> IO StringRef
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
sizePtr
createIdentifier :: Context -> StringRef -> IO Identifier
createIdentifier :: Context -> StringRef -> IO Identifier
createIdentifier Context
ctx (StringRef Ptr CChar
ref CSize
size) =
[C.exp| MlirIdentifier {
mlirIdentifierGet($(MlirContext ctx), (MlirStringRef){$(char* ref), $(size_t size)})
} |]
showSomething :: (Ptr () -> IO ()) -> IO BS.ByteString
showSomething :: (Ptr () -> IO ()) -> IO ByteString
showSomething Ptr () -> IO ()
action = do
Int -> (Ptr (Ptr ()) -> IO ByteString) -> IO ByteString
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray @(Ptr ()) Int
2 \Ptr (Ptr ())
ctx ->
(Ptr CSize -> IO ByteString) -> IO ByteString
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca @C.CSize \Ptr CSize
sizePtr -> do
Ptr CSize -> CSize -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CSize
sizePtr CSize
0
Ptr (Ptr ()) -> Int -> Ptr () -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr (Ptr ())
ctx Int
0 Ptr ()
forall a. Ptr a
nullPtr
Ptr (Ptr ()) -> Int -> Ptr () -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr (Ptr ())
ctx Int
1 (Ptr () -> IO ()) -> Ptr () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CSize -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr CSize
sizePtr
let ctxFlat :: Ptr ()
ctxFlat = (Ptr (Ptr ()) -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr (Ptr ())
ctx) :: Ptr ()
Ptr () -> IO ()
action Ptr ()
ctxFlat
Ptr CChar
dataPtr <- Ptr () -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr (Ptr () -> Ptr CChar) -> IO (Ptr ()) -> IO (Ptr CChar)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr (Ptr ()) -> IO (Ptr ())
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr ())
ctx
CSize
size <- Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
sizePtr
ByteString
bs <- StringRef -> IO ByteString
peekStringRef (StringRef -> IO ByteString) -> StringRef -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> CSize -> StringRef
StringRef Ptr CChar
dataPtr CSize
size
Ptr CChar -> IO ()
forall a. Ptr a -> IO ()
free Ptr CChar
dataPtr
ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
unrollIOMaybe :: (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a]
unrollIOMaybe :: (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a]
unrollIOMaybe a -> IO (Maybe a)
fn IO (Maybe a)
z = do
Maybe a
x <- IO (Maybe a)
z
case Maybe a
x of
Maybe a
Nothing -> [a] -> IO [a]
forall (m :: * -> *) a. Monad m => a -> m a
return []
Just a
x' -> (a
x'a -> [a] -> [a]
forall a. a -> [a] -> [a]
:) ([a] -> [a]) -> IO [a] -> IO [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a]
forall a. (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a]
unrollIOMaybe a -> IO (Maybe a)
fn (a -> IO (Maybe a)
fn a
x')
setDebugMode :: Bool -> IO ()
setDebugMode :: Bool -> IO ()
setDebugMode Bool
enable = do
let nativeEnable :: CBool
nativeEnable = if Bool
enable then CBool
1 else CBool
0
[C.exp| void { mlirEnableGlobalDebug($(bool nativeEnable)) } |]
class HasDump a where
dump :: a -> IO ()
instance HasDump Operation where
dump :: Operation -> IO ()
dump Operation
op = [C.exp| void { mlirOperationDump($(MlirOperation op)) } |]
instance HasDump Module where
dump :: Module -> IO ()
dump = Module -> IO Operation
moduleAsOperation (Module -> IO Operation) -> (Operation -> IO ()) -> Module -> IO ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Operation -> IO ()
forall a. HasDump a => a -> IO ()
dump