-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

module MLIR.Native.Pass where

import qualified Language.C.Inline as C

import Control.Exception (bracket)

import MLIR.Native.FFI

C.context $ C.baseCtx <> mlirCtx

C.include "mlir-c/IR.h"
C.include "mlir-c/Pass.h"
C.include "mlir-c/Conversion.h"

-- TODO(apaszke): Flesh this out based on the header

--------------------------------------------------------------------------------
-- Pass manager

createPassManager :: Context -> IO PassManager
createPassManager :: Context -> IO PassManager
createPassManager Context
ctx =
  [C.exp| MlirPassManager { mlirPassManagerCreate($(MlirContext ctx)) } |]

destroyPassManager :: PassManager -> IO ()
destroyPassManager :: PassManager -> IO ()
destroyPassManager PassManager
pm =
  [C.exp| void { mlirPassManagerDestroy($(MlirPassManager pm)) } |]

withPassManager :: Context -> (PassManager -> IO a) -> IO a
withPassManager :: Context -> (PassManager -> IO a) -> IO a
withPassManager Context
ctx = IO PassManager
-> (PassManager -> IO ()) -> (PassManager -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Context -> IO PassManager
createPassManager Context
ctx) PassManager -> IO ()
destroyPassManager

runPasses :: PassManager -> Operation -> IO LogicalResult
runPasses :: PassManager -> Operation -> IO LogicalResult
runPasses PassManager
pm Operation
op =
  [C.exp| MlirLogicalResult { mlirPassManagerRunOnOp($(MlirPassManager pm), $(MlirOperation op)) } |]

--------------------------------------------------------------------------------
-- Transform passes

--------------------------------------------------------------------------------
-- Conversion passes

addConvertMemRefToLLVMPass :: PassManager -> IO ()
addConvertMemRefToLLVMPass :: PassManager -> IO ()
addConvertMemRefToLLVMPass PassManager
pm =
  [C.exp| void {
    mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionFinalizeMemRefToLLVMConversionPass())
  } |]

addConvertArithToLLVMPass :: PassManager -> IO ()
addConvertArithToLLVMPass :: PassManager -> IO ()
addConvertArithToLLVMPass PassManager
pm =
  [C.exp| void {
    mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionArithToLLVMConversionPass())
  } |]

addConvertControlFlowToLLVMPass :: PassManager -> IO ()
addConvertControlFlowToLLVMPass :: PassManager -> IO ()
addConvertControlFlowToLLVMPass PassManager
pm =
  [C.exp| void {
    mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionConvertControlFlowToLLVMPass())
  } |]

addConvertFuncToLLVMPass :: PassManager -> IO ()
addConvertFuncToLLVMPass :: PassManager -> IO ()
addConvertFuncToLLVMPass PassManager
pm =
  [C.exp| void {
    mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionConvertFuncToLLVMPass())
  } |]

addConvertVectorToLLVMPass :: PassManager -> IO ()
addConvertVectorToLLVMPass :: PassManager -> IO ()
addConvertVectorToLLVMPass PassManager
pm =
  [C.exp| void {
    mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionConvertVectorToLLVMPass())
  } |]

addConvertReconcileUnrealizedCastsPass :: PassManager -> IO ()
addConvertReconcileUnrealizedCastsPass :: PassManager -> IO ()
addConvertReconcileUnrealizedCastsPass PassManager
pm =
  [C.exp| void {
    mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionReconcileUnrealizedCasts())
  } |]