-- Copyright 2022 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

module MLIR.AST.Dialect.ControlFlow
  ( module MLIR.AST.Dialect.ControlFlow,
    module MLIR.AST.Dialect.Generated.ControlFlow
  ) where

import Prelude hiding (return)
import Data.Array.IArray

import MLIR.AST
import MLIR.AST.Builder

import MLIR.AST.Dialect.Generated.ControlFlow

pattern Branch :: Location -> BlockName -> [Name] -> Operation
pattern $bBranch :: Location -> BlockName -> [BlockName] -> Operation
$mBranch :: forall r.
Operation
-> (Location -> BlockName -> [BlockName] -> r) -> (Void# -> r) -> r
Branch loc block args = Operation
  { opName = "cf.br"
  , opLocation = loc
  , opResultTypes = Explicit []
  , opOperands = args
  , opRegions = []
  , opSuccessors = [block]
  , opAttributes = NoAttrs
  }

br :: MonadBlockBuilder m => BlockName -> [Value] -> m EndOfBlock
br :: BlockName -> [Value] -> m EndOfBlock
br BlockName
block [Value]
args = Operation -> m [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
emitOp (Location -> BlockName -> [BlockName] -> Operation
Branch Location
UnknownLocation BlockName
block ([BlockName] -> Operation) -> [BlockName] -> Operation
forall a b. (a -> b) -> a -> b
$ [Value] -> [BlockName]
operands [Value]
args) m [Value] -> m EndOfBlock -> m EndOfBlock
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m EndOfBlock
forall (m :: * -> *). Monad m => m EndOfBlock
terminateBlock

cond_br :: MonadBlockBuilder m => Value -> BlockName -> [Value] -> BlockName -> [Value] -> m EndOfBlock
cond_br :: Value
-> BlockName -> [Value] -> BlockName -> [Value] -> m EndOfBlock
cond_br Value
cond BlockName
trueBlock [Value]
trueArgs BlockName
falseBlock [Value]
falseArgs = do
  Operation -> m ()
forall (m :: * -> *). MonadBlockDecl m => Operation -> m ()
emitOp_ (Operation -> m ()) -> Operation -> m ()
forall a b. (a -> b) -> a -> b
$ Operation :: forall operand.
BlockName
-> Location
-> ResultTypes
-> [operand]
-> [Region]
-> [BlockName]
-> Map BlockName Attribute
-> AbstractOperation operand
Operation
    { opName :: BlockName
opName = BlockName
"cf.cond_br"
    , opLocation :: Location
opLocation = Location
UnknownLocation
    , opResultTypes :: ResultTypes
opResultTypes = [Type] -> ResultTypes
Explicit []
    , opOperands :: [BlockName]
opOperands = [Value] -> [BlockName]
operands ([Value] -> [BlockName]) -> [Value] -> [BlockName]
forall a b. (a -> b) -> a -> b
$ [Value
cond] [Value] -> [Value] -> [Value]
forall a. Semigroup a => a -> a -> a
<> [Value]
trueArgs [Value] -> [Value] -> [Value]
forall a. Semigroup a => a -> a -> a
<> [Value]
falseArgs
    , opRegions :: [Region]
opRegions = []
    , opSuccessors :: [BlockName]
opSuccessors = [BlockName
trueBlock, BlockName
falseBlock]
    , opAttributes :: Map BlockName Attribute
opAttributes = BlockName -> Attribute -> Map BlockName Attribute
namedAttribute BlockName
"operand_segment_sizes" (Attribute -> Map BlockName Attribute)
-> Attribute -> Map BlockName Attribute
forall a b. (a -> b) -> a -> b
$
                       DenseElements -> Attribute
DenseArrayAttr (DenseElements -> Attribute) -> DenseElements -> Attribute
forall a b. (a -> b) -> a -> b
$
                         IStorableArray Int Int32 -> DenseElements
forall i. (Show i, Ix i) => IStorableArray i Int32 -> DenseElements
DenseInt32 (IStorableArray Int Int32 -> DenseElements)
-> IStorableArray Int Int32 -> DenseElements
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> [Int32] -> IStorableArray Int Int32
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Int
0 :: Int, Int
2) ([Int32] -> IStorableArray Int Int32)
-> [Int32] -> IStorableArray Int Int32
forall a b. (a -> b) -> a -> b
$ Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int32) -> [Int] -> [Int32]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
1, [Value] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
trueArgs, [Value] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
falseArgs]
    }
  m EndOfBlock
forall (m :: * -> *). Monad m => m EndOfBlock
terminateBlock