module Vectorise.Generic.PADict ( buildPADict ) where import Vectorise.Monad import Vectorise.Builtins import Vectorise.Generic.Description import Vectorise.Generic.PAMethods ( buildPAScAndMethods ) import Vectorise.Utils import BasicTypes import CoreSyn import CoreUtils import CoreUnfold import DsMonad import TyCon import Type import Id import Var import Name -- |Build the PA dictionary function for some type and hoist it to top level. -- -- The PA dictionary holds fns that convert values to and from their vectorised representations. -- -- @Recall the definition: -- class class PR (PRepr a) => PA a where -- toPRepr :: a -> PRepr a -- fromPRepr :: PRepr a -> a -- toArrPRepr :: PData a -> PData (PRepr a) -- fromArrPRepr :: PData (PRepr a) -> PData a -- -- Example: -- df :: forall a. PA a -> PA (T a) -- df = /\a. \(d:PA a). MkPA ($PR_df a d) ($toPRepr a d) ... -- $dPR_df :: forall a. PA a -> PR (PRepr (T a)) -- $dPR_df = .... -- $toRepr :: forall a. PA a -> T a -> PRepr (T a) -- $toPRepr = ... -- The "..." stuff is filled in by buildPAScAndMethods -- @ -- buildPADict :: TyCon -- ^ tycon of the type being vectorised. -> TyCon -- ^ tycon of the type used for the vectorised representation. -> TyCon -- ^ PData instance tycon -> TyCon -- ^ PDatas instance tycon -> SumRepr -- ^ representation used for the type being vectorised. -> VM Var -- ^ name of the top-level dictionary function. buildPADict vect_tc prepr_tc pdata_tc pdatas_tc repr = polyAbstract tvs $ \args -> -- The args are the dictionaries we lambda -- abstract over; and they are put in the -- envt, so when we need a (PA a) we can -- find it in the envt do { mod <- liftDs getModuleDs ; let dfun_name = mkLocalisedOccName mod mkPADFunOcc vect_tc_name -- Get ids for each of the methods in the dictionary, including superclass ; paMethodBuilders <- buildPAScAndMethods ; method_ids <- mapM (method args dfun_name) paMethodBuilders -- Expression to build the dictionary. ; pa_dc <- builtin paDataCon ; let dict = mkLams (tvs ++ args) $ mkConApp pa_dc $ Type inst_ty : map (method_call args) method_ids -- Build the type of the dictionary function. ; pa_cls <- builtin paClass ; let dfun_ty = mkForAllTys tvs $ mkFunTys (map varType args) (mkClassPred pa_cls [inst_ty]) -- Set the unfolding for the inliner. ; raw_dfun <- newExportedVar dfun_name dfun_ty ; let dfun_unf = mkDFunUnfolding dfun_ty $ map Var method_ids dfun = raw_dfun `setIdUnfolding` dfun_unf `setInlinePragma` dfunInlinePragma -- Add the new binding to the top-level environment. ; hoistBinding dfun dict ; return dfun } where tvs = tyConTyVars vect_tc arg_tys = mkTyVarTys tvs inst_ty = mkTyConApp vect_tc arg_tys vect_tc_name = getName vect_tc method args dfun_name (name, build) = localV $ do expr <- build vect_tc prepr_tc pdata_tc pdatas_tc repr let body = mkLams (tvs ++ args) expr raw_var <- newExportedVar (method_name dfun_name name) (exprType body) let var = raw_var `setIdUnfolding` mkInlineUnfolding (Just (length args)) body `setInlinePragma` alwaysInlinePragma hoistBinding var body return var method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args) method_name dfun_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)