module MLIR.AST.Rewrite
( RewriteBuilderT
, OpRewriteM
, OpRewrite
, RewriteResult(..)
, pattern ReplaceOne
, applyClosedOpRewrite
, applyClosedOpRewriteT
) where
import qualified Data.Map.Strict as M
import Control.Monad.Reader
import Control.Monad.Identity
import qualified MLIR.AST as AST
import MLIR.AST hiding (Operation)
import MLIR.AST.Builder
type Operation = AST.AbstractOperation Value
type ValueMapping = M.Map Name Value
type BlockMapping = M.Map BlockName BlockName
type BlockAndValueMapping = (ValueMapping, BlockMapping)
type SubstT = ReaderT BlockAndValueMapping
type RewriteT m = SubstT (NameSupplyT m)
type RewriteBuilderT m = BlockBuilderT (RewriteT m)
data RewriteResult = Replace [Value] | Skip | Traverse
pattern ReplaceOne :: Value -> RewriteResult
pattern $bReplaceOne :: Value -> RewriteResult
$mReplaceOne :: forall r. RewriteResult -> (Value -> r) -> (Void# -> r) -> r
ReplaceOne val = Replace [val]
type OpRewriteM m = Operation -> RewriteBuilderT m RewriteResult
type OpRewrite = OpRewriteM Identity
extendValueMap :: MonadReader BlockAndValueMapping m => ValueMapping -> m a -> m a
extendValueMap :: ValueMapping -> m a -> m a
extendValueMap ValueMapping
upd = ((ValueMapping, BlockMapping) -> (ValueMapping, BlockMapping))
-> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local \(ValueMapping
vm, BlockMapping
bm) -> (ValueMapping
vm ValueMapping -> ValueMapping -> ValueMapping
forall a. Semigroup a => a -> a -> a
<> ValueMapping
upd, BlockMapping
bm)
extendBlockMap :: MonadReader BlockAndValueMapping m => BlockMapping -> m a -> m a
extendBlockMap :: BlockMapping -> m a -> m a
extendBlockMap BlockMapping
upd = ((ValueMapping, BlockMapping) -> (ValueMapping, BlockMapping))
-> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local \(ValueMapping
vm, BlockMapping
bm) -> (ValueMapping
vm, BlockMapping
bm BlockMapping -> BlockMapping -> BlockMapping
forall a. Semigroup a => a -> a -> a
<> BlockMapping
upd)
applyClosedOpRewrite :: OpRewrite -> AST.Operation -> AST.Operation
applyClosedOpRewrite :: OpRewrite -> Operation -> Operation
applyClosedOpRewrite OpRewrite
rule Operation
op = Identity Operation -> Operation
forall a. Identity a -> a
runIdentity (Identity Operation -> Operation)
-> Identity Operation -> Operation
forall a b. (a -> b) -> a -> b
$ OpRewrite -> Operation -> Identity Operation
forall (m :: * -> *).
MonadFix m =>
OpRewriteM m -> Operation -> m Operation
applyClosedOpRewriteT OpRewrite
rule Operation
op
applyClosedOpRewriteT :: MonadFix m => OpRewriteM m -> AST.Operation -> m AST.Operation
applyClosedOpRewriteT :: OpRewriteM m -> Operation -> m Operation
applyClosedOpRewriteT OpRewriteM m
rule Operation
op = NameSupplyT m Operation -> m Operation
forall (m :: * -> *) a. Monad m => NameSupplyT m a -> m a
evalNameSupplyT (NameSupplyT m Operation -> m Operation)
-> NameSupplyT m Operation -> m Operation
forall a b. (a -> b) -> a -> b
$ OpRewriteM m -> Operation -> NameSupplyT m Operation
forall (m :: * -> *).
MonadFix m =>
OpRewriteM m -> Operation -> NameSupplyT m Operation
applyOpRewrite OpRewriteM m
rule Operation
op
applyOpRewrite :: MonadFix m => OpRewriteM m -> AST.Operation -> NameSupplyT m AST.Operation
applyOpRewrite :: OpRewriteM m -> Operation -> NameSupplyT m Operation
applyOpRewrite OpRewriteM m
rule Operation
op = (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation
-> (ValueMapping, BlockMapping) -> NameSupplyT m Operation)
-> (ValueMapping, BlockMapping)
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation
-> NameSupplyT m Operation
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation
-> (ValueMapping, BlockMapping) -> NameSupplyT m Operation
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ValueMapping
forall a. Monoid a => a
mempty, BlockMapping
forall a. Monoid a => a
mempty) (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation
-> NameSupplyT m Operation)
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation
-> NameSupplyT m Operation
forall a b. (a -> b) -> a -> b
$ do
[Region]
newRegions <- (Region
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Region)
-> [Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OpRewriteM m
-> Region
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Region
forall (m :: * -> *).
MonadFix m =>
OpRewriteM m -> Region -> RewriteT m Region
applyOpRewriteRegion OpRewriteM m
rule) ([Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region])
-> [Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
forall a b. (a -> b) -> a -> b
$ Operation -> [Region]
forall operand. AbstractOperation operand -> [Region]
opRegions Operation
op
Operation
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation
forall (m :: * -> *) a. Monad m => a -> m a
return (Operation
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation)
-> Operation
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Operation
forall a b. (a -> b) -> a -> b
$ Operation
op { opRegions :: [Region]
opRegions = [Region]
newRegions }
applyOpRewriteRegion :: MonadFix m => OpRewriteM m -> Region -> RewriteT m Region
applyOpRewriteRegion :: OpRewriteM m -> Region -> RewriteT m Region
applyOpRewriteRegion OpRewriteM m
rule (Region [Block]
blocks) = do
RegionBuilderT (RewriteT m) () -> RewriteT m Region
forall (m :: * -> *). Monad m => RegionBuilderT m () -> m Region
buildRegion (RegionBuilderT (RewriteT m) () -> RewriteT m Region)
-> RegionBuilderT (RewriteT m) () -> RewriteT m Region
forall a b. (a -> b) -> a -> b
$ RegionBuilderT (RewriteT m) BlockMapping
-> RegionBuilderT (RewriteT m) ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (RegionBuilderT (RewriteT m) BlockMapping
-> RegionBuilderT (RewriteT m) ())
-> RegionBuilderT (RewriteT m) BlockMapping
-> RegionBuilderT (RewriteT m) ()
forall a b. (a -> b) -> a -> b
$ (BlockMapping -> RegionBuilderT (RewriteT m) BlockMapping)
-> RegionBuilderT (RewriteT m) BlockMapping
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix \BlockMapping
blockSubst -> BlockMapping
-> RegionBuilderT (RewriteT m) BlockMapping
-> RegionBuilderT (RewriteT m) BlockMapping
forall (m :: * -> *) a.
MonadReader (ValueMapping, BlockMapping) m =>
BlockMapping -> m a -> m a
extendBlockMap BlockMapping
blockSubst (RegionBuilderT (RewriteT m) BlockMapping
-> RegionBuilderT (RewriteT m) BlockMapping)
-> RegionBuilderT (RewriteT m) BlockMapping
-> RegionBuilderT (RewriteT m) BlockMapping
forall a b. (a -> b) -> a -> b
$ BlockMapping -> [Block] -> RegionBuilderT (RewriteT m) BlockMapping
go BlockMapping
forall a. Monoid a => a
mempty [Block]
blocks
where
go :: BlockMapping -> [Block] -> RegionBuilderT (RewriteT m) BlockMapping
go BlockMapping
blockSubst [Block]
bs = case [Block]
bs of
[] -> BlockMapping -> RegionBuilderT (RewriteT m) BlockMapping
forall (m :: * -> *) a. Monad m => a -> m a
return BlockMapping
blockSubst
(block :: Block
block@(Block Name
oldName [(Name, Type)]
_ [Binding]
_) : [Block]
rest) -> do
Name
newName <- OpRewriteM m -> Block -> RegionBuilderT (RewriteT m) Name
forall (m :: * -> *).
MonadFix m =>
OpRewriteM m -> Block -> RegionBuilderT (RewriteT m) Name
applyOpRewriteBlock OpRewriteM m
rule Block
block
BlockMapping -> [Block] -> RegionBuilderT (RewriteT m) BlockMapping
go (BlockMapping
blockSubst BlockMapping -> BlockMapping -> BlockMapping
forall a. Semigroup a => a -> a -> a
<> Name -> Name -> BlockMapping
forall k a. k -> a -> Map k a
M.singleton Name
oldName Name
newName) [Block]
rest
applyOpRewriteBlock :: MonadFix m => OpRewriteM m -> Block -> RegionBuilderT (RewriteT m) BlockName
applyOpRewriteBlock :: OpRewriteM m -> Block -> RegionBuilderT (RewriteT m) Name
applyOpRewriteBlock OpRewriteM m
rule Block{[(Name, Type)]
[Binding]
Name
blockBody :: Block -> [Binding]
blockArgs :: Block -> [(Name, Type)]
blockName :: Block -> Name
blockBody :: [Binding]
blockArgs :: [(Name, Type)]
blockName :: Name
..} = do
BlockBuilderT (RewriteT m) EndOfBlock
-> RegionBuilderT (RewriteT m) Name
forall (m :: * -> *).
Monad m =>
BlockBuilderT m EndOfBlock -> RegionBuilderT m Name
buildBlock (BlockBuilderT (RewriteT m) EndOfBlock
-> RegionBuilderT (RewriteT m) Name)
-> BlockBuilderT (RewriteT m) EndOfBlock
-> RegionBuilderT (RewriteT m) Name
forall a b. (a -> b) -> a -> b
$ do
let ([Name]
blockArgNames, [Type]
blockArgTypes) = [(Name, Type)] -> ([Name], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Name, Type)]
blockArgs
[Value]
newBlockArgs <- (Type -> BlockBuilderT (RewriteT m) Value)
-> [Type] -> BlockBuilderT (RewriteT m) [Value]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> BlockBuilderT (RewriteT m) Value
forall (m :: * -> *). MonadBlockBuilder m => Type -> m Value
blockArgument [Type]
blockArgTypes
ValueMapping
-> BlockBuilderT (RewriteT m) EndOfBlock
-> BlockBuilderT (RewriteT m) EndOfBlock
forall (m :: * -> *) a.
MonadReader (ValueMapping, BlockMapping) m =>
ValueMapping -> m a -> m a
extendValueMap ([(Name, Value)] -> ValueMapping
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Value)] -> ValueMapping)
-> [(Name, Value)] -> ValueMapping
forall a b. (a -> b) -> a -> b
$ [Name] -> [Value] -> [(Name, Value)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
blockArgNames [Value]
newBlockArgs) (BlockBuilderT (RewriteT m) EndOfBlock
-> BlockBuilderT (RewriteT m) EndOfBlock)
-> BlockBuilderT (RewriteT m) EndOfBlock
-> BlockBuilderT (RewriteT m) EndOfBlock
forall a b. (a -> b) -> a -> b
$ [Binding] -> BlockBuilderT (RewriteT m) EndOfBlock
go [Binding]
blockBody
where
go :: [Binding] -> BlockBuilderT (RewriteT m) EndOfBlock
go [Binding]
bs = case [Binding]
bs of
[] -> BlockBuilderT (RewriteT m) EndOfBlock
forall (m :: * -> *). Monad m => m EndOfBlock
terminateBlock
((Bind [Name]
names Operation
astOp) : [Binding]
rest) -> do
Operation
op <- Operation -> BlockBuilderT (RewriteT m) Operation
forall (m :: * -> *).
MonadReader (ValueMapping, BlockMapping) m =>
Operation -> m Operation
substOp Operation
astOp
RewriteResult
answer <- OpRewriteM m
rule Operation
op
[Value]
newValues <- case RewriteResult
answer of
Replace [Value]
newValues -> do
Bool
-> BlockBuilderT (RewriteT m) () -> BlockBuilderT (RewriteT m) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Value] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
newValues) (BlockBuilderT (RewriteT m) () -> BlockBuilderT (RewriteT m) ())
-> BlockBuilderT (RewriteT m) () -> BlockBuilderT (RewriteT m) ()
forall a b. (a -> b) -> a -> b
$
[Char] -> BlockBuilderT (RewriteT m) ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Rewrite rule returned an incorrect number of values"
[Value] -> BlockBuilderT (RewriteT m) [Value]
forall (m :: * -> *) a. Monad m => a -> m a
return [Value]
newValues
RewriteResult
Traverse -> Operation -> BlockBuilderT (RewriteT m) [Value]
opRewriteTraverse Operation
op
RewriteResult
Skip -> Operation -> BlockBuilderT (RewriteT m) [Value]
forall (t :: (* -> *) -> * -> *) (m :: * -> *).
(MonadTrans t, MonadFix m,
MonadBlockBuilder
(t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m)))) =>
Operation
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m)) [Value]
opRewriteSkip Operation
op
ValueMapping
-> BlockBuilderT (RewriteT m) EndOfBlock
-> BlockBuilderT (RewriteT m) EndOfBlock
forall (m :: * -> *) a.
MonadReader (ValueMapping, BlockMapping) m =>
ValueMapping -> m a -> m a
extendValueMap ([(Name, Value)] -> ValueMapping
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Value)] -> ValueMapping)
-> [(Name, Value)] -> ValueMapping
forall a b. (a -> b) -> a -> b
$ [Name] -> [Value] -> [(Name, Value)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names [Value]
newValues) (BlockBuilderT (RewriteT m) EndOfBlock
-> BlockBuilderT (RewriteT m) EndOfBlock)
-> BlockBuilderT (RewriteT m) EndOfBlock
-> BlockBuilderT (RewriteT m) EndOfBlock
forall a b. (a -> b) -> a -> b
$ [Binding] -> BlockBuilderT (RewriteT m) EndOfBlock
go [Binding]
rest
opRewriteTraverse :: Operation -> BlockBuilderT (RewriteT m) [Value]
opRewriteTraverse Operation
op = do
[Region]
newRegions <- ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
-> BlockBuilderT (RewriteT m) [Region]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
-> BlockBuilderT (RewriteT m) [Region])
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
-> BlockBuilderT (RewriteT m) [Region]
forall a b. (a -> b) -> a -> b
$ (Region
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Region)
-> [Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OpRewriteM m
-> Region
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Region
forall (m :: * -> *).
MonadFix m =>
OpRewriteM m -> Region -> RewriteT m Region
applyOpRewriteRegion OpRewriteM m
rule) ([Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region])
-> [Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
forall a b. (a -> b) -> a -> b
$ Operation -> [Region]
forall operand. AbstractOperation operand -> [Region]
opRegions Operation
op
Operation -> BlockBuilderT (RewriteT m) [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
emitOp (Operation -> BlockBuilderT (RewriteT m) [Value])
-> Operation -> BlockBuilderT (RewriteT m) [Value]
forall a b. (a -> b) -> a -> b
$ Operation
op { opRegions :: [Region]
opRegions = [Region]
newRegions, opOperands :: [Name]
opOperands = [Value] -> [Name]
operands (Operation -> [Value]
forall operand. AbstractOperation operand -> [operand]
opOperands Operation
op) }
opRewriteSkip :: Operation
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m)) [Value]
opRewriteSkip Operation
op = do
[Region]
newRegions <- ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
[Region]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
[Region])
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
[Region]
forall a b. (a -> b) -> a -> b
$ (Region
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Region)
-> [Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OpRewriteM m
-> Region
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) Region
forall (m :: * -> *).
MonadFix m =>
OpRewriteM m -> Region -> RewriteT m Region
applyOpRewriteRegion (BlockBuilderT
(ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
RewriteResult
-> OpRewriteM m
forall a b. a -> b -> a
const (BlockBuilderT
(ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
RewriteResult
-> OpRewriteM m)
-> BlockBuilderT
(ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
RewriteResult
-> OpRewriteM m
forall a b. (a -> b) -> a -> b
$ RewriteResult
-> BlockBuilderT
(ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
RewriteResult
forall (m :: * -> *) a. Monad m => a -> m a
return RewriteResult
Skip)) ([Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region])
-> [Region]
-> ReaderT (ValueMapping, BlockMapping) (NameSupplyT m) [Region]
forall a b. (a -> b) -> a -> b
$ Operation -> [Region]
forall operand. AbstractOperation operand -> [Region]
opRegions Operation
op
Operation
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m)) [Value]
forall (m :: * -> *). MonadBlockBuilder m => Operation -> m [Value]
emitOp (Operation
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m))
[Value])
-> Operation
-> t (ReaderT (ValueMapping, BlockMapping) (NameSupplyT m)) [Value]
forall a b. (a -> b) -> a -> b
$ Operation
op { opRegions :: [Region]
opRegions = [Region]
newRegions, opOperands :: [Name]
opOperands = [Value] -> [Name]
operands (Operation -> [Value]
forall operand. AbstractOperation operand -> [operand]
opOperands Operation
op) }
substOp :: MonadReader BlockAndValueMapping m => AST.Operation -> m Operation
substOp :: Operation -> m Operation
substOp Operation
op = do
(ValueMapping
valueMap, BlockMapping
blockMap) <- m (ValueMapping, BlockMapping)
forall r (m :: * -> *). MonadReader r m => m r
ask
let newOperands :: [Value]
newOperands = (Name -> Value) -> [Name] -> [Value]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ValueMapping
valueMap ValueMapping -> Name -> Value
forall k a. Ord k => Map k a -> k -> a
M.!) ([Name] -> [Value]) -> [Name] -> [Value]
forall a b. (a -> b) -> a -> b
$ Operation -> [Name]
forall operand. AbstractOperation operand -> [operand]
opOperands Operation
op
let newSuccessors :: [Name]
newSuccessors = (Name -> Name) -> [Name] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BlockMapping
blockMap BlockMapping -> Name -> Name
forall k a. Ord k => Map k a -> k -> a
M.!) ([Name] -> [Name]) -> [Name] -> [Name]
forall a b. (a -> b) -> a -> b
$ Operation -> [Name]
forall operand. AbstractOperation operand -> [Name]
opSuccessors Operation
op
Operation -> m Operation
forall (m :: * -> *) a. Monad m => a -> m a
return (Operation -> m Operation) -> Operation -> m Operation
forall a b. (a -> b) -> a -> b
$ Operation
op { opOperands :: [Value]
opOperands = [Value]
newOperands, opSuccessors :: [Name]
opSuccessors = [Name]
newSuccessors }