From 9ad70ec345180f764189153f1d472aea73ca7044 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Fri, 22 Feb 2019 23:22:43 -0800 Subject: [PATCH 01/24] Create JoinIR AST --- src/Simpl/JoinIR.hs | 93 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 src/Simpl/JoinIR.hs diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs new file mode 100644 index 0000000..fcb7cc2 --- /dev/null +++ b/src/Simpl/JoinIR.hs @@ -0,0 +1,93 @@ +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE OverloadedStrings #-} + +{-| +Module : Simpl.JoinIR +Description : AST for the JoinIR + +Defines the abstract syntax tree for JoinIR, an IR for SimPL based on the IR +presented in /Compiling without Continuations/ by Luke Maurer, Zena Ariola, Paul +Downen, and Simon Peyton Jones (PLDI '17). +-} +module Simpl.JoinIR where + +import Data.Text (Text) +import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..)) +import Text.Show.Deriving (deriveShow1) +import Data.Functor.Foldable + +type Name = Text + +type Label = Text + +-- | An operation applied to some arguments +data Callable + = CFunc !Name -- ^ Function + | CBinOp !BinaryOp -- ^ Binary operator + | CCast !Numeric -- ^ Numeric cast + | CCtor !Name -- ^ ADT constructor + | CPrint !Name -- ^ Print string (temporary) + deriving (Show) + +-- | A value +data JValue + -- | A variable + = JVar !Name + -- | A literal + | JLit !Literal + deriving (Show, Eq) + +data JBranch a + = BrAdt Name [Name] !a -- ^ Destructure algebraic data type + deriving (Functor, Foldable, Traversable, Show) + + +-- | Represents expressions that must be bound to a join point. +data Joinable a + -- | If expression on the given variable, with a true branch and a false + -- branch + = JIf !JValue !a !a + + -- | Case expression on the given variable + | JCase !JValue ![JBranch a] + deriving (Functor, Foldable, Traversable, Show) + + +-- | The JoinIR expression type. Syntactically, it is in ANF-form with explicit +-- join points. +data JExprF a + -- | A value + = JVal !JValue + + -- | A value binding + | JLet !Name !JValue + + -- | A join point. Consists of a label, the variable representing the joined + -- value, the expression to join, and the next expression. + | JJoin !Label !Name !(Joinable a) !a + + -- | Jump to the given enclosing join point with the given value. + | JJump !Label !JValue + + -- | Apply the callable to the arguments, bind the result to the given name, + -- and continue to the next expression. + | JApp !Name !Callable ![Name] !a + deriving (Functor, Foldable, Traversable, Show) + +$(deriveShow1 ''JBranch) +$(deriveShow1 ''Joinable) +$(deriveShow1 ''JExprF) + +type JExpr = Fix JExprF + +exampleJExpr :: JExpr +exampleJExpr = Fix $ + JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) + where + ifE = + JIf (JLit (LitInt 5)) + (Fix (JJump "label" (JLit (LitInt 10)))) + (Fix (JJump "label" (JLit (LitInt 5)))) From e2448643fed83c8752c2be865b6b05cd851bc91e Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sat, 23 Feb 2019 23:24:23 -0800 Subject: [PATCH 02/24] Add an annotated JoinIR AST that uses extensible records --- package.yaml | 2 ++ src/Simpl/JoinIR.hs | 73 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/package.yaml b/package.yaml index da24a23..3f33e4c 100644 --- a/package.yaml +++ b/package.yaml @@ -45,6 +45,8 @@ dependencies: - bytestring - optparse-applicative - safe-exceptions +- vinyl +- singletons ghc-options: - -Wall diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index fcb7cc2..b9d6470 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -3,6 +3,13 @@ {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE OverloadedStrings #-} +-- Vinyl stuff +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} {-| Module : Simpl.JoinIR @@ -15,9 +22,11 @@ Downen, and Simon Peyton Jones (PLDI '17). module Simpl.JoinIR where import Data.Text (Text) -import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..)) +import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..), Type, TypeF(..)) import Text.Show.Deriving (deriveShow1) import Data.Functor.Foldable +import qualified Data.Vinyl as V +import Data.Singletons.TH (genSingletons) type Name = Text @@ -83,6 +92,53 @@ $(deriveShow1 ''JExprF) type JExpr = Fix JExprF +-- * Annotated [JExpr]s +-- +-- Because it's possible to have many different annotations on a single AST, we +-- define a "single" annotated AST that is annotated with an extensible record +-- type at each node. Thus, we can add annotations by extending the record with +-- more fields. + +-- | Possible annotations on a [JExpr] +data JFields = ExprType deriving (Show) + +genSingletons [ ''JFields ] + +-- | Maps each possible annotation label to a type +type family ElF (f :: JFields) :: * where + ElF 'ExprType = Type + +-- | Wrapper for annotation fields +newtype Attr f = Attr { _unAttr :: ElF f } + +-- | Helper function for create annotation fields +(=::) :: sing f -> ElF f -> Attr f +_ =:: x = Attr x + +-- | Annotations for a typed [JExpr] +type Typed = '[ 'ExprType ] + +-- | Creates a type field whose value is the given type +withType :: Type -> Attr 'ExprType +withType ty = SExprType =:: ty + +-- | A [JExprF] annotated with some data. +data AnnExprF fields a = AnnExprF { annGetAnn :: V.Rec Attr fields, annGetExpr :: JExprF a } + deriving (Show, Functor, Foldable, Traversable) + +type AnnExpr fields = Fix (AnnExprF fields) + +-- | Converts a [JExprF] to an "unannotated" [AnnExprF] +toAnnExprF :: JExprF a -> AnnExprF '[] a +toAnnExprF expr = AnnExprF { annGetAnn = V.RNil, annGetExpr = expr } + +-- | Adds the given annotation to the expression +addField :: Attr f -> AnnExprF flds a -> AnnExprF (f ': flds) a +addField attr expr = expr { annGetAnn = attr V.:& annGetAnn expr } + +exampleAnnotation :: V.Rec Attr '[ 'ExprType ] +exampleAnnotation = (SExprType =:: Fix (TyNumber NumInt)) V.:& V.RNil + exampleJExpr :: JExpr exampleJExpr = Fix $ JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) @@ -91,3 +147,18 @@ exampleJExpr = Fix $ JIf (JLit (LitInt 5)) (Fix (JJump "label" (JLit (LitInt 10)))) (Fix (JJump "label" (JLit (LitInt 5)))) + +exampleTypedJExpr :: AnnExpr '[ 'ExprType ] +exampleTypedJExpr = Fix $ AnnExprF + { annGetAnn = withType tyInt V.:& V.RNil + , annGetExpr = JJoin "label" "myvar" ifE varE } + where + tyInt = Fix (TyNumber NumInt) + varE = Fix $ AnnExprF + { annGetAnn = withType tyInt V.:& V.RNil + , annGetExpr = JVal (JVar "myvar") } + ifE = + JIf (JLit (LitInt 5)) (jmpE 10) (jmpE 5) + jmpE v = Fix $ AnnExprF + { annGetAnn = withType tyInt V.:& V.RNil + , annGetExpr = JJump "label" (JLit (LitInt v)) } From daa628bec3e1663f435ec94c32cbc35ef36dc867 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 24 Feb 2019 00:09:26 -0800 Subject: [PATCH 03/24] Implement pretty printing for JoinIR --- src/Simpl/JoinIR.hs | 51 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index b9d6470..eed6dde 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} -- Vinyl stuff {-# LANGUAGE DataKinds #-} {-# LANGUAGE PolyKinds #-} @@ -22,6 +25,8 @@ Downen, and Simon Peyton Jones (PLDI '17). module Simpl.JoinIR where import Data.Text (Text) +import Data.Text.Prettyprint.Doc (Pretty, pretty, (<>), (<+>)) +import qualified Data.Text.Prettyprint.Doc as PP import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..), Type, TypeF(..)) import Text.Show.Deriving (deriveShow1) import Data.Functor.Foldable @@ -38,7 +43,7 @@ data Callable | CBinOp !BinaryOp -- ^ Binary operator | CCast !Numeric -- ^ Numeric cast | CCtor !Name -- ^ ADT constructor - | CPrint !Name -- ^ Print string (temporary) + | CPrint -- ^ Print string (temporary) deriving (Show) -- | A value @@ -92,6 +97,50 @@ $(deriveShow1 ''JExprF) type JExpr = Fix JExprF +instance Pretty Callable where + pretty = \case + CFunc name -> pretty name + CBinOp op -> pretty op + CCast num -> "cast[" <> pretty num <> "]" + CCtor name -> pretty name + CPrint -> "print" + +instance Pretty JValue where + pretty = \case + JVar n -> pretty n + JLit l -> pretty l + +instance Pretty a => Pretty (JBranch a) where + pretty (BrAdt ctorName varNames expr) = + PP.hang 2 $ PP.hsep (pretty <$> brPart) <> PP.softline <> pretty expr + where brPart = [ctorName] ++ varNames ++ ["=>"] + +instance Pretty a => Pretty (Joinable a) where + pretty = \case + JIf guard trueBr falseBr -> + PP.hang 2 $ "if" <+> pretty guard <+> PP.group ( + "then" <> PP.softline <> pretty trueBr <> PP.softline + <> "else" <> PP.softline <> pretty falseBr) + JCase expr brs -> + PP.hang 2 $ "case" <+> pretty expr <+> "of" <> PP.hardline <> + (PP.vsep $ pretty <$> brs) + +instance Pretty JExpr where + pretty = f . unfix + where + f :: JExprF JExpr -> PP.Doc ann + f = \case + JVal v -> pretty v + JLet n v -> PP.hsep ["let", pretty n, "=", pretty v] + JJoin lbl n joinbl next -> PP.align $ + PP.hang 2 (PP.hsep ["join", pretty lbl, "bind", pretty n, "="] + <> PP.softline <> pretty joinbl) + <> PP.hardline <> "in" <+> pretty next + JJump lbl v -> PP.hsep ["jump", pretty lbl, "with", pretty v] + JApp name clbl args next -> + PP.hsep (["let app", pretty name, "=", pretty clbl] ++ (pretty <$> args)) + <> PP.hardline <> pretty next + -- * Annotated [JExpr]s -- -- Because it's possible to have many different annotations on a single AST, we From 0ddb2bb2f4d0af1866df830d30ab0ed76ab48c1c Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 25 Feb 2019 16:00:28 -0800 Subject: [PATCH 04/24] Use a newer version of vinyl --- stack.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/stack.yaml b/stack.yaml index d65e916..e758009 100644 --- a/stack.yaml +++ b/stack.yaml @@ -43,6 +43,7 @@ packages: extra-deps: - llvm-hs-pure-7.0.0 - llvm-hs-7.0.1 +- vinyl-0.11.0 # Override default flag values for local packages and extra-deps # flags: {} From fd6b30d4f8913959aabb434f6b80a227bce8adb3 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 25 Feb 2019 16:00:37 -0800 Subject: [PATCH 05/24] Add a function to extract types from annotated JoinIR ASTs --- src/Simpl/JoinIR.hs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index eed6dde..a996bc1 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE DeriveFunctor #-} @@ -13,6 +15,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-| Module : Simpl.JoinIR @@ -31,6 +34,7 @@ import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..), Type, TypeF(..)) import Text.Show.Deriving (deriveShow1) import Data.Functor.Foldable import qualified Data.Vinyl as V +import qualified Data.Vinyl.TypeLevel as V import Data.Singletons.TH (genSingletons) type Name = Text @@ -160,6 +164,8 @@ type family ElF (f :: JFields) :: * where -- | Wrapper for annotation fields newtype Attr f = Attr { _unAttr :: ElF f } +deriving instance Show (Attr 'ExprType) + -- | Helper function for create annotation fields (=::) :: sing f -> ElF f -> Attr f _ =:: x = Attr x @@ -173,7 +179,7 @@ withType ty = SExprType =:: ty -- | A [JExprF] annotated with some data. data AnnExprF fields a = AnnExprF { annGetAnn :: V.Rec Attr fields, annGetExpr :: JExprF a } - deriving (Show, Functor, Foldable, Traversable) + deriving (Functor, Foldable, Traversable) type AnnExpr fields = Fix (AnnExprF fields) @@ -185,6 +191,10 @@ toAnnExprF expr = AnnExprF { annGetAnn = V.RNil, annGetExpr = expr } addField :: Attr f -> AnnExprF flds a -> AnnExprF (f ': flds) a addField attr expr = expr { annGetAnn = attr V.:& annGetAnn expr } +-- | Retrieves the type information stored in a typed [AnnExprF] +getType :: V.RElem 'ExprType fields (V.RIndex 'ExprType fields) => AnnExprF fields a -> Type +getType = _unAttr . V.rget @'ExprType . annGetAnn + exampleAnnotation :: V.Rec Attr '[ 'ExprType ] exampleAnnotation = (SExprType =:: Fix (TyNumber NumInt)) V.:& V.RNil From 67e9e5d903add1cc592814bb05fc95a0b35c40b7 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 25 Feb 2019 21:32:59 -0800 Subject: [PATCH 06/24] Add some helper functions for dealing with annotated JoinIR AST --- src/Simpl/JoinIR.hs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index a996bc1..cc90742 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -11,7 +11,6 @@ -- Vinyl stuff {-# LANGUAGE DataKinds #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} @@ -187,6 +186,10 @@ type AnnExpr fields = Fix (AnnExprF fields) toAnnExprF :: JExprF a -> AnnExprF '[] a toAnnExprF expr = AnnExprF { annGetAnn = V.RNil, annGetExpr = expr } +-- | Removes all annotations from an [AnnExpr] +unannotate :: AnnExpr fields -> JExpr +unannotate = cata (Fix . annGetExpr) + -- | Adds the given annotation to the expression addField :: Attr f -> AnnExprF flds a -> AnnExprF (f ': flds) a addField attr expr = expr { annGetAnn = attr V.:& annGetAnn expr } @@ -195,6 +198,18 @@ addField attr expr = expr { annGetAnn = attr V.:& annGetAnn expr } getType :: V.RElem 'ExprType fields (V.RIndex 'ExprType fields) => AnnExprF fields a -> Type getType = _unAttr . V.rget @'ExprType . annGetAnn +-- * Misc + +-- | Helper for inspecting an [AnnExpr] +newtype PrettyJExprF a = PrettyJExprF (String, JExprF a) deriving (Show) +type PrettyJExpr = Fix PrettyJExprF +$(deriveShow1 ''PrettyJExprF) + +-- | Converts an [AnnExpr] into a [PrettyJExpr] so that it can be shown. +prettyAnnExpr :: Show (V.Rec Attr fields) => AnnExpr fields -> PrettyJExpr +prettyAnnExpr = cata $ \expr -> + Fix (PrettyJExprF (show (annGetAnn expr), annGetExpr expr)) + exampleAnnotation :: V.Rec Attr '[ 'ExprType ] exampleAnnotation = (SExprType =:: Fix (TyNumber NumInt)) V.:& V.RNil From aa24be09968aef0f1f4b937cab24573c5ca23e55 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 27 Feb 2019 20:06:48 -0800 Subject: [PATCH 07/24] Fix let expressions in JoinIR to include next expression --- src/Simpl/JoinIR.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index cc90742..567d7b5 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -80,7 +80,7 @@ data JExprF a = JVal !JValue -- | A value binding - | JLet !Name !JValue + | JLet !Name !JValue !a -- | A join point. Consists of a label, the variable representing the joined -- value, the expression to join, and the next expression. @@ -134,7 +134,7 @@ instance Pretty JExpr where f :: JExprF JExpr -> PP.Doc ann f = \case JVal v -> pretty v - JLet n v -> PP.hsep ["let", pretty n, "=", pretty v] + JLet n v next -> PP.hsep ["let", pretty n, "=", pretty v, "in"] <> PP.softline <> pretty next JJoin lbl n joinbl next -> PP.align $ PP.hang 2 (PP.hsep ["join", pretty lbl, "bind", pretty n, "="] <> PP.softline <> pretty joinbl) From f9c12805efa83ec4c6a0d313d509291426829347 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 27 Feb 2019 20:40:45 -0800 Subject: [PATCH 08/24] [skip-ci] Non-working attempt at JoinIR codegen Codegen for JoinIR currently can't be implemented because annotations are not attached for variable bindings, only for next expressions. The IR needs to be modified so it is a linear list of statements rather than expressions in linked-list firm. --- src/Simpl/Ast.hs | 5 + src/Simpl/Backend/CodegenJoin.hs | 459 +++++++++++++++++++++++++++++++ 2 files changed, 464 insertions(+) create mode 100644 src/Simpl/Backend/CodegenJoin.hs diff --git a/src/Simpl/Ast.hs b/src/Simpl/Ast.hs index 4a32651..9aa5f1c 100644 --- a/src/Simpl/Ast.hs +++ b/src/Simpl/Ast.hs @@ -243,6 +243,11 @@ isComplexType = \case TyFun _ _ -> True _ -> False +functionTypeResult :: Type -> Type +functionTypeResult (Fix ty) = case ty of + TyFun _ res -> functionTypeResult res + _ -> Fix ty + instance Pretty Type where pretty = para go where diff --git a/src/Simpl/Backend/CodegenJoin.hs b/src/Simpl/Backend/CodegenJoin.hs new file mode 100644 index 0000000..e9c3a8c --- /dev/null +++ b/src/Simpl/Backend/CodegenJoin.hs @@ -0,0 +1,459 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE FlexibleContexts #-} +{-# OPTIONS_GHC -Wno-incomplete-record-updates #-} -- Suppress LLVM sum type of records AST warnings +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RecursiveDo #-} +module Simpl.Backend.CodegenJoin where + +import Control.Applicative ((<|>)) +import Control.Monad (forM, forM_, liftM2) +import Control.Monad.Reader +import Control.Monad.State +import Data.Functor.Foldable (para, unfix, Fix(..)) +import Data.Text.Prettyprint.Doc (pretty) +import Data.Char (ord) +import Data.Maybe (fromJust) +import Data.Text (Text) +import Data.Text.Encoding (encodeUtf8) +import Data.ByteString () +import qualified Data.ByteString as BS +import Data.Map (Map) +import Data.Functor.Identity +import Data.List.NonEmpty (NonEmpty(..)) + +import Data.String (fromString) +import qualified Data.List.NonEmpty as NE +import qualified Data.Text as Text +import qualified Data.Map as Map +import qualified LLVM.AST as LLVM +import qualified LLVM.AST.Linkage as LLVM +import qualified LLVM.AST.FloatingPointPredicate as LLVMFP +import qualified LLVM.AST.IntegerPredicate as LLVMIP +import qualified LLVM.AST.Constant as LLVMC +import qualified LLVM.AST.Global as LLVMG +import qualified LLVM.AST.Type as LLVM +import qualified LLVM.IRBuilder.Module as LLVMIR +import qualified LLVM.IRBuilder.Monad as LLVMIR +import qualified LLVM.IRBuilder.Instruction as LLVMIR +import qualified LLVM.IRBuilder.Constant as LLVMIR + +import Simpl.Ast (BinaryOp(..), Type, TypeF(..), Constructor(..), Literal(..), Numeric(..), Decl(..)) +import Simpl.CompilerOptions +import Simpl.SymbolTable +import Simpl.Typing (literalType) +import Simpl.Backend.Runtime () +import Simpl.JoinIR +import qualified Simpl.Backend.Runtime as RT + +data CodegenTable = + MkCodegenTable { tableVars :: Map Text (TypeF Type, LLVM.Operand) -- ^ Pointer to variables + , tableCtors :: Map Text (LLVM.Name, LLVM.Name, Int) -- ^ Data type name, ctor name, index + , tableAdts :: Map Text (LLVM.Name, Type, [Constructor]) + , tableFuns :: Map Text LLVM.Operand + , tableJoinValues :: Map Text (LLVM.Name, [(LLVM.Operand, LLVM.Name)]) + , tablePrintf :: LLVM.Operand + , tableOptions :: CompilerOpts } + deriving (Show) + +-- | An empty codegen table. This will cause a crash if codegen is run when not +-- initialized! +emptyCodegenTable :: CodegenTable +emptyCodegenTable = + MkCodegenTable { tableVars = Map.empty + , tableCtors = Map.empty + , tableAdts = Map.empty + , tableFuns = Map.empty + , tableJoinValues = Map.empty + , tablePrintf = error "printf not set" + , tableOptions = defaultCompilerOpts } + +newtype CodegenT m a = + CodegenT { unCodegen :: StateT CodegenTable m a } + deriving ( Functor + , Applicative + , MonadState CodegenTable + , MonadFix) + +type Codegen = CodegenT Identity + +deriving instance Monad m => Monad (CodegenT m) + +instance MonadTrans CodegenT where + lift = CodegenT . lift + +localCodegenTable :: MonadState CodegenTable m + => (CodegenTable -> CodegenTable) + -> m a + -> m a +localCodegenTable f ma = do + oldTable <- get + res <- ma + modify $ \t -> t + { tableVars = tableVars oldTable + , tableFuns = tableFuns oldTable + , tableJoinValues = tableJoinValues oldTable + } + pure res + +lookupName :: MonadState CodegenTable m + => Text + -> m (Maybe LLVM.Operand) +lookupName name = + gets $ \t -> + (snd <$> Map.lookup name (tableVars t)) <|> Map.lookup name (tableFuns t) + +-- | Looks up the type of the [JValue]. Note: assumes that if the [JValue] is a +-- [JVar], then the variable is actually in the table. +lookupValueType :: MonadState CodegenTable m + => JValue + -> m (TypeF Type) +lookupValueType = \case + JVar name -> gets (fst . fromJust . Map.lookup name . tableVars) + JLit lit -> pure $ literalType lit + +bindVariable :: MonadState CodegenTable m + => Text + -> TypeF Type + -> LLVM.Operand + -> m () +bindVariable name ty oper = do + modify (\t -> t { tableVars = Map.insert name (ty, oper) (tableVars t) }) + +initCodegenTable :: CompilerOpts -> SymbolTable (AnnExpr '[ 'ExprType ]) -> Codegen () +initCodegenTable options symTab = do + let adts = flip Map.mapWithKey (symTabAdts symTab) $ \name (ty, ctors) -> (llvmName name, ty, ctors) + ctors <- forM (Map.elems adts) $ \(adtName, _, ctors) -> + forM ([0..] `zip` ctors) $ \(i, Ctor ctorName _) -> + pure (ctorName, (adtName, llvmName ctorName, i)) + modify $ \t -> t { tableAdts = adts + , tableCtors = Map.fromList (join ctors) + , tableOptions = options } + +llvmByte :: Integer -> LLVMC.Constant +llvmByte = LLVMC.Int 8 + +llvmString :: String -> (Int, LLVMC.Constant) +llvmString s = + let s' = s ++ "\0" + arr = LLVMC.Array LLVM.i8 (fmap (llvmByte . toInteger . ord) s') + in (length s', arr) + +staticString :: LLVMIR.MonadModuleBuilder m => LLVM.Name -> String -> m (Int, LLVM.Operand) +staticString name str = do + let (messageLen, bytes) = llvmString str + messageTy = LLVM.ArrayType { LLVM.nArrayElements = toEnum messageLen + , LLVM.elementType = LLVM.i8 } + LLVMIR.emitDefn (LLVM.GlobalDefinition $ LLVM.globalVariableDefaults + { LLVMG.name = name + , LLVMG.isConstant = True + , LLVMG.unnamedAddr = Just LLVMG.GlobalAddr + , LLVMG.type' = messageTy + , LLVMG.linkage = LLVM.Private + , LLVMG.initializer = Just bytes + , LLVMG.alignment = 1 + }) + let msgPtrTy = LLVM.ptr messageTy + msgPtr = LLVM.ConstantOperand $ + LLVMC.GetElementPtr + { LLVMC.inBounds = True + , LLVMC.address = LLVMC.GlobalReference msgPtrTy name + , LLVMC.indices = [LLVMC.Int 32 0, LLVMC.Int 32 0] } + pure (messageLen, msgPtr) + +llvmName :: Text -> LLVM.Name +llvmName = LLVM.mkName . Text.unpack + +literalCodegen :: LLVMIR.MonadIRBuilder m => Literal -> m LLVM.Operand +literalCodegen = \case + LitInt x -> LLVMIR.int64 (fromIntegral x) + LitDouble x -> LLVMIR.double x + LitBool b -> LLVMIR.bit (if b then 1 else 0) + LitString t -> do + -- TODO: Store literal strings in global memory + let byteS = encodeUtf8 t + let len = toInteger (BS.length byteS) + lenOper <- LLVMIR.int64 len + let byteData = LLVMC.Int 8 . toInteger <$> BS.unpack byteS + byteDataOper <- LLVMIR.array byteData + byteDataPtr <- LLVMIR.alloca (LLVM.ArrayType (fromInteger len) LLVM.i8) Nothing 0 + _ <- LLVMIR.store byteDataPtr 0 byteDataOper + byteDataPtr' <- LLVMIR.bitcast byteDataPtr (LLVM.ptr LLVM.i8) + bytePtr <- LLVMIR.call RT.mallocRef [(lenOper, [])] + _ <- LLVMIR.call RT.memcpyRef [(bytePtr, []), (byteDataPtr', []), (lenOper, [])] + LLVMIR.call RT.stringNewRef [(lenOper, []), (bytePtr, [])] + +-- | Generates code for an arbitrary binary operation +binaryOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) + => m LLVM.Operand + -> m LLVM.Operand + -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) + -> m LLVM.Operand +binaryOpCodegen x y op = do + x' <- x + y' <- y + op x' y' + +-- | Generates code for a numeric binary operation +numBinopCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) + => m LLVM.Operand + -> m LLVM.Operand + -> Type + -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) -- ^ Float operation + -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) -- ^ Integer operation + -> m LLVM.Operand +numBinopCodegen x y ty opDouble opInt = + case unfix ty of + TyNumber numTy -> + if numTy == NumInt then binaryOpCodegen x y opInt + else binaryOpCodegen x y opDouble + _ -> error "Invariant violated" + +-- | Generates code for BinOp +binOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) + => BinaryOp -- ^ Operation + -> Type -- ^ Type of the inputs + -> m LLVM.Operand + -> m LLVM.Operand + -> m LLVM.Operand +binOpCodegen op ty x y = + let (floatInstr, intInstr) = case op of + Add -> (LLVMIR.fadd, LLVMIR.add) + Sub -> (LLVMIR.fsub, LLVMIR.sub) + Mul -> (LLVMIR.fmul, LLVMIR.mul) + Div -> (LLVMIR.fdiv, LLVMIR.sdiv) + Lt -> (LLVMIR.fcmp LLVMFP.OLT, LLVMIR.icmp LLVMIP.SLT) + Lte -> (LLVMIR.fcmp LLVMFP.OLE, LLVMIR.icmp LLVMIP.SLE) + Equal -> (LLVMIR.fcmp LLVMFP.OEQ, LLVMIR.icmp LLVMIP.EQ) + in numBinopCodegen x y ty floatInstr intInstr + +jvalueCodegen + :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) + => JValue + -> m LLVM.Operand +jvalueCodegen = \case + JVar name -> gets (snd . fromJust . Map.lookup name . tableVars) + JLit l -> literalCodegen l + +-- | Looks up the type of the result of the [Callable] +lookupCallableType :: MonadState CodegenTable m + => Callable + -> m (TypeF Type) +lookupCallableType = _ + + +joinableCodegen + :: Joinable (AnnExpr '[ 'ExprType]) + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand +joinableCodegen = \case + JIf guard trueBr falseBr -> do + LLVMIR.ensureBlock + cond <- jvalueCodegen guard + trueLabel <- LLVMIR.freshName "if_then" + falseLabel <- LLVMIR.freshName "if_else" + LLVMIR.condBr cond trueLabel falseLabel + LLVMIR.emitBlockStart trueLabel + _ <- jexprCodegen trueBr + LLVMIR.emitBlockStart falseLabel + jexprCodegen falseBr + JCase val branches -> _ + +callableCodegen + :: Callable + -> [JValue] + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand +callableCodegen callable args = case callable of + CFunc name -> do + fn <- fromJust <$> lookupName name + ops <- traverse jvalueCodegen args + LLVMIR.call fn [(x, []) | x <- ops] + CBinOp op -> case args of + [x, y] -> do + binOpCodegen op _ (jvalueCodegen x) (jvalueCodegen y) + _ -> error $ "callableCodegen: expected 2 args to CBinOp, got " ++ show (length args) + CCast targetNum -> case args of + [x] -> do + ty <- lookupValueType x + case ty of + TyNumber srcNum -> jvalueCodegen x >>= castOpCodegen srcNum targetNum + _ -> error $ "callableCodegen: expected CCast argument to be of numeric type" + _ -> error $ "callableCodegen: expected 1 args to CCast, got " ++ show (length args) + CCtor name -> _ + CPrint -> case args of + [val] -> do + let fmtStr = encodeUtf8 (fromString "Print: %s\n\0") + let fmtStrLen = toInteger (BS.length fmtStr) + let fmtStrData = LLVMC.Int 8 . toInteger <$> BS.unpack fmtStr + fmtStrOper <- LLVMIR.array fmtStrData + fmtStrPtr <- LLVMIR.alloca (LLVM.ArrayType (fromInteger fmtStrLen) LLVM.i8) Nothing 0 + _ <- LLVMIR.store fmtStrPtr 0 fmtStrOper + printf <- gets tablePrintf + str <- jvalueCodegen val + exprCstring <- LLVMIR.call RT.stringCstringRef [(str, [])] + fmtStrPtr' <- LLVMIR.bitcast fmtStrPtr (LLVM.ptr LLVM.i8) + _ <- LLVMIR.call printf [(fmtStrPtr', []), (exprCstring, [])] + LLVMIR.int64 0 + _ -> error $ "callableCodegen: expected 1 args to CPrint, got " ++ show (length args) + +jexprCodegen + :: AnnExpr '[ 'ExprType] + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand +jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix + where + go :: TypeF Type + -> JExprF (AnnExpr '[ 'ExprType ]) + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand + go exprTy = \case + JVal v -> jvalueCodegen v + JLet name val next -> do + oper <- jvalueCodegen val + ty <- lookupValueType val + _ <- bindVariable name ty oper + jexprCodegen next + JJoin lbl varName joinable next -> do + llvmLabel <- LLVMIR.freshName (fromString (Text.unpack lbl)) + let addJoinEntry = \t -> + t { tableJoinValues = Map.insert lbl (llvmLabel, []) (tableJoinValues t) } + _ <- localCodegenTable addJoinEntry (joinableCodegen joinable) + (_, joinValues) <- gets (fromJust . Map.lookup lbl . tableJoinValues) + LLVMIR.emitBlockStart llvmLabel + op <- LLVMIR.phi joinValues + modify (\t -> t { tableVars = Map.insert varName (exprTy, op) (tableVars t) }) + jexprCodegen next + JJump lbl val -> do + v <- jvalueCodegen val + jvals <- gets tableJoinValues + block <- LLVMIR.currentBlock + let f = (\(n, jvs) -> Just $ (n, (v, block) : jvs)) + let updJvals = Map.update f lbl jvals + modify (\t -> t { tableJoinValues = updJvals }) + pure v + JApp varName callable args next -> do + oper <- callableCodegen callable args + ty <- lookupCallableType callable + bindVariable varName ty oper + jexprCodegen next + +-- | Generates code for numeric casting +castOpCodegen :: LLVMIR.MonadIRBuilder m => Numeric -> Numeric -> LLVM.Operand -> m LLVM.Operand +castOpCodegen source castTo oper = case (source, castTo) of + (NumDouble, NumDouble) -> pure oper + (NumInt, NumInt) -> pure oper + (NumDouble, NumInt) -> LLVMIR.fptosi oper LLVM.i64 + (NumInt, NumDouble) -> LLVMIR.sitofp oper LLVM.double + (NumUnknown, _) -> castOpCodegen NumDouble castTo oper + (_, NumUnknown) -> error "castOpCodegen: attempting to cast to NumUnknown" + +ctorToLLVM :: Constructor -> [LLVM.Type] +ctorToLLVM (Ctor _ args) = typeToLLVM <$> args + +typeToLLVM :: Type -> LLVM.Type +typeToLLVM = go . unfix + where + go = \case + TyNumber n -> case n of + NumDouble -> LLVM.double + NumInt -> LLVM.i64 + NumUnknown -> LLVM.double + TyBool -> LLVM.i1 + TyString -> LLVM.ptr RT.stringType + TyAdt name -> LLVM.NamedTypeReference (llvmName name) + TyFun args res -> + LLVM.ptr $ LLVM.FunctionType + { LLVM.resultType = typeToLLVM res + , LLVM.argumentTypes = typeToLLVM <$> args + , LLVM.isVarArg = False + } + +adtToLLVM :: Text + -> [Constructor] + -> LLVMIR.ModuleBuilderT Codegen () +adtToLLVM adtName ctors = do + let adtType = LLVM.StructureType + { LLVM.isPacked = True + , LLVM.elementTypes = [LLVM.i32, LLVM.ptr LLVM.i8] } + adtLLVMName <- gets ((\(n,_,_) -> n) . fromJust . Map.lookup adtName . tableAdts) + -- TODO: Store returned type in symbol table to avoid error-prone type + -- reconstruction + _ <- LLVMIR.typedef adtLLVMName (Just adtType) + forM_ ctors $ \(Ctor ctorName args) -> do + let ctorType = LLVM.StructureType + { LLVM.isPacked = True + , LLVM.elementTypes = typeToLLVM <$> args } + ctorLLVMName <- gets ((\(_,n,_) -> n) . fromJust . Map.lookup ctorName . tableCtors) + LLVMIR.typedef ctorLLVMName (Just ctorType) + +-- | Emits the given function definition +funToLLVM :: Text + -> [(Text, Type)] + -> Type + -> AnnExpr '[ 'ExprType ] + -> LLVMIR.ModuleBuilderT Codegen LLVM.Operand +funToLLVM name params ty body = + let name' = if name == "main" then "__simpl_main" else name + ftype = typeToLLVM ty + fname = llvmName name' + fparams = [(typeToLLVM t, fromString (Text.unpack n)) | (n, t) <- params] + in mdo foper <- LLVMIR.function fname fparams ftype $ \args -> do + LLVMIR.ensureBlock + endLabel <- LLVMIR.freshName "function_end" + -- We need to make sure we don't pollute other function scopes + oldVars <- gets tableVars + let updVars t = tableVars t `Map.union` Map.fromList ((fst <$> params) `zip` args) + modify (\t -> t { tableFuns = Map.insert name foper (tableFuns t) + , tableVars = updVars t }) + retval <- jexprCodegen body + -- Restore old scope + modify (\t -> t { tableVars = oldVars }) + LLVMIR.ret retval + pure foper + +-- | Generate code for the entire module +moduleCodegen :: [Decl JExpr] + -> SymbolTable (AnnExpr '[ 'ExprType ]) + -> LLVMIR.ModuleBuilderT Codegen () +moduleCodegen decls symTab = mdo + -- Message is "Hi\n" (with null terminator) + (_, msg) <- staticString ".message" "Hello world!\n" + (_, resultFmt) <- staticString ".resultformat" "Result: %i\n" + let srcCode = unlines $ show . pretty <$> decls + (_, exprSrc) <- staticString ".sourcecode" $ "Source code: " ++ srcCode ++ "\n" + RT.emitRuntimeDecls + modify (\t -> t { tablePrintf = RT.printfRef }) + forM_ (Map.toList . symTabAdts $ symTab) $ \(name, (_, ctors)) -> + adtToLLVM name ctors + -- Insert function operands into symbol table before emitting so order of + -- definition doesn't matter. This works because the codegen monad is lazy. + modify (\t -> t { tableFuns = tableFuns t `Map.union` Map.fromList funOpers }) + funOpers <- forM (Map.toList . symTabFuns $ symTab) $ \(name, (params, ty, body)) -> + (name, ) <$> funToLLVM name params ty body + + _ <- LLVMIR.function "main" [] LLVM.i64 $ \_ -> do + diagnosticsEnabled <- gets (enableDiagnostics . tableOptions) + when diagnosticsEnabled $ LLVMIR.call RT.printfRef [(msg, [])] >> pure () + let mainTy = LLVM.ptr (LLVM.FunctionType LLVM.i64 [] False) + let mainName = LLVM.mkName "__simpl_main" + let mainRef = LLVM.ConstantOperand (LLVMC.GlobalReference mainTy mainName) + when diagnosticsEnabled $ LLVMIR.call RT.printfRef [(exprSrc, [])] >> pure () + exprResult <- LLVMIR.call mainRef [] + when diagnosticsEnabled $ LLVMIR.call RT.printfRef [(resultFmt, []), (exprResult, [])] >> pure () + retcode <- LLVMIR.int64 1 + LLVMIR.ret retcode + pure () + +runCodegen :: CompilerOpts -> [Decl JExpr] -> SymbolTable (AnnExpr '[ 'ExprType]) -> LLVM.Module +runCodegen opts decls symTab + = runIdentity + . flip evalStateT emptyCodegenTable + . unCodegen + . LLVMIR.buildModuleT "simpl.ll" + $ lift (initCodegenTable opts symTab) >> moduleCodegen decls symTab From d43b0f409db1c8cece20ccc13b0240b2be2fc39c Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 5 Mar 2019 10:48:30 -0800 Subject: [PATCH 09/24] [skip ci] Finish JoinIR codegen Not tested yet --- src/Simpl/Backend/CodegenJoin.hs | 115 +++++++++++++++++++++++-------- src/Simpl/JoinIR.hs | 10 ++- 2 files changed, 95 insertions(+), 30 deletions(-) diff --git a/src/Simpl/Backend/CodegenJoin.hs b/src/Simpl/Backend/CodegenJoin.hs index e9c3a8c..58b0cbf 100644 --- a/src/Simpl/Backend/CodegenJoin.hs +++ b/src/Simpl/Backend/CodegenJoin.hs @@ -16,10 +16,10 @@ module Simpl.Backend.CodegenJoin where import Control.Applicative ((<|>)) -import Control.Monad (forM, forM_, liftM2) +import Control.Monad (forM, forM_) import Control.Monad.Reader import Control.Monad.State -import Data.Functor.Foldable (para, unfix, Fix(..)) +import Data.Functor.Foldable (unfix, Fix(..)) import Data.Text.Prettyprint.Doc (pretty) import Data.Char (ord) import Data.Maybe (fromJust) @@ -29,10 +29,8 @@ import Data.ByteString () import qualified Data.ByteString as BS import Data.Map (Map) import Data.Functor.Identity -import Data.List.NonEmpty (NonEmpty(..)) import Data.String (fromString) -import qualified Data.List.NonEmpty as NE import qualified Data.Text as Text import qualified Data.Map as Map import qualified LLVM.AST as LLVM @@ -97,6 +95,7 @@ localCodegenTable :: MonadState CodegenTable m -> m a localCodegenTable f ma = do oldTable <- get + put (f oldTable) res <- ma modify $ \t -> t { tableVars = tableVars oldTable @@ -193,7 +192,8 @@ literalCodegen = \case LLVMIR.call RT.stringNewRef [(lenOper, []), (bytePtr, [])] -- | Generates code for an arbitrary binary operation -binaryOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) +binaryOpCodegen + :: LLVMIR.MonadIRBuilder m => m LLVM.Operand -> m LLVM.Operand -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) @@ -204,7 +204,7 @@ binaryOpCodegen x y op = do op x' y' -- | Generates code for a numeric binary operation -numBinopCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) +numBinopCodegen :: LLVMIR.MonadIRBuilder m => m LLVM.Operand -> m LLVM.Operand -> Type @@ -219,7 +219,7 @@ numBinopCodegen x y ty opDouble opInt = _ -> error "Invariant violated" -- | Generates code for BinOp -binOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) +binOpCodegen :: LLVMIR.MonadIRBuilder m => BinaryOp -- ^ Operation -> Type -- ^ Type of the inputs -> m LLVM.Operand @@ -244,28 +244,60 @@ jvalueCodegen = \case JVar name -> gets (snd . fromJust . Map.lookup name . tableVars) JLit l -> literalCodegen l --- | Looks up the type of the result of the [Callable] -lookupCallableType :: MonadState CodegenTable m - => Callable - -> m (TypeF Type) -lookupCallableType = _ - - joinableCodegen :: Joinable (AnnExpr '[ 'ExprType]) - -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) () joinableCodegen = \case - JIf guard trueBr falseBr -> do + JIf guardVal trueBr falseBr -> do LLVMIR.ensureBlock - cond <- jvalueCodegen guard + cond <- jvalueCodegen guardVal trueLabel <- LLVMIR.freshName "if_then" falseLabel <- LLVMIR.freshName "if_else" LLVMIR.condBr cond trueLabel falseLabel LLVMIR.emitBlockStart trueLabel _ <- jexprCodegen trueBr LLVMIR.emitBlockStart falseLabel - jexprCodegen falseBr - JCase val branches -> _ + _ <- jexprCodegen falseBr + pure () + JCase val branches -> do + LLVMIR.ensureBlock + defLabel <- LLVMIR.freshName "case_default" + valOper <- jvalueCodegen val + allCaseLabels <- forM branches $ \case + BrAdt name _ _ -> + let labelName = "case_" <> fromString (Text.unpack name) in + (name, ) <$> LLVMIR.freshName labelName + -- Assume the symbol table and type information is correct + dataName <- (\case { TyAdt n -> n; _ -> error "" }) <$> lookupValueType val + ctors <- gets ((\(_,_,cs) -> cs) . fromJust . Map.lookup dataName . tableAdts) + let ctorNames = ctorGetName <$> ctors + let usedLabelTriples = filter (\(_, (n, _)) -> n `elem` ctorNames) $ [0..] `zip` allCaseLabels + let jumpTable = [(LLVMC.Int 32 i, l) | (i, (_, l)) <- usedLabelTriples] + tag <- LLVMIR.extractValue valOper [0] + dataPtr <- LLVMIR.extractValue valOper [1] + LLVMIR.switch tag defLabel jumpTable + forM_ (usedLabelTriples `zip` branches) $ \((_, (ctorName, label)), br) -> do + let expr = branchGetExpr br + (_, ctorLLVMName, index) <- gets (fromJust . Map.lookup ctorName . tableCtors) + let Ctor _ argTys = ctors !! index + let bindingPairs = branchGetBindings br `zip` (typeToLLVM <$> argTys) + LLVMIR.emitBlockStart label + ctorPtr <- LLVMIR.bitcast dataPtr (LLVM.ptr (LLVM.NamedTypeReference ctorLLVMName)) + ctorPtrOffset <- LLVMIR.int32 0 + bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, llvmTy)) -> do + ctorPtrIndex <- LLVMIR.int32 i + -- Need to bitcast the ptr type because we need a concrete type. We + -- also need to load the data immediately because of how variables + -- are implemented. + v <- LLVMIR.gep ctorPtr [ctorPtrOffset, ctorPtrIndex] + >>= flip LLVMIR.bitcast (LLVM.ptr llvmTy) + >>= flip LLVMIR.load 0 + ty <- lookupValueType (JVar n) + pure (n, (ty, v)) + let updateTable t = t { tableVars = Map.union (tableVars t) (Map.fromList bindings) } + localCodegenTable updateTable (jexprCodegen expr) + LLVMIR.emitBlockStart defLabel + LLVMIR.unreachable callableCodegen :: Callable @@ -278,16 +310,44 @@ callableCodegen callable args = case callable of LLVMIR.call fn [(x, []) | x <- ops] CBinOp op -> case args of [x, y] -> do - binOpCodegen op _ (jvalueCodegen x) (jvalueCodegen y) + xTy <- lookupValueType x + binOpCodegen op (Fix xTy) (jvalueCodegen x) (jvalueCodegen y) _ -> error $ "callableCodegen: expected 2 args to CBinOp, got " ++ show (length args) CCast targetNum -> case args of [x] -> do ty <- lookupValueType x case ty of TyNumber srcNum -> jvalueCodegen x >>= castOpCodegen srcNum targetNum - _ -> error $ "callableCodegen: expected CCast argument to be of numeric type" + _ -> error $ "callableCodegen: expected CCast argument to be of numeric type, got " ++ show ty _ -> error $ "callableCodegen: expected 1 args to CCast, got " ++ show (length args) - CCtor name -> _ + CCtor name -> do + -- Assume constructor exists, since typechecker should verify it anyways + (dataTy, ctorName, ctorIndex) <- gets (fromJust . Map.lookup name . tableCtors) + ctorIndex' <- LLVMIR.int32 (fromIntegral ctorIndex) + let tagStruct1 = LLVM.ConstantOperand $ LLVMC.Undef (LLVM.NamedTypeReference dataTy) + -- Tag + tagStruct2 <- LLVMIR.insertValue tagStruct1 ctorIndex' [0] + -- Data pointer + let ctorTy = LLVM.NamedTypeReference ctorName + let nullptr = LLVM.ConstantOperand (LLVMC.Null (LLVM.ptr ctorTy)) + -- Use offsets to calculate struct size + ctorStructSize <- flip LLVMIR.ptrtoint LLVM.i64 + =<< LLVMIR.gep nullptr + =<< pure <$> LLVMIR.int32 0 + -- Allocate memory for constructor. + -- For now, use "leak memory" as an implementation strategy for deallocation. + ctorStructPtr <- LLVMIR.call RT.mallocRef [(ctorStructSize, [])] >>= + flip LLVMIR.bitcast (LLVM.ptr ctorTy) + values <- traverse jvalueCodegen args + indices <- traverse LLVMIR.int32 [0..fromIntegral (length values - 1)] + ptrOffset <- LLVMIR.int32 0 + forM_ (indices `zip` values) $ \(index, v) -> do + valuePtr <- LLVMIR.gep ctorStructPtr [ptrOffset, index] + LLVMIR.store valuePtr 0 v + pure () + dataPtr <- LLVMIR.bitcast ctorStructPtr (LLVM.ptr LLVM.i8) + fullyInit <- LLVMIR.insertValue tagStruct2 dataPtr [1] + pure fullyInit CPrint -> case args of [val] -> do let fmtStr = encodeUtf8 (fromString "Print: %s\n\0") @@ -316,8 +376,7 @@ jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix JVal v -> jvalueCodegen v JLet name val next -> do oper <- jvalueCodegen val - ty <- lookupValueType val - _ <- bindVariable name ty oper + _ <- bindVariable name exprTy oper jexprCodegen next JJoin lbl varName joinable next -> do llvmLabel <- LLVMIR.freshName (fromString (Text.unpack lbl)) @@ -333,14 +392,13 @@ jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix v <- jvalueCodegen val jvals <- gets tableJoinValues block <- LLVMIR.currentBlock - let f = (\(n, jvs) -> Just $ (n, (v, block) : jvs)) + let f = (\(n, jvs) -> Just (n, (v, block) : jvs)) let updJvals = Map.update f lbl jvals modify (\t -> t { tableJoinValues = updJvals }) pure v JApp varName callable args next -> do oper <- callableCodegen callable args - ty <- lookupCallableType callable - bindVariable varName ty oper + bindVariable varName exprTy oper jexprCodegen next -- | Generates code for numeric casting @@ -405,10 +463,9 @@ funToLLVM name params ty body = fparams = [(typeToLLVM t, fromString (Text.unpack n)) | (n, t) <- params] in mdo foper <- LLVMIR.function fname fparams ftype $ \args -> do LLVMIR.ensureBlock - endLabel <- LLVMIR.freshName "function_end" -- We need to make sure we don't pollute other function scopes oldVars <- gets tableVars - let updVars t = tableVars t `Map.union` Map.fromList ((fst <$> params) `zip` args) + let updVars t = tableVars t `Map.union` Map.fromList [(n, (unfix vty, op)) | ((n, vty), op) <- params `zip` args] modify (\t -> t { tableFuns = Map.insert name foper (tableFuns t) , tableVars = updVars t }) retval <- jexprCodegen body diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index 567d7b5..60559ff 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -61,6 +61,14 @@ data JBranch a = BrAdt Name [Name] !a -- ^ Destructure algebraic data type deriving (Functor, Foldable, Traversable, Show) +branchGetExpr :: JBranch a -> a +branchGetExpr = \case + BrAdt _ _ e -> e + +branchGetBindings :: JBranch a -> [Text] +branchGetBindings = \case + BrAdt _ vars _ -> vars + -- | Represents expressions that must be bound to a join point. data Joinable a @@ -91,7 +99,7 @@ data JExprF a -- | Apply the callable to the arguments, bind the result to the given name, -- and continue to the next expression. - | JApp !Name !Callable ![Name] !a + | JApp !Name !Callable ![JValue] !a deriving (Functor, Foldable, Traversable, Show) $(deriveShow1 ''JBranch) From bb9235dc83b5bb510ca1a92d8205d3d5ed2f2b0c Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sat, 23 Mar 2019 16:34:47 -0700 Subject: [PATCH 10/24] [skip ci] WIP start on AST to JoinIR transformation --- src/Simpl/AstToJoinIR.hs | 49 ++++++++++++++++++++++++++++++++++++++++ src/Simpl/JoinIR.hs | 8 +++++++ 2 files changed, 57 insertions(+) create mode 100644 src/Simpl/AstToJoinIR.hs diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs new file mode 100644 index 0000000..21ffb04 --- /dev/null +++ b/src/Simpl/AstToJoinIR.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +module Simpl.AstToJoinIR + ( anfTransform + ) where + +import Data.Functor.Foldable (Fix(..), para, unfix) + +import Simpl.Ast (Type, TypeF(..)) +import qualified Simpl.Ast as A +import qualified Simpl.JoinIR as J + +makeJexpr :: Type + -> J.JExprF (J.AnnExpr '[ 'J.ExprType]) + -> J.AnnExpr '[ 'J.ExprType] +makeJexpr ty = Fix . J.addField (J.withType ty) . J.toAnnExprF + +astType :: A.AnnExpr Type -> Type +astType = A.annGetAnn . unfix + +anfTransform :: A.AnnExpr Type + -> (J.JValue -> J.AnnExpr '[ 'J.ExprType]) + -> J.AnnExpr '[ 'J.ExprType] +anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of + A.Lit lit -> cont (J.JLit lit) + A.Var name -> cont (J.JVar name) + A.Let name bindExpr next -> + anfTransform bindExpr $ \bindVal -> + makeJexpr (A.annGetAnn (unfix bindExpr)) $ + J.JLet name bindVal (anfTransform next cont) + A.BinOp op left right -> + anfTransform left $ \jleft -> + anfTransform right $ \jright -> + let name = "TODO" in + makeJexpr ty (J.JApp name (J.CBinOp op) [jleft, jright] (cont (J.JVar name))) + A.If guard trueBr falseBr -> + anfTransform guard $ \jguard -> + let lbl = "TODO" + name = "TODO" in + makeJexpr ty $ + let trueBr' = anfTransform trueBr (makeJexpr (astType trueBr) . J.JJump lbl) + falseBr' = anfTransform falseBr (makeJexpr (astType falseBr) . J.JJump lbl) in + J.JJoin lbl name (J.JIf jguard trueBr' falseBr') (cont (J.JVar name)) + A.Cons ctorName args -> + let argVals = [] -- TODO + varName = "TODO" in + makeJexpr ty $ + J.JApp varName (J.CCtor ctorName) argVals (cont (J.JVar varName)) diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index 60559ff..4f2cdba 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -108,6 +108,14 @@ $(deriveShow1 ''JExprF) type JExpr = Fix JExprF +jexprGetVal :: JExprF a -> JValue +jexprGetVal = \case + JVal v -> v + JLet n _ _ -> JVar n + JJoin _ n _ _ -> JVar n + JJump _ v -> v + JApp n _ _ _ -> JVar n + instance Pretty Callable where pretty = \case CFunc name -> pretty name From 603418f58a956ba3b9021ed55306b418554f2c27 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 24 Mar 2019 19:36:01 -0700 Subject: [PATCH 11/24] [skip ci] WIP finish most of AST to Join IR transforms --- src/Simpl/AstToJoinIR.hs | 43 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 21ffb04..f0b6d53 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -6,6 +6,7 @@ module Simpl.AstToJoinIR ) where import Data.Functor.Foldable (Fix(..), para, unfix) +import Data.Text (Text) import Simpl.Ast (Type, TypeF(..)) import qualified Simpl.Ast as A @@ -19,6 +20,13 @@ makeJexpr ty = Fix . J.addField (J.withType ty) . J.toAnnExprF astType :: A.AnnExpr Type -> Type astType = A.annGetAnn . unfix +transformBranch :: Text -- ^ Return label + -> A.Branch (A.AnnExpr Type) -- ^ Branches + -> J.JBranch (J.AnnExpr '[ 'J.ExprType]) +transformBranch retLabel (A.BrAdt adtName argNames expr) = + let jexpr = anfTransform expr (makeJexpr (astType expr) . J.JJump retLabel) in + J.BrAdt adtName argNames jexpr + anfTransform :: A.AnnExpr Type -> (J.JValue -> J.AnnExpr '[ 'J.ExprType]) -> J.AnnExpr '[ 'J.ExprType] @@ -42,8 +50,39 @@ anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of let trueBr' = anfTransform trueBr (makeJexpr (astType trueBr) . J.JJump lbl) falseBr' = anfTransform falseBr (makeJexpr (astType falseBr) . J.JJump lbl) in J.JJoin lbl name (J.JIf jguard trueBr' falseBr') (cont (J.JVar name)) + A.Case branches expr -> + anfTransform expr $ \jexpr -> + let lbl = "TODO" + name = "TODO" + jbranches = [transformBranch lbl b | b <- branches] in + makeJexpr ty (J.JJoin lbl name (J.JCase jexpr jbranches) (cont (J.JVar name))) A.Cons ctorName args -> - let argVals = [] -- TODO - varName = "TODO" in + let varName = "TODO" in + collectArgs args [] $ \argVals -> makeJexpr ty $ J.JApp varName (J.CCtor ctorName) argVals (cont (J.JVar varName)) + A.App funcName args -> + let varName = "TODO" in + collectArgs args [] $ \argVals -> + makeJexpr ty $ + J.JApp varName (J.CFunc funcName) argVals (cont (J.JVar varName)) + A.Cast expr numTy -> + let varName = "TODO" in + anfTransform expr $ \jexpr -> + makeJexpr ty $ + J.JApp varName (J.CCast numTy) [jexpr] (cont (J.JVar varName)) + A.Print expr -> + let varName = "TODO" in + anfTransform expr $ \jexpr -> + makeJexpr ty $ + J.JApp varName J.CPrint [jexpr] (cont (J.JVar varName)) + A.FunRef name -> _ + +-- | Normalize each expression in sequential order, and then run the +-- continuation with the expression values. +collectArgs :: [A.AnnExpr Type] + -> [J.JValue] + -> ([J.JValue] -> J.AnnExpr '[ 'J.ExprType]) + -> J.AnnExpr '[ 'J.ExprType] +collectArgs [] vals mcont = mcont (reverse vals) +collectArgs (x:xs) vals mcont = anfTransform x $ \v -> collectArgs xs (v:vals) mcont From bc719004da7e3fbb37f8f818eb6bcc047a233db8 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 25 Mar 2019 18:32:55 -0700 Subject: [PATCH 12/24] Complete AST to Join IR transformation (not tested yet) --- package.yaml | 1 + src/Simpl/AstToJoinIR.hs | 178 +++++++++++++++++++++++-------- src/Simpl/Backend/CodegenJoin.hs | 1 + src/Simpl/JoinIR.hs | 2 + src/Simpl/SymbolTable.hs | 10 ++ stack.yaml | 1 + 6 files changed, 146 insertions(+), 47 deletions(-) diff --git a/package.yaml b/package.yaml index 3f33e4c..3688d3a 100644 --- a/package.yaml +++ b/package.yaml @@ -41,6 +41,7 @@ dependencies: - prettyprinter - unification-fd - mtl +- monad-supply - containers - bytestring - optparse-applicative diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index f0b6d53..bdbd53c 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -1,17 +1,74 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} module Simpl.AstToJoinIR - ( anfTransform + ( astToJoinIR ) where -import Data.Functor.Foldable (Fix(..), para, unfix) +import Control.Monad.Supply +import Control.Monad.Reader hiding (guard) +import Data.Functor.Foldable (Fix(..), unfix) +import Data.Functor.Identity import Data.Text (Text) +import Data.String (fromString) -import Simpl.Ast (Type, TypeF(..)) +import Simpl.Ast (Type) +import Simpl.SymbolTable import qualified Simpl.Ast as A import qualified Simpl.JoinIR as J +-- * Public API + +astToJoinIR :: SymbolTable (A.AnnExpr Type) -> SymbolTable (J.AnnExpr '[ 'J.ExprType]) +astToJoinIR = runTransform transformTable + +-- * Transformation Monad + +newtype TransformT m a = + TransformT { unTransform :: ReaderT (SymbolTable (A.AnnExpr Type)) (SupplyT Int m) a } + deriving ( Functor + , Applicative + , Monad + , MonadReader (SymbolTable (A.AnnExpr Type)) + , MonadFreshVar) + +type Transform = TransformT Identity + +type MonadFreshVar = MonadSupply Int + +varSupply :: [Int] +varSupply = [0..] + +runTransformT :: Monad m => TransformT m a -> SymbolTable (A.AnnExpr Type) -> m a +runTransformT m table + = fmap fst + . flip runSupplyT varSupply + . flip runReaderT table + . unTransform + $ m + +runTransform :: Transform a -> SymbolTable (A.AnnExpr Type) -> a +runTransform m table = runIdentity (runTransformT m table) + +freshName :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => Text -- ^ Prefix + -> m Text +freshName prefix = do + next <- (prefix <>) . fromString . show <$> supply + asks (symTabLookupVar next) >>= \case + Nothing -> pure next + Just _ -> freshName prefix + +freshVar, freshLabel :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) => m Text +freshVar = freshName "var" +freshLabel = freshName "join" + +-- * Private utility functions + makeJexpr :: Type -> J.JExprF (J.AnnExpr '[ 'J.ExprType]) -> J.AnnExpr '[ 'J.ExprType] @@ -20,69 +77,96 @@ makeJexpr ty = Fix . J.addField (J.withType ty) . J.toAnnExprF astType :: A.AnnExpr Type -> Type astType = A.annGetAnn . unfix -transformBranch :: Text -- ^ Return label +-- * ANF Transformation + +-- | Perform ANF transformation on the given symbol table +transformTable :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => m (SymbolTable (J.AnnExpr '[ 'J.ExprType])) +transformTable = do + table <- ask + symTabTraverseExprs (\(args, ty, expr) -> (args, ty, transformExpr expr)) table + +-- | Perform ANF transformation on the given expression +transformExpr :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => A.AnnExpr Type + -> m (J.AnnExpr '[ 'J.ExprType]) +transformExpr expr = anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) + +-- | Perform ANF transformation on the branch, returning to the jump label at +-- the end. +transformBranch :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => Text -- ^ Return label -> A.Branch (A.AnnExpr Type) -- ^ Branches - -> J.JBranch (J.AnnExpr '[ 'J.ExprType]) + -> m (J.JBranch (J.AnnExpr '[ 'J.ExprType])) transformBranch retLabel (A.BrAdt adtName argNames expr) = - let jexpr = anfTransform expr (makeJexpr (astType expr) . J.JJump retLabel) in - J.BrAdt adtName argNames jexpr + let jexprM = anfTransform expr (pure . makeJexpr (astType expr) . J.JJump retLabel) in + J.BrAdt adtName argNames <$> jexprM + -anfTransform :: A.AnnExpr Type - -> (J.JValue -> J.AnnExpr '[ 'J.ExprType]) - -> J.AnnExpr '[ 'J.ExprType] +-- | Main ANF transformation logic +anfTransform :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => A.AnnExpr Type + -> (J.JValue -> m (J.AnnExpr '[ 'J.ExprType])) + -> m (J.AnnExpr '[ 'J.ExprType]) anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of A.Lit lit -> cont (J.JLit lit) A.Var name -> cont (J.JVar name) A.Let name bindExpr next -> anfTransform bindExpr $ \bindVal -> - makeJexpr (A.annGetAnn (unfix bindExpr)) $ - J.JLet name bindVal (anfTransform next cont) + makeJexpr (A.annGetAnn (unfix bindExpr)) . J.JLet name bindVal <$> + local (symTabInsertVar name ty) (anfTransform next cont) A.BinOp op left right -> anfTransform left $ \jleft -> - anfTransform right $ \jright -> - let name = "TODO" in - makeJexpr ty (J.JApp name (J.CBinOp op) [jleft, jright] (cont (J.JVar name))) + anfTransform right $ \jright -> do + name <- freshVar + makeJexpr ty . J.JApp name (J.CBinOp op) [jleft, jright] <$> + local (symTabInsertVar name ty) (cont (J.JVar name)) A.If guard trueBr falseBr -> - anfTransform guard $ \jguard -> - let lbl = "TODO" - name = "TODO" in - makeJexpr ty $ - let trueBr' = anfTransform trueBr (makeJexpr (astType trueBr) . J.JJump lbl) - falseBr' = anfTransform falseBr (makeJexpr (astType falseBr) . J.JJump lbl) in - J.JJoin lbl name (J.JIf jguard trueBr' falseBr') (cont (J.JVar name)) + anfTransform guard $ \jguard -> do + lbl <- freshLabel + trueBr' <- anfTransform trueBr (pure . makeJexpr (astType trueBr) . J.JJump lbl) + falseBr' <- anfTransform falseBr (pure . makeJexpr (astType falseBr) . J.JJump lbl) + name <- freshVar + makeJexpr ty . J.JJoin lbl name (J.JIf jguard trueBr' falseBr') <$> + local (symTabInsertVar name ty) (cont (J.JVar name)) A.Case branches expr -> - anfTransform expr $ \jexpr -> - let lbl = "TODO" - name = "TODO" - jbranches = [transformBranch lbl b | b <- branches] in - makeJexpr ty (J.JJoin lbl name (J.JCase jexpr jbranches) (cont (J.JVar name))) + anfTransform expr $ \jexpr -> do + lbl <- freshLabel + jbranches <- traverse (transformBranch lbl) branches + name <- freshVar + makeJexpr ty . J.JJoin lbl name (J.JCase jexpr jbranches) <$> + local (symTabInsertVar name ty) (cont (J.JVar name)) A.Cons ctorName args -> - let varName = "TODO" in - collectArgs args [] $ \argVals -> - makeJexpr ty $ - J.JApp varName (J.CCtor ctorName) argVals (cont (J.JVar varName)) + collectArgs args [] $ \argVals -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) A.App funcName args -> - let varName = "TODO" in - collectArgs args [] $ \argVals -> - makeJexpr ty $ - J.JApp varName (J.CFunc funcName) argVals (cont (J.JVar varName)) + collectArgs args [] $ \argVals -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CFunc funcName) argVals <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) A.Cast expr numTy -> - let varName = "TODO" in - anfTransform expr $ \jexpr -> - makeJexpr ty $ - J.JApp varName (J.CCast numTy) [jexpr] (cont (J.JVar varName)) + anfTransform expr $ \jexpr -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr] <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) A.Print expr -> - let varName = "TODO" in - anfTransform expr $ \jexpr -> - makeJexpr ty $ - J.JApp varName J.CPrint [jexpr] (cont (J.JVar varName)) - A.FunRef name -> _ + anfTransform expr $ \jexpr -> do + varName <- freshVar + makeJexpr ty . J.JApp varName J.CPrint [jexpr] <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) + A.FunRef name -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CFunRef name) [] <$> + local (symTabInsertVar varName ty) (cont (J.JVar varName)) -- | Normalize each expression in sequential order, and then run the -- continuation with the expression values. -collectArgs :: [A.AnnExpr Type] +collectArgs :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) + => [A.AnnExpr Type] -> [J.JValue] - -> ([J.JValue] -> J.AnnExpr '[ 'J.ExprType]) - -> J.AnnExpr '[ 'J.ExprType] + -> ([J.JValue] -> m (J.AnnExpr '[ 'J.ExprType])) + -> m (J.AnnExpr '[ 'J.ExprType]) collectArgs [] vals mcont = mcont (reverse vals) collectArgs (x:xs) vals mcont = anfTransform x $ \v -> collectArgs xs (v:vals) mcont diff --git a/src/Simpl/Backend/CodegenJoin.hs b/src/Simpl/Backend/CodegenJoin.hs index 58b0cbf..edeba11 100644 --- a/src/Simpl/Backend/CodegenJoin.hs +++ b/src/Simpl/Backend/CodegenJoin.hs @@ -363,6 +363,7 @@ callableCodegen callable args = case callable of _ <- LLVMIR.call printf [(fmtStrPtr', []), (exprCstring, [])] LLVMIR.int64 0 _ -> error $ "callableCodegen: expected 1 args to CPrint, got " ++ show (length args) + CFunRef name -> gets (fromJust . Map.lookup name . tableFuns) jexprCodegen :: AnnExpr '[ 'ExprType] diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index 4f2cdba..4ae9d11 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -47,6 +47,7 @@ data Callable | CCast !Numeric -- ^ Numeric cast | CCtor !Name -- ^ ADT constructor | CPrint -- ^ Print string (temporary) + | CFunRef !Name -- ^ Static function reference deriving (Show) -- | A value @@ -123,6 +124,7 @@ instance Pretty Callable where CCast num -> "cast[" <> pretty num <> "]" CCtor name -> pretty name CPrint -> "print" + CFunRef name -> "funref[" <> pretty name <> "]" instance Pretty JValue where pretty = \case diff --git a/src/Simpl/SymbolTable.hs b/src/Simpl/SymbolTable.hs index e765f1b..d1061bb 100644 --- a/src/Simpl/SymbolTable.hs +++ b/src/Simpl/SymbolTable.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE TupleSections #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} @@ -43,6 +44,15 @@ symTabMapExprs :: (([(Text, Type)], Type, e) -> ([(Text, Type)], Type, e')) -- ^ -> SymbolTable e' symTabMapExprs f t = t { symTabFuns = Map.map f (symTabFuns t) } +symTabTraverseExprs + :: Monad m + => (([(Text, Type)], Type, e) -> ([(Text, Type)], Type, m e')) -- ^ Map over functions + -> SymbolTable e + -> m (SymbolTable e') +symTabTraverseExprs f t = do + upd <- traverse ((\(args, ty, me) -> (args, ty,) <$> me) . f) (symTabFuns t) + pure $ t { symTabFuns = upd } + -- | Searches for the given constructor, returning the name of the ADT, the -- constructor, and the index of the constructor. symTabLookupCtor :: Text -> SymbolTable e -> Maybe (Type, Constructor, Int) diff --git a/stack.yaml b/stack.yaml index e758009..e731951 100644 --- a/stack.yaml +++ b/stack.yaml @@ -44,6 +44,7 @@ extra-deps: - llvm-hs-pure-7.0.0 - llvm-hs-7.0.1 - vinyl-0.11.0 +- monad-supply-0.7 # Override default flag values for local packages and extra-deps # flags: {} From 1a42cd9f816057256f1740f2f3185a4c0daced9f Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 25 Mar 2019 19:21:44 -0700 Subject: [PATCH 13/24] Fix JoinIR code generation variable binding bug --- src/Simpl/Backend/CodegenJoin.hs | 21 +++++++++++---------- src/Simpl/Compiler.hs | 13 +++++++++++-- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/Simpl/Backend/CodegenJoin.hs b/src/Simpl/Backend/CodegenJoin.hs index edeba11..fc90bd4 100644 --- a/src/Simpl/Backend/CodegenJoin.hs +++ b/src/Simpl/Backend/CodegenJoin.hs @@ -20,7 +20,6 @@ import Control.Monad (forM, forM_) import Control.Monad.Reader import Control.Monad.State import Data.Functor.Foldable (unfix, Fix(..)) -import Data.Text.Prettyprint.Doc (pretty) import Data.Char (ord) import Data.Maybe (fromJust) import Data.Text (Text) @@ -45,7 +44,7 @@ import qualified LLVM.IRBuilder.Monad as LLVMIR import qualified LLVM.IRBuilder.Instruction as LLVMIR import qualified LLVM.IRBuilder.Constant as LLVMIR -import Simpl.Ast (BinaryOp(..), Type, TypeF(..), Constructor(..), Literal(..), Numeric(..), Decl(..)) +import Simpl.Ast (BinaryOp(..), Type, TypeF(..), Constructor(..), Literal(..), Numeric(..)) import Simpl.CompilerOptions import Simpl.SymbolTable import Simpl.Typing (literalType) @@ -100,7 +99,6 @@ localCodegenTable f ma = do modify $ \t -> t { tableVars = tableVars oldTable , tableFuns = tableFuns oldTable - , tableJoinValues = tableJoinValues oldTable } pure res @@ -383,11 +381,13 @@ jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix llvmLabel <- LLVMIR.freshName (fromString (Text.unpack lbl)) let addJoinEntry = \t -> t { tableJoinValues = Map.insert lbl (llvmLabel, []) (tableJoinValues t) } + oldJoinEntries <- gets tableJoinValues _ <- localCodegenTable addJoinEntry (joinableCodegen joinable) (_, joinValues) <- gets (fromJust . Map.lookup lbl . tableJoinValues) + modify (\t -> t { tableJoinValues = oldJoinEntries }) LLVMIR.emitBlockStart llvmLabel op <- LLVMIR.phi joinValues - modify (\t -> t { tableVars = Map.insert varName (exprTy, op) (tableVars t) }) + bindVariable varName exprTy op jexprCodegen next JJump lbl val -> do v <- jvalueCodegen val @@ -396,6 +396,8 @@ jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix let f = (\(n, jvs) -> Just (n, (v, block) : jvs)) let updJvals = Map.update f lbl jvals modify (\t -> t { tableJoinValues = updJvals }) + llvmLabel <- gets (fst . fromJust . Map.lookup lbl . tableJoinValues) + LLVMIR.br llvmLabel pure v JApp varName callable args next -> do oper <- callableCodegen callable args @@ -476,14 +478,13 @@ funToLLVM name params ty body = pure foper -- | Generate code for the entire module -moduleCodegen :: [Decl JExpr] +moduleCodegen :: String -> SymbolTable (AnnExpr '[ 'ExprType ]) -> LLVMIR.ModuleBuilderT Codegen () -moduleCodegen decls symTab = mdo +moduleCodegen srcCode symTab = mdo -- Message is "Hi\n" (with null terminator) (_, msg) <- staticString ".message" "Hello world!\n" (_, resultFmt) <- staticString ".resultformat" "Result: %i\n" - let srcCode = unlines $ show . pretty <$> decls (_, exprSrc) <- staticString ".sourcecode" $ "Source code: " ++ srcCode ++ "\n" RT.emitRuntimeDecls modify (\t -> t { tablePrintf = RT.printfRef }) @@ -508,10 +509,10 @@ moduleCodegen decls symTab = mdo LLVMIR.ret retcode pure () -runCodegen :: CompilerOpts -> [Decl JExpr] -> SymbolTable (AnnExpr '[ 'ExprType]) -> LLVM.Module -runCodegen opts decls symTab +runCodegen :: CompilerOpts -> String -> SymbolTable (AnnExpr '[ 'ExprType]) -> LLVM.Module +runCodegen opts srcCode symTab = runIdentity . flip evalStateT emptyCodegenTable . unCodegen . LLVMIR.buildModuleT "simpl.ll" - $ lift (initCodegenTable opts symTab) >> moduleCodegen decls symTab + $ lift (initCodegenTable opts symTab) >> moduleCodegen srcCode symTab diff --git a/src/Simpl/Compiler.hs b/src/Simpl/Compiler.hs index b635063..bb54a36 100644 --- a/src/Simpl/Compiler.hs +++ b/src/Simpl/Compiler.hs @@ -4,16 +4,19 @@ module Simpl.Compiler where import Control.Monad.Except import Control.Monad.State +import Data.Text.Prettyprint.Doc (pretty) import qualified LLVM.AST as LLVM import qualified LLVM.Module as LLVMM import LLVM.Context import Simpl.Ast -import Simpl.Backend.Codegen (runCodegen) +import Simpl.AstToJoinIR +import Simpl.Backend.CodegenJoin (runCodegen) import Simpl.CompilerOptions import Simpl.SymbolTable import Simpl.Typing (TypeError, runTypecheck, checkType, withExtraVars) +import Simpl.JoinIR (unannotate) import Paths_simpl_lang -- | Main error type, aggregating all error types. @@ -53,4 +56,10 @@ fullCompilerPipeline options srcFile@(SourceFile _name decls) = let newSTTypecheck = sequence $ symTabMapExprs tycheckFuns symTable typedSymTable <- MkCompilerMonad . lift . withExceptT ErrTypecheck $ liftEither (runTypecheck symTable newSTTypecheck) - pure $ runCodegen options decls typedSymTable + let jSymTable = astToJoinIR typedSymTable + -- TODO: Add compiler flag to dump JoinIR + -- liftIO $ forM_ (symTabFuns jSymTable) $ \(args, ty, expr) -> do + -- print (pretty args <> pretty " :: " <> pretty ty) + -- print (pretty (unannotate expr)) + let srcCode = unlines [show (pretty d) | d <- decls] + pure $ runCodegen options srcCode jSymTable From cafb6a69baa8ff3a0f26be54932bae353b1d9277 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 25 Mar 2019 19:22:29 -0700 Subject: [PATCH 14/24] Remove code generation for Simpl AST --- src/Simpl/Backend/Codegen.hs | 521 ----------------------------------- 1 file changed, 521 deletions(-) delete mode 100644 src/Simpl/Backend/Codegen.hs diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs deleted file mode 100644 index 66d8e4c..0000000 --- a/src/Simpl/Backend/Codegen.hs +++ /dev/null @@ -1,521 +0,0 @@ -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE FlexibleContexts #-} -{-# OPTIONS_GHC -Wno-incomplete-record-updates #-} -- Suppress LLVM sum type of records AST warnings -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE RecursiveDo #-} -module Simpl.Backend.Codegen where - -import Control.Applicative ((<|>)) -import Control.Monad (forM, forM_) -import Control.Monad.Reader -import Control.Monad.State -import Data.Functor.Foldable (para, unfix) -import Data.Text.Prettyprint.Doc (pretty) -import Data.Char (ord) -import Data.Maybe (fromJust) -import Data.Text (Text) -import Data.Text.Encoding (encodeUtf8) -import Data.ByteString () -import qualified Data.ByteString as BS -import Data.Map (Map) -import Data.Functor.Identity -import Data.List.NonEmpty (NonEmpty(..)) - -import Data.String (fromString) -import qualified Data.List.NonEmpty as NE -import qualified Data.Text as Text -import qualified Data.Map as Map -import qualified LLVM.AST as LLVM -import qualified LLVM.AST.Linkage as LLVM -import qualified LLVM.AST.FloatingPointPredicate as LLVMFP -import qualified LLVM.AST.IntegerPredicate as LLVMIP -import qualified LLVM.AST.Constant as LLVMC -import qualified LLVM.AST.Global as LLVMG -import qualified LLVM.AST.Type as LLVM -import qualified LLVM.IRBuilder.Module as LLVMIR -import qualified LLVM.IRBuilder.Monad as LLVMIR -import qualified LLVM.IRBuilder.Instruction as LLVMIR -import qualified LLVM.IRBuilder.Constant as LLVMIR - -import Simpl.Ast -import Simpl.CompilerOptions -import Simpl.SymbolTable -import Simpl.Typing (TypedExpr) -import Simpl.Backend.Runtime () -import qualified Simpl.Backend.Runtime as RT - -data CodegenTable = - MkCodegenTable { tableVars :: Map Text LLVM.Operand -- ^ Pointer to variables - , tableCtors :: Map Text (LLVM.Name, LLVM.Name, Int) -- ^ Data type name, ctor name, index - , tableAdts :: Map Text (LLVM.Name, Type, [Constructor]) - , tableFuns :: Map Text LLVM.Operand - , tableCurrentJoin :: LLVM.Name - , tablePrintf :: LLVM.Operand - , tableOptions :: CompilerOpts } - deriving (Show) - --- | An empty codegen table. This will cause a crash if codegen is run when not --- initialized! -emptyCodegenTable :: CodegenTable -emptyCodegenTable = - MkCodegenTable { tableVars = Map.empty - , tableCtors = Map.empty - , tableAdts = Map.empty - , tableFuns = Map.empty - , tableCurrentJoin = LLVM.mkName "__default_join_point" - , tablePrintf = error "printf not set" - , tableOptions = defaultCompilerOpts } - -newtype CodegenT m a = - CodegenT { unCodegen :: StateT CodegenTable m a } - deriving ( Functor - , Applicative - , MonadState CodegenTable - , MonadFix) - -type Codegen = CodegenT Identity - -deriving instance Monad m => Monad (CodegenT m) - -instance MonadTrans CodegenT where - lift = CodegenT . lift - -instance Monad m => MonadReader CodegenTable (CodegenT m) where - ask = local id get - local f action = do - curTable <- get - modify f - result <- action - put curTable - pure result - -initCodegenTable :: CompilerOpts -> SymbolTable TypedExpr -> Codegen () -initCodegenTable options symTab = do - let adts = flip Map.mapWithKey (symTabAdts symTab) $ \name (ty, ctors) -> (llvmName name, ty, ctors) - ctors <- forM (Map.elems adts) $ \(adtName, _, ctors) -> - forM ([0..] `zip` ctors) $ \(i, Ctor ctorName _) -> - pure (ctorName, (adtName, llvmName ctorName, i)) - modify $ \t -> t { tableAdts = adts - , tableCtors = Map.fromList (join ctors) - , tableOptions = options } - -data Result = Values (NonEmpty LLVM.Operand) | Branching (NonEmpty (LLVM.Operand, LLVM.Name)) - deriving (Show, Eq) - -instance Ord Result where - compare (Values v1) (Values v2) = compare v1 v2 - compare (Values _) (Branching _) = GT - compare (Branching _) (Values _) = LT - compare (Branching br1) (Branching br2) = compare br1 br2 - -resultHasJump :: Result -> Bool -resultHasJump (Values _) = False -resultHasJump (Branching _) = True - -joinPoint :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m (NonEmpty Result) - -> m (NonEmpty LLVM.Operand) -joinPoint resultsM = do - currentJp <- gets tableCurrentJoin - newJp <- LLVMIR.freshName "join_point" - modify (\t -> t { tableCurrentJoin = newJp }) - results <- resultsM - -- modify (\t -> t { tableCurrentJoin = newJp }) - modify (\t -> t { tableCurrentJoin = currentJp }) - let hasJump = any resultHasJump results - when hasJump $ - LLVMIR.emitBlockStart newJp - let go (Values v) = pure v - go (Branching brs) = - if length brs == 1 - then pure (fst <$> brs) - else (:| []) <$> LLVMIR.phi (NE.toList brs) - join <$> traverse go (NE.sort results) - -joinPoint1 :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m Result - -> m LLVM.Operand -joinPoint1 = fmap NE.head . joinPoint . fmap (:| []) - -resultValue :: LLVM.Operand -> Result -resultValue v = Values (v :| []) - -resultBranching :: NonEmpty (LLVM.Operand, LLVM.Name) -> Result -resultBranching = Branching - -resultEnsureBranch :: LLVMIR.MonadIRBuilder m => Result -> m Result -resultEnsureBranch = \case - Values vs -> LLVMIR.currentBlock >>= \cb -> pure $ Branching ((, cb) <$> vs) - b@Branching {} -> pure b - -resultCombine :: LLVMIR.MonadIRBuilder m => Result -> Result -> m Result -resultCombine (Values v1) (Values v2) = pure $ Values (v1 <> v2) -resultCombine (Values v1) (Branching br2) = - LLVMIR.currentBlock >>= \cb -> pure $ Branching (((, cb) <$> v1) <> br2) -resultCombine (Branching br1) (Values v2) = - LLVMIR.currentBlock >>= \cb -> pure $ Branching (br1 <> ((, cb) <$> v2)) -resultCombine (Branching br1) (Branching br2) = pure $ Branching (br1 <> br2) - -llvmByte :: Integer -> LLVMC.Constant -llvmByte = LLVMC.Int 8 - -llvmString :: String -> (Int, LLVMC.Constant) -llvmString s = - let s' = s ++ "\0" - arr = LLVMC.Array LLVM.i8 (fmap (llvmByte . toInteger . ord) s') - in (length s', arr) - -staticString :: LLVMIR.MonadModuleBuilder m => LLVM.Name -> String -> m (Int, LLVM.Operand) -staticString name str = do - let (messageLen, bytes) = llvmString str - messageTy = LLVM.ArrayType { LLVM.nArrayElements = toEnum messageLen - , LLVM.elementType = LLVM.i8 } - LLVMIR.emitDefn (LLVM.GlobalDefinition $ LLVM.globalVariableDefaults - { LLVMG.name = name - , LLVMG.isConstant = True - , LLVMG.unnamedAddr = Just LLVMG.GlobalAddr - , LLVMG.type' = messageTy - , LLVMG.linkage = LLVM.Private - , LLVMG.initializer = Just bytes - , LLVMG.alignment = 1 - }) - let msgPtrTy = LLVM.ptr messageTy - msgPtr = LLVM.ConstantOperand $ - LLVMC.GetElementPtr - { LLVMC.inBounds = True - , LLVMC.address = LLVMC.GlobalReference msgPtrTy name - , LLVMC.indices = [LLVMC.Int 32 0, LLVMC.Int 32 0] } - pure (messageLen, msgPtr) - -llvmName :: Text -> LLVM.Name -llvmName = LLVM.mkName . Text.unpack - -literalCodegen :: LLVMIR.MonadIRBuilder m => Literal -> m Result -literalCodegen = \case - LitInt x -> resultValue <$> LLVMIR.int64 (fromIntegral x) - LitDouble x -> resultValue <$> LLVMIR.double x - LitBool b -> resultValue <$> LLVMIR.bit (if b then 1 else 0) - LitString t -> do - -- TODO: Store literal strings in global memory - let byteS = encodeUtf8 t - let len = toInteger (BS.length byteS) - lenOper <- LLVMIR.int64 len - let byteData = LLVMC.Int 8 . toInteger <$> BS.unpack byteS - byteDataOper <- LLVMIR.array byteData - byteDataPtr <- LLVMIR.alloca (LLVM.ArrayType (fromInteger len) LLVM.i8) Nothing 0 - _ <- LLVMIR.store byteDataPtr 0 byteDataOper - byteDataPtr' <- LLVMIR.bitcast byteDataPtr (LLVM.ptr LLVM.i8) - bytePtr <- LLVMIR.call RT.mallocRef [(lenOper, [])] - _ <- LLVMIR.call RT.memcpyRef [(bytePtr, []), (byteDataPtr', []), (lenOper, [])] - str <- LLVMIR.call RT.stringNewRef [(lenOper, []), (bytePtr, [])] - pure $ resultValue str - --- | Generates code for an arbitrary binary operation -binaryOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m Result - -> m Result - -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) - -> m Result -binaryOpCodegen x y op = do - x' <- joinPoint1 x - y' <- joinPoint1 y - res <- op x' y' - pure $ resultValue res - --- | Generates code for a numeric binary operation -numBinopCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => m Result - -> m Result - -> Type - -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) -- ^ Float operation - -> (LLVM.Operand -> LLVM.Operand -> m LLVM.Operand) -- ^ Integer operation - -> m Result -numBinopCodegen x y ty opDouble opInt = - case unfix ty of - TyNumber numTy -> - if numTy == NumInt then binaryOpCodegen x y opInt - else binaryOpCodegen x y opDouble - _ -> error "Invariant violated" - --- | Generates code for BinOp -binOpCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) - => BinaryOp -- ^ Operation - -> Type -- ^ Type of the inputs - -> m Result - -> m Result - -> m Result -binOpCodegen op ty x y = - let (floatInstr, intInstr) = case op of - Add -> (LLVMIR.fadd, LLVMIR.add) - Sub -> (LLVMIR.fsub, LLVMIR.sub) - Mul -> (LLVMIR.fmul, LLVMIR.mul) - Div -> (LLVMIR.fdiv, LLVMIR.sdiv) - Lt -> ((LLVMIR.fcmp LLVMFP.OLT), (LLVMIR.icmp LLVMIP.SLT)) - Lte -> ((LLVMIR.fcmp LLVMFP.OLE), (LLVMIR.icmp LLVMIP.SLE)) - Equal -> ((LLVMIR.fcmp LLVMFP.OEQ), (LLVMIR.icmp LLVMIP.EQ)) - in numBinopCodegen x y ty floatInstr intInstr - -exprCodegen :: TypedExpr -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) Result -exprCodegen = para (go . annGetExpr) - where - go :: ExprF (TypedExpr, LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) Result) - -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) Result - go = \case - Lit l -> literalCodegen l - BinOp op (ex, x) (_, y) -> binOpCodegen op (annGetAnn (unfix ex)) x y - If (_, condInstr) (_, t1Instr) (_, t2Instr) -> do - LLVMIR.ensureBlock - cond <- joinPoint1 condInstr - trueLabel <- LLVMIR.freshName "if_then" - falseLabel <- LLVMIR.freshName "if_else" - LLVMIR.condBr cond trueLabel falseLabel - LLVMIR.emitBlockStart trueLabel - t1Res <- t1Instr >>= resultEnsureBranch - jp <- gets tableCurrentJoin - LLVMIR.br jp - LLVMIR.emitBlockStart falseLabel - t2Res <- t2Instr >>= resultEnsureBranch - LLVMIR.br jp - resultCombine t1Res t2Res - Cons name argPairs -> do - let args = snd <$> argPairs - -- Assume constructor exists, since typechecker should verify it anyways - (dataTy, ctorName, ctorIndex) <- asks (fromJust . Map.lookup name . tableCtors) - ctorIndex' <- LLVMIR.int32 (fromIntegral ctorIndex) - let tagStruct1 = LLVM.ConstantOperand $ LLVMC.Undef (LLVM.NamedTypeReference dataTy) - -- Tag - tagStruct2 <- LLVMIR.insertValue tagStruct1 ctorIndex' [0] - -- Data pointer - let ctorTy = LLVM.NamedTypeReference ctorName - let nullptr = LLVM.ConstantOperand (LLVMC.Null (LLVM.ptr ctorTy)) - -- Use offsets to calculate struct size - ctorStructSize <- flip LLVMIR.ptrtoint LLVM.i64 - =<< LLVMIR.gep nullptr - =<< pure <$> LLVMIR.int32 0 - -- Allocate memory for constructor. - -- For now, use "leak memory" as an implementation strategy for deallocation. - ctorStructPtr <- LLVMIR.call RT.mallocRef [(ctorStructSize, [])] >>= - flip LLVMIR.bitcast (LLVM.ptr ctorTy) - if null args - then pure () - else do - values <- joinPoint (NE.fromList <$> sequence args) - indices <- traverse LLVMIR.int32 [0..fromIntegral (length values - 1)] - ptrOffset <- LLVMIR.int32 0 - forM_ (indices `zip` NE.toList values) $ \(index, v) -> do - valuePtr <- LLVMIR.gep ctorStructPtr [ptrOffset, index] - LLVMIR.store valuePtr 0 v - pure () - dataPtr <- LLVMIR.bitcast ctorStructPtr (LLVM.ptr LLVM.i8) - fullyInit <- LLVMIR.insertValue tagStruct2 dataPtr [1] - pure $ resultValue fullyInit - Case branches (valExpr, valM) -> do - LLVMIR.ensureBlock - defLabel <- LLVMIR.freshName "case_default" - endLabel <- gets tableCurrentJoin - val <- joinPoint1 valM - allCaseLabels <- forM branches $ \case - BrAdt name _ _ -> - let labelName = "case_" <> fromString (Text.unpack name) in - (name, ) <$> LLVMIR.freshName labelName - -- Assume the symbol table and type information is correct - let dataName = fromJust $ - case unfix . annGetAnn . unfix $ valExpr of { TyAdt n -> Just n; _ -> Nothing } - ctors <- asks ((\(_,_,cs) -> cs) . fromJust . Map.lookup dataName . tableAdts) - let ctorNames = ctorGetName <$> ctors - let usedLabelTriples = filter (\(_, (n, _)) -> n `elem` ctorNames) $ [0..] `zip` allCaseLabels - let jumpTable = [(LLVMC.Int 32 i, l) | (i, (_, l)) <- usedLabelTriples] - tag <- LLVMIR.extractValue val [0] - dataPtr <- LLVMIR.extractValue val [1] - LLVMIR.switch tag defLabel jumpTable - resVals <- forM (usedLabelTriples `zip` branches) $ \((_, (ctorName, label)), br) -> do - let expr = snd (branchGetExpr br) - (_, ctorLLVMName, index) <- asks (fromJust . Map.lookup ctorName . tableCtors) - let Ctor _ argTys = ctors !! index - let bindingPairs = branchGetBindings br `zip` (typeToLLVM <$> argTys) - LLVMIR.emitBlockStart label - ctorPtr <- LLVMIR.bitcast dataPtr (LLVM.ptr (LLVM.NamedTypeReference ctorLLVMName)) - ctorPtrOffset <- LLVMIR.int32 0 - bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, ty)) -> do - ctorPtrIndex <- LLVMIR.int32 i - -- Need to bitcast the ptr type because we need a concrete type. We - -- also need to load the data immediately because of how variables - -- are implemented. - v <- LLVMIR.gep ctorPtr [ctorPtrOffset, ctorPtrIndex] - >>= flip LLVMIR.bitcast (LLVM.ptr ty) - >>= flip LLVMIR.load 0 - pure (n, v) - let updateTable t = t { tableVars = Map.union (tableVars t) (Map.fromList bindings) } - res <- local updateTable expr >>= resultEnsureBranch - LLVMIR.br endLabel - pure res - LLVMIR.emitBlockStart defLabel - LLVMIR.unreachable - foldM resultCombine (head resVals) (tail resVals) - -- resultCombine pure (sconcat (NE.fromList resVals)) - Let name (_, valM) (_, exprM) -> do - val <- joinPoint1 valM - local (\t -> t { tableVars = Map.insert name val (tableVars t) }) exprM - Var name -> do - -- Assume codegen table is correct - value <- gets (fromJust . Map.lookup name . tableVars) - pure $ resultValue value - App name argsM -> do - args <- traverse joinPoint1 (snd <$> argsM) - -- Assume that name is callable (e.g. either a static function or a - -- function pointer) - fn <- fromJust <$> lookupName name - resultValue <$> LLVMIR.call fn [(a, []) | a <- args] - FunRef name -> do - fn <- gets (fromJust . Map.lookup name . tableFuns) - pure (resultValue fn) - Cast (origExpr, exprM) num -> do - let origNum = case unfix (annGetAnn (unfix origExpr)) of - TyNumber n -> n - _ -> error "codegen: attempting to cast non-numeric" - expr <- joinPoint1 exprM - resultValue <$> castOpCodegen origNum num expr - Print (_, exprM) -> do - let fmtStr = encodeUtf8 (fromString "Print: %s\n\0") - let fmtStrLen = toInteger (BS.length fmtStr) - let fmtStrData = LLVMC.Int 8 . toInteger <$> BS.unpack fmtStr - fmtStrOper <- LLVMIR.array fmtStrData - fmtStrPtr <- LLVMIR.alloca (LLVM.ArrayType (fromInteger fmtStrLen) LLVM.i8) Nothing 0 - _ <- LLVMIR.store fmtStrPtr 0 fmtStrOper - printf <- gets tablePrintf - str <- joinPoint1 exprM - exprCstring <- LLVMIR.call RT.stringCstringRef [(str, [])] - fmtStrPtr' <- LLVMIR.bitcast fmtStrPtr (LLVM.ptr LLVM.i8) - _ <- LLVMIR.call printf [(fmtStrPtr', []), (exprCstring, [])] - retval <- LLVMIR.int64 0 - pure $ resultValue retval - lookupName :: MonadState CodegenTable m - => Text - -> m (Maybe LLVM.Operand) - lookupName name = - gets $ \t -> - Map.lookup name (tableVars t) <|> Map.lookup name (tableFuns t) - --- | Generates code for numeric casting -castOpCodegen :: LLVMIR.MonadIRBuilder m => Numeric -> Numeric -> LLVM.Operand -> m LLVM.Operand -castOpCodegen source castTo oper = case (source, castTo) of - (NumDouble, NumDouble) -> pure oper - (NumInt, NumInt) -> pure oper - (NumDouble, NumInt) -> LLVMIR.fptosi oper LLVM.i64 - (NumInt, NumDouble) -> LLVMIR.sitofp oper LLVM.double - (NumUnknown, _) -> castOpCodegen NumDouble castTo oper - (_, NumUnknown) -> error "castOpCodegen: attempting to cast to NumUnknown" - -ctorToLLVM :: Constructor -> [LLVM.Type] -ctorToLLVM (Ctor _ args) = typeToLLVM <$> args - -typeToLLVM :: Type -> LLVM.Type -typeToLLVM = go . unfix - where - go = \case - TyNumber n -> case n of - NumDouble -> LLVM.double - NumInt -> LLVM.i64 - NumUnknown -> LLVM.double - TyBool -> LLVM.i1 - TyString -> LLVM.ptr RT.stringType - TyAdt name -> LLVM.NamedTypeReference (llvmName name) - TyFun args res -> - LLVM.ptr $ LLVM.FunctionType - { LLVM.resultType = typeToLLVM res - , LLVM.argumentTypes = typeToLLVM <$> args - , LLVM.isVarArg = False - } - -adtToLLVM :: Text - -> [Constructor] - -> LLVMIR.ModuleBuilderT Codegen () -adtToLLVM adtName ctors = do - let adtType = LLVM.StructureType - { LLVM.isPacked = True - , LLVM.elementTypes = [LLVM.i32, LLVM.ptr LLVM.i8] } - adtLLVMName <- gets ((\(n,_,_) -> n) . fromJust . Map.lookup adtName . tableAdts) - -- TODO: Store returned type in symbol table to avoid error-prone type - -- reconstruction - _ <- LLVMIR.typedef adtLLVMName (Just adtType) - forM_ ctors $ \(Ctor ctorName args) -> do - let ctorType = LLVM.StructureType - { LLVM.isPacked = True - , LLVM.elementTypes = typeToLLVM <$> args } - ctorLLVMName <- gets ((\(_,n,_) -> n) . fromJust . Map.lookup ctorName . tableCtors) - LLVMIR.typedef ctorLLVMName (Just ctorType) - --- | Emits the given function definition -funToLLVM :: Text - -> [(Text, Type)] - -> Type - -> TypedExpr - -> LLVMIR.ModuleBuilderT Codegen LLVM.Operand -funToLLVM name params ty body = - let name' = if name == "main" then "__simpl_main" else name - ftype = typeToLLVM ty - fname = llvmName name' - fparams = [(typeToLLVM t, fromString (Text.unpack n)) | (n, t) <- params] - in mdo foper <- LLVMIR.function fname fparams ftype $ \args -> do - LLVMIR.ensureBlock - endLabel <- LLVMIR.freshName "function_end" - -- We need to make sure we don't pollute other function scopes - oldVars <- gets tableVars - let updVars t = tableVars t `Map.union` Map.fromList ((fst <$> params) `zip` args) - modify (\t -> t { tableCurrentJoin = endLabel - , tableFuns = Map.insert name foper (tableFuns t) - , tableVars = updVars t }) - retval <- joinPoint1 (exprCodegen body) - -- Restore old scope - modify (\t -> t { tableVars = oldVars }) - LLVMIR.ret retval - pure foper - --- | Generate code for the entire module -moduleCodegen :: [Decl Expr] - -> SymbolTable TypedExpr - -> LLVMIR.ModuleBuilderT Codegen () -moduleCodegen decls symTab = mdo - -- Message is "Hi\n" (with null terminator) - (_, msg) <- staticString ".message" "Hello world!\n" - (_, resultFmt) <- staticString ".resultformat" "Result: %i\n" - let srcCode = unlines $ show . pretty <$> decls - (_, exprSrc) <- staticString ".sourcecode" $ "Source code: " ++ srcCode ++ "\n" - RT.emitRuntimeDecls - modify (\t -> t { tablePrintf = RT.printfRef }) - forM_ (Map.toList . symTabAdts $ symTab) $ \(name, (_, ctors)) -> - adtToLLVM name ctors - -- Insert function operands into symbol table before emitting so order of - -- definition doesn't matter. This works because the codegen monad is lazy. - modify (\t -> t { tableFuns = tableFuns t `Map.union` Map.fromList funOpers }) - funOpers <- forM (Map.toList . symTabFuns $ symTab) $ \(name, (params, ty, body)) -> - (name, ) <$> funToLLVM name params ty body - - _ <- LLVMIR.function "main" [] LLVM.i64 $ \_ -> do - diagnosticsEnabled <- gets (enableDiagnostics . tableOptions) - when diagnosticsEnabled $ LLVMIR.call RT.printfRef [(msg, [])] >> pure () - let mainTy = LLVM.ptr (LLVM.FunctionType LLVM.i64 [] False) - let mainName = LLVM.mkName "__simpl_main" - let mainRef = LLVM.ConstantOperand (LLVMC.GlobalReference mainTy mainName) - when diagnosticsEnabled $ LLVMIR.call RT.printfRef [(exprSrc, [])] >> pure () - exprResult <- LLVMIR.call mainRef [] - when diagnosticsEnabled $ LLVMIR.call RT.printfRef [(resultFmt, []), (exprResult, [])] >> pure () - retcode <- LLVMIR.int64 1 - LLVMIR.ret retcode - pure () - -runCodegen :: CompilerOpts -> [Decl Expr] -> SymbolTable TypedExpr -> LLVM.Module -runCodegen opts decls symTab - = runIdentity - . flip evalStateT emptyCodegenTable - . unCodegen - . LLVMIR.buildModuleT "simpl.ll" - $ lift (initCodegenTable opts symTab) >> moduleCodegen decls symTab From 3c48d346b1d5c82bab7209552492ee746b13fdb1 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 27 Mar 2019 22:24:20 -0700 Subject: [PATCH 15/24] Improve JoinIR pretty printing --- src/Simpl/JoinIR.hs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index 4ae9d11..1f2d04c 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -139,9 +139,10 @@ instance Pretty a => Pretty (JBranch a) where instance Pretty a => Pretty (Joinable a) where pretty = \case JIf guard trueBr falseBr -> - PP.hang 2 $ "if" <+> pretty guard <+> PP.group ( - "then" <> PP.softline <> pretty trueBr <> PP.softline - <> "else" <> PP.softline <> pretty falseBr) + PP.hang 2 $ PP.sep + [ "if" <+> pretty guard + , "then" <> PP.softline <> PP.align (pretty trueBr) + , "else" <> PP.softline <> PP.align (pretty falseBr) ] JCase expr brs -> PP.hang 2 $ "case" <+> pretty expr <+> "of" <> PP.hardline <> (PP.vsep $ pretty <$> brs) @@ -153,13 +154,13 @@ instance Pretty JExpr where f = \case JVal v -> pretty v JLet n v next -> PP.hsep ["let", pretty n, "=", pretty v, "in"] <> PP.softline <> pretty next - JJoin lbl n joinbl next -> PP.align $ - PP.hang 2 (PP.hsep ["join", pretty lbl, "bind", pretty n, "="] - <> PP.softline <> pretty joinbl) + JJoin lbl n joinbl next -> + (PP.group . PP.hang 2 $ PP.hsep ["join" <> PP.enclose "[" "]" (pretty lbl), pretty n, "="] + <> PP.flatAlt PP.hardline " " <> pretty joinbl) <> PP.hardline <> "in" <+> pretty next JJump lbl v -> PP.hsep ["jump", pretty lbl, "with", pretty v] JApp name clbl args next -> - PP.hsep (["let app", pretty name, "=", pretty clbl] ++ (pretty <$> args)) + PP.hsep (["let app", pretty name, "=", pretty clbl] ++ (pretty <$> args) ++ ["in"]) <> PP.hardline <> pretty next -- * Annotated [JExpr]s From e14de8201534d4b7338fd27453472cbeb8ee927e Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 31 Mar 2019 12:39:12 -0700 Subject: [PATCH 16/24] [skip ci] WIP modify Join IR to support more efficient control flow --- src/Simpl/Backend/CodegenJoin.hs | 58 ++++++++++---------- src/Simpl/JoinIR.hs | 93 +++++++++++++++++++------------- 2 files changed, 87 insertions(+), 64 deletions(-) diff --git a/src/Simpl/Backend/CodegenJoin.hs b/src/Simpl/Backend/CodegenJoin.hs index fc90bd4..d63888b 100644 --- a/src/Simpl/Backend/CodegenJoin.hs +++ b/src/Simpl/Backend/CodegenJoin.hs @@ -242,25 +242,27 @@ jvalueCodegen = \case JVar name -> gets (snd . fromJust . Map.lookup name . tableVars) JLit l -> literalCodegen l -joinableCodegen - :: Joinable (AnnExpr '[ 'ExprType]) +controlFlowCodegen + :: JValue + -> LLVM.Operand + -> ControlFlow (AnnExpr '[ 'ExprType]) -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) () -joinableCodegen = \case - JIf guardVal trueBr falseBr -> do +controlFlowCodegen val valOper = \case + JIf (Cfe trueBr trueCf) (Cfe falseBr falseCf) -> do LLVMIR.ensureBlock - cond <- jvalueCodegen guardVal trueLabel <- LLVMIR.freshName "if_then" falseLabel <- LLVMIR.freshName "if_else" - LLVMIR.condBr cond trueLabel falseLabel + LLVMIR.condBr valOper trueLabel falseLabel LLVMIR.emitBlockStart trueLabel - _ <- jexprCodegen trueBr + (trueVal, trueOper) <- jexprCodegen trueBr + _ <- controlFlowCodegen trueVal trueOper trueCf LLVMIR.emitBlockStart falseLabel - _ <- jexprCodegen falseBr + (falseVal, falseOper) <- jexprCodegen falseBr + _ <- controlFlowCodegen falseVal falseOper falseCf pure () - JCase val branches -> do + JCase branches -> do LLVMIR.ensureBlock defLabel <- LLVMIR.freshName "case_default" - valOper <- jvalueCodegen val allCaseLabels <- forM branches $ \case BrAdt name _ _ -> let labelName = "case_" <> fromString (Text.unpack name) in @@ -276,6 +278,7 @@ joinableCodegen = \case LLVMIR.switch tag defLabel jumpTable forM_ (usedLabelTriples `zip` branches) $ \((_, (ctorName, label)), br) -> do let expr = branchGetExpr br + let cf = branchGetControlFlow br (_, ctorLLVMName, index) <- gets (fromJust . Map.lookup ctorName . tableCtors) let Ctor _ argTys = ctors !! index let bindingPairs = branchGetBindings br `zip` (typeToLLVM <$> argTys) @@ -293,9 +296,19 @@ joinableCodegen = \case ty <- lookupValueType (JVar n) pure (n, (ty, v)) let updateTable t = t { tableVars = Map.union (tableVars t) (Map.fromList bindings) } - localCodegenTable updateTable (jexprCodegen expr) + (exprVal, exprOper) <- jexprCodegen expr + localCodegenTable updateTable (controlFlowCodegen exprVal exprOper cf) LLVMIR.emitBlockStart defLabel LLVMIR.unreachable + JJump lbl -> do + v <- jvalueCodegen val + jvals <- gets tableJoinValues + block <- LLVMIR.currentBlock + let f = (\(n, jvs) -> Just (n, (v, block) : jvs)) + let updJvals = Map.update f lbl jvals + modify (\t -> t { tableJoinValues = updJvals }) + llvmLabel <- gets (fst . fromJust . Map.lookup lbl . tableJoinValues) + LLVMIR.br llvmLabel callableCodegen :: Callable @@ -365,40 +378,31 @@ callableCodegen callable args = case callable of jexprCodegen :: AnnExpr '[ 'ExprType] - -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) (JValue, LLVM.Operand) jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix where go :: TypeF Type -> JExprF (AnnExpr '[ 'ExprType ]) - -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand + -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) (JValue, LLVM.Operand) go exprTy = \case - JVal v -> jvalueCodegen v + JVal v -> (v,) <$> jvalueCodegen v JLet name val next -> do oper <- jvalueCodegen val _ <- bindVariable name exprTy oper jexprCodegen next - JJoin lbl varName joinable next -> do + JJoin lbl varName (Cfe expr cf) next -> do llvmLabel <- LLVMIR.freshName (fromString (Text.unpack lbl)) let addJoinEntry = \t -> t { tableJoinValues = Map.insert lbl (llvmLabel, []) (tableJoinValues t) } oldJoinEntries <- gets tableJoinValues - _ <- localCodegenTable addJoinEntry (joinableCodegen joinable) + (lastVal, lastValOper) <- jexprCodegen expr + _ <- localCodegenTable addJoinEntry (controlFlowCodegen lastVal lastValOper cf) (_, joinValues) <- gets (fromJust . Map.lookup lbl . tableJoinValues) modify (\t -> t { tableJoinValues = oldJoinEntries }) LLVMIR.emitBlockStart llvmLabel op <- LLVMIR.phi joinValues bindVariable varName exprTy op jexprCodegen next - JJump lbl val -> do - v <- jvalueCodegen val - jvals <- gets tableJoinValues - block <- LLVMIR.currentBlock - let f = (\(n, jvs) -> Just (n, (v, block) : jvs)) - let updJvals = Map.update f lbl jvals - modify (\t -> t { tableJoinValues = updJvals }) - llvmLabel <- gets (fst . fromJust . Map.lookup lbl . tableJoinValues) - LLVMIR.br llvmLabel - pure v JApp varName callable args next -> do oper <- callableCodegen callable args bindVariable varName exprTy oper @@ -471,7 +475,7 @@ funToLLVM name params ty body = let updVars t = tableVars t `Map.union` Map.fromList [(n, (unfix vty, op)) | ((n, vty), op) <- params `zip` args] modify (\t -> t { tableFuns = Map.insert name foper (tableFuns t) , tableVars = updVars t }) - retval <- jexprCodegen body + (_, retval) <- jexprCodegen body -- Restore old scope modify (\t -> t { tableVars = oldVars }) LLVMIR.ret retval diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR.hs index 1f2d04c..ef3d76b 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR.hs @@ -58,28 +58,35 @@ data JValue | JLit !Literal deriving (Show, Eq) + +-- | Represents how a value at the end of a control flow branch should be handled. +data ControlFlow a + -- | If expression on the given value, with a true branch and a false branch + = JIf !(Cfe a) !(Cfe a) + + -- | Case expression on the given value + | JCase ![JBranch a] + + -- | Jump to the given enclosing join point with the given value. + | JJump !Text + deriving (Functor, Foldable, Traversable, Show) + + data JBranch a - = BrAdt Name [Name] !a -- ^ Destructure algebraic data type + = BrAdt Name [Name] !(Cfe a) -- ^ Destructure algebraic data type deriving (Functor, Foldable, Traversable, Show) branchGetExpr :: JBranch a -> a branchGetExpr = \case - BrAdt _ _ e -> e + BrAdt _ _ (Cfe e _) -> e branchGetBindings :: JBranch a -> [Text] branchGetBindings = \case BrAdt _ vars _ -> vars - --- | Represents expressions that must be bound to a join point. -data Joinable a - -- | If expression on the given variable, with a true branch and a false - -- branch - = JIf !JValue !a !a - - -- | Case expression on the given variable - | JCase !JValue ![JBranch a] - deriving (Functor, Foldable, Traversable, Show) +branchGetControlFlow :: JBranch a -> ControlFlow a +branchGetControlFlow = \case + BrAdt _ _ (Cfe _ cf) -> cf -- | The JoinIR expression type. Syntactically, it is in ANF-form with explicit @@ -93,18 +100,20 @@ data JExprF a -- | A join point. Consists of a label, the variable representing the joined -- value, the expression to join, and the next expression. - | JJoin !Label !Name !(Joinable a) !a - - -- | Jump to the given enclosing join point with the given value. - | JJump !Label !JValue + | JJoin !Label !Name !(Cfe a) !a -- | Apply the callable to the arguments, bind the result to the given name, -- and continue to the next expression. | JApp !Name !Callable ![JValue] !a deriving (Functor, Foldable, Traversable, Show) + +data Cfe a = Cfe !a !(ControlFlow a) + deriving (Functor, Foldable, Traversable, Show) + $(deriveShow1 ''JBranch) -$(deriveShow1 ''Joinable) +$(deriveShow1 ''ControlFlow) +$(deriveShow1 ''Cfe) $(deriveShow1 ''JExprF) type JExpr = Fix JExprF @@ -114,7 +123,6 @@ jexprGetVal = \case JVal v -> v JLet n _ _ -> JVar n JJoin _ n _ _ -> JVar n - JJump _ v -> v JApp n _ _ _ -> JVar n instance Pretty Callable where @@ -132,20 +140,22 @@ instance Pretty JValue where JLit l -> pretty l instance Pretty a => Pretty (JBranch a) where - pretty (BrAdt ctorName varNames expr) = - PP.hang 2 $ PP.hsep (pretty <$> brPart) <> PP.softline <> pretty expr + pretty (BrAdt ctorName varNames cfe) = + PP.hang 2 $ PP.hsep (pretty <$> brPart) <> PP.softline <> pretty cfe where brPart = [ctorName] ++ varNames ++ ["=>"] -instance Pretty a => Pretty (Joinable a) where +instance Pretty a => Pretty (ControlFlow a) where pretty = \case - JIf guard trueBr falseBr -> - PP.hang 2 $ PP.sep - [ "if" <+> pretty guard - , "then" <> PP.softline <> PP.align (pretty trueBr) - , "else" <> PP.softline <> PP.align (pretty falseBr) ] - JCase expr brs -> - PP.hang 2 $ "case" <+> pretty expr <+> "of" <> PP.hardline <> + JIf trueBr falseBr -> + PP.hang 2 $ "if" <+> PP.sep + [ "then" <> PP.softline <> PP.align (flatParens (pretty trueBr)) + , "else" <> PP.softline <> PP.align (flatParens (pretty falseBr)) ] + JCase brs -> + PP.hang 2 $ "case" <+> "of" <> PP.hardline <> (PP.vsep $ pretty <$> brs) + JJump lbl -> PP.hsep ["jump", pretty lbl] + where + flatParens d = PP.flatAlt d (PP.parens d) instance Pretty JExpr where pretty = f . unfix @@ -158,11 +168,19 @@ instance Pretty JExpr where (PP.group . PP.hang 2 $ PP.hsep ["join" <> PP.enclose "[" "]" (pretty lbl), pretty n, "="] <> PP.flatAlt PP.hardline " " <> pretty joinbl) <> PP.hardline <> "in" <+> pretty next - JJump lbl v -> PP.hsep ["jump", pretty lbl, "with", pretty v] JApp name clbl args next -> PP.hsep (["let app", pretty name, "=", pretty clbl] ++ (pretty <$> args) ++ ["in"]) <> PP.hardline <> pretty next +instance Pretty a => Pretty (Cfe a) where + pretty (Cfe expr cf) = + PP.align (pretty expr <> ";" <> line <> pretty cf) + where + line = case cf of + JIf _ _ -> PP.hardline + JCase _ -> PP.hardline + JJump _ -> PP.softline + -- * Annotated [JExpr]s -- -- Because it's possible to have many different annotations on a single AST, we @@ -236,10 +254,10 @@ exampleJExpr :: JExpr exampleJExpr = Fix $ JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) where - ifE = - JIf (JLit (LitInt 5)) - (Fix (JJump "label" (JLit (LitInt 10)))) - (Fix (JJump "label" (JLit (LitInt 5)))) + intVal = Fix . JVal . JLit . LitInt + ifE = Cfe (intVal 5) $ + JIf (Cfe (intVal 10) (JJump "label")) + (Cfe (intVal 5) (JJump "label")) exampleTypedJExpr :: AnnExpr '[ 'ExprType ] exampleTypedJExpr = Fix $ AnnExprF @@ -250,8 +268,9 @@ exampleTypedJExpr = Fix $ AnnExprF varE = Fix $ AnnExprF { annGetAnn = withType tyInt V.:& V.RNil , annGetExpr = JVal (JVar "myvar") } - ifE = - JIf (JLit (LitInt 5)) (jmpE 10) (jmpE 5) - jmpE v = Fix $ AnnExprF + intVal x = Fix $ AnnExprF { annGetAnn = withType tyInt V.:& V.RNil - , annGetExpr = JJump "label" (JLit (LitInt v)) } + , annGetExpr = JVal (JLit (LitInt x)) } + jmpCfe x = Cfe (intVal x) (JJump "label") + ifE = Cfe (intVal 5) $ + JIf (jmpCfe 10) (jmpCfe 5) From c4b4d262e9755172d75a4cac2faec4fcadebfd6d Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 1 Apr 2019 11:24:02 -0700 Subject: [PATCH 17/24] Update AST to JoinIR transformation --- src/Simpl/AstToJoinIR.hs | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index bdbd53c..94423a5 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -92,15 +92,14 @@ transformExpr :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) -> m (J.AnnExpr '[ 'J.ExprType]) transformExpr expr = anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) --- | Perform ANF transformation on the branch, returning to the jump label at --- the end. +-- | Perform ANF transformation on the branch, afterwards handling control flow. transformBranch :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) - => Text -- ^ Return label + => J.ControlFlow (J.AnnExpr '[ 'J.ExprType]) -- ^ Control flow handler -> A.Branch (A.AnnExpr Type) -- ^ Branches -> m (J.JBranch (J.AnnExpr '[ 'J.ExprType])) -transformBranch retLabel (A.BrAdt adtName argNames expr) = - let jexprM = anfTransform expr (pure . makeJexpr (astType expr) . J.JJump retLabel) in - J.BrAdt adtName argNames <$> jexprM +transformBranch cf (A.BrAdt adtName argNames expr) = do + jexpr <- anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) + pure $ J.BrAdt adtName argNames (J.Cfe jexpr cf) -- | Main ANF transformation logic @@ -124,17 +123,25 @@ anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of A.If guard trueBr falseBr -> anfTransform guard $ \jguard -> do lbl <- freshLabel - trueBr' <- anfTransform trueBr (pure . makeJexpr (astType trueBr) . J.JJump lbl) - falseBr' <- anfTransform falseBr (pure . makeJexpr (astType falseBr) . J.JJump lbl) + trueBr' <- anfTransform trueBr (pure . makeJexpr (astType trueBr) . J.JVal) + falseBr' <- anfTransform falseBr (pure . makeJexpr (astType falseBr) . J.JVal) name <- freshVar - makeJexpr ty . J.JJoin lbl name (J.JIf jguard trueBr' falseBr') <$> + let jmp = J.JJump lbl + let guardTy = A.annGetAnn (unfix guard) + let guardCfe = makeJexpr guardTy (J.JVal jguard) + let cfe = J.Cfe guardCfe (J.JIf (J.Cfe trueBr' jmp) (J.Cfe falseBr' jmp)) + -- TODO: Make JJoin node placement more efficient + makeJexpr ty . J.JJoin lbl name cfe <$> local (symTabInsertVar name ty) (cont (J.JVar name)) A.Case branches expr -> anfTransform expr $ \jexpr -> do lbl <- freshLabel - jbranches <- traverse (transformBranch lbl) branches + let jexprTy = A.annGetAnn (unfix expr) + jbranches <- traverse (transformBranch (J.JJump lbl)) branches + let jexprCfe = J.Cfe (makeJexpr jexprTy (J.JVal jexpr)) (J.JCase jbranches) name <- freshVar - makeJexpr ty . J.JJoin lbl name (J.JCase jexpr jbranches) <$> + -- TODO: Make JJoin node placement more efficient + makeJexpr ty . J.JJoin lbl name jexprCfe <$> local (symTabInsertVar name ty) (cont (J.JVar name)) A.Cons ctorName args -> collectArgs args [] $ \argVals -> do From 28c469fd793d51ebd5e80aa9617c4309a5fe8c63 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 2 Apr 2019 10:58:16 -0700 Subject: [PATCH 18/24] Rename CodegenJoin to Codegen --- src/Simpl/Backend/{CodegenJoin.hs => Codegen.hs} | 2 +- src/Simpl/Compiler.hs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename src/Simpl/Backend/{CodegenJoin.hs => Codegen.hs} (99%) diff --git a/src/Simpl/Backend/CodegenJoin.hs b/src/Simpl/Backend/Codegen.hs similarity index 99% rename from src/Simpl/Backend/CodegenJoin.hs rename to src/Simpl/Backend/Codegen.hs index d63888b..63e7d6b 100644 --- a/src/Simpl/Backend/CodegenJoin.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -13,7 +13,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecursiveDo #-} -module Simpl.Backend.CodegenJoin where +module Simpl.Backend.Codegen where import Control.Applicative ((<|>)) import Control.Monad (forM, forM_) diff --git a/src/Simpl/Compiler.hs b/src/Simpl/Compiler.hs index bb54a36..9ddc60f 100644 --- a/src/Simpl/Compiler.hs +++ b/src/Simpl/Compiler.hs @@ -12,7 +12,7 @@ import LLVM.Context import Simpl.Ast import Simpl.AstToJoinIR -import Simpl.Backend.CodegenJoin (runCodegen) +import Simpl.Backend.Codegen (runCodegen) import Simpl.CompilerOptions import Simpl.SymbolTable import Simpl.Typing (TypeError, runTypecheck, checkType, withExtraVars) From 38ab68aa8fb5da728f0b9ee916fe59eaa2ac8a43 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 2 Apr 2019 12:16:41 -0700 Subject: [PATCH 19/24] Implement JoinIR AST verification algorithm --- src/Simpl/JoinIR/Verify.hs | 134 +++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/Simpl/JoinIR/Verify.hs diff --git a/src/Simpl/JoinIR/Verify.hs b/src/Simpl/JoinIR/Verify.hs new file mode 100644 index 0000000..d9cf303 --- /dev/null +++ b/src/Simpl/JoinIR/Verify.hs @@ -0,0 +1,134 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-| +Module : Simpl.JoinIR.Verify +Description : Verifies validity of a JoinIR AST +-} +module Simpl.JoinIR.Verify + (verify, VerifyCtx(..), emptyCtx) where + +import Control.Monad.Reader +import Control.Monad.Except +import Data.Functor.Foldable (cata) +import Data.Functor.Identity +import Data.Text (Text) +import Data.Set (Set) +import qualified Data.Set as Set + +import Simpl.JoinIR + +-- * Verification Monad + +data VerifyCtx = VerifyCtx + { verifyVars :: Set Text + , verifyLabels :: Set Text + } deriving (Eq, Show) + +emptyCtx :: VerifyCtx +emptyCtx = VerifyCtx + { verifyVars = Set.empty + , verifyLabels = Set.empty } + +ctxWithVar :: Text -> VerifyCtx -> VerifyCtx +ctxWithVar name ctx = ctx { verifyVars = Set.insert name (verifyVars ctx) } + +ctxWithLabel :: Text -> VerifyCtx -> VerifyCtx +ctxWithLabel lbl ctx = ctx { verifyLabels = Set.insert lbl (verifyLabels ctx) } + +newtype VerifyT m a = + VerifyT { unVerify :: ReaderT VerifyCtx (ExceptT VerifyError m) a } + deriving ( Functor + , Applicative + , Monad + , MonadReader VerifyCtx + , MonadError VerifyError) + +type Verify = VerifyT Identity + +runVerifyT :: VerifyT m a -> VerifyCtx -> m (Either VerifyError a) +runVerifyT m ctx + = runExceptT + . flip runReaderT ctx + . unVerify + $ m + +runVerify :: Verify a -> VerifyCtx -> Either VerifyError a +runVerify m ctx = runIdentity (runVerifyT m ctx) + +-- * Verification errors + +data VerifyError = VarRedefinition Text + | LabelRedefinition Text + | NoSuchLabel Text + | NoSuchVar Text + deriving (Show, Eq) + +verify :: AnnExpr a -> Either VerifyError () +verify expr = runVerify (doVerifyExpr expr) emptyCtx + +-- | Throw an error if the variable is already bound +checkUnboundVar :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => Text -> m () +checkUnboundVar var = + asks (Set.member var . verifyVars) >>= \case + True -> throwError $ VarRedefinition var + False -> pure () + +doVerifyValue :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => JValue + -> m () +doVerifyValue = \case + JVar name -> asks (Set.member name . verifyVars) >>= \case + True -> pure () + False -> throwError $ NoSuchVar name + JLit _ -> pure () + +doVerifyExpr :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => AnnExpr a + -> m () +doVerifyExpr = cata (go . annGetExpr) + where + go :: (MonadError VerifyError m, MonadReader VerifyCtx m) => JExprF (m ()) -> m () + go = \case + JVal v -> doVerifyValue v + JLet name val nextM -> do + doVerifyValue val + checkUnboundVar name + local (ctxWithVar name) nextM + JJoin lbl name cfe nextM -> do + asks (Set.member lbl . verifyLabels) >>= \case + True -> throwError $ LabelRedefinition lbl + False -> pure () + local (ctxWithLabel lbl) (doVerifyCfe cfe) + checkUnboundVar name + local (ctxWithVar name) nextM + JApp name _ args nextM -> do + -- Note: we ignore the callable for now + _ <- traverse doVerifyValue args + checkUnboundVar name + local (ctxWithVar name) nextM + +doVerifyCfe :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => Cfe (m ()) + -> m () +doVerifyCfe (Cfe exprM cf) = do + exprM + case cf of + JIf trueCfe falseCfe -> do + doVerifyCfe trueCfe + doVerifyCfe falseCfe + JCase branches -> traverse doVerifyBranch branches >> pure () + JJump lbl -> + asks (Set.member lbl . verifyLabels) >>= \case + True -> pure () + False -> throwError $ NoSuchLabel lbl + +doVerifyBranch :: (MonadError VerifyError m, MonadReader VerifyCtx m) + => JBranch (m ()) + -> m () +doVerifyBranch (BrAdt name args cfe) = do + checkUnboundVar name + _ <- traverse checkUnboundVar args + doVerifyCfe cfe From 001bcd56c56c02c41e87835f006e6b7a00f1a66c Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 2 Apr 2019 16:08:08 -0700 Subject: [PATCH 20/24] Refactor JoinIR syntax definitions to JoinIR subfolder --- src/Simpl/AstToJoinIR.hs | 2 +- src/Simpl/Backend/Codegen.hs | 2 +- src/Simpl/Compiler.hs | 2 +- src/Simpl/{JoinIR.hs => JoinIR/Syntax.hs} | 4 ++-- src/Simpl/JoinIR/Verify.hs | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) rename src/Simpl/{JoinIR.hs => JoinIR/Syntax.hs} (99%) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 94423a5..0be2d81 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -19,7 +19,7 @@ import Data.String (fromString) import Simpl.Ast (Type) import Simpl.SymbolTable import qualified Simpl.Ast as A -import qualified Simpl.JoinIR as J +import qualified Simpl.JoinIR.Syntax as J -- * Public API diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 63e7d6b..4a0d2c8 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -49,7 +49,7 @@ import Simpl.CompilerOptions import Simpl.SymbolTable import Simpl.Typing (literalType) import Simpl.Backend.Runtime () -import Simpl.JoinIR +import Simpl.JoinIR.Syntax import qualified Simpl.Backend.Runtime as RT data CodegenTable = diff --git a/src/Simpl/Compiler.hs b/src/Simpl/Compiler.hs index 9ddc60f..491db2b 100644 --- a/src/Simpl/Compiler.hs +++ b/src/Simpl/Compiler.hs @@ -16,7 +16,7 @@ import Simpl.Backend.Codegen (runCodegen) import Simpl.CompilerOptions import Simpl.SymbolTable import Simpl.Typing (TypeError, runTypecheck, checkType, withExtraVars) -import Simpl.JoinIR (unannotate) +import Simpl.JoinIR.Syntax (unannotate) import Paths_simpl_lang -- | Main error type, aggregating all error types. diff --git a/src/Simpl/JoinIR.hs b/src/Simpl/JoinIR/Syntax.hs similarity index 99% rename from src/Simpl/JoinIR.hs rename to src/Simpl/JoinIR/Syntax.hs index ef3d76b..e399b46 100644 --- a/src/Simpl/JoinIR.hs +++ b/src/Simpl/JoinIR/Syntax.hs @@ -17,14 +17,14 @@ {-# LANGUAGE TypeApplications #-} {-| -Module : Simpl.JoinIR +Module : Simpl.JoinIR.Syntax Description : AST for the JoinIR Defines the abstract syntax tree for JoinIR, an IR for SimPL based on the IR presented in /Compiling without Continuations/ by Luke Maurer, Zena Ariola, Paul Downen, and Simon Peyton Jones (PLDI '17). -} -module Simpl.JoinIR where +module Simpl.JoinIR.Syntax where import Data.Text (Text) import Data.Text.Prettyprint.Doc (Pretty, pretty, (<>), (<+>)) diff --git a/src/Simpl/JoinIR/Verify.hs b/src/Simpl/JoinIR/Verify.hs index d9cf303..4fa1158 100644 --- a/src/Simpl/JoinIR/Verify.hs +++ b/src/Simpl/JoinIR/Verify.hs @@ -17,7 +17,7 @@ import Data.Text (Text) import Data.Set (Set) import qualified Data.Set as Set -import Simpl.JoinIR +import Simpl.JoinIR.Syntax -- * Verification Monad From d24ae79a675702e1e1abf69994337789ad64bd6d Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 2 Apr 2019 16:21:39 -0700 Subject: [PATCH 21/24] Fix JoinIR if cfe pretty printing --- src/Simpl/JoinIR/Syntax.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Simpl/JoinIR/Syntax.hs b/src/Simpl/JoinIR/Syntax.hs index e399b46..7ded960 100644 --- a/src/Simpl/JoinIR/Syntax.hs +++ b/src/Simpl/JoinIR/Syntax.hs @@ -147,7 +147,7 @@ instance Pretty a => Pretty (JBranch a) where instance Pretty a => Pretty (ControlFlow a) where pretty = \case JIf trueBr falseBr -> - PP.hang 2 $ "if" <+> PP.sep + PP.hang 3 $ "if" <+> PP.sep [ "then" <> PP.softline <> PP.align (flatParens (pretty trueBr)) , "else" <> PP.softline <> PP.align (flatParens (pretty falseBr)) ] JCase brs -> From 247439c8db21f7e936f138226497d66bb97fedcf Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 3 Apr 2019 00:10:00 -0700 Subject: [PATCH 22/24] For JoinIR Verify: Export VerifyError, take VerifyCtx as argument --- src/Simpl/JoinIR/Verify.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Simpl/JoinIR/Verify.hs b/src/Simpl/JoinIR/Verify.hs index 4fa1158..16d6784 100644 --- a/src/Simpl/JoinIR/Verify.hs +++ b/src/Simpl/JoinIR/Verify.hs @@ -7,7 +7,7 @@ Module : Simpl.JoinIR.Verify Description : Verifies validity of a JoinIR AST -} module Simpl.JoinIR.Verify - (verify, VerifyCtx(..), emptyCtx) where + (verify, VerifyCtx(..), emptyCtx, VerifyError(..)) where import Control.Monad.Reader import Control.Monad.Except @@ -65,8 +65,8 @@ data VerifyError = VarRedefinition Text | NoSuchVar Text deriving (Show, Eq) -verify :: AnnExpr a -> Either VerifyError () -verify expr = runVerify (doVerifyExpr expr) emptyCtx +verify :: VerifyCtx -> AnnExpr a -> Either VerifyError () +verify ctx expr = runVerify (doVerifyExpr expr) ctx -- | Throw an error if the variable is already bound checkUnboundVar :: (MonadError VerifyError m, MonadReader VerifyCtx m) From fa650af7bb0fc350fd5b742716cbc549879d88b6 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 3 Apr 2019 00:10:43 -0700 Subject: [PATCH 23/24] Add unit tests for JoinIR verification --- package.yaml | 2 + src/Simpl/JoinIR/Syntax.hs | 32 ++-------------- test/JoinVerifySpec.hs | 78 ++++++++++++++++++++++++++++++++++++++ test/Spec.hs | 9 ++++- 4 files changed, 92 insertions(+), 29 deletions(-) create mode 100644 test/JoinVerifySpec.hs diff --git a/package.yaml b/package.yaml index 3688d3a..85b2236 100644 --- a/package.yaml +++ b/package.yaml @@ -89,3 +89,5 @@ tests: - -with-rtsopts=-N dependencies: - simpl-lang + - tasty + - tasty-hunit diff --git a/src/Simpl/JoinIR/Syntax.hs b/src/Simpl/JoinIR/Syntax.hs index 7ded960..a0b0ad3 100644 --- a/src/Simpl/JoinIR/Syntax.hs +++ b/src/Simpl/JoinIR/Syntax.hs @@ -223,6 +223,10 @@ type AnnExpr fields = Fix (AnnExprF fields) toAnnExprF :: JExprF a -> AnnExprF '[] a toAnnExprF expr = AnnExprF { annGetAnn = V.RNil, annGetExpr = expr } +-- | Converts a [JExpr] to an "unannotated" [AnnExpr] +toAnnExpr :: JExpr -> AnnExpr '[] +toAnnExpr expr = cata (Fix . toAnnExprF) expr + -- | Removes all annotations from an [AnnExpr] unannotate :: AnnExpr fields -> JExpr unannotate = cata (Fix . annGetExpr) @@ -246,31 +250,3 @@ $(deriveShow1 ''PrettyJExprF) prettyAnnExpr :: Show (V.Rec Attr fields) => AnnExpr fields -> PrettyJExpr prettyAnnExpr = cata $ \expr -> Fix (PrettyJExprF (show (annGetAnn expr), annGetExpr expr)) - -exampleAnnotation :: V.Rec Attr '[ 'ExprType ] -exampleAnnotation = (SExprType =:: Fix (TyNumber NumInt)) V.:& V.RNil - -exampleJExpr :: JExpr -exampleJExpr = Fix $ - JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) - where - intVal = Fix . JVal . JLit . LitInt - ifE = Cfe (intVal 5) $ - JIf (Cfe (intVal 10) (JJump "label")) - (Cfe (intVal 5) (JJump "label")) - -exampleTypedJExpr :: AnnExpr '[ 'ExprType ] -exampleTypedJExpr = Fix $ AnnExprF - { annGetAnn = withType tyInt V.:& V.RNil - , annGetExpr = JJoin "label" "myvar" ifE varE } - where - tyInt = Fix (TyNumber NumInt) - varE = Fix $ AnnExprF - { annGetAnn = withType tyInt V.:& V.RNil - , annGetExpr = JVal (JVar "myvar") } - intVal x = Fix $ AnnExprF - { annGetAnn = withType tyInt V.:& V.RNil - , annGetExpr = JVal (JLit (LitInt x)) } - jmpCfe x = Cfe (intVal x) (JJump "label") - ifE = Cfe (intVal 5) $ - JIf (jmpCfe 10) (jmpCfe 5) diff --git a/test/JoinVerifySpec.hs b/test/JoinVerifySpec.hs new file mode 100644 index 0000000..4ab05d3 --- /dev/null +++ b/test/JoinVerifySpec.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedStrings #-} +module JoinVerifySpec + (joinVerifyTests) where + +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Functor.Foldable (Fix(..)) +import Test.Tasty +import Test.Tasty.HUnit +import Simpl.Ast (Type, TypeF(..), Literal(..), Numeric(..)) +import Simpl.JoinIR.Syntax +import Simpl.JoinIR.Verify + +simpleCtx :: VerifyCtx +simpleCtx = emptyCtx { verifyVars = Set.singleton "x" } + +joinVerifyTests :: TestTree +joinVerifyTests = testGroup "JoinIR verify tests" + [ testCase "Valid if-cfe verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx (toAnnExpr goodIfExpr)) + , testCase "Invalid if-cfe fails to verify" $ + assertEqual "" (Left $ NoSuchLabel "badlabel") (verify emptyCtx (toAnnExpr badIfExpr)) + , joinBindingTests + ] + +goodIfExpr :: JExpr +goodIfExpr = Fix $ + JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) + where + intVal = Fix . JVal . JLit . LitInt + ifE = Cfe (intVal 5) $ + JIf (Cfe (intVal 10) (JJump "label")) + (Cfe (intVal 5) (JJump "label")) + +badIfExpr :: JExpr +badIfExpr = Fix $ + JJoin "label" "myvar" ifE (Fix (JVal (JVar "myvar"))) + where + intVal = Fix . JVal . JLit . LitInt + ifE = Cfe (intVal 5) $ + JIf (Cfe (intVal 10) (JJump "badlabel")) + (Cfe (intVal 5) (JJump "label")) + + +joinBindingTests :: TestTree +joinBindingTests = testGroup "JoinIR variable, label binding tests" + [ testCase "Known variable verifies successfully" $ + assertEqual "" (Right ()) (verify simpleCtx xVal) + , testCase "Unknown variable fails to verify" $ + assertEqual "" (Left $ NoSuchVar "x") (verify emptyCtx xVal) + , testCase "Let binding of unbound variable verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx letExpr) + , testCase "Let binding of bound variable fails to verify" $ + assertEqual "" (Left $ VarRedefinition "x") (verify simpleCtx letExpr) + , testCase "App binding of unbound variable verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx appPrintExpr) + , testCase "App binding of bound variable fails to verify" $ + assertEqual "" (Left $ VarRedefinition "x") (verify simpleCtx appPrintExpr) + , testCase "Join binding of unbound variable, label verifies successfully" $ + assertEqual "" (Right ()) (verify emptyCtx joinSimpleExpr) + , testCase "Join binding of bound variable, unbound label fails to verify" $ + assertEqual "" (Left $ VarRedefinition "x") (verify simpleCtx joinSimpleExpr) + , testCase "Join binding of unbound variable, bound label fails to verify" $ + assertEqual "" (Left $ LabelRedefinition "lbl") $ + verify (emptyCtx { verifyLabels = Set.singleton "lbl" }) joinSimpleExpr + , testCase "Jump to unknown label fails to verify" $ + assertEqual "" (Left $ NoSuchLabel "badlbl") (verify simpleCtx badJumpExpr) + ] + where + mke = toAnnExpr . Fix + mkef = Fix . toAnnExprF + xVal = mke $ JVal (JVar "x") + yVal = mke $ JVal (JVar "y") + letExpr = mkef $ JLet "x" (JLit (LitInt 5)) xVal + appPrintExpr = mkef $ JApp "x" CPrint [JLit (LitString "foo")] xVal + joinSimpleExpr = mkef $ JJoin "lbl" "x" (Cfe (mke $ JVal (JLit (LitInt 5))) (JJump "lbl")) xVal + badJumpExpr = mkef $ JJoin "lbl" "y" (Cfe xVal (JJump "badlbl")) yVal diff --git a/test/Spec.hs b/test/Spec.hs index cd4753f..66fcf2d 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,2 +1,9 @@ +import JoinVerifySpec + +import Test.Tasty + main :: IO () -main = putStrLn "Test suite not yet implemented" +main = defaultMain tests + +tests :: TestTree +tests = testGroup "Tests" [joinVerifyTests] From 757acbad8e8f44936ef61f7429177c9b3ca7c5f8 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 3 Apr 2019 15:16:31 -0700 Subject: [PATCH 24/24] Documentation and code cleanup for JoinIR --- src/Simpl/AstToJoinIR.hs | 42 +++++++++++++++++++++++------------- src/Simpl/Backend/Codegen.hs | 21 +++++++++++------- src/Simpl/JoinIR/Syntax.hs | 2 +- src/Simpl/JoinIR/Verify.hs | 11 ++++++++++ 4 files changed, 52 insertions(+), 24 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 0be2d81..66db6df 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -5,6 +5,11 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-| +Module : Simpl.AstToJoinIR +Description : Provides a function to normalize SimPL AST, transforming it into + JoinIR. +-} module Simpl.AstToJoinIR ( astToJoinIR ) where @@ -54,18 +59,22 @@ runTransformT m table runTransform :: Transform a -> SymbolTable (A.AnnExpr Type) -> a runTransform m table = runIdentity (runTransformT m table) +-- | Generates a fresh name using the given prefix and lookup function freshName :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) => Text -- ^ Prefix + -> (Text -> SymbolTable (A.AnnExpr Type) -> Maybe a) -- ^ Lookup function -> m Text -freshName prefix = do +freshName prefix lookupFun = do next <- (prefix <>) . fromString . show <$> supply - asks (symTabLookupVar next) >>= \case + asks (lookupFun next) >>= \case Nothing -> pure next - Just _ -> freshName prefix + Just _ -> freshName prefix lookupFun freshVar, freshLabel :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) => m Text -freshVar = freshName "var" -freshLabel = freshName "join" +-- | Generate a fresh variable name +freshVar = freshName "var" symTabLookupVar +-- | Generate a fresh join label +freshLabel = freshName "join" symTabLookupFun -- * Private utility functions @@ -102,10 +111,12 @@ transformBranch cf (A.BrAdt adtName argNames expr) = do pure $ J.BrAdt adtName argNames (J.Cfe jexpr cf) --- | Main ANF transformation logic +-- | Main ANF transformation logic. Given the SimPL AST, this function will +-- normalize the AST, and then it will feed the final JValue into the given +-- continuation to produce the resulting JoinIR AST. anfTransform :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) - => A.AnnExpr Type - -> (J.JValue -> m (J.AnnExpr '[ 'J.ExprType])) + => A.AnnExpr Type -- ^ Expression to translate + -> (J.JValue -> m (J.AnnExpr '[ 'J.ExprType])) -- ^ Continuation -> m (J.AnnExpr '[ 'J.ExprType]) anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of A.Lit lit -> cont (J.JLit lit) @@ -144,12 +155,12 @@ anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of makeJexpr ty . J.JJoin lbl name jexprCfe <$> local (symTabInsertVar name ty) (cont (J.JVar name)) A.Cons ctorName args -> - collectArgs args [] $ \argVals -> do + collectArgs args $ \argVals -> do varName <- freshVar makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals <$> local (symTabInsertVar varName ty) (cont (J.JVar varName)) A.App funcName args -> - collectArgs args [] $ \argVals -> do + collectArgs args $ \argVals -> do varName <- freshVar makeJexpr ty . J.JApp varName (J.CFunc funcName) argVals <$> local (symTabInsertVar varName ty) (cont (J.JVar varName)) @@ -171,9 +182,10 @@ anfTransform (Fix (A.AnnExprF ty exprf)) cont = case exprf of -- | Normalize each expression in sequential order, and then run the -- continuation with the expression values. collectArgs :: (MonadReader (SymbolTable (A.AnnExpr Type)) m, MonadFreshVar m) - => [A.AnnExpr Type] - -> [J.JValue] - -> ([J.JValue] -> m (J.AnnExpr '[ 'J.ExprType])) + => [A.AnnExpr Type] -- ^ Argument expressions + -> ([J.JValue] -> m (J.AnnExpr '[ 'J.ExprType])) -- ^ Continuation -> m (J.AnnExpr '[ 'J.ExprType]) -collectArgs [] vals mcont = mcont (reverse vals) -collectArgs (x:xs) vals mcont = anfTransform x $ \v -> collectArgs xs (v:vals) mcont +collectArgs = go [] + where + go vals [] mcont = mcont (reverse vals) + go vals (x:xs) mcont = anfTransform x $ \v -> go (v:vals) xs mcont diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 4a0d2c8..a5a79bd 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -123,7 +123,7 @@ bindVariable :: MonadState CodegenTable m -> TypeF Type -> LLVM.Operand -> m () -bindVariable name ty oper = do +bindVariable name ty oper = modify (\t -> t { tableVars = Map.insert name (ty, oper) (tableVars t) }) initCodegenTable :: CompilerOpts -> SymbolTable (AnnExpr '[ 'ExprType ]) -> Codegen () @@ -170,6 +170,7 @@ staticString name str = do llvmName :: Text -> LLVM.Name llvmName = LLVM.mkName . Text.unpack +-- | Generates LLVM code for a literal. literalCodegen :: LLVMIR.MonadIRBuilder m => Literal -> m LLVM.Operand literalCodegen = \case LitInt x -> LLVMIR.int64 (fromIntegral x) @@ -234,6 +235,7 @@ binOpCodegen op ty x y = Equal -> (LLVMIR.fcmp LLVMFP.OEQ, LLVMIR.icmp LLVMIP.EQ) in numBinopCodegen x y ty floatInstr intInstr +-- | Generates code for a [JValue]. jvalueCodegen :: (LLVMIR.MonadIRBuilder m, MonadState CodegenTable m) => JValue @@ -242,10 +244,11 @@ jvalueCodegen = \case JVar name -> gets (snd . fromJust . Map.lookup name . tableVars) JLit l -> literalCodegen l +-- | Generates code for a CFE. controlFlowCodegen - :: JValue - -> LLVM.Operand - -> ControlFlow (AnnExpr '[ 'ExprType]) + :: JValue -- ^ The value to continue control flow with. + -> LLVM.Operand -- ^ The LLVM operand of the value. + -> ControlFlow (AnnExpr '[ 'ExprType]) -- ^ The control flow continuation. -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) () controlFlowCodegen val valOper = \case JIf (Cfe trueBr trueCf) (Cfe falseBr falseCf) -> do @@ -304,15 +307,16 @@ controlFlowCodegen val valOper = \case v <- jvalueCodegen val jvals <- gets tableJoinValues block <- LLVMIR.currentBlock - let f = (\(n, jvs) -> Just (n, (v, block) : jvs)) + let f (n, jvs) = Just (n, (v, block) : jvs) let updJvals = Map.update f lbl jvals modify (\t -> t { tableJoinValues = updJvals }) llvmLabel <- gets (fst . fromJust . Map.lookup lbl . tableJoinValues) LLVMIR.br llvmLabel +-- | Generates code for a given callable (i.e. in a JApp) callableCodegen - :: Callable - -> [JValue] + :: Callable -- ^ The callable + -> [JValue] -- ^ The argument values -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) LLVM.Operand callableCodegen callable args = case callable of CFunc name -> do @@ -376,6 +380,7 @@ callableCodegen callable args = case callable of _ -> error $ "callableCodegen: expected 1 args to CPrint, got " ++ show (length args) CFunRef name -> gets (fromJust . Map.lookup name . tableFuns) +-- | Generates code for a [JExpr] jexprCodegen :: AnnExpr '[ 'ExprType] -> LLVMIR.IRBuilderT (LLVMIR.ModuleBuilderT Codegen) (JValue, LLVM.Operand) @@ -392,7 +397,7 @@ jexprCodegen = (\e -> go (unfix (getType e)) (annGetExpr e)) . unfix jexprCodegen next JJoin lbl varName (Cfe expr cf) next -> do llvmLabel <- LLVMIR.freshName (fromString (Text.unpack lbl)) - let addJoinEntry = \t -> + let addJoinEntry t = t { tableJoinValues = Map.insert lbl (llvmLabel, []) (tableJoinValues t) } oldJoinEntries <- gets tableJoinValues (lastVal, lastValOper) <- jexprCodegen expr diff --git a/src/Simpl/JoinIR/Syntax.hs b/src/Simpl/JoinIR/Syntax.hs index a0b0ad3..47e1b93 100644 --- a/src/Simpl/JoinIR/Syntax.hs +++ b/src/Simpl/JoinIR/Syntax.hs @@ -29,7 +29,7 @@ module Simpl.JoinIR.Syntax where import Data.Text (Text) import Data.Text.Prettyprint.Doc (Pretty, pretty, (<>), (<+>)) import qualified Data.Text.Prettyprint.Doc as PP -import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..), Type, TypeF(..)) +import Simpl.Ast (BinaryOp(..), Numeric(..), Literal(..), Type) import Text.Show.Deriving (deriveShow1) import Data.Functor.Foldable import qualified Data.Vinyl as V diff --git a/src/Simpl/JoinIR/Verify.hs b/src/Simpl/JoinIR/Verify.hs index 16d6784..4b92de7 100644 --- a/src/Simpl/JoinIR/Verify.hs +++ b/src/Simpl/JoinIR/Verify.hs @@ -21,22 +21,27 @@ import Simpl.JoinIR.Syntax -- * Verification Monad +-- | Information needed to perform verification. data VerifyCtx = VerifyCtx { verifyVars :: Set Text , verifyLabels :: Set Text } deriving (Eq, Show) +-- | Default verfication context; has no bound labels or variables. emptyCtx :: VerifyCtx emptyCtx = VerifyCtx { verifyVars = Set.empty , verifyLabels = Set.empty } +-- | Adds the given variable to the context. ctxWithVar :: Text -> VerifyCtx -> VerifyCtx ctxWithVar name ctx = ctx { verifyVars = Set.insert name (verifyVars ctx) } +-- | Adds the given label to the context. ctxWithLabel :: Text -> VerifyCtx -> VerifyCtx ctxWithLabel lbl ctx = ctx { verifyLabels = Set.insert lbl (verifyLabels ctx) } +-- | Monad for performing verification. newtype VerifyT m a = VerifyT { unVerify :: ReaderT VerifyCtx (ExceptT VerifyError m) a } deriving ( Functor @@ -59,12 +64,14 @@ runVerify m ctx = runIdentity (runVerifyT m ctx) -- * Verification errors +-- | Errors that cause verification to fail. data VerifyError = VarRedefinition Text | LabelRedefinition Text | NoSuchLabel Text | NoSuchVar Text deriving (Show, Eq) +-- | Verify a JoinIR AST using the given context. verify :: VerifyCtx -> AnnExpr a -> Either VerifyError () verify ctx expr = runVerify (doVerifyExpr expr) ctx @@ -76,6 +83,7 @@ checkUnboundVar var = True -> throwError $ VarRedefinition var False -> pure () +-- | Verify a JoinIR value. doVerifyValue :: (MonadError VerifyError m, MonadReader VerifyCtx m) => JValue -> m () @@ -85,6 +93,7 @@ doVerifyValue = \case False -> throwError $ NoSuchVar name JLit _ -> pure () +-- | Verify a JoinIR expression. doVerifyExpr :: (MonadError VerifyError m, MonadReader VerifyCtx m) => AnnExpr a -> m () @@ -110,6 +119,7 @@ doVerifyExpr = cata (go . annGetExpr) checkUnboundVar name local (ctxWithVar name) nextM +-- | Verify a JoinIR CFE. doVerifyCfe :: (MonadError VerifyError m, MonadReader VerifyCtx m) => Cfe (m ()) -> m () @@ -125,6 +135,7 @@ doVerifyCfe (Cfe exprM cf) = do True -> pure () False -> throwError $ NoSuchLabel lbl +-- | Verify a branch in JoinIR. doVerifyBranch :: (MonadError VerifyError m, MonadReader VerifyCtx m) => JBranch (m ()) -> m ()