-- |
-- Module    : Statistics.Matrix.Mutable
-- Copyright : (c) 2014 Bryan O'Sullivan
-- License   : BSD3
--
-- Basic mutable matrix operations.

module Statistics.Matrix.Mutable
    (
      MMatrix(..)
    , MVector
    , replicate
    , thaw
    , bounds
    , unsafeNew
    , unsafeFreeze
    , unsafeRead
    , unsafeWrite
    , unsafeModify
    , immutably
    , unsafeBounds
    ) where

import Control.Applicative ((<$>))
import Control.DeepSeq (NFData(..))
import Control.Monad.ST (ST)
import Statistics.Matrix.Types (Matrix(..), MMatrix(..), MVector)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Prelude hiding (replicate)

replicate :: Int -> Int -> Double -> ST s (MMatrix s)
replicate :: forall s. Int -> Int -> Double -> ST s (MMatrix s)
replicate Int
r Int
c Double
k = forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
M.replicate (Int
rforall a. Num a => a -> a -> a
*Int
c) Double
k

thaw :: Matrix -> ST s (MMatrix s)
thaw :: forall s. Matrix -> ST s (MMatrix s)
thaw (Matrix Int
r Int
c Vector
v) = forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector
v

unsafeFreeze :: MMatrix s -> ST s Matrix
unsafeFreeze :: forall s. MMatrix s -> ST s Matrix
unsafeFreeze (MMatrix Int
r Int
c MVector s
mv) = Int -> Int -> Vector -> Matrix
Matrix Int
r Int
c forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s
mv

-- | Allocate new matrix. Matrix content is not initialized hence unsafe.
unsafeNew :: Int                -- ^ Number of row
          -> Int                -- ^ Number of columns
          -> ST s (MMatrix s)
unsafeNew :: forall s. Int -> Int -> ST s (MMatrix s)
unsafeNew Int
r Int
c
  | Int
r forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Matrix.Mutable.unsafeNew: negative number of rows"
  | Int
c forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Matrix.Mutable.unsafeNew: negative number of columns"
  | Bool
otherwise = do
      MVector s
vec <- forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.new (Int
rforall a. Num a => a -> a -> a
*Int
c)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c MVector s
vec

unsafeRead :: MMatrix s -> Int -> Int -> ST s Double
unsafeRead :: forall s. MMatrix s -> Int -> Int -> ST s Double
unsafeRead MMatrix s
mat Int
r Int
c = forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
r Int
c forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
M.unsafeRead
{-# INLINE unsafeRead #-}

unsafeWrite :: MMatrix s -> Int -> Int -> Double -> ST s ()
unsafeWrite :: forall s. MMatrix s -> Int -> Int -> Double -> ST s ()
unsafeWrite MMatrix s
mat Int
row Int
col Double
k = forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
row Int
col forall a b. (a -> b) -> a -> b
$ \MVector s
v Int
i ->
  forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector s
v Int
i Double
k
{-# INLINE unsafeWrite #-}

unsafeModify :: MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
unsafeModify :: forall s. MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
unsafeModify MMatrix s
mat Int
row Int
col Double -> Double
f = forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
row Int
col forall a b. (a -> b) -> a -> b
$ \MVector s
v Int
i -> do
  Double
k <- forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
M.unsafeRead MVector s
v Int
i
  forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector s
v Int
i (Double -> Double
f Double
k)
{-# INLINE unsafeModify #-}

-- | Given row and column numbers, calculate the offset into the flat
-- row-major vector.
bounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
bounds :: forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
bounds (MMatrix Int
rs Int
cs MVector s
mv) Int
r Int
c MVector s -> Int -> r
k
  | Int
r forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r forall a. Ord a => a -> a -> Bool
>= Int
rs = forall a. HasCallStack => [Char] -> a
error [Char]
"row out of bounds"
  | Int
c forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c forall a. Ord a => a -> a -> Bool
>= Int
cs = forall a. HasCallStack => [Char] -> a
error [Char]
"column out of bounds"
  | Bool
otherwise        = MVector s -> Int -> r
k MVector s
mv forall a b. (a -> b) -> a -> b
$! Int
r forall a. Num a => a -> a -> a
* Int
cs forall a. Num a => a -> a -> a
+ Int
c
{-# INLINE bounds #-}

-- | Given row and column numbers, calculate the offset into the flat
-- row-major vector, without checking.
unsafeBounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds :: forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds (MMatrix Int
_ Int
cs MVector s
mv) Int
r Int
c MVector s -> Int -> r
k = MVector s -> Int -> r
k MVector s
mv forall a b. (a -> b) -> a -> b
$! Int
r forall a. Num a => a -> a -> a
* Int
cs forall a. Num a => a -> a -> a
+ Int
c
{-# INLINE unsafeBounds #-}

immutably :: NFData a => MMatrix s -> (Matrix -> a) -> ST s a
immutably :: forall a s. NFData a => MMatrix s -> (Matrix -> a) -> ST s a
immutably MMatrix s
mmat Matrix -> a
f = do
  a
k <- Matrix -> a
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s. MMatrix s -> ST s Matrix
unsafeFreeze MMatrix s
mmat
  forall a. NFData a => a -> ()
rnf a
k seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return a
k
{-# INLINE immutably #-}