{-# OPTIONS_GHC -Wno-name-shadowing #-}
module MLIR.AST.Dialect.Vector
( module MLIR.AST.Dialect.Vector
, module MLIR.AST.Dialect.Generated.Vector
) where
import Data.Typeable
import qualified Data.Map.Strict as M
import qualified Data.ByteString as BS
import qualified Language.C.Inline as C
import MLIR.AST.Dialect.Generated.Vector
import qualified MLIR.AST as AST
import qualified MLIR.AST.Serialize as AST
import qualified MLIR.AST.Dialect.Affine as Affine
import qualified MLIR.Native as Native
import qualified MLIR.Native.FFI as Native
data IteratorKind = Parallel | Reduction
deriving (IteratorKind -> IteratorKind -> Bool
(IteratorKind -> IteratorKind -> Bool)
-> (IteratorKind -> IteratorKind -> Bool) -> Eq IteratorKind
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IteratorKind -> IteratorKind -> Bool
$c/= :: IteratorKind -> IteratorKind -> Bool
== :: IteratorKind -> IteratorKind -> Bool
$c== :: IteratorKind -> IteratorKind -> Bool
Eq, Int -> IteratorKind -> ShowS
[IteratorKind] -> ShowS
IteratorKind -> String
(Int -> IteratorKind -> ShowS)
-> (IteratorKind -> String)
-> ([IteratorKind] -> ShowS)
-> Show IteratorKind
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IteratorKind] -> ShowS
$cshowList :: [IteratorKind] -> ShowS
show :: IteratorKind -> String
$cshow :: IteratorKind -> String
showsPrec :: Int -> IteratorKind -> ShowS
$cshowsPrec :: Int -> IteratorKind -> ShowS
Show)
data Attribute = IteratorAttr IteratorKind
deriving (Attribute -> Attribute -> Bool
(Attribute -> Attribute -> Bool)
-> (Attribute -> Attribute -> Bool) -> Eq Attribute
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Attribute -> Attribute -> Bool
$c/= :: Attribute -> Attribute -> Bool
== :: Attribute -> Attribute -> Bool
$c== :: Attribute -> Attribute -> Bool
Eq, Int -> Attribute -> ShowS
[Attribute] -> ShowS
Attribute -> String
(Int -> Attribute -> ShowS)
-> (Attribute -> String)
-> ([Attribute] -> ShowS)
-> Show Attribute
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Attribute] -> ShowS
$cshowList :: [Attribute] -> ShowS
show :: Attribute -> String
$cshow :: Attribute -> String
showsPrec :: Int -> Attribute -> ShowS
$cshowsPrec :: Int -> Attribute -> ShowS
Show)
castVectorAttr :: AST.Attribute -> Maybe Attribute
castVectorAttr :: Attribute -> Maybe Attribute
castVectorAttr Attribute
ty = case Attribute
ty of
AST.DialectAttr t
dty -> t -> Maybe Attribute
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast t
dty
Attribute
_ -> Maybe Attribute
forall a. Maybe a
Nothing
showIterator :: IteratorKind -> BS.ByteString
showIterator :: IteratorKind -> ByteString
showIterator IteratorKind
Parallel = ByteString
"#vector.iterator_type<parallel>"
showIterator IteratorKind
Reduction = ByteString
"#vector.iterator_type<reduction>"
C.context $ C.baseCtx <> Native.mlirCtx
C.include "mlir-c/IR.h"
instance AST.FromAST Attribute Native.Attribute where
fromAST :: Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
fromAST Context
ctx ValueAndBlockMapping
_ Attribute
ty = case Attribute
ty of
IteratorAttr IteratorKind
t -> do
let value :: ByteString
value = IteratorKind -> ByteString
showIterator IteratorKind
t
ByteString -> (StringRef -> IO Attribute) -> IO Attribute
forall a. ByteString -> (StringRef -> IO a) -> IO a
Native.withStringRef ByteString
value \(Native.StringRef Ptr CChar
ptr CSize
len) ->
[C.exp| MlirAttribute {
mlirAttributeParseGet($(MlirContext ctx), (MlirStringRef){$(char* ptr), $(size_t len)})
} |]
iterFromAttribute :: AST.Attribute -> Maybe IteratorKind
iterFromAttribute :: Attribute -> Maybe IteratorKind
iterFromAttribute Attribute
attr = case Attribute
attr of
AST.DialectAttr t
subAttr -> case t -> Maybe Attribute
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast t
subAttr of
Just (IteratorAttr IteratorKind
kind) -> IteratorKind -> Maybe IteratorKind
forall a. a -> Maybe a
Just IteratorKind
kind
Maybe Attribute
_ -> Maybe IteratorKind
forall a. Maybe a
Nothing
Attribute
_ -> Maybe IteratorKind
forall a. Maybe a
Nothing
itersFromAttribute :: AST.Attribute -> Maybe [IteratorKind]
itersFromAttribute :: Attribute -> Maybe [IteratorKind]
itersFromAttribute Attribute
attr = case Attribute
attr of
AST.ArrayAttr [Attribute]
subAttrs -> (Attribute -> Maybe IteratorKind)
-> [Attribute] -> Maybe [IteratorKind]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Attribute -> Maybe IteratorKind
iterFromAttribute [Attribute]
subAttrs
Attribute
_ -> Maybe [IteratorKind]
forall a. Maybe a
Nothing
pattern IteratorAttrs :: [IteratorKind] -> AST.Attribute
pattern $bIteratorAttrs :: [IteratorKind] -> Attribute
$mIteratorAttrs :: forall r. Attribute -> ([IteratorKind] -> r) -> (Void# -> r) -> r
IteratorAttrs iterTypes <- (itersFromAttribute -> Just iterTypes)
where IteratorAttrs [IteratorKind]
iterTypes = [Attribute] -> Attribute
AST.ArrayAttr ([Attribute] -> Attribute) -> [Attribute] -> Attribute
forall a b. (a -> b) -> a -> b
$ (IteratorKind -> Attribute) -> [IteratorKind] -> [Attribute]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Attribute -> Attribute
forall t.
(Typeable t, Eq t, Show t, FromAST t Attribute) =>
t -> Attribute
AST.DialectAttr (Attribute -> Attribute)
-> (IteratorKind -> Attribute) -> IteratorKind -> Attribute
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IteratorKind -> Attribute
IteratorAttr) [IteratorKind]
iterTypes
pattern ContractAttrs :: Affine.Map -> Affine.Map -> Affine.Map -> [IteratorKind] -> AST.NamedAttributes
pattern $bContractAttrs :: Map -> Map -> Map -> [IteratorKind] -> NamedAttributes
$mContractAttrs :: forall r.
NamedAttributes
-> (Map -> Map -> Map -> [IteratorKind] -> r) -> (Void# -> r) -> r
ContractAttrs lhsMap rhsMap accMap iterKinds <-
((\m -> (M.lookup "indexing_maps" m, M.lookup "iterator_types" m)) ->
(Just (AST.ArrayAttr [AST.AffineMapAttr lhsMap, AST.AffineMapAttr rhsMap, AST.AffineMapAttr accMap]),
Just (IteratorAttrs iterKinds)))
where ContractAttrs Map
lhsMap Map
rhsMap Map
accMap [IteratorKind]
iterKinds = [(ByteString, Attribute)] -> NamedAttributes
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (ByteString
"indexing_maps", [Attribute] -> Attribute
AST.ArrayAttr [ Map -> Attribute
AST.AffineMapAttr Map
lhsMap
, Map -> Attribute
AST.AffineMapAttr Map
rhsMap
, Map -> Attribute
AST.AffineMapAttr Map
accMap])
, (ByteString
"iterator_types", [IteratorKind] -> Attribute
IteratorAttrs [IteratorKind]
iterKinds)
]
pattern Contract :: AST.Location -> AST.Type -> AST.Name -> AST.Name -> AST.Name
-> Affine.Map -> Affine.Map -> Affine.Map -> [IteratorKind]
-> AST.Operation
pattern $bContract :: Location
-> Type
-> ByteString
-> ByteString
-> ByteString
-> Map
-> Map
-> Map
-> [IteratorKind]
-> Operation
$mContract :: forall r.
Operation
-> (Location
-> Type
-> ByteString
-> ByteString
-> ByteString
-> Map
-> Map
-> Map
-> [IteratorKind]
-> r)
-> (Void# -> r)
-> r
Contract location resultType lhs rhs acc lhsMap rhsMap accMap iterKinds = AST.Operation
{ opName = "vector.contract"
, opLocation = location
, opResultTypes = AST.Explicit [resultType]
, opOperands = [lhs, rhs, acc]
, opRegions = []
, opSuccessors = []
, opAttributes = ContractAttrs lhsMap rhsMap accMap iterKinds
}