{-# OPTIONS_GHC -Wno-unused-imports #-}
{-# OPTIONS_HADDOCK hide, prune, not-home #-}

module MLIR.AST.Dialect.Generated.Linalg where

import Prelude (Int, Double, Maybe(..), Bool(..), (++), (<$>), ($), (<>), Show)
import qualified Prelude
import Data.Int (Int64)
import qualified Data.Maybe
import Data.Array (Ix)
import qualified Data.Array.IArray as IArray
import qualified Data.ByteString as BS
import qualified Data.Map.Strict as M
import qualified Control.Monad

import MLIR.AST ( Attribute(..), Type(..), AbstractOperation(..), ResultTypes(..)
                , Location(..), Signedness(..), DenseElements(..)
                , NamedAttributes, Name
                , pattern NoAttrs )
import qualified MLIR.AST as AST
import MLIR.AST.Builder (Value, EndOfBlock, MonadBlockBuilder, RegionBuilderT)
import qualified MLIR.AST.Builder as AST
import qualified MLIR.AST.IStorableArray as AST
import qualified MLIR.AST.PatternUtil as PatternUtil
import qualified MLIR.AST.Dialect.Affine as Affine

-- * index
-- $index
-- 
-- The @linalg.index@ operation returns the iteration index of the immediately
-- enclosing linalg structured operation for the iteration dimension @dim@. The
-- @dim@ attribute specifies the position of the accessed dimension in the
-- indexing map domain.
-- 
-- Example:
-- 
-- @
-- \#map = affine_map\<(i, j) -> (i, j)>
-- linalg.generic {indexing_maps = [\#map, \#map],
--                 iterator_types = [\"parallel\", \"parallel\"]}
--   outs(%I, %J : memref\<?x?xindex>, memref\<?x?xindex>) {
--   ^bb0(%arg0 : index, %arg1 : index):
--   // Access the outer iteration dimension i
--   %i = linalg.index 0 : index
--   // Access the inner iteration dimension j
--   %j = linalg.index 1 : index
--   linalg.yield %i, %j : index, index
-- }
-- @
-- 
-- This may lower to IR resembling:
-- 
-- @
-- %0 = dim %I, %c0 : memref\<?x?xindex>
-- %1 = dim %I, %c1 : memref\<?x?xindex>
-- scf.for %i = %c0 to %0 step %c1 {
--   scf.for %j = %c0 to %1 step %c1 {
--     store %i, %I[%i, %j] : memref\<?x?xindex>
--     store %j, %J[%i, %j] : memref\<?x?xindex>
--   }
-- }
-- @
--   

pattern InternalIndexOpAttributes :: () => () => Int -> NamedAttributes
pattern $bInternalIndexOpAttributes :: Int -> NamedAttributes
$mInternalIndexOpAttributes :: forall r. NamedAttributes -> (Int -> r) -> (Void# -> r) -> r
InternalIndexOpAttributes dim_ <- ((\m -> (M.lookup "dim" m)) -> (Just (IntegerAttr (IntegerType Signless 64) dim_)))
  where InternalIndexOpAttributes Int
dim_ = [(Name, Attribute)] -> NamedAttributes
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Attribute)] -> NamedAttributes)
-> [(Name, Attribute)] -> NamedAttributes
forall a b. (a -> b) -> a -> b
$ [(Name
"dim", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
dim_)]

-- | A pattern for @linalg.index@.
pattern Linalg_Index :: () => () => Location -> Type -> Int -> AbstractOperation operand
pattern $bLinalg_Index :: Location -> Type -> Int -> AbstractOperation operand
$mLinalg_Index :: forall r operand.
AbstractOperation operand
-> (Location -> Type -> Int -> r) -> (Void# -> r) -> r
Linalg_Index loc ty0  dim_ = Operation
          { opName = "linalg.index"
          , opLocation = loc
          , opResultTypes = Explicit [ty0]
          , opOperands = []
          , opRegions = []
          , opSuccessors = []
          , opAttributes = (InternalIndexOpAttributes dim_)
          }

-- | A builder for @linalg.index@.
index :: () => MonadBlockBuilder m => Type -> Int -> m Value
index :: Type -> Int -> m Value
index Type
ty0  Int
dim_  = do
  ([Value] -> Value) -> m [Value] -> m Value
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Control.Monad.liftM [Value] -> Value
forall a. [a] -> a
Prelude.head (Operation -> m [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
AST.emitOp (Operation :: forall operand.
Name
-> Location
-> ResultTypes
-> [operand]
-> [Region]
-> [Name]
-> NamedAttributes
-> AbstractOperation operand
Operation
          { opName :: Name
opName = Name
"linalg.index"
          , opLocation :: Location
opLocation = Location
UnknownLocation
          , opResultTypes :: ResultTypes
opResultTypes = [Type] -> ResultTypes
Explicit [Type
ty0]
          , opOperands :: [Name]
opOperands = []
          , opRegions :: [Region]
opRegions = []
          , opSuccessors :: [Name]
opSuccessors = []
          , opAttributes :: NamedAttributes
opAttributes = (Int -> NamedAttributes
InternalIndexOpAttributes Int
dim_)
          }))

-- * softmax
-- $softmax
-- 
-- linalg.softmax computes a numerically stable version of softmax.
-- 
-- For a given input tensor and a specified dimension @d@, compute:
--   1. the max @m@ along that dimension @d@
--   2. f(x) = exp(x - m)
--   3. sum f(x) along dimension d to get l(x).
--   4. compute the final result f(x) / l(x).
-- 
-- This is an aggregate linalg operation that further reduces to a small DAG of
-- structured operations.
-- 
-- Warning: Regarding the tiling capabilities, the implementation doesn\'t
-- check that the provided dimensions make sense. This is the responsability
-- of the transformation calling the tiling to ensure that the provided
-- sizes for each dimension make sense with respect to the semantic of
-- softmax.
--   

pattern InternalSoftmaxOpAttributes :: () => () => Int -> NamedAttributes
pattern $bInternalSoftmaxOpAttributes :: Int -> NamedAttributes
$mInternalSoftmaxOpAttributes :: forall r. NamedAttributes -> (Int -> r) -> (Void# -> r) -> r
InternalSoftmaxOpAttributes dimension_ <- ((\m -> (M.lookup "dimension" m)) -> (Just (IntegerAttr (IntegerType Signless 64) dimension_)))
  where InternalSoftmaxOpAttributes Int
dimension_ = [(Name, Attribute)] -> NamedAttributes
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Attribute)] -> NamedAttributes)
-> [(Name, Attribute)] -> NamedAttributes
forall a b. (a -> b) -> a -> b
$ [(Name
"dimension", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
dimension_)]

-- | A builder for @linalg.softmax@.
softmax :: () => MonadBlockBuilder m => [Type] -> Value -> Value -> Int -> m Value
softmax :: [Type] -> Value -> Value -> Int -> m Value
softmax [Type]
ty0 Value
input_ Value
output_ Int
dimension_  = do
  ([Value] -> Value) -> m [Value] -> m Value
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Control.Monad.liftM [Value] -> Value
forall a. [a] -> a
Prelude.head (Operation -> m [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
AST.emitOp (Operation :: forall operand.
Name
-> Location
-> ResultTypes
-> [operand]
-> [Region]
-> [Name]
-> NamedAttributes
-> AbstractOperation operand
Operation
          { opName :: Name
opName = Name
"linalg.softmax"
          , opLocation :: Location
opLocation = Location
UnknownLocation
          , opResultTypes :: ResultTypes
opResultTypes = [Type] -> ResultTypes
Explicit ([Type]
ty0)
          , opOperands :: [Name]
opOperands = [(Value -> Name
AST.operand Value
input_), (Value -> Name
AST.operand Value
output_)]
          , opRegions :: [Region]
opRegions = []
          , opSuccessors :: [Name]
opSuccessors = []
          , opAttributes :: NamedAttributes
opAttributes = (Int -> NamedAttributes
InternalSoftmaxOpAttributes Int
dimension_)
          }))

-- * winograd_filter_transform
-- $winograd_filter_transform
-- 
-- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-- matrix multiply. Before the matrix multiply, it will convert filter and
-- input into a format suitable for batched matrix multiply. After the matrix
-- multiply, it will convert output to the final result tensor.
-- 
-- The algorithm F(m x m, r x r) is
-- 
-- Y = A^T x [(G x g x G^T) \@ (B^T x d x B)] x A
-- 
-- The size of output Y is m x m. The size of filter g is r x r. The size of
-- input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-- transformation matrices.
-- 
-- This operator is defined to represent the high level concept of filter
-- transformation (G x g x G^T) in the Winograd Conv2D algorithm.
--   

pattern InternalWinogradFilterTransformOpAttributes :: () => () => Int -> Int -> NamedAttributes
pattern $bInternalWinogradFilterTransformOpAttributes :: Int -> Int -> NamedAttributes
$mInternalWinogradFilterTransformOpAttributes :: forall r. NamedAttributes -> (Int -> Int -> r) -> (Void# -> r) -> r
InternalWinogradFilterTransformOpAttributes m_ r_ <- ((\m -> (M.lookup "m" m, M.lookup "r" m)) -> (Just (IntegerAttr (IntegerType Signless 64) m_), Just (IntegerAttr (IntegerType Signless 64) r_)))
  where InternalWinogradFilterTransformOpAttributes Int
m_ Int
r_ = [(Name, Attribute)] -> NamedAttributes
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Attribute)] -> NamedAttributes)
-> [(Name, Attribute)] -> NamedAttributes
forall a b. (a -> b) -> a -> b
$ [(Name
"m", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
m_)] [(Name, Attribute)] -> [(Name, Attribute)] -> [(Name, Attribute)]
forall a. [a] -> [a] -> [a]
++ [(Name
"r", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
r_)]

-- | A pattern for @linalg.winograd_filter_transform@.
pattern Linalg_WinogradFilterTransform :: () => () => Location -> Type -> operand -> operand -> Int -> Int -> AbstractOperation operand
pattern $bLinalg_WinogradFilterTransform :: Location
-> Type
-> operand
-> operand
-> Int
-> Int
-> AbstractOperation operand
$mLinalg_WinogradFilterTransform :: forall r operand.
AbstractOperation operand
-> (Location -> Type -> operand -> operand -> Int -> Int -> r)
-> (Void# -> r)
-> r
Linalg_WinogradFilterTransform loc ty0 filter_ output_ m_ r_ = Operation
          { opName = "linalg.winograd_filter_transform"
          , opLocation = loc
          , opResultTypes = Explicit [ty0]
          , opOperands = [filter_, output_]
          , opRegions = []
          , opSuccessors = []
          , opAttributes = (InternalWinogradFilterTransformOpAttributes m_ r_)
          }

-- | A builder for @linalg.winograd_filter_transform@.
winograd_filter_transform :: () => MonadBlockBuilder m => Type -> Value -> Value -> Int -> Int -> m Value
winograd_filter_transform :: Type -> Value -> Value -> Int -> Int -> m Value
winograd_filter_transform Type
ty0 Value
filter_ Value
output_ Int
m_ Int
r_  = do
  ([Value] -> Value) -> m [Value] -> m Value
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Control.Monad.liftM [Value] -> Value
forall a. [a] -> a
Prelude.head (Operation -> m [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
AST.emitOp (Operation :: forall operand.
Name
-> Location
-> ResultTypes
-> [operand]
-> [Region]
-> [Name]
-> NamedAttributes
-> AbstractOperation operand
Operation
          { opName :: Name
opName = Name
"linalg.winograd_filter_transform"
          , opLocation :: Location
opLocation = Location
UnknownLocation
          , opResultTypes :: ResultTypes
opResultTypes = [Type] -> ResultTypes
Explicit [Type
ty0]
          , opOperands :: [Name]
opOperands = [(Value -> Name
AST.operand Value
filter_), (Value -> Name
AST.operand Value
output_)]
          , opRegions :: [Region]
opRegions = []
          , opSuccessors :: [Name]
opSuccessors = []
          , opAttributes :: NamedAttributes
opAttributes = (Int -> Int -> NamedAttributes
InternalWinogradFilterTransformOpAttributes Int
m_ Int
r_)
          }))

-- * winograd_input_transform
-- $winograd_input_transform
-- 
-- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-- matrix multiply. Before the matrix multiply, it will convert filter and
-- input into a format suitable for batched matrix multiply. After the matrix
-- multiply, it will convert output to the final result tensor.
-- 
-- The algorithm F(m x m, r x r) is
-- 
-- Y = A^T x [(G x g x G^T) \@ (B^T x d x B)] x A
-- 
-- The size of output Y is m x m. The size of filter g is r x r. The size of
-- input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-- transformation matrices.
-- 
-- This operator is defined to represent the high level concept of input
-- transformation (B^T x d x B) in the Winograd Conv2D algorithm.
--   

pattern InternalWinogradInputTransformOpAttributes :: () => () => Int -> Int -> NamedAttributes
pattern $bInternalWinogradInputTransformOpAttributes :: Int -> Int -> NamedAttributes
$mInternalWinogradInputTransformOpAttributes :: forall r. NamedAttributes -> (Int -> Int -> r) -> (Void# -> r) -> r
InternalWinogradInputTransformOpAttributes m_ r_ <- ((\m -> (M.lookup "m" m, M.lookup "r" m)) -> (Just (IntegerAttr (IntegerType Signless 64) m_), Just (IntegerAttr (IntegerType Signless 64) r_)))
  where InternalWinogradInputTransformOpAttributes Int
m_ Int
r_ = [(Name, Attribute)] -> NamedAttributes
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Attribute)] -> NamedAttributes)
-> [(Name, Attribute)] -> NamedAttributes
forall a b. (a -> b) -> a -> b
$ [(Name
"m", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
m_)] [(Name, Attribute)] -> [(Name, Attribute)] -> [(Name, Attribute)]
forall a. [a] -> [a] -> [a]
++ [(Name
"r", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
r_)]

-- | A pattern for @linalg.winograd_input_transform@.
pattern Linalg_WinogradInputTransform :: () => () => Location -> Type -> operand -> operand -> Int -> Int -> AbstractOperation operand
pattern $bLinalg_WinogradInputTransform :: Location
-> Type
-> operand
-> operand
-> Int
-> Int
-> AbstractOperation operand
$mLinalg_WinogradInputTransform :: forall r operand.
AbstractOperation operand
-> (Location -> Type -> operand -> operand -> Int -> Int -> r)
-> (Void# -> r)
-> r
Linalg_WinogradInputTransform loc ty0 input_ output_ m_ r_ = Operation
          { opName = "linalg.winograd_input_transform"
          , opLocation = loc
          , opResultTypes = Explicit [ty0]
          , opOperands = [input_, output_]
          , opRegions = []
          , opSuccessors = []
          , opAttributes = (InternalWinogradInputTransformOpAttributes m_ r_)
          }

-- | A builder for @linalg.winograd_input_transform@.
winograd_input_transform :: () => MonadBlockBuilder m => Type -> Value -> Value -> Int -> Int -> m Value
winograd_input_transform :: Type -> Value -> Value -> Int -> Int -> m Value
winograd_input_transform Type
ty0 Value
input_ Value
output_ Int
m_ Int
r_  = do
  ([Value] -> Value) -> m [Value] -> m Value
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Control.Monad.liftM [Value] -> Value
forall a. [a] -> a
Prelude.head (Operation -> m [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
AST.emitOp (Operation :: forall operand.
Name
-> Location
-> ResultTypes
-> [operand]
-> [Region]
-> [Name]
-> NamedAttributes
-> AbstractOperation operand
Operation
          { opName :: Name
opName = Name
"linalg.winograd_input_transform"
          , opLocation :: Location
opLocation = Location
UnknownLocation
          , opResultTypes :: ResultTypes
opResultTypes = [Type] -> ResultTypes
Explicit [Type
ty0]
          , opOperands :: [Name]
opOperands = [(Value -> Name
AST.operand Value
input_), (Value -> Name
AST.operand Value
output_)]
          , opRegions :: [Region]
opRegions = []
          , opSuccessors :: [Name]
opSuccessors = []
          , opAttributes :: NamedAttributes
opAttributes = (Int -> Int -> NamedAttributes
InternalWinogradInputTransformOpAttributes Int
m_ Int
r_)
          }))

-- * winograd_output_transform
-- $winograd_output_transform
-- 
-- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
-- matrix multiply. Before the matrix multiply, it will convert filter and
-- input into a format suitable for batched matrix multiply. After the matrix
-- multiply, it will convert output to the final result tensor.
-- 
-- The algorithm F(m x m, r x r) is
-- 
-- Y = A^T x [(G x g x G^T) \@ (B^T x d x B)] x A
-- 
-- The size of output Y is m x m. The size of filter g is r x r. The size of
-- input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
-- transformation matrices.
-- 
-- This operator is defined to represent the high level concept of output
-- transformation (A^T x y x A) in the Winograd Conv2D algorithm.
--   

pattern InternalWinogradOutputTransformOpAttributes :: () => () => Int -> Int -> NamedAttributes
pattern $bInternalWinogradOutputTransformOpAttributes :: Int -> Int -> NamedAttributes
$mInternalWinogradOutputTransformOpAttributes :: forall r. NamedAttributes -> (Int -> Int -> r) -> (Void# -> r) -> r
InternalWinogradOutputTransformOpAttributes m_ r_ <- ((\m -> (M.lookup "m" m, M.lookup "r" m)) -> (Just (IntegerAttr (IntegerType Signless 64) m_), Just (IntegerAttr (IntegerType Signless 64) r_)))
  where InternalWinogradOutputTransformOpAttributes Int
m_ Int
r_ = [(Name, Attribute)] -> NamedAttributes
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Attribute)] -> NamedAttributes)
-> [(Name, Attribute)] -> NamedAttributes
forall a b. (a -> b) -> a -> b
$ [(Name
"m", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
m_)] [(Name, Attribute)] -> [(Name, Attribute)] -> [(Name, Attribute)]
forall a. [a] -> [a] -> [a]
++ [(Name
"r", Type -> Int -> Attribute
IntegerAttr (Signedness -> UInt -> Type
IntegerType Signedness
Signless UInt
64) Int
r_)]

-- | A pattern for @linalg.winograd_output_transform@.
pattern Linalg_WinogradOutputTransform :: () => () => Location -> Type -> operand -> operand -> Int -> Int -> AbstractOperation operand
pattern $bLinalg_WinogradOutputTransform :: Location
-> Type
-> operand
-> operand
-> Int
-> Int
-> AbstractOperation operand
$mLinalg_WinogradOutputTransform :: forall r operand.
AbstractOperation operand
-> (Location -> Type -> operand -> operand -> Int -> Int -> r)
-> (Void# -> r)
-> r
Linalg_WinogradOutputTransform loc ty0 value_ output_ m_ r_ = Operation
          { opName = "linalg.winograd_output_transform"
          , opLocation = loc
          , opResultTypes = Explicit [ty0]
          , opOperands = [value_, output_]
          , opRegions = []
          , opSuccessors = []
          , opAttributes = (InternalWinogradOutputTransformOpAttributes m_ r_)
          }

-- | A builder for @linalg.winograd_output_transform@.
winograd_output_transform :: () => MonadBlockBuilder m => Type -> Value -> Value -> Int -> Int -> m Value
winograd_output_transform :: Type -> Value -> Value -> Int -> Int -> m Value
winograd_output_transform Type
ty0 Value
value_ Value
output_ Int
m_ Int
r_  = do
  ([Value] -> Value) -> m [Value] -> m Value
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Control.Monad.liftM [Value] -> Value
forall a. [a] -> a
Prelude.head (Operation -> m [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
AST.emitOp (Operation :: forall operand.
Name
-> Location
-> ResultTypes
-> [operand]
-> [Region]
-> [Name]
-> NamedAttributes
-> AbstractOperation operand
Operation
          { opName :: Name
opName = Name
"linalg.winograd_output_transform"
          , opLocation :: Location
opLocation = Location
UnknownLocation
          , opResultTypes :: ResultTypes
opResultTypes = [Type] -> ResultTypes
Explicit [Type
ty0]
          , opOperands :: [Name]
opOperands = [(Value -> Name
AST.operand Value
value_), (Value -> Name
AST.operand Value
output_)]
          , opRegions :: [Region]
opRegions = []
          , opSuccessors :: [Name]
opSuccessors = []
          , opAttributes :: NamedAttributes
opAttributes = (Int -> Int -> NamedAttributes
InternalWinogradOutputTransformOpAttributes Int
m_ Int
r_)
          }))

-- * yield
-- $yield
-- 
-- @linalg.yield@ is a special terminator operation for blocks inside regions
-- in @linalg@ generic ops. It returns values to the immediately enclosing
-- @linalg@ generic op.
-- 
-- Example:
-- 
-- @
-- linalg.yield %f0, %f1 : f32, f32
-- @
--   

-- | A pattern for @linalg.yield@.
pattern Linalg_Yield :: () => () => Location -> [operand] -> AbstractOperation operand
pattern $bLinalg_Yield :: Location -> [operand] -> AbstractOperation operand
$mLinalg_Yield :: forall r operand.
AbstractOperation operand
-> (Location -> [operand] -> r) -> (Void# -> r) -> r
Linalg_Yield loc  values_  = Operation
          { opName = "linalg.yield"
          , opLocation = loc
          , opResultTypes = Explicit []
          , opOperands = values_
          , opRegions = []
          , opSuccessors = []
          , opAttributes = (NoAttrs)
          }

-- | A builder for @linalg.yield@.
yield :: () => MonadBlockBuilder m => [Value] -> m EndOfBlock
yield :: [Value] -> m EndOfBlock
yield  [Value]
values_   = do
  m [Value] -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
Control.Monad.void (Operation -> m [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
AST.emitOp (Operation :: forall operand.
Name
-> Location
-> ResultTypes
-> [operand]
-> [Region]
-> [Name]
-> NamedAttributes
-> AbstractOperation operand
Operation
          { opName :: Name
opName = Name
"linalg.yield"
          , opLocation :: Location
opLocation = Location
UnknownLocation
          , opResultTypes :: ResultTypes
opResultTypes = [Type] -> ResultTypes
Explicit []
          , opOperands :: [Name]
opOperands = ([Value] -> [Name]
AST.operands [Value]
values_)
          , opRegions :: [Region]
opRegions = []
          , opSuccessors :: [Name]
opSuccessors = []
          , opAttributes :: NamedAttributes
opAttributes = (NamedAttributes
NoAttrs)
          }))
  m EndOfBlock
forall (m :: * -> *). Monad m => m EndOfBlock
AST.terminateBlock