diff --git a/runtime/runtime.c b/runtime/runtime.c index bd15016..a562980 100644 --- a/runtime/runtime.c +++ b/runtime/runtime.c @@ -45,3 +45,23 @@ int simpl_string_print(const struct simpl_string* s) { free(cstring); return 0; } + +uint32_t simpl_tagged_size(const struct simpl_type_tag* const tag) { + return tag->size; +} + +const struct simpl_type_tag* const simpl_tagged_tag(struct simpl_tagged_value* value) { + return value->type_tag; +} + +void* simpl_tagged_unbox(struct simpl_tagged_value* value) { + return value->data; +} + + +struct simpl_tagged_value* simpl_tagged_box(struct simpl_type_tag* tag, void* data) { + struct simpl_tagged_value *value = simpl_malloc(sizeof(struct simpl_tagged_value)); + value->type_tag = tag; + value->data = data; + return value; +} diff --git a/runtime/runtime.h b/runtime/runtime.h index 68b27c1..a23501e 100644 --- a/runtime/runtime.h +++ b/runtime/runtime.h @@ -1,4 +1,5 @@ #include +#include #ifndef RUNTIME_H #define RUNTIME_H @@ -38,4 +39,44 @@ char* simpl_string_cstring(const struct simpl_string* s); int simpl_string_print(const struct simpl_string* s); +/** + * Describes static information about a type. This struct should not be visible + * from SimPL programs. + */ +struct simpl_type_tag { + /** + * The size (e.g. when compiled) of the type, in bytes. + */ + uint32_t size; +}; + +/** + * Returns the size recorded in the type tag. + */ +uint32_t simpl_tag_size(const struct simpl_type_tag* const); + +/** + * A value tagged with its type tag. Used for polymorphic functions and + * variables. + */ +struct simpl_tagged_value { + const struct simpl_type_tag* type_tag; + void* data; +}; + +/** + * Returns the type tag of a tagged value. + */ +const struct simpl_type_tag* const simpl_tagged_tag(struct simpl_tagged_value*); + +/** + * Returns a pointer to the boxed value of a tagged value. + */ +void* simpl_tagged_unbox(struct simpl_tagged_value*); + +/** + * Boxes the given value + */ +struct simpl_tagged_value* simpl_tagged_box(struct simpl_type_tag* tag, void* data); + #endif diff --git a/sample.spl b/sample.spl index cc0f013..07fb8c5 100644 --- a/sample.spl +++ b/sample.spl @@ -5,78 +5,82 @@ data MaybeI = { JustI Double | Nothing } # data Barbar a = { Asdf Int a } # fun id (x : a) : a := { x } - -fun not (b : Bool) : Bool := { if b then false else true } - -fun and (p : Bool, q : Bool) : Bool := { if p then (if q then true else false) else false } - -fun or (p : Bool, q : Bool) : Bool := { if p then true else (if q then true else false) } - -fun factorial (x : Int) : Int := { - if (x <= 0) then 1 - else x * @factorial(x - 1) -} - -fun even (x: Int) : Bool := { - if (x <= 0) then true - else @not(@odd(x - 2)) -} - -fun odd (x: Int) : Bool := { - if (x <= 1) then true - else @not(@even(x - 2)) -} - -fun nested_ifs : Double := { - (if true then (if true then 4.0 else 5.0) else (if false then 2.0 else 3.0) + 1.0) * 2.0 -} - -fun main : Int := { - case JustI 10.0 of - JustI x => - let msg = println("In Just branch") in - let mynum = @abs(-1) in - (if (@even(4)) then @double_me(5) else @factorial(6)) * @asdf * mynum - Nothing => - let msg = println("In Nothing branch") in - let res = 4 in - @double_me(res + 1) -} - -fun asdf : Int := { - (if false then 5 else 10) + 2 -} - -fun double_me (x : Int) : Int := { x * 2 } - -fun foo : Foo := { - Bar 5.0 -} - -fun foo2 : Bool := { - case Bar 5.0 of - Bar x => true -} - -fun lots_of_lets : Double := { - let x = if true then 1.0 else 2.0 in - let y = if true then x * 2.0 else x * 2.0 + 1.0 in - y -} - -fun fun_ptr : Int -> Int := { &double_me } - -fun fun_ptr_test (b : Bool) : Int := { - let f = if b then &double_me else &factorial in - @f(4) -} - -fun cast_test_1 : Int := { - (1 + 1) * (cast 2.0 as Int) -} - -fun cast_test_2 : Int := { - (1 * 1) + (cast 2.0 as Int) -} - -fun abs (x : Int) : Int := extern +# +# fun eqargs (x : a, y : a, z: b) : a := { x } + +fun main : Int := { let x = &asdf in 5 } + +# fun not (b : Bool) : Bool := { if b then false else true } +# +# fun and (p : Bool, q : Bool) : Bool := { if p then (if q then true else false) else false } +# +# fun or (p : Bool, q : Bool) : Bool := { if p then true else (if q then true else false) } +# +# fun factorial (x : Int) : Int := { +# if (x <= 0) then 1 +# else x * @factorial(x - 1) +# } +# +# fun even (x: Int) : Bool := { +# if (x <= 0) then true +# else @not(@odd(x - 2)) +# } +# +# fun odd (x: Int) : Bool := { +# if (x <= 1) then true +# else @not(@even(x - 2)) +# } +# +# fun nested_ifs : Double := { +# (if true then (if true then 4.0 else 5.0) else (if false then 2.0 else 3.0) + 1.0) * 2.0 +# } +# +# fun main : Int := { +# case JustI 10.0 of +# JustI x => +# let msg = println("In Just branch") in +# let mynum = @abs(-1) in +# (if (@even(4)) then @double_me(5) else @factorial(6)) * @asdf * mynum +# Nothing => +# let msg = println("In Nothing branch") in +# let res = 4 in +# @double_me(res + 1) +# } +# +# fun asdf : Int := { +# (if false then 5 else 10) + 2 +# } +# +# fun double_me (x : Int) : Int := { x * 2 } +# +# fun foo : Foo := { +# Bar 5.0 +# } +# +# fun foo2 : Bool := { +# case Bar 5.0 of +# Bar x => true +# } +# +# fun lots_of_lets : Double := { +# let x = if true then 1.0 else 2.0 in +# let y = if true then x * 2.0 else x * 2.0 + 1.0 in +# y +# } +# +# fun fun_ptr : Int -> Int := { &double_me } +# +# fun fun_ptr_test (b : Bool) : Int := { +# let f = if b then &double_me else &factorial in +# @f(4) +# } +# +# fun cast_test_1 : Int := { +# (1 + 1) * (cast 2.0 as Int) +# } +# +# fun cast_test_2 : Int := { +# (1 * 1) + (cast 2.0 as Int) +# } +# +# fun abs (x : Int) : Int := extern diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 3bb2c5e..0cd35b0 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -22,6 +22,9 @@ import Data.Functor.Foldable (Fix(..), unfix) import Data.Functor.Identity import Data.Text (Text) import Data.String (fromString) +import Data.Maybe (fromJust) +import Data.Map (Map) +import qualified Data.Map as Map import Data.Set (Set) import qualified Data.Set as Set @@ -29,7 +32,8 @@ import Simpl.Annotation import Simpl.SymbolTable import qualified Simpl.Ast as A import qualified Simpl.JoinIR.Syntax as J -import Simpl.Type (Type) +import Simpl.Type (Type, TypeF(TyBox, TyVar, TyAdt, TyFun), substituteTypeVars, typeRepIsPtr) +import Simpl.Typecheck (literalType) import Simpl.Util.Supply import qualified Simpl.Util.Stream as Stream @@ -40,21 +44,25 @@ astToJoinIR table = runTransform transformTable (defaultCtx table) -- * Transformation Monad +data BoxedVal = Boxed | Unboxed deriving (Show, Eq, Ord) + data TransformCtx fields = TransformCtx { tcSymTab :: SymbolTable (A.AnnExpr fields) , tcJoinLabels :: Set Text + , tcBoxStatus :: Map Text BoxedVal } defaultCtx :: SymbolTable (A.AnnExpr flds) -> TransformCtx flds -defaultCtx table = TransformCtx { tcSymTab = table, tcJoinLabels = Set.empty } - -modifySymTab :: (SymbolTable (A.AnnExpr flds) -> SymbolTable (A.AnnExpr flds)) - -> TransformCtx flds - -> TransformCtx flds -modifySymTab f ctx = ctx { tcSymTab = f (tcSymTab ctx) } +defaultCtx table = TransformCtx + { tcSymTab = table + , tcJoinLabels = Set.empty + , tcBoxStatus = Map.empty } insertVar :: Text -> Type -> TransformCtx flds -> TransformCtx flds -insertVar name ty = modifySymTab (symTabInsertVar name ty) +insertVar name ty ctx = ctx + { tcSymTab = symTabInsertVar name ty (tcSymTab ctx) + , tcBoxStatus = Map.insert name (boxedVal ty) (tcBoxStatus ctx) + } newtype TransformT fields m a = TransformT { unTransform :: ReaderT (TransformCtx fields) (SupplyT Int m) a } @@ -110,6 +118,50 @@ makeJexpr ty = Fix . addField (withType ty) . toAnnExprF astType :: HasType flds => A.AnnExpr flds -> Type astType = getType . unfix +getJvalueType :: MonadReader (TransformCtx flds) m => J.JValue -> m Type +getJvalueType = \case + J.JVar n -> asks (fromJust . symTabLookupVar n . tcSymTab) + J.JLit l -> pure . Fix $ literalType l + +-- | Get boxed type. Left is unboxed, right is boxed. +boxedType :: Type -> Either Type Type +boxedType = \case + t@(Fix (TyVar _)) -> Right t + Fix (TyBox t) -> Right t + t -> Left t + +isBoxed :: Type -> Bool +isBoxed t = case boxedType t of { Left _ -> False; Right _ -> True } + +boxedVal :: Type -> BoxedVal +boxedVal t = if isBoxed t then Boxed else Unboxed + +rebindBoxing :: (HasType flds, MonadFreshVar m, MonadReader (TransformCtx flds) m) + => J.JValue -- ^ Variable name + -> BoxedVal -- ^ Whether to ensure boxed or unboxed + -> (J.JValue -> m (J.AnnExpr '[ 'ExprType])) -- ^ Continuation + -> m (J.AnnExpr '[ 'ExprType]) +rebindBoxing val b cont = do + ty <- getJvalueType val + let create ty' action = do + name <- case val of { J.JVar n -> pure n; _ -> freshVar } + makeJexpr ty' . J.JApp name action [val] <$> local (insertVar name ty')(cont (J.JVar name)) + case (boxedType ty, b) of + (Right ty', Unboxed) -> create ty' J.CUntag + (Left ty', Boxed) -> create (Fix (TyBox ty')) J.CTag + _ -> cont val + +-- | If needed, rebinds the given value to match the boxing target. +withRebindBoxing :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) + => BoxedVal -- ^Boxing target + -> (J.JValue -> J.JExprF (J.AnnExpr '[ 'ExprType])) + -> J.JValue -- ^Expression value + -> m (J.AnnExpr '[ 'ExprType]) +withRebindBoxing boxVal f val = + rebindBoxing val boxVal $ \val' -> do + ty <- getJvalueType val' + pure . makeJexpr ty . f $ val' + -- * ANF Transformation -- | Perform ANF transformation on the given symbol table @@ -117,22 +169,41 @@ transformTable :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVa => m (SymbolTable (J.AnnExpr '[ 'ExprType])) transformTable = do table <- asks tcSymTab - symTabTraverseExprs (\(tvars, args, ty, expr) -> (tvars, args, ty, transformExpr expr)) table + flip symTabTraverseExprs table $ \(tvars, args, ty, expr) -> + -- Initialize boxing status first, then transform + let initVars tab = foldl (flip (uncurry insertVar)) tab args + in (tvars, args, ty, local initVars (transformExpr expr (boxedVal ty))) -- | Perform ANF transformation on the given expression transformExpr :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) => A.AnnExpr flds + -> BoxedVal -- ^ Expected boxing of the final value -> m (J.AnnExpr '[ 'ExprType]) -transformExpr expr = anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) +transformExpr expr boxVal = anfTransform expr $ withRebindBoxing boxVal J.JVal -- | Perform ANF transformation on the branch, afterwards handling control flow. transformBranch :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) => J.ControlFlow (J.AnnExpr '[ 'ExprType]) -- ^ Control flow handler + -> Type -- ^ Type of guard expression + -> BoxedVal -- ^ Whether branch is expected to be boxed -> A.Branch (A.AnnExpr flds) -- ^ Branches -> m (J.JBranch (J.AnnExpr '[ 'ExprType])) -transformBranch cf (A.BrAdt adtName argNames expr) = do - jexpr <- anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) - pure $ J.BrAdt adtName argNames (J.Cfe jexpr cf) +transformBranch cf guardTy boxVal (A.BrAdt ctorName argNames expr) = do + (_, A.Ctor _ argTys, _) <- asks (fromJust . symTabLookupCtor ctorName . tcSymTab) + argTys' <- case unfix guardTy of + TyAdt name tvars -> do + (tvars', _) <- asks (fromJust . symTabLookupAdt name . tcSymTab) + let tvarMapping = Map.fromList (tvars' `zip` tvars) + pure $ flip fmap argTys $ \t -> case unfix t of + TyVar _ -> + let t' = substituteTypeVars tvarMapping t in + if t' == t || typeRepIsPtr (unfix t') + then t' else Fix (TyBox t') + _ -> substituteTypeVars tvarMapping t + _ -> pure argTys + let withScope ctx = foldr (uncurry insertVar) ctx (argNames `zip` argTys') + jexpr <- local withScope $ anfTransform expr $ withRebindBoxing boxVal J.JVal + pure $ J.BrAdt ctorName (argNames `zip` argTys') (J.Cfe jexpr cf) -- | Main ANF transformation logic. Given the SimPL AST, this function will @@ -146,70 +217,114 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of A.Lit lit -> cont (J.JLit lit) A.Var name -> cont (J.JVar name) A.Let name bindExpr next -> - anfTransform bindExpr $ \bindVal -> - makeJexpr (getType (unfix bindExpr)) . J.JLet name bindVal <$> - local (insertVar name ty) (anfTransform next cont) + anfTransform bindExpr $ \bindVal -> do + bindTy <- getJvalueType bindVal + makeJexpr bindTy . J.JLet name bindVal <$> + local (insertVar name bindTy) (anfTransform next cont) A.BinOp op left right -> anfTransform left $ \jleft -> - anfTransform right $ \jright -> do + anfTransform right $ \jright -> + rebindBoxing jleft Unboxed $ \jl -> + rebindBoxing jright Unboxed $ \jr -> do + name <- freshVar + makeJexpr ty . J.JApp name (J.CBinOp op) [jl, jr] <$> + local (insertVar name ty) (cont (J.JVar name)) + A.If guard trueBr falseBr -> + anfTransform guard $ \jguard -> + -- Guard must be unboxed to compare for truthiness + rebindBoxing jguard Unboxed $ \jguard' -> do + guardTy <- getJvalueType jguard' + let guardCfe = makeJexpr guardTy (J.JVal jguard') + -- Handle branches + let transformBr br = anfTransform br $ withRebindBoxing (boxedVal ty) J.JVal + trueBr' <- transformBr trueBr + falseBr' <- transformBr falseBr + lbl <- freshLabel + let jmp = J.JJump lbl + let cfe = J.Cfe guardCfe (J.JIf (J.Cfe trueBr' jmp) (J.Cfe falseBr' jmp)) + -- TODO: Make JJoin node placement more efficient name <- freshVar - makeJexpr ty . J.JApp name (J.CBinOp op) [jleft, jright] <$> + makeJexpr ty . J.JJoin lbl name cfe <$> local (insertVar name ty) (cont (J.JVar name)) - A.If guard trueBr falseBr -> - anfTransform guard $ \jguard -> do - lbl <- freshLabel - trueBr' <- anfTransform trueBr (pure . makeJexpr (astType trueBr) . J.JVal) - falseBr' <- anfTransform falseBr (pure . makeJexpr (astType falseBr) . J.JVal) - name <- freshVar - let jmp = J.JJump lbl - let guardTy = getType (unfix guard) - let guardCfe = makeJexpr guardTy (J.JVal jguard) - let cfe = J.Cfe guardCfe (J.JIf (J.Cfe trueBr' jmp) (J.Cfe falseBr' jmp)) - -- TODO: Make JJoin node placement more efficient - makeJexpr ty . J.JJoin lbl name cfe <$> - local (insertVar name ty) (cont (J.JVar name)) A.Case branches expr -> - anfTransform expr $ \jexpr -> do - lbl <- freshLabel - let jexprTy = getType (unfix expr) - jbranches <- traverse (transformBranch (J.JJump lbl)) branches - let jexprCfe = J.Cfe (makeJexpr jexprTy (J.JVal jexpr)) (J.JCase jbranches) - name <- freshVar - -- TODO: Make JJoin node placement more efficient - makeJexpr ty . J.JJoin lbl name jexprCfe <$> - local (insertVar name ty) (cont (J.JVar name)) + anfTransform expr $ \jexpr -> + -- Case value must be unboxed + rebindBoxing jexpr Unboxed $ \jexpr' -> do + jTy <- getJvalueType jexpr' + -- Transform branches + lbl <- freshLabel + jbranches <- traverse (transformBranch (J.JJump lbl) jTy (boxedVal ty)) branches + let jexprCfe = J.Cfe (makeJexpr jTy (J.JVal jexpr')) (J.JCase jbranches) + name <- freshVar + -- TODO: Make JJoin node placement more efficient + makeJexpr ty . J.JJoin lbl name jexprCfe <$> + local (insertVar name ty) (cont (J.JVar name)) A.Cons ctorName args -> collectArgs args $ \argVals -> do - varName <- freshVar - makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals <$> - local (insertVar varName ty) (cont (J.JVar varName)) + (_, A.Ctor _ ctorTyArgs, _) <- asks (fromJust . symTabLookupCtor ctorName . tcSymTab) + -- Box each argument as needed + collectRebinds (argVals `zip` (boxedVal <$> ctorTyArgs)) $ \argVals' -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals' <$> + local (insertVar varName ty) (cont (J.JVar varName)) A.App funcName args -> collectArgs args $ \argVals -> do varName <- freshVar - makeJexpr ty . J.JApp varName (J.CFunc funcName) argVals <$> - local (insertVar varName ty) (cont (J.JVar varName)) + -- Function references are stored in variable scope, so we need to perform + -- lookup there. + funcCandidate <- asks (fmap unfix . symTabLookupVar funcName . tcSymTab) >>= \case + Just (TyFun params resTy) -> pure $ Just (params, resTy) + Just t -> error $ "Variable called as function: " ++ show funcName ++ " : " ++ show t + Nothing -> asks (fmap (\(_, p, r) -> (snd <$> p, r)) . symTabLookupFun funcName . tcSymTab) + let (funcArgs, funcRetTy) = fromJust funcCandidate + let valueBoxPairs = [(val, boxedVal fTy) | (val, fTy) <- argVals `zip` funcArgs] + collectRebinds valueBoxPairs $ \argVals' -> do + let ty' = if isBoxed funcRetTy then Fix (TyBox ty) else ty + makeJexpr ty' . J.JApp varName (J.CFunc funcName) argVals' <$> + local (insertVar varName ty') (cont (J.JVar varName)) A.Cast expr numTy -> - anfTransform expr $ \jexpr -> do - varName <- freshVar - makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr] <$> - local (insertVar varName ty) (cont (J.JVar varName)) + anfTransform expr $ \jexpr -> + -- Resulting value is unboxed, so use original type + rebindBoxing jexpr Unboxed $ \jexpr' -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr'] <$> + local (insertVar varName ty) (cont (J.JVar varName)) A.Print expr -> - anfTransform expr $ \jexpr -> do - varName <- freshVar - makeJexpr ty . J.JApp varName J.CPrint [jexpr] <$> - local (insertVar varName ty) (cont (J.JVar varName)) + anfTransform expr $ \jval -> + -- Resulting value is unboxed, so use original type + rebindBoxing jval Unboxed $ \jval' -> do + varName <- freshVar + makeJexpr ty . J.JApp varName J.CPrint [jval'] <$> + local (insertVar varName ty) (cont (J.JVar varName)) A.FunRef name -> do varName <- freshVar makeJexpr ty . J.JApp varName (J.CFunRef name) [] <$> local (insertVar varName ty) (cont (J.JVar varName)) +-- | Utility function for collecting arguments to CPS style functions +collectConts :: (a -> (b -> m c) -> m c) + -> [a] + -> ([b] -> m c) + -> m c +collectConts f as cont = go [] as + where + go vals [] = cont (reverse vals) + go vals (x:xs) = f x $ \v -> go (v:vals) xs + -- | Normalize each expression in sequential order, and then run the -- continuation with the expression values. collectArgs :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) => [A.AnnExpr flds] -- ^ Argument expressions -> ([J.JValue] -> m (J.AnnExpr '[ 'ExprType])) -- ^ Continuation -> m (J.AnnExpr '[ 'ExprType]) -collectArgs = go [] - where - go vals [] mcont = mcont (reverse vals) - go vals (x:xs) mcont = anfTransform x $ \v -> go (v:vals) xs mcont +collectArgs = collectConts anfTransform +-- collectArgs = go [] +-- where +-- go vals [] mcont = mcont (reverse vals) +-- go vals (x:xs) mcont = anfTransform x $ \v -> go (v:vals) xs mcont + +collectRebinds :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) + => [(J.JValue, BoxedVal)] -- ^ Value-boxing pairs + -> ([J.JValue] -> m (J.AnnExpr '[ 'ExprType])) -- ^ Continuation + -> m (J.AnnExpr '[ 'ExprType]) +collectRebinds = collectConts (uncurry rebindBoxing) diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 8c9fdc3..951dfb4 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -48,7 +48,7 @@ import Simpl.Annotation hiding (AnnExpr, AnnExprF) import Simpl.Ast (BinaryOp(..), Constructor(..), Literal(..)) import Simpl.CompilerOptions import Simpl.SymbolTable -import Simpl.Type (Type, TypeF(..), Numeric(..)) +import Simpl.Type (Type, TypeF(..), Numeric(..), typeRepIsPtr) import Simpl.Typecheck (literalType) import Simpl.Backend.Runtime () import Simpl.JoinIR.Syntax @@ -57,11 +57,12 @@ import qualified Simpl.Backend.Runtime as RT data CodegenTable = MkCodegenTable { tableVars :: Map Text (TypeF Type, LLVM.Operand) -- ^ Pointer to variables , tableCtors :: Map Text (LLVM.Name, LLVM.Name, Int) -- ^ Data type name, ctor name, index - , tableAdts :: Map Text (LLVM.Name, Type, [Constructor]) + , tableAdts :: Map Text (LLVM.Name, [Text], [Constructor]) , tableFuns :: Map Text LLVM.Operand , tableJoinValues :: Map Text (LLVM.Name, [(LLVM.Operand, LLVM.Name)]) , tablePrintf :: LLVM.Operand - , tableOptions :: CompilerOpts } + , tableOptions :: CompilerOpts + , tableTypeTags :: Map (TypeF Type) LLVM.Operand } deriving (Show) -- | An empty codegen table. This will cause a crash if codegen is run when not @@ -74,7 +75,8 @@ emptyCodegenTable = , tableFuns = Map.empty , tableJoinValues = Map.empty , tablePrintf = error "printf not set" - , tableOptions = defaultCompilerOpts } + , tableOptions = defaultCompilerOpts + , tableTypeTags = Map.empty } newtype CodegenT m a = CodegenT { unCodegen :: StateT CodegenTable m a } @@ -120,6 +122,32 @@ lookupValueType = \case JVar name -> gets (fst . fromJust . Map.lookup name . tableVars) JLit lit -> pure $ literalType lit +-- | If the type tag does not exist, emit it into the IR. Then return the +-- operand of the type tag. +lookupTypeTag :: (LLVMIR.MonadModuleBuilder m, MonadState CodegenTable m) + => TypeF Type + -> m LLVM.Operand +lookupTypeTag ty = + gets (Map.lookup ty . tableTypeTags) >>= \case + Just oper -> pure oper + Nothing -> do + let name = case ty of + TyNumber NumInt -> "Int" + TyNumber NumDouble -> "Double" + TyBool -> "Bool" + TyString -> "String" + TyAdt n _ -> "data." <> n + x -> error ("TODO: handle tag type lookup for " ++ show x) + let llvmTy = typeToLLVM (Fix ty) + let size = LLVMC.sizeof llvmTy + let tagContents = LLVMC.Struct { LLVMC.structName = Nothing + , LLVMC.isPacked = False + , LLVMC.memberValues = [size] } + let lname = LLVM.mkName $ "simpl.tag." ++ Text.unpack name + oper <- LLVMIR.global lname RT.typeTagType tagContents + modify (\t -> t { tableTypeTags = Map.insert ty oper (tableTypeTags t) }) + pure oper + bindVariable :: MonadState CodegenTable m => Text -> TypeF Type @@ -273,7 +301,9 @@ controlFlowCodegen val valOper = \case let labelName = "case_" <> fromString (Text.unpack name) in (name, ) <$> LLVMIR.freshName labelName -- Assume the symbol table and type information is correct - dataName <- (\case { TyAdt n _ -> n; _ -> error "" }) <$> lookupValueType val + dataName <- flip fmap (lookupValueType val) $ \case + TyAdt n _ -> n + t -> error $ "controlFlowCodegen: Unexpected case value type: " ++ show t ++ " " ++ show val ctors <- gets ((\(_,_,cs) -> cs) . fromJust . Map.lookup dataName . tableAdts) let ctorNames = ctorGetName <$> ctors let usedLabelTriples = filter (\(_, (n, _)) -> n `elem` ctorNames) $ [0..] `zip` allCaseLabels @@ -284,25 +314,24 @@ controlFlowCodegen val valOper = \case forM_ (usedLabelTriples `zip` branches) $ \((_, (ctorName, label)), br) -> do let expr = branchGetExpr br let cf = branchGetControlFlow br - (_, ctorLLVMName, index) <- gets (fromJust . Map.lookup ctorName . tableCtors) - let Ctor _ argTys = ctors !! index - let bindingPairs = branchGetBindings br `zip` (typeToLLVM <$> argTys) + (_, ctorLLVMName, _) <- gets (fromJust . Map.lookup ctorName . tableCtors) + let bindingPairs = branchGetBindings br LLVMIR.emitBlockStart label ctorPtr <- LLVMIR.bitcast dataPtr (LLVM.ptr (LLVM.NamedTypeReference ctorLLVMName)) let ctorPtrOffset = LLVMIR.int32 0 - bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, llvmTy)) -> do + bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, ty)) -> do let ctorPtrIndex = LLVMIR.int32 i -- Need to bitcast the ptr type because we need a concrete type. We -- also need to load the data immediately because of how variables -- are implemented. v <- LLVMIR.gep ctorPtr [ctorPtrOffset, ctorPtrIndex] - >>= flip LLVMIR.bitcast (LLVM.ptr llvmTy) + >>= flip LLVMIR.bitcast (LLVM.ptr (typeToLLVM ty)) >>= flip LLVMIR.load 0 - ty <- lookupValueType (JVar n) - pure (n, (ty, v)) + pure (n, (unfix ty, v)) let updateTable t = t { tableVars = Map.union (tableVars t) (Map.fromList bindings) } - (exprVal, exprOper) <- jexprCodegen expr - localCodegenTable updateTable (controlFlowCodegen exprVal exprOper cf) + localCodegenTable updateTable $ do + (exprVal, exprOper) <- jexprCodegen expr + controlFlowCodegen exprVal exprOper cf LLVMIR.emitBlockStart defLabel LLVMIR.unreachable JJump lbl -> do @@ -344,13 +373,10 @@ callableCodegen callable args = case callable of -- Tag (index = 0) tagStruct2 <- LLVMIR.insertValue tagStruct1 (LLVMIR.int32 (fromIntegral ctorIndex)) [0] -- Data pointer (index = 1) - let ctorTy = LLVM.NamedTypeReference ctorName - let nullptr = LLVM.ConstantOperand (LLVMC.Null (LLVM.ptr ctorTy)) - -- Use offsets to calculate struct size - ctorStructSize <- LLVMIR.gep nullptr [LLVMIR.int32 0] - >>= flip LLVMIR.ptrtoint LLVM.i64 -- Allocate memory for constructor. -- For now, use "leak memory" as an implementation strategy for deallocation. + let ctorTy = LLVM.NamedTypeReference ctorName + let ctorStructSize = LLVM.ConstantOperand (LLVMC.ZExt (LLVMC.sizeof ctorTy) LLVM.i64) ctorStructPtr <- LLVMIR.call RT.mallocRef [(ctorStructSize, [])] >>= flip LLVMIR.bitcast (LLVM.ptr ctorTy) values <- traverse jvalueCodegen args @@ -377,6 +403,36 @@ callableCodegen callable args = case callable of pure $ LLVMIR.int64 0 _ -> error $ "callableCodegen: expected 1 args to CPrint, got " ++ show (length args) CFunRef name -> gets (fromJust . Map.lookup name . tableFuns) + CTag -> case args of + [jval] -> do + ty <- lookupValueType jval + tag <- lookupTypeTag ty + let llvmTy = typeToLLVM (Fix ty) + let size = LLVM.ConstantOperand (LLVMC.ZExt (LLVMC.sizeof llvmTy) LLVM.i64) + -- We only need to allocate if the value is not a pointer + val <- jvalueCodegen jval + bytes <- if typeRepIsPtr ty + then LLVMIR.bitcast val (LLVM.ptr LLVM.i8) + else do b <- LLVMIR.call RT.mallocRef [(size, [])] + allocRef <- LLVMIR.bitcast b (LLVM.ptr llvmTy) + LLVMIR.store allocRef 0 val + pure b + LLVMIR.call RT.taggedBoxRef [(tag, []), (bytes, [])] + _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) + CUntag -> case args of + [jval] -> do + ty <- lookupValueType jval >>= \case + TyBox t -> pure t + t -> error $ "callableCodegen: untagging non-boxed type " ++ show t + val <- jvalueCodegen jval + bytesPtr <- LLVMIR.call RT.taggedUnboxRef [(val, [])] + let llvmTy = typeToLLVM ty + -- Need to unbox if the stored type is not a pointer + if typeRepIsPtr (unfix ty) + then LLVMIR.bitcast bytesPtr llvmTy + else do ptr <- LLVMIR.bitcast bytesPtr (LLVM.ptr llvmTy) + LLVMIR.load ptr 0 + _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) -- | Generates code for a [JExpr] jexprCodegen @@ -443,7 +499,8 @@ typeToLLVM = go . unfix , LLVM.argumentTypes = typeToLLVM <$> args , LLVM.isVarArg = False } - TyVar _ -> error "compilation of parametrically polymorphic functions not implemented yet" + TyVar _ -> LLVM.ptr RT.taggedValueType + TyBox _ -> LLVM.ptr RT.taggedValueType adtToLLVM :: Text -> [Constructor] @@ -515,8 +572,8 @@ moduleCodegen srcCode symTab = mdo -- Insert function operands into symbol table before emitting so order of -- definition doesn't matter. This works because the codegen monad is lazy. modify (\t -> t { tableFuns = tableFuns t `Map.union` Map.fromList funOpers }) - -- TODO: Care about type variables - funOpers <- forM (Map.toList . symTabFuns $ symTab) $ \(name, (tvars, params, ty, body)) -> + funOpers <- forM (Map.toList . symTabFuns $ symTab) $ \(name, (_, params, ty, body)) -> + -- Ignore type variables since they're handled with boxing code (name, ) <$> funToLLVM name params ty body _ <- LLVMIR.function "main" [] LLVM.i64 $ \_ -> do diff --git a/src/Simpl/Backend/Runtime.hs b/src/Simpl/Backend/Runtime.hs index e58af3f..06915fa 100644 --- a/src/Simpl/Backend/Runtime.hs +++ b/src/Simpl/Backend/Runtime.hs @@ -64,8 +64,8 @@ mallocType = mkFunType [("ptr", LLVM.i64)] (LLVM.ptr LLVM.i8) memcpyType = mkFunType [ ("dest", LLVM.ptr LLVM.i8) , ("src", LLVM.ptr LLVM.i8) , ("len", LLVM.i64) ] - LLVM.void -printfType = ([("", LLVM.ptr LLVM.i8)], LLVM.void, True) + LLVM.i8 +printfType = ([("", LLVM.ptr LLVM.i8)], LLVM.i8, True) mallocRef, memcpyRef, printfRef :: LLVM.Operand mallocRef = runtimeFunRef "simpl_malloc" mallocType @@ -97,16 +97,50 @@ stringFuns = [ ("simpl_string_cstring", stringCstringType) stringStructs :: [String] stringStructs = ["simpl_string"] +-- * Tags + +typeTagType :: LLVM.Type +typeTagType = LLVM.StructureType + { LLVM.isPacked = False + -- Size + , LLVM.elementTypes = [LLVM.i32] } + +taggedValueType :: LLVM.Type +taggedValueType = runtimeStruct "simpl_tagged_value" + +tagSizeType, taggedTagType, taggedBoxType, taggedUnboxType :: FunType +tagSizeType = mkFunType [("t", LLVM.ptr typeTagType)] LLVM.i64 +taggedTagType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr taggedValueType) +taggedBoxType = mkFunType [("t", LLVM.ptr typeTagType), ("d", LLVM.ptr LLVM.i8)] (LLVM.ptr taggedValueType) +taggedUnboxType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr LLVM.i8) + +tagSizeRef, taggedTagRef, taggedBoxRef, taggedUnboxRef :: LLVM.Operand +tagSizeRef = runtimeFunRef "simpl_tag_size" tagSizeType +taggedTagRef = runtimeFunRef "simpl_tagged_tag" taggedTagType +taggedUnboxRef = runtimeFunRef "simpl_tagged_unbox" taggedUnboxType +taggedBoxRef = runtimeFunRef "simpl_tagged_box" taggedBoxType + +runtimeTypeFuns :: [(String, FunType)] +runtimeTypeFuns = [ ("simpl_tag_size", tagSizeType) + , ("simpl_tagged_tag", taggedTagType) + , ("simpl_tagged_unbox", taggedUnboxType) + , ("simpl_tagged_box", taggedBoxType) ] + +runtimeTypeStructs :: [String] +runtimeTypeStructs = ["simpl_type_tag", "simpl_tagged_value"] + -- * Entire runtime allRuntimeFuns :: [(String, FunType)] -allRuntimeFuns = join [cstdlibFuns, stringFuns] +allRuntimeFuns = join [cstdlibFuns, stringFuns, runtimeTypeFuns] allRuntimeStructs :: [String] -allRuntimeStructs = stringStructs +allRuntimeStructs = stringStructs ++ runtimeTypeStructs emitRuntimeDecls :: LLVMIR.MonadModuleBuilder m => m () emitRuntimeDecls = do forM_ allRuntimeFuns (uncurry emitRuntimeFun) - forM_ allRuntimeStructs $ \name -> do + forM_ allRuntimeStructs $ \name -> LLVMIR.typedef (LLVM.mkName ("struct." <> name)) Nothing + _ <- LLVMIR.typedef (LLVM.mkName ("struct.simpl_type_tag")) (Just typeTagType) + pure () diff --git a/src/Simpl/JoinIR/Syntax.hs b/src/Simpl/JoinIR/Syntax.hs index d92e99a..e02dddd 100644 --- a/src/Simpl/JoinIR/Syntax.hs +++ b/src/Simpl/JoinIR/Syntax.hs @@ -23,7 +23,7 @@ import Data.Text.Prettyprint.Doc (Pretty, pretty, (<>), (<+>)) import qualified Data.Text.Prettyprint.Doc as PP import qualified Simpl.Annotation as Ann import Simpl.Ast (BinaryOp(..), Literal(..)) -import Simpl.Type (Numeric(..)) +import Simpl.Type (Numeric(..), Type) import Text.Show.Deriving (deriveShow1) import Data.Functor.Foldable @@ -39,6 +39,8 @@ data Callable | CCtor !Name -- ^ ADT constructor | CPrint -- ^ Print string (temporary) | CFunRef !Name -- ^ Static function reference + | CTag -- ^ Create a boxed representation of the given value + | CUntag -- ^ Unbox the value deriving (Show) -- | A value @@ -64,14 +66,14 @@ data ControlFlow a data JBranch a - = BrAdt Name [Name] !(Cfe a) -- ^ Destructure algebraic data type + = BrAdt Name [(Name, Type)] !(Cfe a) -- ^ Destructure algebraic data type deriving (Functor, Foldable, Traversable, Show) branchGetExpr :: JBranch a -> a branchGetExpr = \case BrAdt _ _ (Cfe e _) -> e -branchGetBindings :: JBranch a -> [Text] +branchGetBindings :: JBranch a -> [(Text, Type)] branchGetBindings = \case BrAdt _ vars _ -> vars @@ -124,6 +126,8 @@ instance Pretty Callable where CCtor name -> pretty name CPrint -> "print" CFunRef name -> "funref[" <> pretty name <> "]" + CTag -> "tag" + CUntag -> "untag" instance Pretty JValue where pretty = \case @@ -131,9 +135,11 @@ instance Pretty JValue where JLit l -> pretty l instance Pretty a => Pretty (JBranch a) where - pretty (BrAdt ctorName varNames cfe) = - PP.hang 2 $ PP.hsep (pretty <$> brPart) <> PP.softline <> pretty cfe - where brPart = [ctorName] ++ varNames ++ ["=>"] + pretty (BrAdt ctorName varPairs cfe) = + PP.hang 2 $ PP.hsep brPart <> PP.softline <> pretty cfe + where + vars = (\(n, t) -> PP.parens (pretty n <+> ":" <+> pretty t)) <$> varPairs + brPart = [pretty ctorName] ++ vars ++ [pretty ("=>" :: Text)] instance Pretty a => Pretty (ControlFlow a) where pretty = \case diff --git a/src/Simpl/JoinIR/Verify.hs b/src/Simpl/JoinIR/Verify.hs index 6b85d11..152654c 100644 --- a/src/Simpl/JoinIR/Verify.hs +++ b/src/Simpl/JoinIR/Verify.hs @@ -142,5 +142,5 @@ doVerifyBranch :: (MonadError VerifyError m, MonadReader VerifyCtx m) -> m () doVerifyBranch (BrAdt name args cfe) = do checkUnboundVar name - _ <- traverse checkUnboundVar args + _ <- traverse checkUnboundVar (fst <$> args) doVerifyCfe cfe diff --git a/src/Simpl/Parser.hs b/src/Simpl/Parser.hs index 534800d..98541ff 100644 --- a/src/Simpl/Parser.hs +++ b/src/Simpl/Parser.hs @@ -184,11 +184,11 @@ typeVar = Fix . TyVar <$> identifier typeAdt :: Parser m Type typeAdt = do name <- typeIdentifier - tparams <- many typeVar + tparams <- many (parens type' <|> type') pure $ Fix (TyAdt name tparams) typeAtom :: Parser m Type -typeAtom = typeLit <|> typeAdt +typeAtom = typeLit <|> typeAdt <|> typeVar typeFun :: Parser m Type typeFun = lexeme $ do @@ -198,7 +198,7 @@ typeFun = lexeme $ do pure . Fix $ TyFun (first : init rest) (last rest) type' :: Parser m Type -type' = try typeFun <|> typeAtom <|> typeVar "type" +type' = try typeFun <|> typeAtom "type" declFunParamList :: Parser m [(Text, Type)] declFunParamList = lexeme $ option [] (parens params) @@ -228,7 +228,7 @@ declFun = lexeme $ constructor :: Parser m Ast.Constructor constructor = lexeme $ do name <- typeIdentifier - args <- many type' + args <- many (parens type' <|> type') pure $ Ast.Ctor name args declAdt :: Parser m (Decl SourcedExpr) diff --git a/src/Simpl/SymbolTable.hs b/src/Simpl/SymbolTable.hs index 3bbed45..2fef48f 100644 --- a/src/Simpl/SymbolTable.hs +++ b/src/Simpl/SymbolTable.hs @@ -5,7 +5,7 @@ {-# LANGUAGE DeriveTraversable #-} module Simpl.SymbolTable where -import Data.Functor.Foldable (Fix(Fix)) +import Control.Applicative ((<|>)) import Data.List (find) import Data.Maybe (mapMaybe, listToMaybe) import Data.Map.Strict (Map) @@ -17,7 +17,7 @@ import Simpl.Ast import Simpl.Type data SymbolTable expr = MkSymbolTable - { symTabAdts :: Map Text (Type, [Constructor]) + { symTabAdts :: Map Text ([Text], [Constructor]) -- ^ ADT definitions: name, type vars, constructors , symTabFuns :: Map Text (Set Text, [(Text, Type)], Type, expr) -- ^ Static functions: free type vars, params, return type, body , symTabVars :: Map Text Type -- ^ Variables , symTabExtern :: Map Text ([(Text, Type)], Type) -- ^ External functions @@ -29,7 +29,7 @@ buildSymbolTable (SourceFile _ decls) = let adts = Map.fromList $ mapMaybe (\case DeclAdt name tparams ctors -> - Just (name, (Fix (TyAdt name (Fix . TyVar <$> tparams)), ctors)) + Just (name, (tparams, ctors)) _ -> Nothing) decls funs = Map.fromList $ mapMaybe (\case @@ -47,7 +47,7 @@ buildSymbolTable (SourceFile _ decls) = , symTabVars = Map.empty , symTabExtern = extern } -symTabModifyAdts :: (Map Text (Type, [Constructor]) -> Map Text (Type, [Constructor])) +symTabModifyAdts :: (Map Text ([Text], [Constructor]) -> Map Text ([Text], [Constructor])) -> SymbolTable e -> SymbolTable e symTabModifyAdts f t = t { symTabAdts = f (symTabAdts t) } @@ -68,13 +68,13 @@ symTabTraverseExprs f t = do -- | Searches for the given constructor, returning the name of the ADT, the -- constructor, and the index of the constructor. -symTabLookupCtor :: Text -> SymbolTable e -> Maybe (Type, Constructor, Int) +symTabLookupCtor :: Text -> SymbolTable e -> Maybe ((Text, [Text]), Constructor, Int) symTabLookupCtor name t = listToMaybe . mapMaybe (find isTheCtor) $ getCtors where - getCtors = (\(ty, cs) -> zip3 (repeat ty) cs [0..]) <$> Map.elems (symTabAdts t) + getCtors = (\(tname, (tvars, cs)) -> zip3 (repeat (tname, tvars)) cs [0..]) <$> Map.toList (symTabAdts t) isTheCtor (_, Ctor ctorName _, _) = ctorName == name -symTabLookupAdt :: Text -> SymbolTable e -> Maybe (Type, [Constructor]) +symTabLookupAdt :: Text -> SymbolTable e -> Maybe ([Text], [Constructor]) symTabLookupAdt name = Map.lookup name . symTabAdts symTabLookupVar :: Text -> SymbolTable e -> Maybe Type @@ -91,3 +91,8 @@ symTabLookupStaticFun name = Map.lookup name . symTabFuns symTabLookupExternFun :: Text -> SymbolTable e -> Maybe ([(Text, Type)], Type) symTabLookupExternFun name = Map.lookup name . symTabExtern + +symTabLookupFun :: Text -> SymbolTable e -> Maybe (Set Text, [(Text, Type)], Type) +symTabLookupFun name tab = static tab <|> extern tab + where static = fmap (\(tvars, p, r, _) -> (tvars, p, r)) . symTabLookupStaticFun name + extern = fmap (\(p, r) -> (Set.empty, p, r)) . symTabLookupExternFun name diff --git a/src/Simpl/Type.hs b/src/Simpl/Type.hs index f1ce189..98b0a22 100644 --- a/src/Simpl/Type.hs +++ b/src/Simpl/Type.hs @@ -15,6 +15,9 @@ import Data.Functor.Foldable (Fix(Fix), para, cata) import Data.Text (Text) import Data.Text.Prettyprint.Doc import Data.Eq.Deriving (deriveEq1) +import Data.Ord.Deriving (deriveOrd1) +import Data.Map (Map) +import qualified Data.Map as Map import Data.Set (Set) import qualified Data.Set as Set import Text.Show.Deriving (deriveShow1) @@ -25,7 +28,7 @@ import Text.Show.Deriving (deriveShow1) data Numeric = NumDouble -- ^ 64-bit floating point | NumInt -- ^ 64-bit signed integer | NumUnknown -- ^ Unknown (defaults to 64-bit floating point) - deriving (Show, Eq) + deriving (Show, Eq, Ord) instance Pretty Numeric where pretty = \case @@ -41,12 +44,14 @@ data TypeF a | TyAdt Text [a] | TyFun [a] a | TyVar Text - deriving (Show, Functor, Foldable, Traversable) + | TyBox a -- ^ Boxed polymorphic type + deriving (Show, Eq, Ord, Functor, Foldable, Traversable) type Type = Fix TypeF $(deriveShow1 ''TypeF) $(deriveEq1 ''TypeF) +$(deriveOrd1 ''TypeF) instance Unifiable TypeF where zipMatch (TyNumber n) (TyNumber m) = case (n, m) of @@ -63,9 +68,11 @@ instance Unifiable TypeF where else Nothing zipMatch (TyFun as1 r1) (TyFun as2 r2) = if length as1 == length as2 then + -- TODO: Check by alpha-equivalence instead of raw equality Just $ TyFun (zipWith (curry Right) as1 as2) (Right (r1, r2)) else Nothing + zipMatch (TyBox t1) (TyBox t2) = Just $ TyBox (Right (t1, t2)) zipMatch _ _ = Nothing isComplexType :: TypeF a -> Bool @@ -73,6 +80,14 @@ isComplexType = \case TyFun _ _ -> True _ -> False +-- | Whether a type is represented using a pointer +typeRepIsPtr :: TypeF Type -> Bool +typeRepIsPtr = \case + TyNumber _ -> False + TyBool -> False + TyAdt _ _ -> False + _ -> True + functionTypeResult :: Type -> Type functionTypeResult (Fix ty) = case ty of TyFun _ res -> functionTypeResult res @@ -86,7 +101,14 @@ getTypeVars = cata $ \case TyVar v -> Set.singleton v TyFun vparams vret -> Set.unions (vret:vparams) TyAdt _ vargs -> Set.unions vargs + TyBox vs -> vs +substituteTypeVars :: Map Text Type -> Type -> Type +substituteTypeVars vars = cata go + where + go = \case + ty@(TyVar n) -> Map.findWithDefault (Fix ty) n vars + ty -> Fix ty instance Pretty Type where pretty = para go @@ -97,10 +119,11 @@ instance Pretty Type where go (TyNumber n) = pretty n go TyBool = "Bool" go TyString = "String" - go (TyAdt n tparams) = pretty n <> hsep (snd <$> tparams) + go (TyAdt n tparams) = hsep (pretty n : (snd <$> tparams)) go (TyFun args res) = encloseSep mempty mempty " -> " (wrapComplex <$> args ++ [res]) go (TyVar n) = pretty n + go (TyBox b) = "#<" <> snd b <> ">" -- | A universally quantified type. data PolyType a = PolyType (Set Text) a -- ^ The type variables and the Type diff --git a/src/Simpl/Typecheck.hs b/src/Simpl/Typecheck.hs index cfb3422..494ca79 100644 --- a/src/Simpl/Typecheck.hs +++ b/src/Simpl/Typecheck.hs @@ -15,8 +15,8 @@ import Control.Monad.Reader (ReaderT, MonadReader, runReaderT, asks, local) import Control.Monad.Except (ExceptT, MonadError, lift, runExceptT, throwError) import Control.Unification import Control.Unification.IntVar -import Data.Maybe (fromMaybe) -import Data.Foldable (traverse_, asum) +import Data.Maybe (fromMaybe, fromJust) +import Data.Foldable (traverse_) import Data.Functor.Identity import Data.Functor.Foldable (Fix(..), unfix, cata) import Data.Text (Text) @@ -112,13 +112,17 @@ inferType = cata $ \ae -> case annGetExpr ae of let argTys = extractTy <$> args ctorRes <- asks (symTabLookupCtor name) case ctorRes of - Just (adtTy, Ctor _ ctorArgTys, _) -> do - let conArgs = typeToUtype <$> ctorArgTys - let numConArgs = length conArgs + Just ((adtName, tvars), Ctor _ ctorArgTys, _) -> do + let numConArgs = length ctorArgTys when (numConArgs /= length argTys) $ throwError $ TyErrArgCount numConArgs (length argTys) ctorArgTys + -- Instantiate type variables + substMap <- instantiateVars (Set.fromList tvars) + let conArgs = substituteUVars substMap . typeToUtype <$> ctorArgTys traverse_ (uncurry unifyTy) (zip conArgs argTys) - pure $ annotate (Cons name args) (annGetAnn ae) (typeToUtype adtTy) + let newTvars = [fromJust $ Map.lookup v substMap | v <- tvars] + let newTy = UTerm (TyAdt adtName newTvars) + pure $ annotate (Cons name args) (annGetAnn ae) newTy Nothing -> throwError $ TyErrNoSuchCtor name Case branchMs valM -> do val <- valM @@ -126,11 +130,21 @@ inferType = cata $ \ae -> case annGetExpr ae of branches <- forM branchMs $ \case BrAdt ctorName bindings exprM -> asks (symTabLookupCtor ctorName) >>= \case - Just (dataTy, Ctor _ ctorArgs, _) -> do + Just ((adtName, tvars), Ctor _ ctorArgs, _) -> do when (length bindings /= length ctorArgs) $ throwError $ TyErrArgCount (length ctorArgs) (length bindings) ctorArgs - _ <- unifyTy valTy (typeToUtype dataTy) - let updatedBinds = Map.fromList (bindings `zip` ctorArgs) + -- Instantiate type variables + substMap <- instantiateVars (Set.fromList tvars) + let dataTy = Fix (TyAdt adtName (Fix . TyVar <$> tvars)) + _ <- unifyTy valTy (substituteUVars substMap (typeToUtype dataTy)) + let substCtorArgs = substituteUVars substMap . typeToUtype <$> ctorArgs + -- TODO: Same hack as in let binding + instCtorArgs <- forM substCtorArgs $ \t -> do + t' <- utypeToType <$> forceBindings t + case t' of + Just t'' -> pure t'' + Nothing -> throwError $ TyErrAmbiguousType t + let updatedBinds = Map.fromList (bindings `zip` instCtorArgs) -- Infer result type with ctor args bound expr <- local (\t -> t { symTabVars = Map.union (symTabVars t) updatedBinds }) exprM pure $ BrAdt ctorName bindings expr @@ -140,8 +154,9 @@ inferType = cata $ \ae -> case annGetExpr ae of annotate (Case branches val) (annGetAnn ae) <$> foldM unifyTy resTy brTys Let name valM nextM -> do val <- valM + valTy <- forceBindings (extractTy val) -- TODO: Hack, fix this - case utypeToType (extractTy val) of + case utypeToType valTy of Just ty -> do next <- local (symTabInsertVar name ty) nextM pure $ annotate (Let name val next) (annGetAnn ae) (extractTy next) @@ -158,27 +173,28 @@ inferType = cata $ \ae -> case annGetExpr ae of (tvars, params, ty) <- lookupFun name (extractTy <$> argsTc) -- Check parameter count let numParams = length params - let paramCount = length params + let paramCount = length args when (numParams /= paramCount) $ throwError $ TyErrArgCount numParams paramCount params - let unifyExprTy expr pTy = - annotate (annGetExpr (unfix expr)) (annGetAnn ae) <$> unifyTy (extractTy expr) pTy -- Instantiate type variables substMap <- instantiateVars tvars let instParams = substituteUVars substMap . typeToUtype <$> params let resTy = substituteUVars substMap (typeToUtype ty) -- Check parameter types + let unifyExprTy expr pTy = + annotate (annGetExpr (unfix expr)) (annGetAnn ae) <$> unifyTy (extractTy expr) pTy params' <- zipWithM unifyExprTy argsTc instParams -- Annotate with result type pure $ annotate (App name params') (annGetAnn ae) resTy FunRef name -> asks (symTabLookupStaticFun name) >>= \case - Just (tvars, params, ty, _) -> - -- TODO: How to handle polymorphic function? - let paramTys = snd <$> params - funTy = typeToUtype (Fix $ TyFun paramTys ty) in - pure $ annotate (FunRef name) (annGetAnn ae) funTy + Just (tvars, params, ty, _) -> do + substMap <- instantiateVars tvars + let paramTys = substituteUVars substMap . typeToUtype . snd <$> params + let resTy = substituteUVars substMap (typeToUtype ty) + let funTy = substituteUVars substMap (UTerm (TyFun paramTys resTy)) + pure $ annotate (FunRef name) (annGetAnn ae) funTy Nothing -> throwError $ TyErrNoSuchVar name Cast exprM num -> do expr <- exprM @@ -214,12 +230,13 @@ inferType = cata $ \ae -> case annGetExpr ae of rTy <- if unifyArgResult then unifyTy yTy resultTy else pure resultTy pure $ annotate (BinOp op x y) annFields rTy + -- | Looks up a function from either variable or function scope. Does not + -- instantiate type variables (should be handled on a per-construct basis). lookupFun :: Text -> [UType] -> Typecheck fields (Set Text, [Type], Type) lookupFun name argTys = asks (symTabLookupVar name) >>= \case Just ty -> case unfix ty of - -- TODO: Currently there is no way to know if the function is polymorphic TyFun params resTy -> pure (Set.empty, params, resTy) _ -> do resTy <- mkMetaVar @@ -227,11 +244,7 @@ inferType = cata $ \ae -> case annGetExpr ae of expected = TyFun argTys resTy throwError $ TyErrMismatch expected got Nothing -> - -- TODO: Instantiate type variables - let lookupStatic n = fmap (\(tvars, p, r, _) -> (tvars, p, r)) . symTabLookupStaticFun n - lookupExtern n = fmap (\(p, r) -> (Set.empty, p, r)) . symTabLookupExternFun n - result = traverse (\f -> asks (f name)) [lookupStatic, lookupExtern] in - asum <$> result >>= \case + asks (symTabLookupFun name) >>= \case Just (tvars, params, resTy) -> pure (tvars, snd <$> params, resTy) Nothing -> throwError (TyErrNoSuchVar name) @@ -263,9 +276,10 @@ typeToUtype = cata $ \case TyNumber n -> UTerm (TyNumber n) TyBool -> UTerm TyBool TyString -> UTerm TyString - TyAdt n tparams -> UTerm (TyAdt n tparams) -- TODO: Instantiate variables somewhere + TyAdt n tparams -> UTerm (TyAdt n tparams) TyFun args res -> UTerm (TyFun args res) TyVar n -> UTerm (TyVar n) + TyBox _ -> error "TyBox should not be in SimPL AST" -- | Instantiate the type variables with new unification variables instantiateVars :: Set Text -> Typecheck fields (Map.Map Text UType) diff --git a/test-suite/poly-data.spl b/test-suite/poly-data.spl new file mode 100644 index 0000000..4d4be16 --- /dev/null +++ b/test-suite/poly-data.spl @@ -0,0 +1,75 @@ +data Maybe a = { Just a | Nothing } + +fun testJust : Int := { + let foo = Just 5 in # 0 + let asdf = case foo of + Just x => let _ = println("Got Just!") in x + Nothing => let _ = println("ERROR: Got Nothing") in 0 + in asdf +} + +data List a = { Nil | Cons a (List a) } + +fun head (xs : List a) : Maybe a := { + case xs of + Nil => Nothing + Cons h t => Just h +} + +# FIXME: passing a function ref w/ unboxed argument to a function +# requiring a function ref w/ boxed argument does not work. +# Passed function ref needs to be automatically "promoted" to a +# boxed version. + +# fun filter (xs : List a, f : a -> Bool) : List a := { +# case xs of +# Nil => Nil +# Cons h t => if @f(h) then Cons h @filter(t, f) else @filter(t, f) +# } + +fun filter (xs : List Int, f : Int -> Bool) : List Int := { + case xs of + Nil => Nil + Cons h t => + let newT = @filter(t, f) in + if @f(h) then Cons h newT else newT +} + +fun lte5 (x: Int) : Bool := { + x <= 5 +} + +fun numberList (n : Int) : List Int := { + if n <= 0 then Nil else Cons n @numberList(n - 1) +} + +fun printList (xs : List a, msg : String) : Int := { + case xs of + Nil => 0 + Cons h t => + let _ = println(msg) in + @printList(t, msg) +} + +fun main : Int := { + let nums = @numberList(10) in + let x1 = @printList(nums, "numberList") in + let x2 = @printList(@filter(nums, <e5), "filter(numberList, <= 5)") in + x2 +} + +data AddTable a = { Add (a -> a -> a) } + +fun add (x : a, y : a, table : AddTable a) : a := { + case table of + Add f => @f(x, y) +} + +fun add_Int (x : Int, y : Int) : Int := { + x + y +} + +fun addTable_Int : AddTable Int := { + let f = &add_Int in + Add f +} diff --git a/test-suite/polymorphism.spl b/test-suite/polymorphism.spl new file mode 100644 index 0000000..85c24f7 --- /dev/null +++ b/test-suite/polymorphism.spl @@ -0,0 +1,27 @@ +fun id (x: a) : a := { x } + +data Foo = { Bar Int | Nope } + +fun test_data (f : Foo) : Foo := { + case @id(f) of + Bar y => f + Nope => f +} + +fun test_case (x : Int) : String := { + @id(case Bar x of + Bar y => @id("hi") + Nope => "bye") +} + +fun hi : String := { + let f = &test_case in @test_case(0) +} + +fun main : Int := { + let _ = println(@id("hi")) in + let y = 5 in + let _ = println(if true then @id("bye") else "sigh") in +# let _ = println(if @id(y) <= 0 then @id("bye") else "sigh") in + 0 +} diff --git a/test.sh b/test.sh index a7af0e5..7ef87d5 100755 --- a/test.sh +++ b/test.sh @@ -13,5 +13,5 @@ for src_file in $(find "$TEST_DIR" -type f -name '*.spl'); do out_name="$TEST_BIN_DIR/$out_name" echo "Building ${out_name%.o}..." stack exec simplc -- "$src_file" -o "$out_name" $COMPILER_ARGS - clang -g -pthread runtime/libgc.a "$out_name" -o "${out_name%.o}" -lm + clang -g -pthread "$out_name" runtime/libgc.a -o "${out_name%.o}" -lm done