-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE UndecidableInstances #-}
module MLIR.AST.Builder where

import MLIR.AST
import Data.String
import Data.Functor
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.Writer
import Control.Monad.Reader

--------------------------------------------------------------------------------
-- Value

data Value = Name :> Type

typeOf :: Value -> Type
typeOf :: Value -> Type
typeOf (Name
_ :> Type
ty) = Type
ty

operand :: Value -> Name
operand :: Value -> Name
operand (Name
n :> Type
_) = Name
n

operands :: [Value] -> [Name]
operands :: [Value] -> [Name]
operands = (Value -> Name) -> [Value] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Value -> Name
operand

--------------------------------------------------------------------------------
-- Name supply monad

newtype NameSupply = NameSupply { NameSupply -> Int
nextName :: Int }
newtype NameSupplyT m a = NameSupplyT (StateT NameSupply m a)
                          deriving (a -> NameSupplyT m b -> NameSupplyT m a
(a -> b) -> NameSupplyT m a -> NameSupplyT m b
(forall a b. (a -> b) -> NameSupplyT m a -> NameSupplyT m b)
-> (forall a b. a -> NameSupplyT m b -> NameSupplyT m a)
-> Functor (NameSupplyT m)
forall a b. a -> NameSupplyT m b -> NameSupplyT m a
forall a b. (a -> b) -> NameSupplyT m a -> NameSupplyT m b
forall (m :: * -> *) a b.
Functor m =>
a -> NameSupplyT m b -> NameSupplyT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> NameSupplyT m a -> NameSupplyT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> NameSupplyT m b -> NameSupplyT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> NameSupplyT m b -> NameSupplyT m a
fmap :: (a -> b) -> NameSupplyT m a -> NameSupplyT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> NameSupplyT m a -> NameSupplyT m b
Functor, Functor (NameSupplyT m)
a -> NameSupplyT m a
Functor (NameSupplyT m)
-> (forall a. a -> NameSupplyT m a)
-> (forall a b.
    NameSupplyT m (a -> b) -> NameSupplyT m a -> NameSupplyT m b)
-> (forall a b c.
    (a -> b -> c)
    -> NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m c)
-> (forall a b.
    NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b)
-> (forall a b.
    NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m a)
-> Applicative (NameSupplyT m)
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m a
NameSupplyT m (a -> b) -> NameSupplyT m a -> NameSupplyT m b
(a -> b -> c)
-> NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m c
forall a. a -> NameSupplyT m a
forall a b. NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m a
forall a b. NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
forall a b.
NameSupplyT m (a -> b) -> NameSupplyT m a -> NameSupplyT m b
forall a b c.
(a -> b -> c)
-> NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m c
forall (m :: * -> *). Monad m => Functor (NameSupplyT m)
forall (m :: * -> *) a. Monad m => a -> NameSupplyT m a
forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m a
forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m (a -> b) -> NameSupplyT m a -> NameSupplyT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m a
*> :: NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
liftA2 :: (a -> b -> c)
-> NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m c
<*> :: NameSupplyT m (a -> b) -> NameSupplyT m a -> NameSupplyT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m (a -> b) -> NameSupplyT m a -> NameSupplyT m b
pure :: a -> NameSupplyT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> NameSupplyT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (NameSupplyT m)
Applicative, Applicative (NameSupplyT m)
a -> NameSupplyT m a
Applicative (NameSupplyT m)
-> (forall a b.
    NameSupplyT m a -> (a -> NameSupplyT m b) -> NameSupplyT m b)
-> (forall a b.
    NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b)
-> (forall a. a -> NameSupplyT m a)
-> Monad (NameSupplyT m)
NameSupplyT m a -> (a -> NameSupplyT m b) -> NameSupplyT m b
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
forall a. a -> NameSupplyT m a
forall a b. NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
forall a b.
NameSupplyT m a -> (a -> NameSupplyT m b) -> NameSupplyT m b
forall (m :: * -> *). Monad m => Applicative (NameSupplyT m)
forall (m :: * -> *) a. Monad m => a -> NameSupplyT m a
forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> (a -> NameSupplyT m b) -> NameSupplyT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> NameSupplyT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> NameSupplyT m a
>> :: NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> NameSupplyT m b -> NameSupplyT m b
>>= :: NameSupplyT m a -> (a -> NameSupplyT m b) -> NameSupplyT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
NameSupplyT m a -> (a -> NameSupplyT m b) -> NameSupplyT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (NameSupplyT m)
Monad,
                                    m a -> NameSupplyT m a
(forall (m :: * -> *) a. Monad m => m a -> NameSupplyT m a)
-> MonadTrans NameSupplyT
forall (m :: * -> *) a. Monad m => m a -> NameSupplyT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> NameSupplyT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> NameSupplyT m a
MonadTrans, Monad (NameSupplyT m)
Monad (NameSupplyT m)
-> (forall a. (a -> NameSupplyT m a) -> NameSupplyT m a)
-> MonadFix (NameSupplyT m)
(a -> NameSupplyT m a) -> NameSupplyT m a
forall a. (a -> NameSupplyT m a) -> NameSupplyT m a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
forall (m :: * -> *). MonadFix m => Monad (NameSupplyT m)
forall (m :: * -> *) a.
MonadFix m =>
(a -> NameSupplyT m a) -> NameSupplyT m a
mfix :: (a -> NameSupplyT m a) -> NameSupplyT m a
$cmfix :: forall (m :: * -> *) a.
MonadFix m =>
(a -> NameSupplyT m a) -> NameSupplyT m a
$cp1MonadFix :: forall (m :: * -> *). MonadFix m => Monad (NameSupplyT m)
MonadFix,
                                    MonadReader r, MonadWriter w)

