{-# LANGUAGE EmptyDataDecls #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Array.Parallel.Unlifted.Distributed.Arrays -- Copyright : (c) 2006 Roman Leshchinskiy -- License : see libraries/ndp/LICENSE -- -- Maintainer : Roman Leshchinskiy -- Stability : experimental -- Portability : non-portable (GHC Extensions) -- -- Operations on distributed arrays. -- {-# LANGUAGE CPP #-} #include "fusion-phases.h" module Data.Array.Parallel.Unlifted.Distributed.Arrays ( lengthD, splitLenD, splitLengthD, splitAsD, splitD, joinLengthD, joinD, splitJoinD, splitSegdD, splitSD, permuteD, bpermuteD, atomicUpdateD, Distribution, balanced, unbalanced ) where import Data.Array.Parallel.Base ( (:*:)(..), fstS, sndS, ST, runST) import Data.Array.Parallel.Unlifted.Sequential import Data.Array.Parallel.Unlifted.Distributed.Gang ( Gang, gangSize, seqGang) import Data.Array.Parallel.Unlifted.Distributed.DistST ( stToDistST) import Data.Array.Parallel.Unlifted.Distributed.Types ( DT, Dist, indexD, lengthD, newD, writeMD, zipD, unzipD, fstD, sndD, elementsUSegdD, checkGangD) import Data.Array.Parallel.Unlifted.Distributed.Basics import Data.Array.Parallel.Unlifted.Distributed.Combinators ( mapD, zipWithD, scanD, mapAccumLD, zipWithDST_, mapDST_) import Data.Array.Parallel.Unlifted.Distributed.Scalars ( sumD) here s = "Data.Array.Parallel.Unlifted.Distributed.Arrays." ++ s data Distribution balanced :: Distribution {-# NOINLINE balanced #-} balanced = error $ here "balanced: touched" unbalanced :: Distribution {-# NOINLINE unbalanced #-} unbalanced = error $ here "unbalanced: touched" -- | Distribute the length of an array over a 'Gang'. splitLengthD :: UA a => Gang -> UArr a -> Dist Int {-# INLINE splitLengthD #-} splitLengthD g = splitLenD g . lengthU -- | Distribute the given array length over a 'Gang'. splitLenD :: Gang -> Int -> Dist Int {-# NOINLINE PHASE_DIST splitLenD #-} splitLenD g !n = newD g (`fill` 0) where p = gangSize g l = n `div` p m = n `mod` p -- fill md i | i < m = writeMD md i (l+1) >> fill md (i+1) | i < p = writeMD md i l >> fill md (i+1) | otherwise = return () -- | Distribute an array over a 'Gang' such that each threads gets the given -- number of elements. splitAsD :: UA a => Gang -> Dist Int -> UArr a -> Dist (UArr a) {-# INLINE_DIST splitAsD #-} splitAsD g dlen !arr = zipWithD (seqGang g) (sliceU arr) is dlen where is = fstS $ scanD g (+) 0 dlen -- | Distribute an array over a 'Gang'. splitD :: UA a => Gang -> Distribution -> UArr a -> Dist (UArr a) {-# INLINE_DIST splitD #-} splitD g _ !arr = zipWithD (seqGang g) (sliceU arr) is dlen where dlen = splitLengthD g arr is = fstS $ scanD g (+) 0 dlen -- lengthD reexported from types -- | Overall length of a distributed array. joinLengthD :: UA a => Gang -> Dist (UArr a) -> Int {-# INLINE joinLengthD #-} joinLengthD g = sumD g . lengthD -- | Join a distributed array. joinD :: UA a => Gang -> Distribution -> Dist (UArr a) -> UArr a {-# INLINE_DIST joinD #-} joinD g _ !darr = checkGangD (here "joinD") g darr $ newU n (\ma -> zipWithDST_ g (copy ma) di darr) where di :*: n = scanD g (+) 0 $ lengthD darr -- copy ma i arr = stToDistST (copyMU ma i arr) splitJoinD :: (UA a, UA b) => Gang -> (Dist (UArr a) -> Dist (UArr b)) -> UArr a -> UArr b {-# INLINE_DIST splitJoinD #-} splitJoinD g f !xs = joinD g unbalanced (f (splitD g unbalanced xs)) -- | Join a distributed array, yielding a mutable global array joinDM :: UA a => Gang -> Dist (UArr a) -> ST s (MUArr a s) {-# INLINE joinDM #-} joinDM g darr = checkGangD (here "joinDM") g darr $ do marr <- newMU n zipWithDST_ g (copy marr) di darr return marr where di :*: n = scanD g (+) 0 $ lengthD darr -- copy ma i arr = stToDistST (copyMU ma i arr) {-# RULES "splitD[unbalanced]/joinD" forall g b da. splitD g unbalanced (joinD g b da) = da "splitD[balanced]/joinD" forall g da. splitD g balanced (joinD g balanced da) = da "splitD/splitJoinD" forall g b f xs. splitD g b (splitJoinD g f xs) = f (splitD g b xs) "splitJoinD/joinD" forall g b f da. splitJoinD g f (joinD g b da) = joinD g b (f da) "splitJoinD/splitJoinD" forall g f1 f2 xs. splitJoinD g f1 (splitJoinD g f2 xs) = splitJoinD g (f1 . f2) xs {- "splitD/zipU" forall g b xs ys. splitD g b (zipU xs ys) = zipWithD g zipU (splitD g balanced xs) (splitD g balanced ys) "splitJoinD/zipU" forall g f xs ys. splitJoinD g f (zipU xs ys) = joinD g balanced (f (zipWithD g zipU (splitD g balanced xs) (splitD g balanced ys))) "splitAsD/zipU" forall g dlen xs ys. splitAsD g dlen (zipU xs ys) = zipWithD g zipU (splitAsD g dlen xs) (splitAsD g dlen ys) -} "fstU/joinD" forall g b xs. fstU (joinD g b xs) = joinD g b (mapD g fstU xs) "sndU/joinD" forall g b xs. sndU (joinD g b xs) = joinD g b (mapD g sndU xs) "fstU/splitJoinD" forall g f xs. fstU (splitJoinD g f xs) = splitJoinD g (mapD g fstU . f) xs "sndU/splitJoinD" forall g f xs. sndU (splitJoinD g f xs) = splitJoinD g (mapD g sndU . f) xs #-} -- | Permute for distributed arrays. permuteD :: UA a => Gang -> Dist (UArr a) -> Dist (UArr Int) -> UArr a {-# INLINE_DIST permuteD #-} permuteD g darr dis = newU n (\ma -> zipWithDST_ g (permute ma) darr dis) where n = joinLengthD g darr -- permute ma arr is = stToDistST (permuteMU ma arr is) -- NOTE: The bang is necessary because the array must be fully evaluated -- before we pass it to the parallel computation. bpermuteD :: UA a => Gang -> UArr a -> Dist (UArr Int) -> Dist (UArr a) {-# INLINE bpermuteD #-} bpermuteD g !as ds = mapD g (bpermuteU as) ds -- NB: This does not (and cannot) try to prevent two threads from writing to -- the same position. We probably want to consider this an (unchecked) user -- error. atomicUpdateD :: UA a => Gang -> Dist (UArr a) -> Dist (UArr (Int :*: a)) -> UArr a {-# INLINE atomicUpdateD #-} atomicUpdateD g darr upd = runST ( do marr <- joinDM g darr mapDST_ g (update marr) upd unsafeFreezeAllMU marr ) where update marr arr = stToDistST (atomicUpdateMU marr arr) splitSegdD :: Gang -> USegd -> Dist USegd {-# NOINLINE splitSegdD #-} splitSegdD g !segd = mapD g lengthsToUSegd $ splitAsD g d lens where d = sndS . mapAccumLD g chunk 0 . splitLenD g $ elementsUSegd segd n = lengthUSegd segd lens = lengthsUSegd segd chunk i k = let j = go i k in j :*: (j-i) go i k | i >= n = i | m == 0 = go (i+1) k | k <= 0 = i | otherwise = go (i+1) (k-m) where m = lens !: i joinSegD :: Gang -> Dist USegd -> USegd {-# INLINE_DIST joinSegD #-} joinSegD g = lengthsToUSegd . joinD g unbalanced . mapD (seqGang g) lengthsUSegd splitSD :: UA a => Gang -> Dist USegd -> UArr a -> Dist (UArr a) {-# INLINE_DIST splitSD #-} splitSD g dsegd xs = splitAsD g (elementsUSegdD dsegd) xs {-# RULES "splitSD/splitJoinD" forall g d f xs. splitSD g d (splitJoinD g f xs) = f (splitSD g d xs) "splitSD/zipU" forall g d xs ys. splitSD g d (zipU xs ys) = zipWithD g zipU (splitSD g d xs) (splitSD g d ys) #-} {- RULES "splitSD[unbalanced]/joinSD" forall g b da. splitSD g unbalanced (joinSD g b da) = da "splitSD[balanced]/joinSD" forall g da. splitSD g balanced (joinSD g balanced da) = da "splitSD/splitJoinSD" forall g b f xs. splitSD g b (splitJoinSD g f xs) = f (splitSD g b xs) "splitJoinSD/joinSD" forall g b f da. splitJoinSD g f (joinSD g b da) = joinSD g b (f da) "splitJoinSD/splitJoinSD" forall g f1 f2 xs. splitJoinSD g f1 (splitJoinSD g f2 xs) = splitJoinSD g (f1 . f2) xs "fstSU/joinSD" forall g b xs. fstSU (joinSD g b xs) = joinSD g b (mapD g fstSU xs) "sndSU/joinSD" forall g b xs. sndSU (joinSD g b xs) = joinSD g b (mapD g sndSU xs) "splitSD/SUArr" forall g b segd xss. splitSD g b (SUArr segd xss) = splitSD' g b segd xss -}