-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# OPTIONS_GHC -Wno-name-shadowing #-}

module MLIR.AST.Dialect.Vector
  ( module MLIR.AST.Dialect.Vector
  , module MLIR.AST.Dialect.Generated.Vector
  ) where

import Data.Typeable
import qualified Data.Map.Strict as M
import qualified Data.ByteString as BS
import qualified Language.C.Inline as C

import MLIR.AST.Dialect.Generated.Vector
import qualified MLIR.AST                as AST
import qualified MLIR.AST.Serialize      as AST
import qualified MLIR.AST.Dialect.Affine as Affine
import qualified MLIR.Native             as Native
import qualified MLIR.Native.FFI         as Native

data IteratorKind = Parallel | Reduction
          deriving (IteratorKind -> IteratorKind -> Bool
(IteratorKind -> IteratorKind -> Bool)
-> (IteratorKind -> IteratorKind -> Bool) -> Eq IteratorKind
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IteratorKind -> IteratorKind -> Bool
$c/= :: IteratorKind -> IteratorKind -> Bool
== :: IteratorKind -> IteratorKind -> Bool
$c== :: IteratorKind -> IteratorKind -> Bool
Eq, Int -> IteratorKind -> ShowS
[IteratorKind] -> ShowS
IteratorKind -> String
(Int -> IteratorKind -> ShowS)
-> (IteratorKind -> String)
-> ([IteratorKind] -> ShowS)
-> Show IteratorKind
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IteratorKind] -> ShowS
$cshowList :: [IteratorKind] -> ShowS
show :: IteratorKind -> String
$cshow :: IteratorKind -> String
showsPrec :: Int -> IteratorKind -> ShowS
$cshowsPrec :: Int -> IteratorKind -> ShowS
Show)

data Attribute = IteratorAttr IteratorKind
          deriving (Attribute -> Attribute -> Bool
(Attribute -> Attribute -> Bool)
-> (Attribute -> Attribute -> Bool) -> Eq Attribute
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Attribute -> Attribute -> Bool
$c/= :: Attribute -> Attribute -> Bool
== :: Attribute -> Attribute -> Bool
$c== :: Attribute -> Attribute -> Bool
Eq, Int -> Attribute -> ShowS
[Attribute] -> ShowS
Attribute -> String
(Int -> Attribute -> ShowS)
-> (Attribute -> String)
-> ([Attribute] -> ShowS)
-> Show Attribute
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Attribute] -> ShowS
$cshowList :: [Attribute] -> ShowS
show :: Attribute -> String
$cshow :: Attribute -> String
showsPrec :: Int -> Attribute -> ShowS
$cshowsPrec :: Int -> Attribute -> ShowS
Show)

castVectorAttr :: AST.Attribute -> Maybe Attribute
castVectorAttr :: Attribute -> Maybe Attribute
castVectorAttr Attribute
ty = case Attribute
ty of
  AST.DialectAttr t
dty -> t -> Maybe Attribute
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast t
dty
  Attribute
_                   -> Maybe Attribute
forall a. Maybe a
Nothing

showIterator :: IteratorKind -> BS.ByteString
showIterator :: IteratorKind -> ByteString
showIterator IteratorKind
Parallel  = ByteString
"#vector.iterator_type<parallel>"
showIterator IteratorKind
Reduction = ByteString
"#vector.iterator_type<reduction>"

C.context $ C.baseCtx <> Native.mlirCtx

C.include "mlir-c/IR.h"

instance AST.FromAST Attribute Native.Attribute where
  fromAST :: Context -> ValueAndBlockMapping -> Attribute -> IO Attribute
fromAST Context
ctx ValueAndBlockMapping
_ Attribute
ty = case Attribute
ty of
    IteratorAttr IteratorKind
t -> do
      let value :: ByteString
value = IteratorKind -> ByteString
showIterator IteratorKind
t
      ByteString -> (StringRef -> IO Attribute) -> IO Attribute
forall a. ByteString -> (StringRef -> IO a) -> IO a
Native.withStringRef ByteString
value \(Native.StringRef Ptr CChar
ptr CSize
len) ->
        [C.exp| MlirAttribute {
          mlirAttributeParseGet($(MlirContext ctx), (MlirStringRef){$(char* ptr), $(size_t len)})
        } |]

iterFromAttribute :: AST.Attribute -> Maybe IteratorKind
iterFromAttribute :: Attribute -> Maybe IteratorKind
iterFromAttribute Attribute
attr = case Attribute
attr of
  AST.DialectAttr t