instance MonadState s m => MonadState s (NameSupplyT m) where
  get :: NameSupplyT m s
get = m s -> NameSupplyT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> NameSupplyT m ()
put = m () -> NameSupplyT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> NameSupplyT m ()) -> (s -> m ()) -> s -> NameSupplyT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

class Monad m => MonadNameSupply m where
  freshName :: m Name

instance MonadNameSupply m => MonadNameSupply (ReaderT r m) where
  freshName :: ReaderT r m Name
freshName = m Name -> ReaderT r m Name
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Name
forall (m :: * -> *). MonadNameSupply m => m Name
freshName

evalNameSupplyT :: Monad m => NameSupplyT m a -> m a
evalNameSupplyT :: NameSupplyT m a -> m a
evalNameSupplyT (NameSupplyT StateT NameSupply m a
a) = StateT NameSupply m a -> NameSupply -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT NameSupply m a
a (NameSupply -> m a) -> NameSupply -> m a
forall a b. (a -> b) -> a -> b
$ Int -> NameSupply
NameSupply Int
0

instance Monad m => MonadNameSupply (NameSupplyT m) where
  freshName :: NameSupplyT m Name
freshName = StateT NameSupply m Name -> NameSupplyT m Name
forall (m :: * -> *) a. StateT NameSupply m a -> NameSupplyT m a
NameSupplyT (StateT NameSupply m Name -> NameSupplyT m Name)
-> StateT NameSupply m Name -> NameSupplyT m Name
forall a b. (a -> b) -> a -> b
$ do
    Int
curId <- (NameSupply -> Int) -> StateT NameSupply m Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets NameSupply -> Int
nextName
    (NameSupply -> NameSupply) -> StateT NameSupply m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify \NameSupply
s -> NameSupply
s { nextName :: Int
nextName = NameSupply -> Int
nextName NameSupply
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 }
    Name -> StateT NameSupply m Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> StateT NameSupply m Name)
-> Name -> StateT NameSupply m Name
forall a b. (a -> b) -> a -> b
$ String -> Name
forall a. IsString a => String -> a
fromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
curId

freshValue :: MonadNameSupply m => Type -> m Value
freshValue :: Type -> m Value
freshValue Type
ty = m Name
forall (m :: * -> *). MonadNameSupply m => m Name
freshName m Name -> (Name -> Value) -> m Value
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Name -> Type -> Value
:> Type
ty)

freshBlockArg :: MonadNameSupply m => Type -> m Value
freshBlockArg :: Type -> m Value
freshBlockArg Type
ty = ((Name
"arg" Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<>) (Name -> Name) -> m Name -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Name
forall (m :: * -> *). MonadNameSupply m => m Name
freshName) m Name -> (Name -> Value) -> m Value
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Name -> Type -> Value
:> Type
ty)

--------------------------------------------------------------------------------
-- Block builder monad

-- TODO(apaszke): Thread locations through
-- TODO(apaszke): Use a writer monad
data BlockBindings = BlockBindings
  { BlockBindings -> SnocList Binding
blockBindings :: SnocList Binding
  , BlockBindings -> SnocList Value
blockArguments :: SnocList Value
  , BlockBindings -> Location
blockDefaultLocation :: Location
  }

instance Semigroup BlockBindings where
  BlockBindings SnocList Binding
bs SnocList Value
args Location
_ <> :: BlockBindings -> BlockBindings -> BlockBindings
<> BlockBindings SnocList Binding
bs' SnocList Value
args' Location
loc' =
    SnocList Binding -> SnocList Value -> Location -> BlockBindings
