-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

module MLIR.Native.ExecutionEngine where

import Foreign.Ptr
import Foreign.Storable
import Foreign.Marshal.Alloc
import Foreign.Marshal.Array
import Data.Int
import qualified Language.C.Inline as C

import Control.Exception (bracket)
import Control.Monad

import MLIR.Native
import MLIR.Native.FFI

C.context $ C.baseCtx <> mlirCtx

C.include "mlir-c/ExecutionEngine.h"

-- TODO(apaszke): Flesh this out based on the header

--------------------------------------------------------------------------------
-- Execution engine

-- TODO(apaszke): Make the opt level configurable
-- TODO(apaszke): Allow loading shared libraries
createExecutionEngine :: Module -> IO (Maybe ExecutionEngine)
createExecutionEngine :: Module -> IO (Maybe ExecutionEngine)
createExecutionEngine Module
m = ExecutionEngine -> Maybe ExecutionEngine
forall a. Coercible a (Ptr ()) => a -> Maybe a
nullable (ExecutionEngine -> Maybe ExecutionEngine)
-> IO ExecutionEngine -> IO (Maybe ExecutionEngine)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  [C.exp| MlirExecutionEngine { mlirExecutionEngineCreate($(MlirModule m), 3, 0, NULL, false) } |]

destroyExecutionEngine :: ExecutionEngine -> IO ()
destroyExecutionEngine :: ExecutionEngine -> IO ()
destroyExecutionEngine ExecutionEngine
eng =
  [C.exp| void { mlirExecutionEngineDestroy($(MlirExecutionEngine eng)) } |]

withExecutionEngine :: Module -> (Maybe ExecutionEngine -> IO a) -> IO a
withExecutionEngine :: Module -> (Maybe ExecutionEngine -> IO a) -> IO a
withExecutionEngine Module
m = IO (Maybe ExecutionEngine)
-> (Maybe ExecutionEngine -> IO ())
-> (Maybe ExecutionEngine -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Module -> IO (Maybe ExecutionEngine)
createExecutionEngine Module
m)
                                (\case Just ExecutionEngine
e  -> ExecutionEngine -> IO ()
destroyExecutionEngine ExecutionEngine
e
                                       Maybe ExecutionEngine
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())


data SomeStorable = forall a. Storable a => SomeStorable a

executionEngineInvoke :: forall result. Storable result
                      => ExecutionEngine -> StringRef -> [SomeStorable] -> IO (Maybe result)
executionEngineInvoke :: ExecutionEngine -> StringRef -> [SomeStorable] -> IO (Maybe result)
executionEngineInvoke ExecutionEngine
eng (StringRef Ptr CChar
namePtr CSize
nameLen) [SomeStorable]
args =
  (Ptr (Ptr ()) -> Ptr result -> IO (Maybe result))
-> IO (Maybe result)
forall a. (Ptr (Ptr ()) -> Ptr result -> IO a) -> IO a
withPackedPtr \Ptr (Ptr ())
packPtr Ptr result
resultPtr -> do
    LogicalResult
result <- [C.exp| MlirLogicalResult {
      mlirExecutionEngineInvokePacked($(MlirExecutionEngine eng),
                                      (MlirStringRef){$(char* namePtr), $(size_t nameLen)},
                                      $(void** packPtr))
    } |]
    case LogicalResult
result of
      LogicalResult
Success -> result -> Maybe result
forall a. a -> Maybe a
Just (result -> Maybe result) -> IO result -> IO (Maybe result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr result -> IO result
forall a. Storable a => Ptr a -> IO a
peek Ptr result
resultPtr
      LogicalResult
Failure -> Maybe result -> IO (Maybe result)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe result
forall a. Maybe a
Nothing
  where
    numArgs :: Int
numArgs = [SomeStorable] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SomeStorable]
args

    -- TODO(apaszke): Are tuples exploded, or stored as pointers?
    withPackedPtr :: (Ptr (Ptr ()) -> Ptr result -> IO a) -> IO a
    withPackedPtr :: (Ptr (Ptr ()) -> Ptr result -> IO a) -> IO a
withPackedPtr Ptr (Ptr ()) -> Ptr result -> IO a
f =
      Int -> (Ptr (Ptr ()) -> IO a) -> IO a
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray (Int
numArgs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) \Ptr (Ptr ())
packedPtr ->
        (Ptr result -> IO a) -> IO a
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca @result \Ptr result
resultPtr -> do
          Ptr (Ptr ()) -> Int -> Ptr () -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr (Ptr ())
packedPtr Int
numArgs (Ptr result -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr result
resultPtr)
          [SomeStorable] -> Ptr (Ptr ()) -> IO a -> IO a
forall a. [SomeStorable] -> Ptr (Ptr ()) -> IO a -> IO a
withStoredArgs [SomeStorable]
args Ptr (Ptr ())
packedPtr (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ Ptr (Ptr ()) -> Ptr result -> IO a
f Ptr (Ptr ())
packedPtr Ptr result
resultPtr

    withStoredArgs :: [SomeStorable] -> Ptr (Ptr ()) -> IO a -> IO a
    withStoredArgs :: [SomeStorable] -> Ptr (Ptr ()) -> IO a -> IO a
withStoredArgs [] Ptr (Ptr ())
_ IO a
m = IO a
m
    withStoredArgs (SomeStorable a
h:[SomeStorable]
t) Ptr (Ptr ())
nextArgPtr IO a
m =
      (Ptr a -> IO a) -> IO a
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr a
argPtr -> do
        Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
argPtr a
h
        Ptr (Ptr ()) -> Ptr () -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (Ptr ())
nextArgPtr (Ptr a -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr a
argPtr)
        [SomeStorable] -> Ptr (Ptr ()) -> IO a -> IO a
forall a. [SomeStorable] -> Ptr (Ptr ()) -> IO a -> IO a
withStoredArgs [SomeStorable]
t (Ptr (Ptr ()) -> Int -> Ptr (Ptr ())
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr (Ptr ())
nextArgPtr Int
1) IO a
m

packStruct64 :: [SomeStorable] -> (Ptr () -> IO a) -> IO a
packStruct64 :: [SomeStorable] -> (Ptr () -> IO a) -> IO a
packStruct64 [SomeStorable]
fields Ptr () -> IO a
f = do
  Int -> (Ptr Int64 -> IO a) -> IO a
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray ([SomeStorable] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SomeStorable]
fields) \(Ptr Int64
structPtr :: Ptr Int64) -> do
    [(Int, SomeStorable)] -> ((Int, SomeStorable) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [SomeStorable] -> [(Int, SomeStorable)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [SomeStorable]
fields) \(Int
i, SomeStorable a
field) -> do
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (a -> Int
forall a. Storable a => a -> Int
sizeOf a
field Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
8) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"packStruct64 expects all fields to be exactly 8 bytes in size"
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (a -> Int
forall a. Storable a => a -> Int
alignment a
field Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
8) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"packStruct64 expects all fields to have an alignment of at most 8 bytes"
      Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (Ptr Int64 -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr Ptr Int64
structPtr) Int
i a
field
    Ptr () -> IO a
f (Ptr () -> IO a) -> Ptr () -> IO a
forall a b. (a -> b) -> a -> b
$ Ptr Int64 -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr Int64
structPtr