subAttr -> case t -> Maybe Attribute
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast t
subAttr of
    Just (IteratorAttr IteratorKind
kind) -> IteratorKind -> Maybe IteratorKind
forall a. a -> Maybe a
Just IteratorKind
kind
    Maybe Attribute
_ -> Maybe IteratorKind
forall a. Maybe a
Nothing
  Attribute
_ -> Maybe IteratorKind
forall a. Maybe a
Nothing

itersFromAttribute :: AST.Attribute -> Maybe [IteratorKind]
itersFromAttribute :: Attribute -> Maybe [IteratorKind]
itersFromAttribute Attribute
attr = case Attribute
attr of
  AST.ArrayAttr [Attribute]
subAttrs -> (Attribute -> Maybe IteratorKind)
-> [Attribute] -> Maybe [IteratorKind]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Attribute -> Maybe IteratorKind
iterFromAttribute [Attribute]
subAttrs
  Attribute
_                      -> Maybe [IteratorKind]
forall a. Maybe a
Nothing

pattern IteratorAttrs :: [IteratorKind] -> AST.Attribute
pattern $bIteratorAttrs :: [IteratorKind] -> Attribute
$mIteratorAttrs :: forall r. Attribute -> ([IteratorKind] -> r) -> (Void# -> r) -> r
IteratorAttrs iterTypes <- (itersFromAttribute -> Just iterTypes)
  where IteratorAttrs [IteratorKind]
iterTypes = [Attribute] -> Attribute
AST.ArrayAttr ([Attribute] -> Attribute) -> [Attribute] -> Attribute
forall a b. (a -> b) -> a -> b
$ (IteratorKind -> Attribute) -> [IteratorKind] -> [Attribute]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Attribute -> Attribute
forall t.
(Typeable t, Eq t, Show t, FromAST t Attribute) =>
t -> Attribute
AST.DialectAttr (Attribute -> Attribute)
-> (IteratorKind -> Attribute) -> IteratorKind -> Attribute
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IteratorKind -> Attribute
IteratorAttr) [IteratorKind]
iterTypes

pattern ContractAttrs :: Affine.Map -> Affine.Map -> Affine.Map -> [IteratorKind] -> AST.NamedAttributes
pattern $bContractAttrs :: Map -> Map -> Map -> [IteratorKind] -> NamedAttributes
$mContractAttrs :: forall r.
NamedAttributes
-> (Map -> Map -> Map -> [IteratorKind] -> r) -> (Void# -> r) -> r
ContractAttrs lhsMap rhsMap accMap iterKinds <-
  ((\m -> (M.lookup "indexing_maps" m, M.lookup "iterator_types" m)) ->
     (Just (AST.ArrayAttr [AST.AffineMapAttr lhsMap, AST.AffineMapAttr rhsMap, AST.AffineMapAttr accMap]),
      Just (IteratorAttrs iterKinds)))
  where ContractAttrs Map
lhsMap Map
rhsMap Map
accMap [IteratorKind]
iterKinds = [(ByteString, Attribute)] -> NamedAttributes
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
          [ (ByteString
"indexing_maps", [Attribute] -> Attribute
AST.ArrayAttr [ Map -> Attribute
AST.AffineMapAttr Map
lhsMap
                                            , Map -> Attribute
AST.AffineMapAttr Map
rhsMap
                                            , Map -> Attribute
AST.AffineMapAttr Map
accMap])
          , (ByteString
"iterator_types", [IteratorKind] -> Attribute
IteratorAttrs [IteratorKind]
iterKinds)
          ]

pattern Contract :: AST.Location -> AST.Type -> AST.Name -> AST.Name -> AST.Name
                 -> Affine.Map -> Affine.Map -> Affine.Map -> [IteratorKind]
                 -> AST.Operation
pattern $bContract :: Location
-> Type
-> ByteString
-> ByteString
-> ByteString
-> Map
-> Map
-> Map
-> [IteratorKind]
-> Operation
$mContract :: forall r.
Operation
-> (Location
    -> Type
    -> ByteString
    -> ByteString
    -> ByteString
    -> Map
    -> Map
    -> Map
    -> [IteratorKind]
    -> r)
-> (Void# -> r)
-> r
Contract location resultType lhs rhs acc lhsMap rhsMap accMap iterKinds = AST.Operation
  { opName = "vector.contract"
  , opLocation = location
  , opResultTypes = AST.Explicit [resultType]
  , opOperands = [lhs, rhs, acc]
  , opRegions = []
  , opSuccessors = []
  , opAttributes = ContractAttrs lhsMap rhsMap accMap iterKinds
  }