BlockBindings (SnocList Binding
bs SnocList Binding -> SnocList Binding -> SnocList Binding
forall a. Semigroup a => a -> a -> a
<> SnocList Binding
bs') (SnocList Value
args SnocList Value -> SnocList Value -> SnocList Value
forall a. Semigroup a => a -> a -> a
<> SnocList Value
args') Location
loc'

instance Monoid BlockBindings where
  mempty :: BlockBindings
mempty = SnocList Binding -> SnocList Value -> Location -> BlockBindings
BlockBindings SnocList Binding
forall a. Monoid a => a
mempty SnocList Value
forall a. Monoid a => a
mempty Location
UnknownLocation

newtype BlockBuilderT m a = BlockBuilderT (StateT BlockBindings m a)
                            deriving (a -> BlockBuilderT m b -> BlockBuilderT m a
(a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
(forall a b. (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b)
-> (forall a b. a -> BlockBuilderT m b -> BlockBuilderT m a)
-> Functor (BlockBuilderT m)
forall a b. a -> BlockBuilderT m b -> BlockBuilderT m a
forall a b. (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
forall (m :: * -> *) a b.
Functor m =>
a -> BlockBuilderT m b -> BlockBuilderT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> BlockBuilderT m b -> BlockBuilderT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> BlockBuilderT m b -> BlockBuilderT m a
fmap :: (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
Functor, Functor (BlockBuilderT m)
a -> BlockBuilderT m a
Functor (BlockBuilderT m)
-> (forall a. a -> BlockBuilderT m a)
-> (forall a b.
    BlockBuilderT m (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b)
-> (forall a b c.
    (a -> b -> c)
    -> BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m c)
-> (forall a b.
    BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b)
-> (forall a b.
    BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m a)
-> Applicative (BlockBuilderT m)
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m a
BlockBuilderT m (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
(a -> b -> c)
-> BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m c
forall a. a -> BlockBuilderT m a
forall a b.
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m a
forall a b.
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
forall a b.
BlockBuilderT m (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
forall a b c.
(a -> b -> c)
-> BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m c
forall (m :: * -> *). Monad m => Functor (BlockBuilderT m)
forall (m :: * -> *) a. Monad m => a -> BlockBuilderT m a
forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m a
forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m a
*> :: BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
liftA2 :: (a -> b -> c)
-> BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m c
<*> :: BlockBuilderT m (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m (a -> b) -> BlockBuilderT m a -> BlockBuilderT m b
pure :: a -> BlockBuilderT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> BlockBuilderT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (BlockBuilderT m)
Applicative, Applicative (BlockBuilderT m)
a -> BlockBuilderT m a
Applicative (BlockBuilderT m)
-> (forall a b.
    BlockBuilderT m a -> (a -> BlockBuilderT m b) -> BlockBuilderT m b)
-> (forall a b.
    BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b)
-> (forall a. a -> BlockBuilderT m a)
-> Monad (BlockBuilderT m)
BlockBuilderT m a -> (a -> BlockBuilderT m b) -> BlockBuilderT m b
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
forall a. a -> BlockBuilderT m a
forall a b.
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
forall a b.
BlockBuilderT m a -> (a -> BlockBuilderT m b) -> BlockBuilderT m b
forall (m :: * -> *). Monad m => Applicative (BlockBuilderT m)
forall (m :: * -> *) a. Monad m => a -> BlockBuilderT m a
forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> (a -> BlockBuilderT m b) -> BlockBuilderT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> BlockBuilderT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> BlockBuilderT m a
>> :: BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> BlockBuilderT m b -> BlockBuilderT m b
>>= :: BlockBuilderT m a -> (a -> BlockBuilderT m b) -> BlockBuilderT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
BlockBuilderT m a -> (a -> BlockBuilderT m b) -> BlockBuilderT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (BlockBuilderT m)
Monad,
                                      m a -> BlockBuilderT m a
(forall (m :: * -> *) a. Monad m => m a -> BlockBuilderT m a)
-> MonadTrans BlockBuilderT
forall (m :: * -> *) a. Monad m => m a -> BlockBuilderT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> BlockBuilderT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> BlockBuilderT m a
MonadTrans, Monad (BlockBuilderT m)
Monad (BlockBuilderT m)
-> (forall a. (a -> BlockBuilderT m a) -> BlockBuilderT m a)
-> MonadFix (BlockBuilderT m)
(a -> BlockBuilderT m a) -> BlockBuilderT m a
forall a. (a -> BlockBuilderT m a) -> BlockBuilderT m a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
forall (m :: * -> *). MonadFix m => Monad (BlockBuilderT m)
forall (m :: * -> *) a.
MonadFix m =>
(a -> BlockBuilderT m a) -> BlockBuilderT m a
mfix :: (a -> BlockBuilderT m a) -> BlockBuilderT m a
$cmfix :: forall (m :: * -> *) a.
MonadFix m =>
(a -> BlockBuilderT m a) -> BlockBuilderT m a
$cp1MonadFix :: forall (m :: * -> *). MonadFix m => Monad (BlockBuilderT m)
MonadFix,
                                      MonadReader r, MonadWriter w)

instance MonadState s m => MonadState s (BlockBuilderT m) where
  get :: BlockBuilderT m s
get = m s -> BlockBuilderT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> BlockBuilderT m ()
put = m () -> BlockBuilderT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> BlockBuilderT m ())
-> (s -> m ()) -> s -> BlockBuilderT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

class Monad m => MonadBlockDecl m where
  emitOp_ :: Operation -> m ()
class MonadBlockDecl m => MonadBlockBuilder m where
  emitOp :: Operation -> m [Value]
  blockArgument :: Type -> m Value
  setDefaultLocation :: Location -> m ()

data EndOfBlock = EndOfBlock

terminateBlock :: Monad m => m EndOfBlock
terminateBlock :: m EndOfBlock
terminateBlock = EndOfBlock -> m EndOfBlock
forall (m :: * -> *) a. Monad m => a -> m a
return EndOfBlock
EndOfBlock

noTerminator :: Monad m => m EndOfBlock
noTerminator :: m EndOfBlock
noTerminator = EndOfBlock -> m EndOfBlock
forall (m :: * -> *) a. Monad m => a -> m a
return EndOfBlock
EndOfBlock

runBlockBuilder :: Monad m => BlockBuilderT m a -> m (a, ([Value], [Binding]))
runBlockBuilder :: BlockBuilderT m a -> m (a, ([Value], [Binding]))
runBlockBuilder (BlockBuilderT StateT BlockBindings m a
act) = do
  (a
result, BlockBindings SnocList Binding
binds SnocList Value
args Location
_) <- StateT BlockBindings m a -> BlockBindings -> m (a, BlockBindings)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT BlockBindings m a
act BlockBindings
forall a. Monoid a => a
mempty
  (a, ([Value], [Binding])) -> m (a, ([Value], [Binding]))
forall (m :: * -> *) a. Monad m => a -> m a
return (a
result, (SnocList Value -> [Value]
forall a. SnocList a -> [a]
unsnocList SnocList Value
args, SnocList Binding -> [Binding]
forall a. SnocList a -> [a]
unsnocList SnocList Binding
binds))

instance Monad m => MonadBlockDecl (BlockBuilderT m) where
  emitOp_ :: Operation -> BlockBuilderT m ()
emitOp_ Operation
op = StateT BlockBindings m () -> BlockBuilderT m ()
forall (m :: * -> *) a.
StateT BlockBindings m a -> BlockBuilderT m a
BlockBuilderT (StateT BlockBindings m () -> BlockBuilderT m ())
-> StateT BlockBindings m () -> BlockBuilderT m ()
forall a b. (a -> b) -> a -> b
$ do
    case Operation -> ResultTypes
forall operand. AbstractOperation operand -> ResultTypes
opResultTypes Operation
op of
      ResultTypes
Inferred    -> String -> StateT BlockBindings m ()
forall a. HasCallStack => String -> a
error String
"Builder doesn't support inferred result types!"
      Explicit [] -> (BlockBindings -> BlockBindings) -> StateT BlockBindings m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify \BlockBindings
s -> BlockBindings
s { blockBindings :: SnocList Binding
blockBindings = BlockBindings -> SnocList Binding
blockBindings BlockBindings
s SnocList Binding -> Binding -> SnocList Binding
forall a. SnocList a -> a -> SnocList a
.:. (Operation -> Binding
Do Operation
op) }
      Explicit [Type]
_  -> String -> StateT BlockBindings m ()
forall a. HasCallStack => String -> a
error String
"emitOp_ can only be used on ops that have no results"

instance MonadNameSupply m => MonadBlockBuilder (BlockBuilderT m) where
  emitOp :: Operation -> BlockBuilderT m [Value]
emitOp Operation
opNoLoc = StateT BlockBindings m [Value] -> BlockBuilderT m [Value]
forall (m :: * -> *) a.
StateT BlockBindings m a -> BlockBuilderT m a
BlockBuilderT (StateT BlockBindings m [Value] -> BlockBuilderT m [Value])
-> StateT BlockBindings m [Value] -> BlockBuilderT m [Value]
forall a b. (a -> b) -> a -> b
$ do
    Location
loc <- (BlockBindings -> Location) -> StateT BlockBindings m Location
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets BlockBindings -> Location
blockDefaultLocation
    let op :: Operation
op = case Operation -> Location
forall operand. AbstractOperation operand -> Location
opLocation Operation
opNoLoc of
          Location
UnknownLocation -> Operation
opNoLoc { opLocation :: Location
opLocation = Location
loc }
          Location
_ -> Operation
opNoLoc
    [Value]
results <- case Operation -> ResultTypes
forall operand. AbstractOperation operand -> ResultTypes
opResultTypes Operation
op of
      ResultTypes
Inferred     -> String -> StateT BlockBindings m [Value]
forall a. HasCallStack => String -> a
error String
"Builder doesn't support inferred result types!"
      Explicit [Type]
tys -> m [Value] -> StateT BlockBindings m [Value]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m [Value] -> StateT BlockBindings m [Value])
-> m [Value] -> StateT BlockBindings m [Value]
forall a b. (a -> b) -> a -> b
$ (Type -> m Value) -> [Type] -> m [Value]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> m Value
forall (m :: * -> *). MonadNameSupply m => Type -> m Value
freshValue [Type]
tys
    (BlockBindings -> BlockBindings) -> StateT BlockBindings m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify \BlockBindings
s -> BlockBindings
s { blockBindings :: SnocList Binding
blockBindings = BlockBindings -> SnocList Binding
blockBindings BlockBindings
s SnocList Binding -> Binding -> SnocList Binding
forall a. SnocList a -> a -> SnocList a
.:. ([Value] -> [Name]
operands [Value]
results [Name] -> Operation -> Binding
::= Operation
op) }
    [Value] -> StateT BlockBindings m [Value]
forall (m :: * -> *) a. Monad m => a -> m a
return [Value]
results
  blockArgument :: Type -> BlockBuilderT m Value
blockArgument Type
ty = StateT BlockBindings m Value -> BlockBuilderT m Value
forall (m :: * -> *) a.
StateT BlockBindings m a -> BlockBuilderT m a
BlockBuilderT (StateT BlockBindings m Value -> BlockBuilderT m Value)
-> StateT BlockBindings m Value -> BlockBuilderT m Value
forall a b. (a -> b) -> a -> b
$ do
    Value
value <- m Value -> StateT BlockBindings m Value
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Value -> StateT BlockBindings m Value)
-> m Value -> StateT BlockBindings m Value
forall a b. (a -> b) -> a -> b
$ Type -> m Value
forall (m :: * -> *). MonadNameSupply m => Type -> m Value
freshValue Type
ty
    (BlockBindings -> BlockBindings) -> StateT BlockBindings m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify \BlockBindings
s -> BlockBindings
s { blockArguments :: SnocList Value
blockArguments = BlockBindings -> SnocList Value
blockArguments BlockBindings
s SnocList Value -> Value -> SnocList Value
forall a. SnocList a -> a -> SnocList a
.:. Value
value }
    Value -> StateT BlockBindings m Value
forall (m :: * -> *) a. Monad m => a -> m a
return Value
value
  setDefaultLocation :: Location -> BlockBuilderT m ()
setDefaultLocation Location
loc = StateT BlockBindings m () -> BlockBuilderT m ()
forall (m :: * -> *) a.
StateT BlockBindings m a -> BlockBuilderT m a
BlockBuilderT (StateT BlockBindings m () -> BlockBuilderT m ())
-> StateT BlockBindings m () -> BlockBuilderT m ()
forall a b. (a -> b) -> a -> b
$ (BlockBindings -> BlockBindings) -> StateT BlockBindings m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify \BlockBindings
s -> BlockBindings
s { blockDefaultLocation :: Location
blockDefaultLocation = Location
loc }

--------------------------------------------------------------------------------
-- Region builder monad

-- TODO(apaszke): Make this a writer, assign block names only at the very end
data RegionBuilderState = RegionBuilderState
  { RegionBuilderState -> SnocList Block
blocks      :: SnocList Block
  , RegionBuilderState -> Int
nextBlockId :: Int
  }
newtype RegionBuilderT m a = RegionBuilderT (StateT RegionBuilderState m a)
                             deriving (a -> RegionBuilderT m b -> RegionBuilderT m a
(a -> b) -> RegionBuilderT m a -> RegionBuilderT m b
(forall a b. (a -> b) -> RegionBuilderT m a -> RegionBuilderT m b)
-> (forall a b. a -> RegionBuilderT m b -> RegionBuilderT m a)
-> Functor (RegionBuilderT m)
forall a b. a -> RegionBuilderT m b -> RegionBuilderT m a
forall a b. (a -> b) -> RegionBuilderT m a -> RegionBuilderT m b
forall (m :: * -> *) a b.
Functor m =>
a -> RegionBuilderT m b -> RegionBuilderT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> RegionBuilderT m a -> RegionBuilderT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> RegionBuilderT m b -> RegionBuilderT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> RegionBuilderT m b -> RegionBuilderT m a
fmap :: (a -> b) -> RegionBuilderT m a -> RegionBuilderT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> RegionBuilderT m a -> RegionBuilderT m b
Functor, Functor (RegionBuilderT m)
a -> RegionBuilderT m a
Functor (RegionBuilderT m)
-> (forall a. a -> RegionBuilderT m a)
-> (forall a b.
    RegionBuilderT m (a -> b)
    -> RegionBuilderT m a -> RegionBuilderT m b)
-> (forall a b c.
    (a -> b -> c)
    -> RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m c)
-> (forall a b.
    RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b)
-> (forall a b.
    RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m a)
-> Applicative (RegionBuilderT m)
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m a
RegionBuilderT m (a -> b)
-> RegionBuilderT m a -> RegionBuilderT m b
(a -> b -> c)
-> RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m c
forall a. a -> RegionBuilderT m a
forall a b.
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m a
forall a b.
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
forall a b.
RegionBuilderT m (a -> b)
-> RegionBuilderT m a -> RegionBuilderT m b
forall a b c.
(a -> b -> c)
-> RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m c
forall (m :: * -> *). Monad m => Functor (RegionBuilderT m)
forall (m :: * -> *) a. Monad m => a -> RegionBuilderT m a
forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m a
forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m (a -> b)
-> RegionBuilderT m a -> RegionBuilderT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m a
*> :: RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
liftA2 :: (a -> b -> c)
-> RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m c
<*> :: RegionBuilderT m (a -> b)
-> RegionBuilderT m a -> RegionBuilderT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m (a -> b)
-> RegionBuilderT m a -> RegionBuilderT m b
pure :: a -> RegionBuilderT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> RegionBuilderT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (RegionBuilderT m)
Applicative, Applicative (RegionBuilderT m)
a -> RegionBuilderT m a
Applicative (RegionBuilderT m)
-> (forall a b.
    RegionBuilderT m a
    -> (a -> RegionBuilderT m b) -> RegionBuilderT m b)
-> (forall a b.
    RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b)
-> (forall a. a -> RegionBuilderT m a)
-> Monad (RegionBuilderT m)
RegionBuilderT m a
-> (a -> RegionBuilderT m b) -> RegionBuilderT m b
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
forall a. a -> RegionBuilderT m a
forall a b.
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
forall a b.
RegionBuilderT m a
-> (a -> RegionBuilderT m b) -> RegionBuilderT m b
forall (m :: * -> *). Monad m => Applicative (RegionBuilderT m)
forall (m :: * -> *) a. Monad m => a -> RegionBuilderT m a
forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a
-> (a -> RegionBuilderT m b) -> RegionBuilderT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> RegionBuilderT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> RegionBuilderT m a
>> :: RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a -> RegionBuilderT m b -> RegionBuilderT m b
>>= :: RegionBuilderT m a
-> (a -> RegionBuilderT m b) -> RegionBuilderT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
RegionBuilderT m a
-> (a -> RegionBuilderT m b) -> RegionBuilderT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (RegionBuilderT m)
Monad,
                                       m a -> RegionBuilderT m a
(forall (m :: * -> *) a. Monad m => m a -> RegionBuilderT m a)
-> MonadTrans RegionBuilderT
forall (m :: * -> *) a. Monad m => m a -> RegionBuilderT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> RegionBuilderT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> RegionBuilderT m a
MonadTrans, Monad (RegionBuilderT m)
Monad (RegionBuilderT m)
-> (forall a. (a -> RegionBuilderT m a) -> RegionBuilderT m a)
-> MonadFix (RegionBuilderT m)
(a -> RegionBuilderT m a) -> RegionBuilderT m a
forall a. (a -> RegionBuilderT m a) -> RegionBuilderT m a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
forall (m :: * -> *). MonadFix m => Monad (RegionBuilderT m)
forall (m :: * -> *) a.
MonadFix m =>
(a -> RegionBuilderT m a) -> RegionBuilderT m a
mfix :: (a -> RegionBuilderT m a) -> RegionBuilderT m a
$cmfix :: forall (m :: * -> *) a.
MonadFix m =>
(a -> RegionBuilderT m a) -> RegionBuilderT m a
$cp1MonadFix :: forall (m :: * -> *). MonadFix m => Monad (RegionBuilderT m)
MonadFix,
                                       MonadReader r, MonadWriter w)

instance MonadState s m => MonadState s (RegionBuilderT m) where
  get :: RegionBuilderT m s
get = m s -> RegionBuilderT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> RegionBuilderT m ()
put = m () -> RegionBuilderT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> RegionBuilderT m ())
-> (s -> m ()) -> s -> RegionBuilderT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

type BlockName = Name

class Monad m => MonadRegionBuilder m where
  appendBlock :: BlockBuilderT m EndOfBlock -> m BlockName

endOfRegion :: Monad m => m ()
endOfRegion :: m ()
endOfRegion = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

buildRegion :: Monad m => RegionBuilderT m () -> m Region
buildRegion :: RegionBuilderT m () -> m Region
buildRegion (RegionBuilderT StateT RegionBuilderState m ()
regionBuilder) =
  [Block] -> Region
Region ([Block] -> Region)
-> (RegionBuilderState -> [Block]) -> RegionBuilderState -> Region
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SnocList Block -> [Block]
forall a. SnocList a -> [a]
unsnocList (SnocList Block -> [Block])
-> (RegionBuilderState -> SnocList Block)
-> RegionBuilderState
-> [Block]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RegionBuilderState -> SnocList Block
blocks (RegionBuilderState -> Region) -> m RegionBuilderState -> m Region
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT RegionBuilderState m ()
-> RegionBuilderState -> m RegionBuilderState
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT StateT RegionBuilderState m ()
regionBuilder (SnocList Block -> Int -> RegionBuilderState
RegionBuilderState SnocList Block
forall a. Monoid a => a
mempty Int
0)

buildBlock :: Monad m => BlockBuilderT m EndOfBlock -> RegionBuilderT m BlockName
buildBlock :: BlockBuilderT m EndOfBlock -> RegionBuilderT m Name
buildBlock BlockBuilderT m EndOfBlock
builder = StateT RegionBuilderState m Name -> RegionBuilderT m Name
forall (m :: * -> *) a.
StateT RegionBuilderState m a -> RegionBuilderT m a
RegionBuilderT (StateT RegionBuilderState m Name -> RegionBuilderT m Name)
-> StateT RegionBuilderState m Name -> RegionBuilderT m Name
forall a b. (a -> b) -> a -> b
$ do
  (EndOfBlock
EndOfBlock, ([Value]
args, [Binding]
body)) <- m (EndOfBlock, ([Value], [Binding]))
-> StateT RegionBuilderState m (EndOfBlock, ([Value], [Binding]))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (EndOfBlock, ([Value], [Binding]))
 -> StateT RegionBuilderState m (EndOfBlock, ([Value], [Binding])))
-> m (EndOfBlock, ([Value], [Binding]))
-> StateT RegionBuilderState m (EndOfBlock, ([Value], [Binding]))
forall a b. (a -> b) -> a -> b
$ BlockBuilderT m EndOfBlock -> m (EndOfBlock, ([Value], [Binding]))
forall (m :: * -> *) a.
Monad m =>
BlockBuilderT m a -> m (a, ([Value], [Binding]))
runBlockBuilder BlockBuilderT m EndOfBlock
builder
  [Value] -> [Binding] -> StateT RegionBuilderState m Name
forall (m :: * -> *).
MonadState RegionBuilderState m =>
[Value] -> [Binding] -> m Name
makeBlock [Value]
args [Binding]
body
  where
    makeBlock :: [Value] -> [Binding] -> m Name
makeBlock [Value]
args [Binding]
body = do
      Int
curBlockId <- (RegionBuilderState -> Int) -> m Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RegionBuilderState -> Int
nextBlockId
      (RegionBuilderState -> RegionBuilderState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RegionBuilderState
s -> RegionBuilderState
s { nextBlockId :: Int
nextBlockId = RegionBuilderState -> Int
nextBlockId RegionBuilderState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 })
      let blockName :: Name
blockName = Name
"bb" Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> (String -> Name
forall a. IsString a => String -> a
fromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
curBlockId)
      let block :: Block
block = Name -> [(Name, Type)] -> [Binding] -> Block
Block Name
blockName ((Value -> (Name, Type)) -> [Value] -> [(Name, Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Name
n :> Type
t) -> (Name
n, Type
t)) [Value]
args) [Binding]
body
      (RegionBuilderState -> RegionBuilderState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RegionBuilderState
s -> RegionBuilderState
s { blocks :: SnocList Block
blocks = RegionBuilderState -> SnocList Block
blocks RegionBuilderState
s SnocList Block -> Block -> SnocList Block
forall a. SnocList a -> a -> SnocList a
.:. Block
block })
      Name -> m Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
blockName

--------------------------------------------------------------------------------
-- Builtin dialect

soleBlock :: Monad m => BlockBuilderT m EndOfBlock -> m Block
soleBlock :: BlockBuilderT m EndOfBlock -> m Block
soleBlock BlockBuilderT m EndOfBlock
builder = do
  (EndOfBlock
EndOfBlock, ([Value]
args, [Binding]
body)) <- BlockBuilderT m EndOfBlock -> m (EndOfBlock, ([Value], [Binding]))
forall (m :: * -> *) a.
Monad m =>
BlockBuilderT m a -> m (a, ([Value], [Binding]))
runBlockBuilder BlockBuilderT m EndOfBlock
builder
  Block -> m Block
forall (m :: * -> *) a. Monad m => a -> m a
return (Block -> m Block) -> Block -> m Block
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Type)] -> [Binding] -> Block
Block Name
"0" ((Value -> (Name, Type)) -> [Value] -> [(Name, Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Name
n :> Type
t) -> (Name
n, Type
t)) [Value]
args) [Binding]
body

buildModule :: Monad m => BlockBuilderT m () -> m Operation
buildModule :: BlockBuilderT m () -> m Operation
buildModule BlockBuilderT m ()
build = (Block -> Operation) -> m Block -> m Operation
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Block -> Operation
ModuleOp (m Block -> m Operation) -> m Block -> m Operation
forall a b. (a -> b) -> a -> b
$ BlockBuilderT m EndOfBlock -> m Block
forall (m :: * -> *).
Monad m =>
BlockBuilderT m EndOfBlock -> m Block
soleBlock (BlockBuilderT m EndOfBlock -> m Block)
-> BlockBuilderT m EndOfBlock -> m Block
forall a b. (a -> b) -> a -> b
$ BlockBuilderT m ()
build BlockBuilderT m ()
-> BlockBuilderT m EndOfBlock -> BlockBuilderT m EndOfBlock
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> BlockBuilderT m EndOfBlock
forall (m :: * -> *). Monad m => m EndOfBlock
noTerminator

declareFunction :: MonadBlockDecl m => Name -> Type -> m ()
declareFunction :: Name -> Type -> m ()
declareFunction Name
name Type
funcTy =
  Operation -> m ()
forall (m :: * -> *). MonadBlockDecl m => Operation -> m ()
emitOp_ (Operation -> m ()) -> Operation -> m ()
forall a b. (a -> b) -> a -> b
$ Location -> Name -> Type -> Region -> Operation
FuncOp Location
UnknownLocation Name
name Type
funcTy (Region -> Operation) -> Region -> Operation
forall a b. (a -> b) -> a -> b
$ [Block] -> Region
Region []

buildFunction :: MonadBlockDecl m
              => Name -> [Type] -> NamedAttributes
              -> RegionBuilderT (NameSupplyT m) () -> m ()
buildFunction :: Name
-> [Type]
-> NamedAttributes
-> RegionBuilderT (NameSupplyT m) ()
-> m ()
buildFunction Name
name [Type]
retTypes NamedAttributes
attrs RegionBuilderT (NameSupplyT m) ()
bodyBuilder = do
  body :: Region
body@(Region [Block]
blocks) <- NameSupplyT m Region -> m Region
forall (m :: * -> *) a. Monad m => NameSupplyT m a -> m a
evalNameSupplyT (NameSupplyT m Region -> m Region)
-> NameSupplyT m Region -> m Region
forall a b. (a -> b) -> a -> b
$ RegionBuilderT (NameSupplyT m) () -> NameSupplyT m Region
forall (m :: * -> *). Monad m => RegionBuilderT m () -> m Region
buildRegion RegionBuilderT (NameSupplyT m) ()
bodyBuilder
  let argTypes :: [Type]
argTypes = case [Block]
blocks of
        [] -> String -> [Type]
forall a. HasCallStack => String -> a
error (String -> [Type]) -> String -> [Type]
forall a b. (a -> b) -> a -> b
$ String
"buildFunction cannot be used for function declarations! " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                      String
"Build at least one block!"
        (Block Name
_ [(Name, Type)]
args [Binding]
_) : [Block]
_ -> ((Name, Type) -> Type) -> [(Name, Type)] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name, Type) -> Type
forall a b. (a, b) -> b
snd [(Name, Type)]
args
  let op :: Operation
op = Location -> Name -> Type -> Region -> Operation
FuncOp Location
UnknownLocation Name
name ([Type] -> [Type] -> Type
FunctionType [Type]
argTypes [Type]
retTypes) Region
body
  Operation -> m ()
forall (m :: * -> *). MonadBlockDecl m => Operation -> m ()
emitOp_ (Operation -> m ()) -> Operation -> m ()
forall a b. (a -> b) -> a -> b
$ Operation
op { opAttributes :: NamedAttributes
opAttributes = Operation -> NamedAttributes
forall operand. AbstractOperation operand -> NamedAttributes
opAttributes Operation
op NamedAttributes -> NamedAttributes -> NamedAttributes
forall a. Semigroup a => a -> a -> a
<> NamedAttributes
attrs }

buildSimpleFunction :: MonadBlockDecl m
                    => Name -> [Type] -> NamedAttributes
                    -> BlockBuilderT (NameSupplyT m) EndOfBlock -> m ()
buildSimpleFunction :: Name
-> [Type]
-> NamedAttributes
-> BlockBuilderT (NameSupplyT m) EndOfBlock
-> m ()
buildSimpleFunction Name
name [Type]
retTypes NamedAttributes
attrs BlockBuilderT (NameSupplyT m) EndOfBlock
bodyBuilder = do
  Block
block <- NameSupplyT m Block -> m Block
forall (m :: * -> *) a. Monad m => NameSupplyT m a -> m a
evalNameSupplyT (NameSupplyT m Block -> m Block) -> NameSupplyT m Block -> m Block
forall a b. (a -> b) -> a -> b
$ BlockBuilderT (NameSupplyT m) EndOfBlock -> NameSupplyT m Block
forall (m :: * -> *).
Monad m =>
BlockBuilderT m EndOfBlock -> m Block
soleBlock BlockBuilderT (NameSupplyT m) EndOfBlock
bodyBuilder
  let argTypes :: [Type]
argTypes = ((Name, Type) -> Type) -> [(Name, Type)] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name, Type) -> Type
forall a b. (a, b) -> b
snd ([(Name, Type)] -> [Type]) -> [(Name, Type)] -> [Type]
forall a b. (a -> b) -> a -> b
$ Block -> [(Name, Type)]
blockArgs Block
block
  let fTy :: Type
