-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.


module MLIR.AST.Dialect.Affine where

import Control.Monad.IO.Class
import Control.Monad.Trans.Cont
import qualified Language.C.Inline as C

import qualified MLIR.Native.FFI as Native
import MLIR.AST.Serialize

C.context $ C.baseCtx <> Native.mlirCtx

C.include "mlir-c/AffineExpr.h"
C.include "mlir-c/AffineMap.h"

data Expr =
    Dimension Int
  | Symbol    Int
  | Constant  Int
  | Add       Expr Expr
  | Mul       Expr Expr
  | Mod       Expr Expr
  | FloorDiv  Expr Expr
  | CeilDiv   Expr Expr
  deriving Expr -> Expr -> Bool
(Expr -> Expr -> Bool) -> (Expr -> Expr -> Bool) -> Eq Expr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Expr -> Expr -> Bool
$c/= :: Expr -> Expr -> Bool
== :: Expr -> Expr -> Bool
$c== :: Expr -> Expr -> Bool
Eq

data Map = Map { Map -> Int
mapDimensionCount :: Int
               , Map -> Int
mapSymbolCount :: Int
               , Map -> [Expr]
mapExprs :: [Expr]
               }
               deriving Map -> Map -> Bool
(Map -> Map -> Bool) -> (Map -> Map -> Bool) -> Eq Map
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Map -> Map -> Bool
$c/= :: Map -> Map -> Bool
== :: Map -> Map -> Bool
$c== :: Map -> Map -> Bool
Eq


instance FromAST Expr Native.AffineExpr where
  fromAST :: Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
fromAST Context
ctx ValueAndBlockMapping
env Expr
expr = case Expr
expr of
    Dimension Int
idx -> do
      let natIdx :: CIntPtr
natIdx = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
idx
      [C.exp| MlirAffineExpr { mlirAffineDimExprGet($(MlirContext ctx), $(intptr_t natIdx)) } |]
    Symbol    Int
idx -> do
      let natIdx :: CIntPtr
natIdx = Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
idx
      [C.exp| MlirAffineExpr { mlirAffineSymbolExprGet($(MlirContext ctx), $(intptr_t natIdx)) } |]
    Constant  Int
val -> do
      let natVal :: Int64
natVal = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
val
      [C.exp| MlirAffineExpr { mlirAffineConstantExprGet($(MlirContext ctx), $(int64_t natVal)) } |]
    Add       Expr
l Expr
r -> do
      AffineExpr
natL <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
l
      AffineExpr
natR <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
r
      [C.exp| MlirAffineExpr { mlirAffineAddExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |]
    Mul       Expr
l Expr
r -> do
      AffineExpr
natL <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
l
      AffineExpr
natR <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
r
      [C.exp| MlirAffineExpr { mlirAffineMulExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |]
    Mod       Expr
l Expr
r -> do
      AffineExpr
natL <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
l
      AffineExpr
natR <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
r
      [C.exp| MlirAffineExpr { mlirAffineModExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |]
    FloorDiv  Expr
l Expr
r -> do
      AffineExpr
natL <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
l
      AffineExpr
natR <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
r
      [C.exp| MlirAffineExpr { mlirAffineFloorDivExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |]
    CeilDiv   Expr
l Expr
r -> do
      AffineExpr
natL <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
l
      AffineExpr
natR <- Context -> ValueAndBlockMapping -> Expr -> IO AffineExpr
forall ast native.
FromAST ast native =>
Context -> ValueAndBlockMapping -> ast -> IO native
fromAST Context
ctx ValueAndBlockMapping
env Expr
r
      [C.exp| MlirAffineExpr { mlirAffineCeilDivExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |]


instance FromAST Map Native.AffineMap where
  fromAST :: Context -> ValueAndBlockMapping -> Map -> IO AffineMap
fromAST Context
ctx ValueAndBlockMapping
env Map{Int
[Expr]
mapExprs :: [Expr]
mapSymbolCount :: Int
mapDimensionCount :: Int
mapExprs :: Map -> [Expr]
mapSymbolCount :: Map -> Int
mapDimensionCount :: Map -> Int
..} = ContT AffineMap IO AffineMap -> IO AffineMap
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT AffineMap IO AffineMap -> IO AffineMap)
-> ContT AffineMap IO AffineMap -> IO AffineMap
forall a b. (a -> b) -> a -> b
$ do
    (CIntPtr
numExprs, Ptr AffineExpr
nativeExprs) <- Context
-> ValueAndBlockMapping
-> [Expr]
-> ContT AffineMap IO (CIntPtr, Ptr AffineExpr)
forall ast native r.
(FromAST ast native, Storable native) =>
Context
-> ValueAndBlockMapping
-> [ast]
-> ContT r IO (CIntPtr, Ptr native)
packFromAST Context
ctx ValueAndBlockMapping
env [Expr]
mapExprs
    let nativeDimCount :: CIntPtr
nativeDimCount      =  Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
mapDimensionCount
    let nativeSymbolCount :: CIntPtr
nativeSymbolCount   =  Int -> CIntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
mapSymbolCount
    IO AffineMap -> ContT AffineMap IO AffineMap
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO AffineMap -> ContT AffineMap IO AffineMap)
-> IO AffineMap -> ContT AffineMap IO AffineMap
forall a b. (a -> b) -> a -> b
$ [C.exp| MlirAffineMap {
      mlirAffineMapGet($(MlirContext ctx),
                       $(intptr_t nativeDimCount),
                       $(intptr_t nativeSymbolCount),
                       $(intptr_t numExprs), $(MlirAffineExpr* nativeExprs))
    } |]