diff --git a/package.yaml b/package.yaml index da24a23..85b2236 100644 --- a/package.yaml +++ b/package.yaml @@ -41,10 +41,13 @@ dependencies: - prettyprinter - unification-fd - mtl +- monad-supply - containers - bytestring - optparse-applicative - safe-exceptions +- vinyl +- singletons ghc-options: - -Wall @@ -86,3 +89,5 @@ tests: - -with-rtsopts=-N dependencies: - simpl-lang + - tasty + - tasty-hunit diff --git a/src/Simpl/Ast.hs b/src/Simpl/Ast.hs index 4a32651..9aa5f1c 100644 --- a/src/Simpl/Ast.hs +++ b/src/Simpl/Ast.hs @@ -243,6 +243,11 @@ isComplexType = \case TyFun _ _ -> True _ -> False +functionTypeResult :: Type -> Type +functionTypeResult (Fix ty) = case ty of + TyFun _ res -> functionTypeResult res + _ -> Fix ty + instance Pretty Type where pretty = para go where diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs new file mode 100644 index 0000000..66db6df --- /dev/null +++ b/src/Simpl/AstToJoinIR.hs @@ -0,0 +1,191 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-| +Module : Simpl.AstToJoinIR +Description : Provides a function to normalize SimPL AST, transforming it into + JoinIR. +-} +module Simpl.AstToJoinIR + ( astToJoinIR + ) where + +import Control.Monad.Supply +import Control.Monad.Reader hiding (guard) +import Data.Functor.Foldable (Fix(..), unfix) +import Data.Functor.Identity +import Data.Text (Text) +import Data.String (fromString) + +import Simpl.Ast (Type) +import Simpl.SymbolTable +import qualified Simpl.Ast as A +import qualified Simpl.JoinIR.Syntax as J + +-- * Public API + +astToJoinIR :: SymbolTable (A.AnnExpr Type) -> SymbolTable (J.AnnExpr '[ 'J.ExprType]) +astToJoinIR = runTransform transformTable + +-- * Transformation Monad + +newtype TransformT m a = + TransformT { unTransform :: ReaderT (SymbolTable (A.AnnExpr Type)) (SupplyT Int m) a } + deriving ( Functor + , Applicative + , Monad + , MonadReader (SymbolTable (A.AnnExpr Type)) + , MonadFreshVar) + +type Transform = TransformT Identity + +type MonadFreshVar = MonadSupply Int + +varSupply :: [Int] +varSupply = [0..] + +runTransformT :: Monad m => TransformT m a -> SymbolTable (A.AnnExpr Type) -> m a +runTransformT m table + = fmap fst + . flip runSupplyT varSupply + . flip runReaderT table + . unTransform + $ m + +runTransform :: Transform a -> SymbolTable (A.AnnExpr Type) -> a +runTransform m table = runIdentity (runTransformT m table) + +-- | Generates a fresh name using the given prefix and lookup function +freshName :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => Text -- ^ Prefix + -> (Text -> SymbolTable (A.AnnExpr Type) -> Maybe a) -- ^ Lookup function + -> m Text +freshName prefix lookupFun = do + next <- (prefix <>) . fromString . show <$> supply + asks (lookupFun next) >>= \case + Nothing -> pure next + Just _ -> freshName prefix lookupFun + +freshVar, freshLabel :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) => m Text +-- | Generate a fresh variable name +freshVar = freshName "var" symTabLookupVar +-- | Generate a fresh join label +freshLabel = freshName "join" symTabLookupFun + +-- * Private utility functions + +makeJexpr :: Type + -> J.JExprF (J.AnnExpr '[ 'J.ExprType]) + -> J.AnnExpr '[ 'J.ExprType] +makeJexpr ty = Fix . J.addField (J.withType ty) . J.toAnnExprF + +astType :: A.AnnExpr Type -> Type +astType = A.annGetAnn . unfix + +-- * ANF Transformation + +-- | Perform ANF transformation on the given symbol table +transformTable :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => m (SymbolTable (J.AnnExpr '[ 'J.ExprType])) +transformTable = do + table <- ask + symTabTraverseExprs (\(args, ty, expr) -> (args, ty, transformExpr expr)) table + +-- | Perform ANF transformation on the given expression +transformExpr :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => A.AnnExpr Type + -> m (J.AnnExpr '[ 'J.ExprType]) +transformExpr expr = anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) + +-- | Perform ANF transformation on the branch, afterwards handling control flow. +transformBranch :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => J.ControlFlow (J.AnnExpr '[ 'J.ExprType]) -- ^ Control flow handler + -> A.Branch (A.AnnExpr Type) -- ^ Branches + -> m (J.JBranch (J.AnnExpr '[ 'J.ExprType])) +transformBranch cf (A.BrAdt adtName argNames expr) = do + jexpr <- anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) + pure $ J.BrAdt adtName argNames (J.Cfe jexpr cf) + + +-- | Main ANF transformation logic. Given the SimPL AST, this function will +-- normalize the AST, and then it will feed the final JValue into the given +-- continuation to produce the resulting JoinIR AST. +anfTransform :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => A.AnnExpr Type -- ^ Expression to translate + -> (J.JValue -> m (J.AnnExpr '[ 'J.ExprType])) -- ^ Continuation + -> m (J.AnnExpr '[ 'J.ExprType]) +anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of + A.Lit lit -> cont (J.JLit lit) + A.Var name -> cont (J.JVar name) + A.Let name bindExpr next -> + anfTransform bindExpr $ \bindVal -> + makeJexpr (A.annGetAnn (unfix bindExpr)) . J.JLet name bindVal <$> + local (symTabInsertVar name ty) (anfTransform next cont) + A.BinOp op left right -> + anfTransform left $ \jleft -> + anfTransform right $ \jright -> do + name <- freshVar + makeJexpr ty . J.JApp name (J.CBinOp op) [jleft, jright] <$> + local (symTabInsertVar name ty) (cont (J.JVar name)) + A.If guard trueBr falseBr -> + anfTransform guard $ \jguard -> do + lbl <- freshLabel + trueBr' <- anfTransform trueBr (pure . makeJexpr (astType trueBr) . J.JVal) + falseBr' <- anfTransform falseBr (pure . makeJexpr (astType falseBr) . J.JVal) + name <- freshVar + let jmp = J.JJump lbl + let guardTy = A.annGetAnn (unfix guard) + let guardCfe = makeJexpr guardTy (J.JVal jguard) + let cfe = J.Cfe guardCfe (J.JIf (J.Cfe trueBr' jmp) (J.Cfe falseBr' jmp)) + -- TODO: Make JJoin node placement more efficient + makeJexpr ty . J.JJoin lbl name cfe <$> + local (symTabInsertVar name ty) (cont (J.JVar name)) + A.Case branches expr -> + anfTransform expr $ \jexpr -> do + lbl <- freshLabel + let jexprTy = A.annGetAnn (unfix expr) + jbranches <- traverse (transformBranch (J.JJump lbl)) branches + let jexprCfe = J.Cfe (makeJexpr jexprTy (J.JVal jexpr)) (J.JCase jbranches) + name <- freshVar + -- TODO: Make JJoin node placement more efficient + makeJexpr ty . J.JJoin lbl name jexprCfe <$> + local (symTabInsertVar name ty) (cont (J.JVar name)) + A.Cons ctorName args -> + collectArgs args $ \argVals -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) + A.App funcName args -> + collectArgs args $ \argVals -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CFunc funcName) argVals <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) + A.Cast expr numTy -> + anfTransform expr $ \jexpr -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr] <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) + A.Print expr -> + anfTransform expr $ \jexpr -> do + varName <- freshVar + makeJexpr ty . J.JApp varName J.CPrint [jexpr] <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) + A.FunRef name -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CFunRef name) [] <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) + +-- | Normalize each expression in sequential order, and then run the +-- continuation with the expression values. +collectArgs :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => [A.AnnExpr Type] -- ^ Argument expressions + -> ([J.JValue] -> m (J.AnnExpr '[ 'J.ExprType])) -- ^ Continuation + -> m (J.AnnExpr '[ 'J.ExprType]) +collectArgs = go [] + where + go vals [] mcont = mcont (reverse vals) + go vals (x:xs) mcont = anfTransform x $ \v -> go (v:vals) xs mcont diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 66d8e4c..a5a79bd 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# OPTIONS_GHC -Wno-incomplete-record-updates #-} -- Suppress LLVM sum type of records AST warnings @@ -17,8 +19,7 @@ import Control.Applicative ((<|>)) import Control.Monad (forM, forM_) import Control.Monad.Reader import Control.Monad.State -import Data.Functor.Foldable (para, unfix) -import Data.Text.Prettyprint.Doc (pretty) +import Data.Functor.Foldable (unfix, Fix(..)) import Data.Char (ord) import Data.Maybe (fromJust) import Data.Text (Text) @@ -27,10 +28,8 @@ import Data.ByteString () import qualified Data.ByteString as BS import Data.Map (Map) import Data.Functor.Identity -import Data.List.NonEmpty (NonEmpty(..)) import Data.String (fromString) -import qualified Data.List.NonEmpty as NE import qualified Data.Text as Text import qualified Data.Map as Map import qualified LLVM.AST as LLVM @@ -45,19 +44,20 @@ import qualified LLVM.IRBuilder.Monad as LLVMIR import qualified LLVM.IRBuilder.Instruction as LLVMIR import qualified LLVM.IRBuilder.Constant as LLVMIR -import Simpl.Ast +import Simpl.Ast (BinaryOp(..), Type, TypeF(..), Constructor(..), Literal(..), Numeric(..)) import Simpl.CompilerOptions import Simpl.SymbolTable -import Simpl.Typing (TypedExpr) +import Simpl.Typing (literalType) import Simpl.Backend.Runtime () +import Simpl.JoinIR.Syntax import qualified Simpl.Backend.Runtime as RT data CodegenTable = - MkCodegenTable { tableVars :: Map Text LLVM.Operand -- ^ Pointer to variables + MkCodegenTable { tableVars :: Map Text (TypeF Type, LLVM.Operand) -- ^ Pointer to variables , tableCtors :: Map Text (LLVM.Name, LLVM.Name, Int) -- ^ Data type name, ctor name, index , tableAdts :: Map Text (LLVM.Name, Type, [Constructor]) , tableFuns :: Map Text LLVM.Operand - , tableCurrentJoin :: LLVM.Name + , tableJoinValues :: Map Text (LLVM.Name, [(LLVM.Operand, LLVM.Name)]) , tablePrintf :: LLVM.Operand , tableOptions :: CompilerOpts } deriving (Show) @@ -70,7 +70,7 @@ emptyCodegenTable = , tableCtors = Map.empty , tableAdts = Map.empty , tableFuns = Map.empty - , tableCurrentJoin = LLVM.mkName "__default_join_point" + , tableJoinValues = Map.empty , tablePrintf = error "printf not set" , tableOptions = defaultCompilerOpts } @@ -88,16 +88,45 @@ deriving instance Monad m => Monad (CodegenT m) instance MonadTrans CodegenT where lift = CodegenT . lift -instance Monad m => MonadReader CodegenTable (CodegenT m) where - ask = local id get - local f action = do - curTable <- get - modify f - result <- action - put curTable - pure result - -initCodegenTable :: CompilerOpts -> SymbolTable TypedExpr -> Codegen () +localCodegenTable :: MonadState CodegenTable m + => (CodegenTable -> CodegenTable) + -> m a + -> m a +localCodegenTable f ma = do + oldTable <- get + put (f oldTable) + res <- ma + modify $ \t -> t + { tableVars = tableVars oldTable + , tableFuns = tableFuns oldTable + } + pure res + +lookupName :: MonadState CodegenTable m + => Text + -> m (Maybe LLVM.Operand) +lookupName name = + gets $ \t -> + (snd <$> Map.lookup name (tableVars t)) <|> Map.lookup name (tableFuns t) + +-- | Looks up the type of the [JValue]. Note: assumes that if the [JValue] is a +-- [JVar], then the variable is actually in the table. +lookupValueType :: MonadState CodegenTable m + => JValue + -> m (TypeF Type) +lookupValueType = \case + JVar name -> gets (fst . fromJust . Map.lookup name . tableVars) + JLit lit -> pure $ literalType lit + +bindVariable :: MonadState CodegenTable m + => Text + -> TypeF Type + -> LLVM.Operand + -> m () +bindVariable name ty oper = + modify (\t -> t { tableVars = Map.insert name (ty, oper) (tableVars t) }) + +initCodegenTable :: CompilerOpts -> SymbolTable (AnnExpr '[ 'ExprType ]) -> Codegen () initCodegenTable options symTab = do let adts = flip Map.mapWithKey (symTabAdts symTab) $ \name (ty, ctors) -> (llvmName name, ty, ctors) ctors <- forM (Map.elems adts) $ \(adtName, _, ctors) -> @@ -107,63 +136,6 @@ initCodegenTable options symTab = do , tableCtors = Map.fromList (join ctors) , tableOptions = options } -data Result = Values (NonEmpty LLVM.Operand) | Branching (NonEmpty (LLVM.Operand, LLVM.Name)) - deriving (Show, Eq) - -instance Ord Result where - compare (Values v1) (Values v2) = compare v1 v2 - compare (Values _) (Branching _) = GT - compare (Branching _) (Values _) = LT - compare (Branching br1) (Branching br2) = compare br1 br2 - -resultHasJump :: Result -> Bool -resultHasJump (Values _) = False -resultHasJump (Branching _) = True - -joinPoint :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m (NonEmpty Result) - -> m (NonEmpty LLVM.Operand) -joinPoint resultsM = do - currentJp <- gets tableCurrentJoin - newJp <- LLVMIR.freshName "join_point" - modify (\t -> t { tableCurrentJoin = newJp }) - results <- resultsM - -- modify (\t -> t { tableCurrentJoin = newJp }) - modify (\t -> t { tableCurrentJoin = currentJp }) - let hasJump = any resultHasJump results - when hasJump $ - LLVMIR.emitBlockStart newJp - let go (Values v) = pure v - go (Branching brs) = - if length brs == 1 - then pure (fst <$> brs) - else (:| []) <$> LLVMIR.phi (NE.toList brs) - join <$> traverse go (NE.sort results) - -joinPoint1 :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m Result - -> m LLVM.Operand -joinPoint1 = fmap NE.head . joinPoint . fmap (:| []) - -resultValue :: LLVM.Operand -> Result -resultValue v = Values (v :| []) - -resultBranching :: NonEmpty (LLVM.Operand, LLVM.Name) -> Result -resultBranching = Branching - -resultEnsureBranch :: LLVMIR.MonadIRBuilder m => Result -> m Result -resultEnsureBranch = \case - Values vs -> LLVMIR.currentBlock >>= \cb -> pure $ Branching ((, cb) <$> vs) - b@Branching {} -> pure b - -resultCombine :: LLVMIR.MonadIRBuilder m => Result -> Result -> m Result -resultCombine (Values v1) (Values v2) = pure $ Values (v1 <> v2) -resultCombine (Values v1) (Branching br2) = - LLVMIR.currentBlock >>= \cb -> pure $ Branching (((, cb) <$> v1) <> br2) -resultCombine (Branching br1) (Values v2) = - LLVMIR.currentBlock >>= \cb -> pure $ Branching (br1 <> ((, cb) <$> v2)) -resultCombine (Branching br1) (Branching br2) = pure $ Branching (br1 <> br2) - llvmByte :: Integer -> LLVMC.Constant llvmByte = LLVMC.Int 8 @@ -198,11 +170,12 @@ staticString name str = do llvmName :: Text -> LLVM.Name llvmName = LLVM.mkName . Text.unpack -literalCodegen :: LLVMIR.MonadIRBuilder m => Literal -> m Result +-- | Generates LLVM code for a literal. +literalCodegen :: LLVMIR.MonadIRBuilder m => Literal -> m LLVM.Operand literalCodegen = \case - LitInt x -> resultValue <$> LLVMIR.int64 (fromIntegral x) - LitDouble x -> resultValue <$> LLVMIR.double x - LitBool b -> resultValue <$> LLVMIR.bit (if b then 1 else 0) + LitInt x -> LLVMIR.int64 (fromIntegral x) + LitDouble x -> LLVMIR.double x + LitBool b -> LLVMIR.bit (if b then 1 else 0) LitString t -> do -- TODO: Store literal strings in global memory let byteS = encodeUtf8 t @@ -215,29 +188,28 @@ literalCodegen = \case byteDataPtr' <- LLVMIR.bitcast byteDataPtr (LLVM.ptr LLVM.i8) bytePtr <- LLVMIR.call RT.mallocRef [(lenOper, [])] _ <- LLVMIR.call RT.memcpyRef [(bytePtr, []), (byteDataPtr', []), (lenOper, [])] - str <- LLVMIR.call RT.stringNewRef [(lenOper, []), (bytePtr, [])] - pure $ resultValue str + LLVMIR.call RT.stringNewRef [(lenOper, []), (bytePtr, [])] -- | Generates code for an arbitrary binary operation -binaryOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m Result - -> m Result +binaryOpCodegen + :: LLVMIR.MonadIRBuilder m + => m LLVM.Operand + -> m LLVM.Operand -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) - -> m Result + -> m LLVM.Operand binaryOpCodegen x y op = do - x' <- joinPoint1 x - y' <- joinPoint1 y - res <- op x' y' - pure $ resultValue res + x' <- x + y' <- y + op x' y' -- | Generates code for a numeric binary operation -numBinopCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m Result - -> m Result +numBinopCodegen :: LLVMIR.MonadIRBuilder m + => m LLVM.Operand + -> m LLVM.Operand -> Type -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) -- ^ Float operation -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) -- ^ Integer operation - -> m Result + -> m LLVM.Operand numBinopCodegen x y ty opDouble opInt = case unfix ty of TyNumber numTy -> @@ -246,163 +218,200 @@ numBinopCodegen x y ty opDouble opInt = _ -> error "Invariant violated" -- | Generates code for BinOp -binOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) +binOpCodegen :: LLVMIR.MonadIRBuilder m => BinaryOp -- ^ Operation -> Type -- ^ Type of the inputs - -> m Result - -> m Result - -> m Result + -> m LLVM.Operand + -> m LLVM.Operand + -> m LLVM.Operand binOpCodegen op ty x y = let (floatInstr, intInstr) = case op of Add -> (LLVMIR.fadd, LLVMIR.add) Sub -> (LLVMIR.fsub, LLVMIR.sub) Mul -> (LLVMIR.fmul, LLVMIR.mul) Div -> (LLVMIR.fdiv, LLVMIR.sdiv) - Lt -> ((LLVMIR.fcmp LLVMFP.OLT), (LLVMIR.icmp LLVMIP.SLT)) - Lte -> ((LLVMIR.fcmp LLVMFP.OLE), (LLVMIR.icmp LLVMIP.SLE)) - Equal -> ((LLVMIR.fcmp LLVMFP.OEQ), (LLVMIR.icmp LLVMIP.EQ)) + Lt -> (LLVMIR.fcmp LLVMFP.OLT, LLVMIR.icmp LLVMIP.SLT) + Lte -> (LLVMIR.fcmp LLVMFP.OLE, LLVMIR.icmp LLVMIP.SLE) + Equal -> (LLVMIR.fcmp LLVMFP.OEQ, LLVMIR.icmp LLVMIP.EQ) in numBinopCodegen x y ty floatInstr intInstr -exprCodegen :: TypedExpr -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) Result -exprCodegen = para (go . annGetExpr) +-- | Generates code for a [JValue]. +jvalueCodegen + :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) + => JValue + -> m LLVM.Operand +jvalueCodegen = \case + JVar name -> gets (snd . fromJust . Map.lookup name . tableVars) + JLit l -> literalCodegen l + +-- | Generates code for a CFE. +controlFlowCodegen + :: JValue -- ^ The value to continue control flow with. + -> LLVM.Operand -- ^ The LLVM operand of the value. + -> ControlFlow (AnnExpr '[ 'ExprType]) -- ^ The control flow continuation. + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) () +controlFlowCodegen val valOper = \case + JIf (Cfe trueBr trueCf) (Cfe falseBr falseCf) -> do + LLVMIR.ensureBlock + trueLabel <- LLVMIR.freshName "if_then" + falseLabel <- LLVMIR.freshName "if_else" + LLVMIR.condBr valOper trueLabel falseLabel + LLVMIR.emitBlockStart trueLabel + (trueVal, trueOper) <- jexprCodegen trueBr + _ <- controlFlowCodegen trueVal trueOper trueCf + LLVMIR.emitBlockStart falseLabel + (falseVal, falseOper) <- jexprCodegen falseBr + _ <- controlFlowCodegen falseVal falseOper falseCf + pure () + JCase branches -> do + LLVMIR.ensureBlock + defLabel <- LLVMIR.freshName "case_default" + allCaseLabels <- forM branches $ \case + BrAdt name _ _ -> + let labelName = "case_" <> fromString (Text.unpack name) in + (name, ) <$> LLVMIR.freshName labelName + -- Assume the symbol table and type information is correct + dataName <- (\case { TyAdt n -> n; _ -> error "" }) <$> lookupValueType val + ctors <- gets ((\(_,_,cs) -> cs) . fromJust . Map.lookup dataName . tableAdts) + let ctorNames = ctorGetName <$> ctors + let usedLabelTriples = filter (\(_, (n, _)) -> n `elem` ctorNames) $ [0..] `zip` allCaseLabels + let jumpTable = [(LLVMC.Int 32 i, l) | (i, (_, l)) <- usedLabelTriples] + tag <- LLVMIR.extractValue valOper [0] + dataPtr <- LLVMIR.extractValue valOper [1] + LLVMIR.switch tag defLabel jumpTable + forM_ (usedLabelTriples `zip` branches) $ \((_, (ctorName, label)), br) -> do + let expr = branchGetExpr br + let cf = branchGetControlFlow br + (_, ctorLLVMName, index) <- gets (fromJust . Map.lookup ctorName . tableCtors) + let Ctor _ argTys = ctors !! index + let bindingPairs = branchGetBindings br `zip` (typeToLLVM <$> argTys) + LLVMIR.emitBlockStart label + ctorPtr <- LLVMIR.bitcast dataPtr (LLVM.ptr (LLVM.NamedTypeReference ctorLLVMName)) + ctorPtrOffset <- LLVMIR.int32 0 + bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, llvmTy)) -> do + ctorPtrIndex <- LLVMIR.int32 i + -- Need to bitcast the ptr type because we need a concrete type. We + -- also need to load the data immediately because of how variables + -- are implemented. + v <- LLVMIR.gep ctorPtr [ctorPtrOffset, ctorPtrIndex] + >>= flip LLVMIR.bitcast (LLVM.ptr llvmTy) + >>= flip LLVMIR.load 0 + ty <- lookupValueType (JVar n) + pure (n, (ty, v)) + let updateTable t = t { tableVars = Map.union (tableVars t) (Map.fromList bindings) } + (exprVal, exprOper) <- jexprCodegen expr + localCodegenTable updateTable (controlFlowCodegen exprVal exprOper cf) + LLVMIR.emitBlockStart defLabel + LLVMIR.unreachable + JJump lbl -> do + v <- jvalueCodegen val + jvals <- gets tableJoinValues + block <- LLVMIR.currentBlock + let f (n, jvs) = Just (n, (v, block) : jvs) + let updJvals = Map.update f lbl jvals + modify (\t -> t { tableJoinValues = updJvals }) + llvmLabel <- gets (fst . fromJust . Map.lookup lbl . tableJoinValues) + LLVMIR.br llvmLabel + +-- | Generates code for a given callable (i.e. in a JApp) +callableCodegen + :: Callable -- ^ The callable + -> [JValue] -- ^ The argument values + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand +callableCodegen callable args = case callable of + CFunc name -> do + fn <- fromJust <$> lookupName name + ops <- traverse jvalueCodegen args + LLVMIR.call fn [(x, []) | x <- ops] + CBinOp op -> case args of + [x, y] -> do + xTy <- lookupValueType x + binOpCodegen op (Fix xTy) (jvalueCodegen x) (jvalueCodegen y) + _ -> error $ "callableCodegen: expected 2 args to CBinOp, got " ++ show (length args) + CCast targetNum -> case args of + [x] -> do + ty <- lookupValueType x + case ty of + TyNumber srcNum -> jvalueCodegen x >>= castOpCodegen srcNum targetNum + _ -> error $ "callableCodegen: expected CCast argument to be of numeric type, got " ++ show ty + _ -> error $ "callableCodegen: expected 1 args to CCast, got " ++ show (length args) + CCtor name -> do + -- Assume constructor exists, since typechecker should verify it anyways + (dataTy, ctorName, ctorIndex) <- gets (fromJust . Map.lookup name . tableCtors) + ctorIndex' <- LLVMIR.int32 (fromIntegral ctorIndex) + let tagStruct1 = LLVM.ConstantOperand $ LLVMC.Undef (LLVM.NamedTypeReference dataTy) + -- Tag + tagStruct2 <- LLVMIR.insertValue tagStruct1 ctorIndex' [0] + -- Data pointer + let ctorTy = LLVM.NamedTypeReference ctorName + let nullptr = LLVM.ConstantOperand (LLVMC.Null (LLVM.ptr ctorTy)) + -- Use offsets to calculate struct size + ctorStructSize <- flip LLVMIR.ptrtoint LLVM.i64 + =<< LLVMIR.gep nullptr + =<< pure <$> LLVMIR.int32 0 + -- Allocate memory for constructor. + -- For now, use "leak memory" as an implementation strategy for deallocation. + ctorStructPtr <- LLVMIR.call RT.mallocRef [(ctorStructSize, [])] >>= + flip LLVMIR.bitcast (LLVM.ptr ctorTy) + values <- traverse jvalueCodegen args + indices <- traverse LLVMIR.int32 [0..fromIntegral (length values - 1)] + ptrOffset <- LLVMIR.int32 0 + forM_ (indices `zip` values) $ \(index, v) -> do + valuePtr <- LLVMIR.gep ctorStructPtr [ptrOffset, index] + LLVMIR.store valuePtr 0 v + pure () + dataPtr <- LLVMIR.bitcast ctorStructPtr (LLVM.ptr LLVM.i8) + fullyInit <- LLVMIR.insertValue tagStruct2 dataPtr [1] + pure fullyInit + CPrint -> case args of + [val] -> do + let fmtStr = encodeUtf8 (fromString "Print: %s\n\0") + let fmtStrLen = toInteger (BS.length fmtStr) + let fmtStrData = LLVMC.Int 8 . toInteger <$> BS.unpack fmtStr + fmtStrOper <- LLVMIR.array fmtStrData + fmtStrPtr <- LLVMIR.alloca (LLVM.ArrayType (fromInteger fmtStrLen) LLVM.i8) Nothing 0 + _ <- LLVMIR.store fmtStrPtr 0 fmtStrOper + printf <- gets tablePrintf + str <- jvalueCodegen val + exprCstring <- LLVMIR.call RT.stringCstringRef [(str, [])] + fmtStrPtr' <- LLVMIR.bitcast fmtStrPtr (LLVM.ptr LLVM.i8) + _ <- LLVMIR.call printf [(fmtStrPtr', []), (exprCstring, [])] + LLVMIR.int64 0 + _ -> error $ "callableCodegen: expected 1 args to CPrint, got " ++ show (length args) + CFunRef name -> gets (fromJust . Map.lookup name . tableFuns) + +-- | Generates code for a [JExpr] +jexprCodegen + :: AnnExpr '[ 'ExprType] + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) (JValue, LLVM.Operand) +jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix where - go :: ExprF (TypedExpr, LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) Result) - -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) Result - go = \case - Lit l -> literalCodegen l - BinOp op (ex, x) (_, y) -> binOpCodegen op (annGetAnn (unfix ex)) x y - If (_, condInstr) (_, t1Instr) (_, t2Instr) -> do - LLVMIR.ensureBlock - cond <- joinPoint1 condInstr - trueLabel <- LLVMIR.freshName "if_then" - falseLabel <- LLVMIR.freshName "if_else" - LLVMIR.condBr cond trueLabel falseLabel - LLVMIR.emitBlockStart trueLabel - t1Res <- t1Instr >>= resultEnsureBranch - jp <- gets tableCurrentJoin - LLVMIR.br jp - LLVMIR.emitBlockStart falseLabel - t2Res <- t2Instr >>= resultEnsureBranch - LLVMIR.br jp - resultCombine t1Res t2Res - Cons name argPairs -> do - let args = snd <$> argPairs - -- Assume constructor exists, since typechecker should verify it anyways - (dataTy, ctorName, ctorIndex) <- asks (fromJust . Map.lookup name . tableCtors) - ctorIndex' <- LLVMIR.int32 (fromIntegral ctorIndex) - let tagStruct1 = LLVM.ConstantOperand $ LLVMC.Undef (LLVM.NamedTypeReference dataTy) - -- Tag - tagStruct2 <- LLVMIR.insertValue tagStruct1 ctorIndex' [0] - -- Data pointer - let ctorTy = LLVM.NamedTypeReference ctorName - let nullptr = LLVM.ConstantOperand (LLVMC.Null (LLVM.ptr ctorTy)) - -- Use offsets to calculate struct size - ctorStructSize <- flip LLVMIR.ptrtoint LLVM.i64 - =<< LLVMIR.gep nullptr - =<< pure <$> LLVMIR.int32 0 - -- Allocate memory for constructor. - -- For now, use "leak memory" as an implementation strategy for deallocation. - ctorStructPtr <- LLVMIR.call RT.mallocRef [(ctorStructSize, [])] >>= - flip LLVMIR.bitcast (LLVM.ptr ctorTy) - if null args - then pure () - else do - values <- joinPoint (NE.fromList <$> sequence args) - indices <- traverse LLVMIR.int32 [0..fromIntegral (length values - 1)] - ptrOffset <- LLVMIR.int32 0 - forM_ (indices `zip` NE.toList values) $ \(index, v) -> do - valuePtr <- LLVMIR.gep ctorStructPtr [ptrOffset, index] - LLVMIR.store valuePtr 0 v - pure () - dataPtr <- LLVMIR.bitcast ctorStructPtr (LLVM.ptr LLVM.i8) - fullyInit <- LLVMIR.insertValue tagStruct2 dataPtr [1] - pure $ resultValue fullyInit - Case branches (valExpr, valM) -> do - LLVMIR.ensureBlock - defLabel <- LLVMIR.freshName "case_default" - endLabel <- gets tableCurrentJoin - val <- joinPoint1 valM - allCaseLabels <- forM branches $ \case - BrAdt name _ _ -> - let labelName = "case_" <> fromString (Text.unpack name) in - (name, ) <$> LLVMIR.freshName labelName - -- Assume the symbol table and type information is correct - let dataName = fromJust $ - case unfix . annGetAnn . unfix $ valExpr of { TyAdt n -> Just n; _ -> Nothing } - ctors <- asks ((\(_,_,cs) -> cs) . fromJust . Map.lookup dataName . tableAdts) - let ctorNames = ctorGetName <$> ctors - let usedLabelTriples = filter (\(_, (n, _)) -> n `elem` ctorNames) $ [0..] `zip` allCaseLabels - let jumpTable = [(LLVMC.Int 32 i, l) | (i, (_, l)) <- usedLabelTriples] - tag <- LLVMIR.extractValue val [0] - dataPtr <- LLVMIR.extractValue val [1] - LLVMIR.switch tag defLabel jumpTable - resVals <- forM (usedLabelTriples `zip` branches) $ \((_, (ctorName, label)), br) -> do - let expr = snd (branchGetExpr br) - (_, ctorLLVMName, index) <- asks (fromJust . Map.lookup ctorName . tableCtors) - let Ctor _ argTys = ctors !! index - let bindingPairs = branchGetBindings br `zip` (typeToLLVM <$> argTys) - LLVMIR.emitBlockStart label - ctorPtr <- LLVMIR.bitcast dataPtr (LLVM.ptr (LLVM.NamedTypeReference ctorLLVMName)) - ctorPtrOffset <- LLVMIR.int32 0 - bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, ty)) -> do - ctorPtrIndex <- LLVMIR.int32 i - -- Need to bitcast the ptr type because we need a concrete type. We - -- also need to load the data immediately because of how variables - -- are implemented. - v <- LLVMIR.gep ctorPtr [ctorPtrOffset, ctorPtrIndex] - >>= flip LLVMIR.bitcast (LLVM.ptr ty) - >>= flip LLVMIR.load 0 - pure (n, v) - let updateTable t = t { tableVars = Map.union (tableVars t) (Map.fromList bindings) } - res <- local updateTable expr >>= resultEnsureBranch - LLVMIR.br endLabel - pure res - LLVMIR.emitBlockStart defLabel - LLVMIR.unreachable - foldM resultCombine (head resVals) (tail resVals) - -- resultCombine pure (sconcat (NE.fromList resVals)) - Let name (_, valM) (_, exprM) -> do - val <- joinPoint1 valM - local (\t -> t { tableVars = Map.insert name val (tableVars t) }) exprM - Var name -> do - -- Assume codegen table is correct - value <- gets (fromJust . Map.lookup name . tableVars) - pure $ resultValue value - App name argsM -> do - args <- traverse joinPoint1 (snd <$> argsM) - -- Assume that name is callable (e.g. either a static function or a - -- function pointer) - fn <- fromJust <$> lookupName name - resultValue <$> LLVMIR.call fn [(a, []) | a <- args] - FunRef name -> do - fn <- gets (fromJust . Map.lookup name . tableFuns) - pure (resultValue fn) - Cast (origExpr, exprM) num -> do - let origNum = case unfix (annGetAnn (unfix origExpr)) of - TyNumber n -> n - _ -> error "codegen: attempting to cast non-numeric" - expr <- joinPoint1 exprM - resultValue <$> castOpCodegen origNum num expr - Print (_, exprM) -> do - let fmtStr = encodeUtf8 (fromString "Print: %s\n\0") - let fmtStrLen = toInteger (BS.length fmtStr) - let fmtStrData = LLVMC.Int 8 . toInteger <$> BS.unpack fmtStr - fmtStrOper <- LLVMIR.array fmtStrData - fmtStrPtr <- LLVMIR.alloca (LLVM.ArrayType (fromInteger fmtStrLen) LLVM.i8) Nothing 0 - _ <- LLVMIR.store fmtStrPtr 0 fmtStrOper - printf <- gets tablePrintf - str <- joinPoint1 exprM - exprCstring <- LLVMIR.call RT.stringCstringRef [(str, [])] - fmtStrPtr' <- LLVMIR.bitcast fmtStrPtr (LLVM.ptr LLVM.i8) - _ <- LLVMIR.call printf [(fmtStrPtr', []), (exprCstring, [])] - retval <- LLVMIR.int64 0 - pure $ resultValue retval - lookupName :: MonadState CodegenTable m - => Text - -> m (Maybe LLVM.Operand) - lookupName name = - gets $ \t -> - Map.lookup name (tableVars t) <|> Map.lookup name (tableFuns t) + go :: TypeF Type + -> JExprF (AnnExpr '[ 'ExprType ]) + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) (JValue, LLVM.Operand) + go exprTy = \case + JVal v -> (v,) <$> jvalueCodegen v + JLet name val next -> do + oper <- jvalueCodegen val + _ <- bindVariable name exprTy oper + jexprCodegen next + JJoin lbl varName (Cfe expr cf) next -> do + llvmLabel <- LLVMIR.freshName (fromString (Text.unpack lbl)) + let addJoinEntry t = + t { tableJoinValues = Map.insert lbl (llvmLabel, []) (tableJoinValues t) } + oldJoinEntries <- gets tableJoinValues + (lastVal, lastValOper) <- jexprCodegen expr + _ <- localCodegenTable addJoinEntry (controlFlowCodegen lastVal lastValOper cf) + (_, joinValues) <- gets (fromJust . Map.lookup lbl . tableJoinValues) + modify (\t -> t { tableJoinValues = oldJoinEntries }) + LLVMIR.emitBlockStart llvmLabel + op <- LLVMIR.phi joinValues + bindVariable varName exprTy op + jexprCodegen next + JApp varName callable args next -> do + oper <- callableCodegen callable args + bindVariable varName exprTy oper + jexprCodegen next -- | Generates code for numeric casting castOpCodegen :: LLVMIR.MonadIRBuilder m => Numeric -> Numeric -> LLVM.Operand -> m LLVM.Operand @@ -457,7 +466,7 @@ adtToLLVM adtName ctors = do funToLLVM :: Text -> [(Text, Type)] -> Type - -> TypedExpr + -> AnnExpr '[ 'ExprType ] -> LLVMIR.ModuleBuilderT Codegen LLVM.Operand funToLLVM name params ty body = let name' = if name == "main" then "__simpl_main" else name @@ -466,28 +475,25 @@ funToLLVM name params ty body = fparams = [(typeToLLVM t, fromString (Text.unpack n)) | (n, t) <- params] in mdo foper <- LLVMIR.function fname fparams ftype $ \args -> do LLVMIR.ensureBlock - endLabel <- LLVMIR.freshName "function_end" -- We need to make sure we don't pollute other function scopes oldVars <- gets tableVars - let updVars t = tableVars t `Map.union` Map.fromList ((fst <$> params) `zip` args) - modify (\t -> t { tableCurrentJoin = endLabel - , tableFuns = Map.insert name foper (tableFuns t) + let updVars t = tableVars t `Map.union` Map.fromList [(n, (unfix vty, op)) | ((n, vty), op) <- params `zip` args] + modify (\t -> t { tableFuns = Map.insert name foper (tableFuns t) , tableVars = updVars t }) - retval <- joinPoint1 (exprCodegen body) + (_, retval) <- jexprCodegen body -- Restore old scope modify (\t -> t { tableVars = oldVars }) LLVMIR.ret retval pure foper -- | Generate code for the entire module -moduleCodegen :: [Decl Expr] - -> SymbolTable TypedExpr +moduleCodegen :: String + -> SymbolTable (AnnExpr '[ 'ExprType ]) -> LLVMIR.ModuleBuilderT Codegen () -moduleCodegen decls symTab = mdo +moduleCodegen srcCode symTab = mdo -- Message is "Hi\n" (with null terminator) (_, msg) <- staticString ".message" "Hello world!\n" (_, resultFmt) <- staticString ".resultformat" "Result: %i\n" - let srcCode = unlines $ show . pretty <$> decls (_, exprSrc) <- staticString ".sourcecode" $ "Source code: " ++ srcCode ++ "\n" RT.emitRuntimeDecls modify (\t -> t { tablePrintf = RT.printfRef }) @@ -512,10 +518,10 @@ moduleCodegen decls symTab = mdo LLVMIR.ret retcode pure () -runCodegen :: CompilerOpts -> [Decl Expr] -> SymbolTable TypedExpr -> LLVM.Module -runCodegen opts decls symTab +runCodegen :: CompilerOpts -> String -> SymbolTable (AnnExpr '[ 'ExprType]) -> LLVM.Module +runCodegen opts srcCode symTab = runIdentity . flip evalStateT emptyCodegenTable . unCodegen . LLVMIR.buildModuleT "simpl.ll" - $ lift (initCodegenTable opts symTab) >> moduleCodegen decls symTab + $ lift (initCodegenTable opts symTab) >> moduleCodegen srcCode symTab diff --git a/src/Simpl/Compiler.hs b/src/Simpl/Compiler.hs index b635063..491db2b 100644 --- a/src/Simpl/Compiler.hs +++ b/src/Simpl/Compiler.hs @@ -4,16 +4,19 @@ module Simpl.Compiler where import Control.Monad.Except import Control.Monad.State +import Data.Text.Prettyprint.Doc (pretty) import qualified LLVM.AST as LLVM import qualified LLVM.Module as LLVMM import LLVM.Context import Simpl.Ast +import Simpl.AstToJoinIR import Simpl.Backend.Codegen (runCodegen) import Simpl.CompilerOptions import Simpl.SymbolTable import Simpl.Typing (TypeError, runTypecheck, checkType, withExtraVars) +import Simpl.JoinIR.Syntax (unannotate) import Paths_simpl_lang -- | Main error type, aggregating all error types. @@ -53,4 +56,10 @@ fullCompilerPipeline options srcFile@(SourceFile _name decls) = let newSTTypecheck = sequence $ symTabMapExprs tycheckFuns symTable typedSymTable <- MkCompilerMonad . lift . withExceptT ErrTypecheck $ liftEither (runTypecheck symTable newSTTypecheck) - pure $ runCodegen options decls typedSymTable + let jSymTable = astToJoinIR typedSymTable + -- TODO: Add compiler flag to dump JoinIR + -- liftIO $ forM_ (symTabFuns jSymTable) $ \(args, ty, expr) -> do + -- print (pretty args <> pretty " :: " <> pretty ty) + -- print (pretty (unannotate expr)) + let srcCode = unlines [show (pretty d) | d <- decls] + pure $ runCodegen options srcCode jSymTable diff --git a/src/Simpl/JoinIR/Syntax.hs b/src/Simpl/JoinIR/Syntax.hs new file mode 100644 index 0000000..47e1b93 --- /dev/null +++ b/src/Simpl/JoinIR/Syntax.hs @@ -0,0 +1,252 @@ +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} +-- Vinyl stuff +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} + +{-| +Module : Simpl.JoinIR.Syntax +Description : AST for the JoinIR + +Defines the abstract syntax tree for JoinIR, an IR for SimPL based on the IR +presented in /Compiling without Continuations/ by Luke Maurer, Zena Ariola, Paul +Downen, and Simon Peyton Jones (PLDI '17). +-} +module Simpl.JoinIR.Syntax where + +import Data.Text (Text) +import Data.Text.Prettyprint.Doc (Pretty, pretty, (<>), (<+>)) +import qualified Data.Text.Prettyprint.Doc as PP +import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..), Type) +import Text.Show.Deriving (deriveShow1) +import Data.Functor.Foldable +import qualified Data.Vinyl as V +import qualified Data.Vinyl.TypeLevel as V +import Data.Singletons.TH (genSingletons) + +type Name = Text + +type Label = Text + +-- | An operation applied to some arguments +data Callable + = CFunc !Name -- ^ Function + | CBinOp !BinaryOp -- ^ Binary operator + | CCast !Numeric -- ^ Numeric cast + | CCtor !Name -- ^ ADT constructor + | CPrint -- ^ Print string (temporary) + | CFunRef !Name -- ^ Static function reference + deriving (Show) + +-- | A value +data JValue + -- | A variable + = JVar !Name + -- | A literal + | JLit !Literal + deriving (Show, Eq) + + +-- | Represents how a value at the end of a control flow branch should be handled. +data ControlFlow a + -- | If expression on the given value, with a true branch and a false branch + = JIf !(Cfe a) !(Cfe a) + + -- | Case expression on the given value + | JCase ![JBranch a] + + -- | Jump to the given enclosing join point with the given value. + | JJump !Text + deriving (Functor, Foldable, Traversable, Show) + + +data JBranch a + = BrAdt Name [Name] !(Cfe a) -- ^ Destructure algebraic data type + deriving (Functor, Foldable, Traversable, Show) + +branchGetExpr :: JBranch a -> a +branchGetExpr = \case + BrAdt _ _ (Cfe e _) -> e + +branchGetBindings :: JBranch a -> [Text] +branchGetBindings = \case + BrAdt _ vars _ -> vars + +branchGetControlFlow :: JBranch a -> ControlFlow a +branchGetControlFlow = \case + BrAdt _ _ (Cfe _ cf) -> cf + + +-- | The JoinIR expression type. Syntactically, it is in ANF-form with explicit +-- join points. +data JExprF a + -- | A value + = JVal !JValue + + -- | A value binding + | JLet !Name !JValue !a + + -- | A join point. Consists of a label, the variable representing the joined + -- value, the expression to join, and the next expression. + | JJoin !Label !Name !(Cfe a) !a + + -- | Apply the callable to the arguments, bind the result to the given name, + -- and continue to the next expression. + | JApp !Name !Callable ![JValue] !a + deriving (Functor, Foldable, Traversable, Show) + + +data Cfe a = Cfe !a !(ControlFlow a) + deriving (Functor, Foldable, Traversable, Show) + +$(deriveShow1 ''JBranch) +$(deriveShow1 ''ControlFlow) +$(deriveShow1 ''Cfe) +$(deriveShow1 ''JExprF) + +type JExpr = Fix JExprF + +jexprGetVal :: JExprF a -> JValue +jexprGetVal = \case + JVal v -> v + JLet n _ _ -> JVar n + JJoin _ n _ _ -> JVar n + JApp n _ _ _ -> JVar n + +instance Pretty Callable where + pretty = \case + CFunc name -> pretty name + CBinOp op -> pretty op + CCast num -> "cast[" <> pretty num <> "]" + CCtor name -> pretty name + CPrint -> "print" + CFunRef name -> "funref[" <> pretty name <> "]" + +instance Pretty JValue where + pretty = \case + JVar n -> pretty n + JLit l -> pretty l + +instance Pretty a => Pretty (JBranch a) where + pretty (BrAdt ctorName varNames cfe) = + PP.hang 2 $ PP.hsep (pretty <$> brPart) <> PP.softline <> pretty cfe + where brPart = [ctorName] ++ varNames ++ ["=>"] + +instance Pretty a => Pretty (ControlFlow a) where + pretty = \case + JIf trueBr falseBr -> + PP.hang 3 $ "if" <+> PP.sep + [ "then" <> PP.softline <> PP.align (flatParens (pretty trueBr)) + , "else" <> PP.softline <> PP.align (flatParens (pretty falseBr)) ] + JCase brs -> + PP.hang 2 $ "case" <+> "of" <> PP.hardline <> + (PP.vsep $ pretty <$> brs) + JJump lbl -> PP.hsep ["jump", pretty lbl] + where + flatParens d = PP.flatAlt d (PP.parens d) + +instance Pretty JExpr where + pretty = f . unfix + where + f :: JExprF JExpr -> PP.Doc ann + f = \case + JVal v -> pretty v + JLet n v next -> PP.hsep ["let", pretty n, "=", pretty v, "in"] <> PP.softline <> pretty next + JJoin lbl n joinbl next -> + (PP.group . PP.hang 2 $ PP.hsep ["join" <> PP.enclose "[" "]" (pretty lbl), pretty n, "="] + <> PP.flatAlt PP.hardline " " <> pretty joinbl) + <> PP.hardline <> "in" <+> pretty next + JApp name clbl args next -> + PP.hsep (["let app", pretty name, "=", pretty clbl] ++ (pretty <$> args) ++ ["in"]) + <> PP.hardline <> pretty next + +instance Pretty a => Pretty (Cfe a) where + pretty (Cfe expr cf) = + PP.align (pretty expr <> ";" <> line <> pretty cf) + where + line = case cf of + JIf _ _ -> PP.hardline + JCase _ -> PP.hardline + JJump _ -> PP.softline + +-- * Annotated [JExpr]s +-- +-- Because it's possible to have many different annotations on a single AST, we +-- define a "single" annotated AST that is annotated with an extensible record +-- type at each node. Thus, we can add annotations by extending the record with +-- more fields. + +-- | Possible annotations on a [JExpr] +data JFields = ExprType deriving (Show) + +genSingletons [ ''JFields ] + +-- | Maps each possible annotation label to a type +type family ElF (f :: JFields) :: * where + ElF 'ExprType = Type + +-- | Wrapper for annotation fields +newtype Attr f = Attr { _unAttr :: ElF f } + +deriving instance Show (Attr 'ExprType) + +-- | Helper function for create annotation fields +(=::) :: sing f -> ElF f -> Attr f +_ =:: x = Attr x + +-- | Annotations for a typed [JExpr] +type Typed = '[ 'ExprType ] + +-- | Creates a type field whose value is the given type +withType :: Type -> Attr 'ExprType +withType ty = SExprType =:: ty + +-- | A [JExprF] annotated with some data. +data AnnExprF fields a = AnnExprF { annGetAnn :: V.Rec Attr fields, annGetExpr :: JExprF a } + deriving (Functor, Foldable, Traversable) + +type AnnExpr fields = Fix (AnnExprF fields) + +-- | Converts a [JExprF] to an "unannotated" [AnnExprF] +toAnnExprF :: JExprF a -> AnnExprF '[] a +toAnnExprF expr = AnnExprF { annGetAnn = V.RNil, annGetExpr = expr } + +-- | Converts a [JExpr] to an "unannotated" [AnnExpr] +toAnnExpr :: JExpr -> AnnExpr '[] +toAnnExpr expr = cata (Fix . toAnnExprF) expr + +-- | Removes all annotations from an [AnnExpr] +unannotate :: AnnExpr fields -> JExpr +unannotate = cata (Fix . annGetExpr) + +-- | Adds the given annotation to the expression +addField :: Attr f -> AnnExprF flds a -> AnnExprF (f ': flds) a +addField attr expr = expr { annGetAnn = attr V.:& annGetAnn expr } + +-- | Retrieves the type information stored in a typed [AnnExprF] +getType :: V.RElem 'ExprType fields (V.RIndex 'ExprType fields) => AnnExprF fields a -> Type +getType = _unAttr . V.rget @'ExprType . annGetAnn + +-- * Misc + +-- | Helper for inspecting an [AnnExpr] +newtype PrettyJExprF a = PrettyJExprF (String, JExprF a) deriving (Show) +type PrettyJExpr = Fix PrettyJExprF +$(deriveShow1 ''PrettyJExprF) + +-- | Converts an [AnnExpr] into a [PrettyJExpr] so that it can be shown. +prettyAnnExpr :: Show (V.Rec Attr fields) => AnnExpr fields -> PrettyJExpr +prettyAnnExpr = cata $ \expr -> + Fix (PrettyJExprF (show (annGetAnn expr), annGetExpr expr)) diff --git a/src/Simpl/JoinIR/Verify.hs b/src/Simpl/JoinIR/Verify.hs new file mode 100644 index 0000000..4b92de7 --- /dev/null +++ b/src/Simpl/JoinIR/Verify.hs @@ -0,0 +1,145 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-| +Module : Simpl.JoinIR.Verify +Description : Verifies validity of a JoinIR AST +-} +module Simpl.JoinIR.Verify + (verify, VerifyCtx(..), emptyCtx, VerifyError(..)) where + +import Control.Monad.Reader +import Control.Monad.Except +import Data.Functor.Foldable (cata) +import Data.Functor.Identity +import Data.Text (Text) +import Data.Set (Set) +import qualified Data.Set as Set + +import Simpl.JoinIR.Syntax + +-- * Verification Monad + +-- | Information needed to perform verification. +data VerifyCtx = VerifyCtx + { verifyVars :: Set Text + , verifyLabels :: Set Text + } deriving (Eq, Show) + +-- | Default verfication context; has no bound labels or variables. +emptyCtx :: VerifyCtx +emptyCtx = VerifyCtx + { verifyVars = Set.empty + , verifyLabels = Set.empty } + +-- | Adds the given variable to the context. +ctxWithVar :: Text -> VerifyCtx -> VerifyCtx +ctxWithVar name ctx = ctx { verifyVars = Set.insert name (verifyVars ctx) } + +-- | Adds the given label to the context. +ctxWithLabel :: Text -> VerifyCtx -> VerifyCtx +ctxWithLabel lbl ctx = ctx { verifyLabels = Set.insert lbl (verifyLabels ctx) } + +-- | Monad for performing verification. +newtype VerifyT m a = + VerifyT { unVerify :: ReaderT VerifyCtx (ExceptT VerifyError m) a } + deriving ( Functor + , Applicative + , Monad + , MonadReader VerifyCtx + , MonadError VerifyError) + +type Verify = VerifyT Identity + +runVerifyT :: VerifyT m a -> VerifyCtx -> m (Either VerifyError a) +runVerifyT m ctx + = runExceptT + . flip runReaderT ctx + . unVerify + $ m + +runVerify :: Verify a -> VerifyCtx -> Either VerifyError a +runVerify m ctx = runIdentity (runVerifyT m ctx) + +-- * Verification errors + +-- | Errors that cause verification to fail. +data VerifyError = VarRedefinition Text + | LabelRedefinition Text + | NoSuchLabel Text + | NoSuchVar Text + deriving (Show, Eq) + +-- | Verify a JoinIR AST using the given context. +verify :: VerifyCtx -> AnnExpr a -> Either VerifyError () +verify ctx expr = runVerify (doVerifyExpr expr) ctx + +-- | Throw an error if the variable is already bound +checkUnboundVar :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => Text -> m () +checkUnboundVar var = + asks (Set.member var . verifyVars) >>= \case + True -> throwError $ VarRedefinition var + False -> pure () + +-- | Verify a JoinIR value. +doVerifyValue :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => JValue + -> m () +doVerifyValue = \case + JVar name -> asks (Set.member name . verifyVars) >>= \case + True -> pure () + False -> throwError $ NoSuchVar name + JLit _ -> pure () + +-- | Verify a JoinIR expression. +doVerifyExpr :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => AnnExpr a + -> m () +doVerifyExpr = cata (go . annGetExpr) + where + go :: (MonadError VerifyError m, MonadReader VerifyCtx m) => JExprF (m ()) -> m () + go = \case + JVal v -> doVerifyValue v + JLet name val nextM -> do + doVerifyValue val + checkUnboundVar name + local (ctxWithVar name) nextM + JJoin lbl name cfe nextM -> do + asks (Set.member lbl . verifyLabels) >>= \case + True -> throwError $ LabelRedefinition lbl + False -> pure () + local (ctxWithLabel lbl) (doVerifyCfe cfe) + checkUnboundVar name + local (ctxWithVar name) nextM + JApp name _ args nextM -> do + -- Note: we ignore the callable for now + _ <- traverse doVerifyValue args + checkUnboundVar name + local (ctxWithVar name) nextM + +-- | Verify a JoinIR CFE. +doVerifyCfe :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => Cfe (m ()) + -> m () +doVerifyCfe (Cfe exprM cf) = do + exprM + case cf of + JIf trueCfe falseCfe -> do + doVerifyCfe trueCfe + doVerifyCfe falseCfe + JCase branches -> traverse doVerifyBranch branches >> pure () + JJump lbl -> + asks (Set.member lbl . verifyLabels) >>= \case + True -> pure () + False -> throwError $ NoSuchLabel lbl + +-- | Verify a branch in JoinIR. +doVerifyBranch :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => JBranch (m ()) + -> m () +doVerifyBranch (BrAdt name args cfe) = do + checkUnboundVar name + _ <- traverse checkUnboundVar args + doVerifyCfe cfe diff --git a/src/Simpl/SymbolTable.hs b/src/Simpl/SymbolTable.hs index e765f1b..d1061bb 100644 --- a/src/Simpl/SymbolTable.hs +++ b/src/Simpl/SymbolTable.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE TupleSections #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} @@ -43,6 +44,15 @@ symTabMapExprs :: (([(Text, Type)], Type, e) -> ([(Text, Type)], Type, e')) -- ^ -> SymbolTable e' symTabMapExprs f t = t { symTabFuns = Map.map f (symTabFuns t) } +symTabTraverseExprs + :: Monad m + => (([(Text, Type)], Type, e) -> ([(Text, Type)], Type, m e')) -- ^ Map over functions + -> SymbolTable e + -> m (SymbolTable e') +symTabTraverseExprs f t = do + upd <- traverse ((\(args, ty, me) -> (args, ty,) <$> me) . f) (symTabFuns t) + pure $ t { symTabFuns = upd } + -- | Searches for the given constructor, returning the name of the ADT, the -- constructor, and the index of the constructor. symTabLookupCtor :: Text -> SymbolTable e -> Maybe (Type, Constructor, Int) diff --git a/stack.yaml b/stack.yaml index d65e916..e731951 100644 --- a/stack.yaml +++ b/stack.yaml @@ -43,6 +43,8 @@ packages: extra-deps: - llvm-hs-pure-7.0.0 - llvm-hs-7.0.1 +- vinyl-0.11.0 +- monad-supply-0.7 # Override default flag values for local packages and extra-deps # flags: {} diff --git a/test/JoinVerifySpec.hs b/test/JoinVerifySpec.hs new file mode 100644 index 0000000..4ab05d3 --- /dev/null +++ b/test/JoinVerifySpec.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedStrings #-} +module JoinVerifySpec + (joinVerifyTests) where + +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Functor.Foldable (Fix(..)) +import Test.Tasty +import Test.Tasty.HUnit +import Simpl.Ast (Type, TypeF(..), Literal(..), Numeric(..)) +import Simpl.JoinIR.Syntax +import Simpl.JoinIR.Verify + +simpleCtx :: VerifyCtx +simpleCtx = emptyCtx { verifyVars = Set.singleton "x" } + +joinVerifyTests :: TestTree +joinVerifyTests = testGroup "JoinIR verify tests" + [ testCase "Valid if-cfe verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx (toAnnExpr goodIfExpr)) + , testCase "Invalid if-cfe fails to verify" $ + assertEqual "" (Left $ NoSuchLabel "badlabel") (verify emptyCtx (toAnnExpr badIfExpr)) + , joinBindingTests + ] + +goodIfExpr :: JExpr +goodIfExpr = Fix $ + JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) + where + intVal = Fix . JVal . JLit . LitInt + ifE = Cfe (intVal 5) $ + JIf (Cfe (intVal 10) (JJump "label")) + (Cfe (intVal 5) (JJump "label")) + +badIfExpr :: JExpr +badIfExpr = Fix $ + JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) + where + intVal = Fix . JVal . JLit . LitInt + ifE = Cfe (intVal 5) $ + JIf (Cfe (intVal 10) (JJump "badlabel")) + (Cfe (intVal 5) (JJump "label")) + + +joinBindingTests :: TestTree +joinBindingTests = testGroup "JoinIR variable, label binding tests" + [ testCase "Known variable verifies successfully" $ + assertEqual "" (Right ()) (verify simpleCtx xVal) + , testCase "Unknown variable fails to verify" $ + assertEqual "" (Left $ NoSuchVar "x") (verify emptyCtx xVal) + , testCase "Let binding of unbound variable verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx letExpr) + , testCase "Let binding of bound variable fails to verify" $ + assertEqual "" (Left $ VarRedefinition "x") (verify simpleCtx letExpr) + , testCase "App binding of unbound variable verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx appPrintExpr) + , testCase "App binding of bound variable fails to verify" $ + assertEqual "" (Left $ VarRedefinition "x") (verify simpleCtx appPrintExpr) + , testCase "Join binding of unbound variable, label verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx joinSimpleExpr) + , testCase "Join binding of bound variable, unbound label fails to verify" $ + assertEqual "" (Left $ VarRedefinition "x") (verify simpleCtx joinSimpleExpr) + , testCase "Join binding of unbound variable, bound label fails to verify" $ + assertEqual "" (Left $ LabelRedefinition "lbl") $ + verify (emptyCtx { verifyLabels = Set.singleton "lbl" }) joinSimpleExpr + , testCase "Jump to unknown label fails to verify" $ + assertEqual "" (Left $ NoSuchLabel "badlbl") (verify simpleCtx badJumpExpr) + ] + where + mke = toAnnExpr . Fix + mkef = Fix . toAnnExprF + xVal = mke $ JVal (JVar "x") + yVal = mke $ JVal (JVar "y") + letExpr = mkef $ JLet "x" (JLit (LitInt 5)) xVal + appPrintExpr = mkef $ JApp "x" CPrint [JLit (LitString "foo")] xVal + joinSimpleExpr = mkef $ JJoin "lbl" "x" (Cfe (mke $ JVal (JLit (LitInt 5))) (JJump "lbl")) xVal + badJumpExpr = mkef $ JJoin "lbl" "y" (Cfe xVal (JJump "badlbl")) yVal diff --git a/test/Spec.hs b/test/Spec.hs index cd4753f..66fcf2d 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,2 +1,9 @@ +import JoinVerifySpec + +import Test.Tasty + main :: IO () -main = putStrLn "Test suite not yet implemented" +main = defaultMain tests + +tests :: TestTree +tests = testGroup "Tests" [joinVerifyTests]