fTy = [Type] -> [Type] -> Type
FunctionType [Type]
argTypes [Type]
retTypes
  let op :: Operation
op = Location -> Name -> Type -> Region -> Operation
FuncOp Location
UnknownLocation Name
name Type
fTy (Region -> Operation) -> Region -> Operation
forall a b. (a -> b) -> a -> b
$ [Block] -> Region
Region [Block
block]
  Operation -> m ()
forall (m :: * -> *). MonadBlockDecl m => Operation -> m ()
emitOp_ (Operation -> m ()) -> Operation -> m ()
forall a b. (a -> b) -> a -> b
$ Operation
op { opAttributes :: NamedAttributes
opAttributes = Operation -> NamedAttributes
forall operand. AbstractOperation operand -> NamedAttributes
opAttributes Operation
op NamedAttributes -> NamedAttributes -> NamedAttributes
forall a. Semigroup a => a -> a -> a
<> NamedAttributes
attrs }

--------------------------------------------------------------------------------
-- Utilities

newtype SnocList a = SnocList [a]

(.:.) :: SnocList a -> a -> SnocList a
(SnocList [a]
t) .:. :: SnocList a -> a -> SnocList a
.:. a
h = [a] -> SnocList a
forall a. [a] -> SnocList a
SnocList (a
h a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
t)

unsnocList :: SnocList a -> [a]
unsnocList :: SnocList a -> [a]
unsnocList (SnocList [a]
l) = [a] -> [a]
forall a. [a] -> [a]
reverse [a]
l

instance Semigroup (SnocList a) where
  SnocList [a]
l <> :: SnocList a -> SnocList a -> SnocList a
<> SnocList [a]
r = [a] -> SnocList a
forall a. [a] -> SnocList a
SnocList ([a]
r [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
l)

instance Monoid (SnocList a) where
  mempty :: SnocList a
mempty = [a] -> SnocList a
forall a. [a] -> SnocList a
SnocList []