module MLIR.AST.IStorableArray (IStorableArray, unsafeWithIStorableArray) where

import Data.Ix
import Data.Array.Storable
import Data.Array.Base
import Foreign.Ptr
import Foreign.Storable
import System.IO.Unsafe

newtype IStorableArray i e = UnsafeIStorableArray (StorableArray i e)

unsafeWithIStorableArray :: IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray :: IStorableArray i e -> (Ptr e -> IO c) -> IO c
unsafeWithIStorableArray (UnsafeIStorableArray StorableArray i e
arr) = StorableArray i e -> (Ptr e -> IO c) -> IO c
forall i e a. StorableArray i e -> (Ptr e -> IO a) -> IO a
withStorableArray StorableArray i e
arr

instance Storable e => IArray IStorableArray e where
  bounds :: IStorableArray i e -> (i, i)
bounds (UnsafeIStorableArray StorableArray i e
arr) = IO (i, i) -> (i, i)
forall a. IO a -> a
unsafeDupablePerformIO (IO (i, i) -> (i, i)) -> IO (i, i) -> (i, i)
forall a b. (a -> b) -> a -> b
$ StorableArray i e -> IO (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds StorableArray i e
arr
  numElements :: IStorableArray i e -> Int
numElements = (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize ((i, i) -> Int)
-> (IStorableArray i e -> (i, i)) -> IStorableArray i e -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IStorableArray i e -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds
  unsafeArray :: (i, i) -> [(Int, e)] -> IStorableArray i e
unsafeArray (i, i)
bs [(Int, e)]
inits = IO (IStorableArray i e) -> IStorableArray i e
forall a. IO a -> a
unsafeDupablePerformIO (IO (IStorableArray i e) -> IStorableArray i e)
-> IO (IStorableArray i e) -> IStorableArray i e
forall a b. (a -> b) -> a -> b
$ do
    StorableArray i e
arr <- (i, i) -> IO (StorableArray i e)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ (i, i)
bs
    ((Int, e) -> IO ()) -> [(Int, e)] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Int -> e -> IO ()) -> (Int, e) -> IO ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Int -> e -> IO ()) -> (Int, e) -> IO ())
-> (Int -> e -> IO ()) -> (Int, e) -> IO ()
forall a b. (a -> b) -> a -> b
$ StorableArray i e -> Int -> e -> IO ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite StorableArray i e
arr) [(Int, e)]
inits
    IStorableArray i e -> IO (IStorableArray i e)
forall (m :: * -> *) a. Monad m => a -> m a
return (IStorableArray i e -> IO (IStorableArray i e))
-> IStorableArray i e -> IO (IStorableArray i e)
forall a b. (a -> b) -> a -> b
$ StorableArray i e -> IStorableArray i e
forall i e. StorableArray i e -> IStorableArray i e
UnsafeIStorableArray StorableArray i e
arr
  unsafeAt :: IStorableArray i e -> Int -> e
unsafeAt (UnsafeIStorableArray StorableArray i e
arr) Int
i = IO e -> e
forall a. IO a -> a
unsafeDupablePerformIO (IO e -> e) -> IO e -> e
forall a b. (a -> b) -> a -> b
$ StorableArray i e -> Int -> IO e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead StorableArray i e
arr Int
i

instance (Ix i, Show i, Show e, Storable e) => Show (IStorableArray i e) where
  showsPrec :: Int -> IStorableArray i e -> ShowS
showsPrec = Int -> IStorableArray i e -> ShowS
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i, Show i, Show e) =>
Int -> a i e -> ShowS
showsIArray

instance (Ix i, Eq e, Storable e) => Eq (IStorableArray i e) where
  IStorableArray i e
a == :: IStorableArray i e -> IStorableArray i e -> Bool
== IStorableArray i e
b = (IStorableArray i e -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i e
a (i, i) -> (i, i) -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i e -> (i, i)
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> (i, i)
bounds IStorableArray i e
b) Bool -> Bool -> Bool
&&
    ((Bool -> Bool) -> [Bool] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Bool -> Bool
forall a. a -> a
id [IStorableArray i e -> Int -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
unsafeAt IStorableArray i e
a Int
i e -> e -> Bool
forall a. Eq a => a -> a -> Bool
== IStorableArray i e -> Int -> e
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
unsafeAt IStorableArray i e
b Int
i | Int
i <- [Int
0 .. IStorableArray i e -> Int
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> Int
numElements IStorableArray i e
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]])