From acaceae1a9f8b26b46cd292738f59984ac30f805 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Thu, 12 Dec 2019 10:28:28 -0800 Subject: [PATCH 01/23] WIP implement runtime support for polymorphism --- runtime/runtime.c | 12 +++ runtime/runtime.h | 40 +++++++++ sample.spl | 154 ++++++++++++++++++----------------- src/Simpl/Backend/Codegen.hs | 15 ++++ src/Simpl/Backend/Runtime.hs | 30 ++++++- src/Simpl/JoinIR/Syntax.hs | 4 + test-suite/polymorphism.spl | 6 ++ 7 files changed, 184 insertions(+), 77 deletions(-) create mode 100644 test-suite/polymorphism.spl diff --git a/runtime/runtime.c b/runtime/runtime.c index bd15016..5702a42 100644 --- a/runtime/runtime.c +++ b/runtime/runtime.c @@ -45,3 +45,15 @@ int simpl_string_print(const struct simpl_string* s) { free(cstring); return 0; } + +int simpl_tagged_size(const struct simpl_type_tag* const tag) { + return tag->size; +} + +const struct simpl_type_tag* const simpl_tagged_tag(struct simpl_tagged_value* value) { + return value->type_tag; +} + +void* simpl_tagged_unbox(struct simpl_tagged_value* value) { + return value->data; +} diff --git a/runtime/runtime.h b/runtime/runtime.h index 68b27c1..0232ca2 100644 --- a/runtime/runtime.h +++ b/runtime/runtime.h @@ -38,4 +38,44 @@ char* simpl_string_cstring(const struct simpl_string* s); int simpl_string_print(const struct simpl_string* s); +/** + * Describes static information about a type. This struct should not be visible + * from SimPL programs. + */ +struct simpl_type_tag { + /** + * The size (e.g. when compiled) of the type, in bytes. + */ + unsigned int size; +}; + +/** + * A value tagged with its type tag. Used for polymorphic functions and + * variables. + */ +struct simpl_tagged_value { + const struct simpl_type_tag* const type_tag; + void* data; +}; + +/** + * Returns the size recorded in the type tag. + */ +int simpl_tag_size(const struct simpl_type_tag* const); + +/** + * Returns the type tag of a tagged value. + */ +const struct simpl_type_tag* const simpl_tagged_tag(struct simpl_tagged_value*); + +/** + * Returns a pointer to the boxed value of a tagged value. + */ +void* simpl_tagged_unbox(struct simpl_tagged_value*); + +/** + * Boxes the given value + */ +struct simpl_tagged_value* simpl_tagged_box(struct simpl_type_tag* tag, void* data); + #endif diff --git a/sample.spl b/sample.spl index cc0f013..07fb8c5 100644 --- a/sample.spl +++ b/sample.spl @@ -5,78 +5,82 @@ data MaybeI = { JustI Double | Nothing } # data Barbar a = { Asdf Int a } # fun id (x : a) : a := { x } - -fun not (b : Bool) : Bool := { if b then false else true } - -fun and (p : Bool, q : Bool) : Bool := { if p then (if q then true else false) else false } - -fun or (p : Bool, q : Bool) : Bool := { if p then true else (if q then true else false) } - -fun factorial (x : Int) : Int := { - if (x <= 0) then 1 - else x * @factorial(x - 1) -} - -fun even (x: Int) : Bool := { - if (x <= 0) then true - else @not(@odd(x - 2)) -} - -fun odd (x: Int) : Bool := { - if (x <= 1) then true - else @not(@even(x - 2)) -} - -fun nested_ifs : Double := { - (if true then (if true then 4.0 else 5.0) else (if false then 2.0 else 3.0) + 1.0) * 2.0 -} - -fun main : Int := { - case JustI 10.0 of - JustI x => - let msg = println("In Just branch") in - let mynum = @abs(-1) in - (if (@even(4)) then @double_me(5) else @factorial(6)) * @asdf * mynum - Nothing => - let msg = println("In Nothing branch") in - let res = 4 in - @double_me(res + 1) -} - -fun asdf : Int := { - (if false then 5 else 10) + 2 -} - -fun double_me (x : Int) : Int := { x * 2 } - -fun foo : Foo := { - Bar 5.0 -} - -fun foo2 : Bool := { - case Bar 5.0 of - Bar x => true -} - -fun lots_of_lets : Double := { - let x = if true then 1.0 else 2.0 in - let y = if true then x * 2.0 else x * 2.0 + 1.0 in - y -} - -fun fun_ptr : Int -> Int := { &double_me } - -fun fun_ptr_test (b : Bool) : Int := { - let f = if b then &double_me else &factorial in - @f(4) -} - -fun cast_test_1 : Int := { - (1 + 1) * (cast 2.0 as Int) -} - -fun cast_test_2 : Int := { - (1 * 1) + (cast 2.0 as Int) -} - -fun abs (x : Int) : Int := extern +# +# fun eqargs (x : a, y : a, z: b) : a := { x } + +fun main : Int := { let x = &asdf in 5 } + +# fun not (b : Bool) : Bool := { if b then false else true } +# +# fun and (p : Bool, q : Bool) : Bool := { if p then (if q then true else false) else false } +# +# fun or (p : Bool, q : Bool) : Bool := { if p then true else (if q then true else false) } +# +# fun factorial (x : Int) : Int := { +# if (x <= 0) then 1 +# else x * @factorial(x - 1) +# } +# +# fun even (x: Int) : Bool := { +# if (x <= 0) then true +# else @not(@odd(x - 2)) +# } +# +# fun odd (x: Int) : Bool := { +# if (x <= 1) then true +# else @not(@even(x - 2)) +# } +# +# fun nested_ifs : Double := { +# (if true then (if true then 4.0 else 5.0) else (if false then 2.0 else 3.0) + 1.0) * 2.0 +# } +# +# fun main : Int := { +# case JustI 10.0 of +# JustI x => +# let msg = println("In Just branch") in +# let mynum = @abs(-1) in +# (if (@even(4)) then @double_me(5) else @factorial(6)) * @asdf * mynum +# Nothing => +# let msg = println("In Nothing branch") in +# let res = 4 in +# @double_me(res + 1) +# } +# +# fun asdf : Int := { +# (if false then 5 else 10) + 2 +# } +# +# fun double_me (x : Int) : Int := { x * 2 } +# +# fun foo : Foo := { +# Bar 5.0 +# } +# +# fun foo2 : Bool := { +# case Bar 5.0 of +# Bar x => true +# } +# +# fun lots_of_lets : Double := { +# let x = if true then 1.0 else 2.0 in +# let y = if true then x * 2.0 else x * 2.0 + 1.0 in +# y +# } +# +# fun fun_ptr : Int -> Int := { &double_me } +# +# fun fun_ptr_test (b : Bool) : Int := { +# let f = if b then &double_me else &factorial in +# @f(4) +# } +# +# fun cast_test_1 : Int := { +# (1 + 1) * (cast 2.0 as Int) +# } +# +# fun cast_test_2 : Int := { +# (1 * 1) + (cast 2.0 as Int) +# } +# +# fun abs (x : Int) : Int := extern diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 8c9fdc3..edfcf49 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -377,6 +377,21 @@ callableCodegen callable args = case callable of pure $ LLVMIR.int64 0 _ -> error $ "callableCodegen: expected 1 args to CPrint, got " ++ show (length args) CFunRef name -> gets (fromJust . Map.lookup name . tableFuns) + CTag -> case args of + [val] -> do + -- TODO: tag lookup + error "CTag compilation: Unimplemented" + _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) + CUntag -> case args of + [jval] -> do + val <- jvalueCodegen jval + typeTag <- LLVMIR.call RT.taggedTagRef [(val, [])] + len <- LLVMIR.call RT.tagSizeRef [(typeTag, [])] + bytesPtr <- LLVMIR.call RT.taggedUnboxRef [(val, [])] + -- TODO: lookup LLVM type + -- LLVMIR.bitcast bytesPtr _ + error "CUntag compilation: Unimplemented" + _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) -- | Generates code for a [JExpr] jexprCodegen diff --git a/src/Simpl/Backend/Runtime.hs b/src/Simpl/Backend/Runtime.hs index e58af3f..efd076f 100644 --- a/src/Simpl/Backend/Runtime.hs +++ b/src/Simpl/Backend/Runtime.hs @@ -97,13 +97,39 @@ stringFuns = [ ("simpl_string_cstring", stringCstringType) stringStructs :: [String] stringStructs = ["simpl_string"] +-- * Tags + +typeTagType :: LLVM.Type +typeTagType = runtimeStruct "simpl_type_tag" + +taggedValueType :: LLVM.Type +taggedValueType = runtimeStruct "simpl_tagged_value" + +tagSizeType, taggedTagType, taggedUnboxType :: FunType +tagSizeType = mkFunType [("t", LLVM.ptr typeTagType)] LLVM.i64 +taggedTagType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr taggedValueType) +taggedUnboxType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr LLVM.void) + +tagSizeRef, taggedTagRef, taggedUnboxRef :: LLVM.Operand +tagSizeRef = runtimeFunRef "simpl_tag_size" tagSizeType +taggedTagRef = runtimeFunRef "simpl_tagged_tag" taggedTagType +taggedUnboxRef = runtimeFunRef "simpl_tagged_unbox" taggedUnboxType + +runtimeTypeFuns :: [(String, FunType)] +runtimeTypeFuns = [ ("simpl_tag_size", tagSizeType) + , ("simpl_tagged_tag", taggedTagType) + , ("simpl_tagged_unbox", taggedUnboxType) ] + +runtimeTypeStructs :: [String] +runtimeTypeStructs = ["simpl_type_tag", "simpl_tagged_value"] + -- * Entire runtime allRuntimeFuns :: [(String, FunType)] -allRuntimeFuns = join [cstdlibFuns, stringFuns] +allRuntimeFuns = join [cstdlibFuns, stringFuns, runtimeTypeFuns] allRuntimeStructs :: [String] -allRuntimeStructs = stringStructs +allRuntimeStructs = stringStructs ++ runtimeTypeStructs emitRuntimeDecls :: LLVMIR.MonadModuleBuilder m => m () emitRuntimeDecls = do diff --git a/src/Simpl/JoinIR/Syntax.hs b/src/Simpl/JoinIR/Syntax.hs index d92e99a..3a6a93b 100644 --- a/src/Simpl/JoinIR/Syntax.hs +++ b/src/Simpl/JoinIR/Syntax.hs @@ -39,6 +39,8 @@ data Callable | CCtor !Name -- ^ ADT constructor | CPrint -- ^ Print string (temporary) | CFunRef !Name -- ^ Static function reference + | CTag -- ^ Create a boxed representation of the given value + | CUntag -- ^ Unbox the value deriving (Show) -- | A value @@ -124,6 +126,8 @@ instance Pretty Callable where CCtor name -> pretty name CPrint -> "print" CFunRef name -> "funref[" <> pretty name <> "]" + CTag -> "tag" + CUntag -> "untag" instance Pretty JValue where pretty = \case diff --git a/test-suite/polymorphism.spl b/test-suite/polymorphism.spl new file mode 100644 index 0000000..1484872 --- /dev/null +++ b/test-suite/polymorphism.spl @@ -0,0 +1,6 @@ +fun id (x: a) : a := { x } + +fun main : Int := { + let _ = println(@id("hi")) in + 0 +} \ No newline at end of file From 74e8ccc1213483be019c7b2dc9a110f15975b20a Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 15 Dec 2019 00:21:37 -0800 Subject: [PATCH 02/23] Fix test suite linker flags --- runtime/runtime.h | 10 +++++----- test.sh | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/runtime/runtime.h b/runtime/runtime.h index 0232ca2..cfc47ea 100644 --- a/runtime/runtime.h +++ b/runtime/runtime.h @@ -49,6 +49,11 @@ struct simpl_type_tag { unsigned int size; }; +/** + * Returns the size recorded in the type tag. + */ +int simpl_tag_size(const struct simpl_type_tag* const); + /** * A value tagged with its type tag. Used for polymorphic functions and * variables. @@ -58,11 +63,6 @@ struct simpl_tagged_value { void* data; }; -/** - * Returns the size recorded in the type tag. - */ -int simpl_tag_size(const struct simpl_type_tag* const); - /** * Returns the type tag of a tagged value. */ diff --git a/test.sh b/test.sh index a7af0e5..7ef87d5 100755 --- a/test.sh +++ b/test.sh @@ -13,5 +13,5 @@ for src_file in $(find "$TEST_DIR" -type f -name '*.spl'); do out_name="$TEST_BIN_DIR/$out_name" echo "Building ${out_name%.o}..." stack exec simplc -- "$src_file" -o "$out_name" $COMPILER_ARGS - clang -g -pthread runtime/libgc.a "$out_name" -o "${out_name%.o}" -lm + clang -g -pthread "$out_name" runtime/libgc.a -o "${out_name%.o}" -lm done From d7afb51d0a0e07151e1ad5bcaa8b6d7d46bedf98 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 15 Dec 2019 00:22:05 -0800 Subject: [PATCH 03/23] WIP More work on boxed/tagged data compilation --- src/Simpl/Backend/Codegen.hs | 14 +++++++++----- src/Simpl/Backend/Runtime.hs | 5 ++++- src/Simpl/Type.hs | 5 +++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index edfcf49..8086460 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -378,19 +378,23 @@ callableCodegen callable args = case callable of _ -> error $ "callableCodegen: expected 1 args to CPrint, got " ++ show (length args) CFunRef name -> gets (fromJust . Map.lookup name . tableFuns) CTag -> case args of - [val] -> do + [jval] -> do + ty <- lookupValueType jval + val <- jvalueCodegen jval -- TODO: tag lookup - error "CTag compilation: Unimplemented" + let tag = error "TODO" + bytes <- LLVMIR.bitcast val (LLVM.ptr LLVM.void) + LLVMIR.call RT.taggedBoxRef [(tag, []), (bytes, [])] _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) CUntag -> case args of [jval] -> do + ty <- lookupValueType jval + let llvmTy = typeToLLVM (Fix ty) val <- jvalueCodegen jval typeTag <- LLVMIR.call RT.taggedTagRef [(val, [])] len <- LLVMIR.call RT.tagSizeRef [(typeTag, [])] bytesPtr <- LLVMIR.call RT.taggedUnboxRef [(val, [])] - -- TODO: lookup LLVM type - -- LLVMIR.bitcast bytesPtr _ - error "CUntag compilation: Unimplemented" + LLVMIR.bitcast bytesPtr (LLVM.ptr llvmTy) _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) -- | Generates code for a [JExpr] diff --git a/src/Simpl/Backend/Runtime.hs b/src/Simpl/Backend/Runtime.hs index efd076f..eb16341 100644 --- a/src/Simpl/Backend/Runtime.hs +++ b/src/Simpl/Backend/Runtime.hs @@ -109,16 +109,19 @@ tagSizeType, taggedTagType, taggedUnboxType :: FunType tagSizeType = mkFunType [("t", LLVM.ptr typeTagType)] LLVM.i64 taggedTagType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr taggedValueType) taggedUnboxType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr LLVM.void) +taggedBoxType = mkFunType [("t", LLVM.ptr taggedValueType), ("d", LLVM.ptr LLVM.void)] (LLVM.ptr typeTagType) tagSizeRef, taggedTagRef, taggedUnboxRef :: LLVM.Operand tagSizeRef = runtimeFunRef "simpl_tag_size" tagSizeType taggedTagRef = runtimeFunRef "simpl_tagged_tag" taggedTagType taggedUnboxRef = runtimeFunRef "simpl_tagged_unbox" taggedUnboxType +taggedBoxRef = runtimeFunRef "simpl_tagged_box" taggedBoxType runtimeTypeFuns :: [(String, FunType)] runtimeTypeFuns = [ ("simpl_tag_size", tagSizeType) , ("simpl_tagged_tag", taggedTagType) - , ("simpl_tagged_unbox", taggedUnboxType) ] + , ("simpl_tagged_unbox", taggedUnboxType) + , ("simpl_tagged_box", taggedBoxType) ] runtimeTypeStructs :: [String] runtimeTypeStructs = ["simpl_type_tag", "simpl_tagged_value"] diff --git a/src/Simpl/Type.hs b/src/Simpl/Type.hs index f1ce189..3d718b4 100644 --- a/src/Simpl/Type.hs +++ b/src/Simpl/Type.hs @@ -41,6 +41,7 @@ data TypeF a | TyAdt Text [a] | TyFun [a] a | TyVar Text + | TyBox a -- ^ Boxed polymorphic type deriving (Show, Functor, Foldable, Traversable) type Type = Fix TypeF @@ -63,9 +64,11 @@ instance Unifiable TypeF where else Nothing zipMatch (TyFun as1 r1) (TyFun as2 r2) = if length as1 == length as2 then + -- TODO: Check by alpha-equivalence instead of raw equality Just $ TyFun (zipWith (curry Right) as1 as2) (Right (r1, r2)) else Nothing + zipMatch (TyBox t1) (TyBox t2) = Just $ TyBox (Right (t1, t2)) zipMatch _ _ = Nothing isComplexType :: TypeF a -> Bool @@ -86,6 +89,7 @@ getTypeVars = cata $ \case TyVar v -> Set.singleton v TyFun vparams vret -> Set.unions (vret:vparams) TyAdt _ vargs -> Set.unions vargs + TyBox vs -> vs instance Pretty Type where @@ -101,6 +105,7 @@ instance Pretty Type where go (TyFun args res) = encloseSep mempty mempty " -> " (wrapComplex <$> args ++ [res]) go (TyVar n) = pretty n + go (TyBox b) = "#<" <> snd b <> ">" -- | A universally quantified type. data PolyType a = PolyType (Set Text) a -- ^ The type variables and the Type From b7eaec1bcb50a40a2e0022d24bc3af2ed321c651 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 15 Dec 2019 21:05:48 -0800 Subject: [PATCH 04/23] WIP track boxing in AST to JoinIR conversion --- src/Simpl/AstToJoinIR.hs | 78 ++++++++++++++++++++++++++++++------ src/Simpl/Backend/Codegen.hs | 3 +- 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 3bb2c5e..84892d2 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -22,14 +22,20 @@ import Data.Functor.Foldable (Fix(..), unfix) import Data.Functor.Identity import Data.Text (Text) import Data.String (fromString) +import Data.Maybe (fromJust) +import Data.Map (Map) +import qualified Data.Map as Map import Data.Set (Set) import qualified Data.Set as Set +import Debug.Trace + import Simpl.Annotation import Simpl.SymbolTable import qualified Simpl.Ast as A import qualified Simpl.JoinIR.Syntax as J -import Simpl.Type (Type) +import Simpl.Type (Type, TypeF(TyBox, TyVar)) +import Simpl.Typecheck (literalType) import Simpl.Util.Supply import qualified Simpl.Util.Stream as Stream @@ -40,21 +46,25 @@ astToJoinIR table = runTransform transformTable (defaultCtx table) -- * Transformation Monad +data BoxedVal = Boxed | Unboxed deriving (Show, Eq, Ord) + data TransformCtx fields = TransformCtx { tcSymTab :: SymbolTable (A.AnnExpr fields) , tcJoinLabels :: Set Text + , tcBoxStatus :: Map Text BoxedVal } defaultCtx :: SymbolTable (A.AnnExpr flds) -> TransformCtx flds -defaultCtx table = TransformCtx { tcSymTab = table, tcJoinLabels = Set.empty } - -modifySymTab :: (SymbolTable (A.AnnExpr flds) -> SymbolTable (A.AnnExpr flds)) - -> TransformCtx flds - -> TransformCtx flds -modifySymTab f ctx = ctx { tcSymTab = f (tcSymTab ctx) } +defaultCtx table = TransformCtx + { tcSymTab = table + , tcJoinLabels = Set.empty + , tcBoxStatus = Map.empty } insertVar :: Text -> Type -> TransformCtx flds -> TransformCtx flds -insertVar name ty = modifySymTab (symTabInsertVar name ty) +insertVar name ty ctx = ctx + { tcSymTab = symTabInsertVar name ty (tcSymTab ctx) + , tcBoxStatus = Map.insert name (boxedVal ty) (tcBoxStatus ctx) + } newtype TransformT fields m a = TransformT { unTransform :: ReaderT (TransformCtx fields) (SupplyT Int m) a } @@ -110,6 +120,38 @@ makeJexpr ty = Fix . addField (withType ty) . toAnnExprF astType :: HasType flds => A.AnnExpr flds -> Type astType = getType . unfix +getJvalueType :: MonadReader (TransformCtx flds) m => J.JValue -> m Type +getJvalueType = \case + J.JVar n -> asks (fromJust . symTabLookupVar n . tcSymTab) + J.JLit l -> pure . Fix $ literalType l + +-- | Get boxed type. Left is unboxed, right is boxed. +boxedType :: Type -> Either Type Type +boxedType = \case + t@(Fix (TyVar _)) -> Right t + Fix (TyBox t) -> Right t + t -> Left t + +isBoxed :: Type -> Bool +isBoxed t = case boxedType t of { Left _ -> False; Right _ -> True } + +boxedVal :: Type -> BoxedVal +boxedVal t = if isBoxed t then Boxed else Unboxed + +rebindBoxing :: (HasType flds, MonadFreshVar m, MonadReader (TransformCtx flds) m) + => J.JValue -- ^ Variable name + -> Type -- ^ Variable type + -> BoxedVal -- ^ Whether to ensure boxed or unboxed + -> m (J.JValue, J.AnnExpr '[ 'ExprType] -> J.AnnExpr '[ 'ExprType]) +rebindBoxing val ty b = do + let create ty' action = do + name <- case val of { J.JVar n -> pure n; _ -> freshVar } + local (insertVar name ty') (pure (J.JVar name, makeJexpr ty' . J.JApp name action [val])) + case (boxedType ty, b) of + (Right ty', Unboxed) -> create ty' J.CUntag + (Left ty', Boxed) -> create (Fix (TyBox ty')) J.CTag + _ -> pure (val, id) + -- * ANF Transformation -- | Perform ANF transformation on the given symbol table @@ -171,7 +213,7 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of A.Case branches expr -> anfTransform expr $ \jexpr -> do lbl <- freshLabel - let jexprTy = getType (unfix expr) + let jexprTy = astType expr jbranches <- traverse (transformBranch (J.JJump lbl)) branches let jexprCfe = J.Cfe (makeJexpr jexprTy (J.JVal jexpr)) (J.JCase jbranches) name <- freshVar @@ -180,23 +222,33 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of local (insertVar name ty) (cont (J.JVar name)) A.Cons ctorName args -> collectArgs args $ \argVals -> do + -- TODO: find and box polymorphic args varName <- freshVar makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals <$> local (insertVar varName ty) (cont (J.JVar varName)) A.App funcName args -> collectArgs args $ \argVals -> do varName <- freshVar - makeJexpr ty . J.JApp varName (J.CFunc funcName) argVals <$> - local (insertVar varName ty) (cont (J.JVar varName)) + (_, funcArgs, funcRetTy, _) <- asks (fromJust . symTabLookupStaticFun funcName . tcSymTab) + argTys <- traverse getJvalueType argVals + tuples <- sequence [rebindBoxing val aTy (boxedVal faTy) + | (((_, faTy), val), aTy) <- funcArgs `zip` argVals `zip` argTys] + let (argVals', boxConvArgs_) = unzip tuples + let boxConvArgs = foldl (.) id boxConvArgs_ + let ty' = if isBoxed funcRetTy then Fix (TyBox ty) else ty + boxConvArgs . makeJexpr ty' . J.JApp varName (J.CFunc funcName) argVals' <$> + local (insertVar varName ty') (cont (J.JVar varName)) A.Cast expr numTy -> anfTransform expr $ \jexpr -> do varName <- freshVar makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr] <$> local (insertVar varName ty) (cont (J.JVar varName)) A.Print expr -> - anfTransform expr $ \jexpr -> do + anfTransform expr $ \jval -> do varName <- freshVar - makeJexpr ty . J.JApp varName J.CPrint [jexpr] <$> + valTy <- getJvalueType jval + (jval', boxConv) <- rebindBoxing jval valTy Unboxed + boxConv . makeJexpr ty . J.JApp varName J.CPrint [jval'] <$> local (insertVar varName ty) (cont (J.JVar varName)) A.FunRef name -> do varName <- freshVar diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 8086460..1872609 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -462,7 +462,8 @@ typeToLLVM = go . unfix , LLVM.argumentTypes = typeToLLVM <$> args , LLVM.isVarArg = False } - TyVar _ -> error "compilation of parametrically polymorphic functions not implemented yet" + TyVar _ -> RT.taggedValueType + TyBox _ -> RT.taggedValueType adtToLLVM :: Text -> [Constructor] From 438ead762ea7bfc6eee469740607d01d75966883 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 16 Dec 2019 16:28:47 -0800 Subject: [PATCH 05/23] Finish codegen for tag/untagging --- runtime/runtime.c | 10 +++++++- runtime/runtime.h | 7 +++--- src/Simpl/Backend/Codegen.hs | 46 ++++++++++++++++++++++++++---------- src/Simpl/Backend/Runtime.hs | 15 ++++++++---- src/Simpl/Type.hs | 6 +++-- src/Simpl/Typecheck.hs | 1 + 6 files changed, 62 insertions(+), 23 deletions(-) diff --git a/runtime/runtime.c b/runtime/runtime.c index 5702a42..a562980 100644 --- a/runtime/runtime.c +++ b/runtime/runtime.c @@ -46,7 +46,7 @@ int simpl_string_print(const struct simpl_string* s) { return 0; } -int simpl_tagged_size(const struct simpl_type_tag* const tag) { +uint32_t simpl_tagged_size(const struct simpl_type_tag* const tag) { return tag->size; } @@ -57,3 +57,11 @@ const struct simpl_type_tag* const simpl_tagged_tag(struct simpl_tagged_value* v void* simpl_tagged_unbox(struct simpl_tagged_value* value) { return value->data; } + + +struct simpl_tagged_value* simpl_tagged_box(struct simpl_type_tag* tag, void* data) { + struct simpl_tagged_value *value = simpl_malloc(sizeof(struct simpl_tagged_value)); + value->type_tag = tag; + value->data = data; + return value; +} diff --git a/runtime/runtime.h b/runtime/runtime.h index cfc47ea..a23501e 100644 --- a/runtime/runtime.h +++ b/runtime/runtime.h @@ -1,4 +1,5 @@ #include +#include #ifndef RUNTIME_H #define RUNTIME_H @@ -46,20 +47,20 @@ struct simpl_type_tag { /** * The size (e.g. when compiled) of the type, in bytes. */ - unsigned int size; + uint32_t size; }; /** * Returns the size recorded in the type tag. */ -int simpl_tag_size(const struct simpl_type_tag* const); +uint32_t simpl_tag_size(const struct simpl_type_tag* const); /** * A value tagged with its type tag. Used for polymorphic functions and * variables. */ struct simpl_tagged_value { - const struct simpl_type_tag* const type_tag; + const struct simpl_type_tag* type_tag; void* data; }; diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 1872609..aa02605 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -48,7 +48,7 @@ import Simpl.Annotation hiding (AnnExpr, AnnExprF) import Simpl.Ast (BinaryOp(..), Constructor(..), Literal(..)) import Simpl.CompilerOptions import Simpl.SymbolTable -import Simpl.Type (Type, TypeF(..), Numeric(..)) +import Simpl.Type (Type(..), TypeF(..), Numeric(..)) import Simpl.Typecheck (literalType) import Simpl.Backend.Runtime () import Simpl.JoinIR.Syntax @@ -61,7 +61,8 @@ data CodegenTable = , tableFuns :: Map Text LLVM.Operand , tableJoinValues :: Map Text (LLVM.Name, [(LLVM.Operand, LLVM.Name)]) , tablePrintf :: LLVM.Operand - , tableOptions :: CompilerOpts } + , tableOptions :: CompilerOpts + , tableTypeTags :: Map (TypeF Type) LLVM.Operand } deriving (Show) -- | An empty codegen table. This will cause a crash if codegen is run when not @@ -74,7 +75,8 @@ emptyCodegenTable = , tableFuns = Map.empty , tableJoinValues = Map.empty , tablePrintf = error "printf not set" - , tableOptions = defaultCompilerOpts } + , tableOptions = defaultCompilerOpts + , tableTypeTags = Map.empty } newtype CodegenT m a = CodegenT { unCodegen :: StateT CodegenTable m a } @@ -120,6 +122,28 @@ lookupValueType = \case JVar name -> gets (fst . fromJust . Map.lookup name . tableVars) JLit lit -> pure $ literalType lit +-- | If the type tag does not exist, emit it into the IR. Then return the +-- operand of the type tag. +lookupTypeTag :: (LLVMIR.MonadModuleBuilder m, MonadState CodegenTable m) + => TypeF Type + -> m LLVM.Operand +lookupTypeTag ty = + gets (Map.lookup ty . tableTypeTags) >>= \case + Just oper -> pure oper + Nothing -> do + let name = case ty of + TyNumber _ -> "Int" -- TODO: fix this + TyString -> "String" + TyBool -> "Bool" + _ -> error "TODO" + let llvmTy = typeToLLVM (Fix ty) + let nullptr = LLVMC.Null (LLVM.ptr RT.typeTagType) + let size = LLVMC.sizeof llvmTy + let tagContents = LLVMC.Struct { LLVMC.structName = Nothing + , LLVMC.isPacked = False + , LLVMC.memberValues = [size] } + LLVMIR.global (LLVM.mkName $ "simpl.tag." ++ name) RT.typeTagType tagContents + bindVariable :: MonadState CodegenTable m => Text -> TypeF Type @@ -381,20 +405,18 @@ callableCodegen callable args = case callable of [jval] -> do ty <- lookupValueType jval val <- jvalueCodegen jval - -- TODO: tag lookup - let tag = error "TODO" + tag <- lookupTypeTag ty bytes <- LLVMIR.bitcast val (LLVM.ptr LLVM.void) LLVMIR.call RT.taggedBoxRef [(tag, []), (bytes, [])] _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) CUntag -> case args of [jval] -> do - ty <- lookupValueType jval - let llvmTy = typeToLLVM (Fix ty) + ty <- lookupValueType jval >>= \case + TyBox t -> pure t + _ -> error "callableCodegen: untagging non-boxed type" val <- jvalueCodegen jval - typeTag <- LLVMIR.call RT.taggedTagRef [(val, [])] - len <- LLVMIR.call RT.tagSizeRef [(typeTag, [])] bytesPtr <- LLVMIR.call RT.taggedUnboxRef [(val, [])] - LLVMIR.bitcast bytesPtr (LLVM.ptr llvmTy) + LLVMIR.bitcast bytesPtr (typeToLLVM ty) _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) -- | Generates code for a [JExpr] @@ -462,8 +484,8 @@ typeToLLVM = go . unfix , LLVM.argumentTypes = typeToLLVM <$> args , LLVM.isVarArg = False } - TyVar _ -> RT.taggedValueType - TyBox _ -> RT.taggedValueType + TyVar _ -> LLVM.ptr RT.taggedValueType + TyBox _ -> LLVM.ptr RT.taggedValueType adtToLLVM :: Text -> [Constructor] diff --git a/src/Simpl/Backend/Runtime.hs b/src/Simpl/Backend/Runtime.hs index eb16341..343dbc3 100644 --- a/src/Simpl/Backend/Runtime.hs +++ b/src/Simpl/Backend/Runtime.hs @@ -100,18 +100,21 @@ stringStructs = ["simpl_string"] -- * Tags typeTagType :: LLVM.Type -typeTagType = runtimeStruct "simpl_type_tag" +typeTagType = LLVM.StructureType + { LLVM.isPacked = False + -- Size + , LLVM.elementTypes = [LLVM.i32] } taggedValueType :: LLVM.Type taggedValueType = runtimeStruct "simpl_tagged_value" -tagSizeType, taggedTagType, taggedUnboxType :: FunType +tagSizeType, taggedTagType, taggedBoxType, taggedUnboxType :: FunType tagSizeType = mkFunType [("t", LLVM.ptr typeTagType)] LLVM.i64 taggedTagType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr taggedValueType) +taggedBoxType = mkFunType [("t", LLVM.ptr typeTagType), ("d", LLVM.ptr LLVM.void)] (LLVM.ptr taggedValueType) taggedUnboxType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr LLVM.void) -taggedBoxType = mkFunType [("t", LLVM.ptr taggedValueType), ("d", LLVM.ptr LLVM.void)] (LLVM.ptr typeTagType) -tagSizeRef, taggedTagRef, taggedUnboxRef :: LLVM.Operand +tagSizeRef, taggedTagRef, taggedBoxRef, taggedUnboxRef :: LLVM.Operand tagSizeRef = runtimeFunRef "simpl_tag_size" tagSizeType taggedTagRef = runtimeFunRef "simpl_tagged_tag" taggedTagType taggedUnboxRef = runtimeFunRef "simpl_tagged_unbox" taggedUnboxType @@ -137,5 +140,7 @@ allRuntimeStructs = stringStructs ++ runtimeTypeStructs emitRuntimeDecls :: LLVMIR.MonadModuleBuilder m => m () emitRuntimeDecls = do forM_ allRuntimeFuns (uncurry emitRuntimeFun) - forM_ allRuntimeStructs $ \name -> do + forM_ allRuntimeStructs $ \name -> LLVMIR.typedef (LLVM.mkName ("struct." <> name)) Nothing + _ <- LLVMIR.typedef (LLVM.mkName ("struct.simpl_type_tag")) (Just typeTagType) + pure () diff --git a/src/Simpl/Type.hs b/src/Simpl/Type.hs index 3d718b4..d4c76d1 100644 --- a/src/Simpl/Type.hs +++ b/src/Simpl/Type.hs @@ -15,6 +15,7 @@ import Data.Functor.Foldable (Fix(Fix), para, cata) import Data.Text (Text) import Data.Text.Prettyprint.Doc import Data.Eq.Deriving (deriveEq1) +import Data.Ord.Deriving (deriveOrd1) import Data.Set (Set) import qualified Data.Set as Set import Text.Show.Deriving (deriveShow1) @@ -25,7 +26,7 @@ import Text.Show.Deriving (deriveShow1) data Numeric = NumDouble -- ^ 64-bit floating point | NumInt -- ^ 64-bit signed integer | NumUnknown -- ^ Unknown (defaults to 64-bit floating point) - deriving (Show, Eq) + deriving (Show, Eq, Ord) instance Pretty Numeric where pretty = \case @@ -42,12 +43,13 @@ data TypeF a | TyFun [a] a | TyVar Text | TyBox a -- ^ Boxed polymorphic type - deriving (Show, Functor, Foldable, Traversable) + deriving (Show, Eq, Ord, Functor, Foldable, Traversable) type Type = Fix TypeF $(deriveShow1 ''TypeF) $(deriveEq1 ''TypeF) +$(deriveOrd1 ''TypeF) instance Unifiable TypeF where zipMatch (TyNumber n) (TyNumber m) = case (n, m) of diff --git a/src/Simpl/Typecheck.hs b/src/Simpl/Typecheck.hs index cfb3422..63b2faf 100644 --- a/src/Simpl/Typecheck.hs +++ b/src/Simpl/Typecheck.hs @@ -266,6 +266,7 @@ typeToUtype = cata $ \case TyAdt n tparams -> UTerm (TyAdt n tparams) -- TODO: Instantiate variables somewhere TyFun args res -> UTerm (TyFun args res) TyVar n -> UTerm (TyVar n) + TyBox _ -> error "TyBox should not be in SimPL AST" -- | Instantiate the type variables with new unification variables instantiateVars :: Set Text -> Typecheck fields (Map.Map Text UType) From 94e837e013638ff86dacb55627233a4cb7146553 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Mon, 16 Dec 2019 17:12:43 -0800 Subject: [PATCH 06/23] Add boxing logic to more expression types --- src/Simpl/AstToJoinIR.hs | 23 ++++++++++++++++------- src/Simpl/Backend/Codegen.hs | 2 ++ test-suite/polymorphism.spl | 5 ++++- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 84892d2..26aa9cc 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -28,8 +28,6 @@ import qualified Data.Map as Map import Data.Set (Set) import qualified Data.Set as Set -import Debug.Trace - import Simpl.Annotation import Simpl.SymbolTable import qualified Simpl.Ast as A @@ -189,19 +187,28 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of A.Var name -> cont (J.JVar name) A.Let name bindExpr next -> anfTransform bindExpr $ \bindVal -> - makeJexpr (getType (unfix bindExpr)) . J.JLet name bindVal <$> + makeJexpr ty . J.JLet name bindVal <$> local (insertVar name ty) (anfTransform next cont) A.BinOp op left right -> anfTransform left $ \jleft -> anfTransform right $ \jright -> do + jlTy <- getJvalueType jleft + jrTy <- getJvalueType jright + (jl, boxlConv) <- rebindBoxing jleft jlTy Unboxed + (jr, boxrConv) <- rebindBoxing jright jrTy Unboxed + let boxConv = boxlConv . boxrConv name <- freshVar - makeJexpr ty . J.JApp name (J.CBinOp op) [jleft, jright] <$> + boxConv . makeJexpr ty . J.JApp name (J.CBinOp op) [jl, jr] <$> local (insertVar name ty) (cont (J.JVar name)) A.If guard trueBr falseBr -> anfTransform guard $ \jguard -> do lbl <- freshLabel - trueBr' <- anfTransform trueBr (pure . makeJexpr (astType trueBr) . J.JVal) - falseBr' <- anfTransform falseBr (pure . makeJexpr (astType falseBr) . J.JVal) + let transformBr br = anfTransform br $ \result -> do + rTy <- getJvalueType result + (result', boxConv) <- rebindBoxing result rTy (boxedVal ty) + pure . boxConv . makeJexpr (astType br) . J.JVal $ result' + trueBr' <- transformBr trueBr + falseBr' <- transformBr falseBr name <- freshVar let jmp = J.JJump lbl let guardTy = getType (unfix guard) @@ -241,7 +248,9 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of A.Cast expr numTy -> anfTransform expr $ \jexpr -> do varName <- freshVar - makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr] <$> + valTy <- getJvalueType jexpr + (jexpr', boxConv) <- rebindBoxing jexpr valTy Unboxed + boxConv . makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr'] <$> local (insertVar varName ty) (cont (J.JVar varName)) A.Print expr -> anfTransform expr $ \jval -> do diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index aa02605..42983a7 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -406,6 +406,8 @@ callableCodegen callable args = case callable of ty <- lookupValueType jval val <- jvalueCodegen jval tag <- lookupTypeTag ty + -- TODO: Box primitive types like int, bool, etc.; should require a malloc + -- and store bytes <- LLVMIR.bitcast val (LLVM.ptr LLVM.void) LLVMIR.call RT.taggedBoxRef [(tag, []), (bytes, [])] _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) diff --git a/test-suite/polymorphism.spl b/test-suite/polymorphism.spl index 1484872..55f9db3 100644 --- a/test-suite/polymorphism.spl +++ b/test-suite/polymorphism.spl @@ -2,5 +2,8 @@ fun id (x: a) : a := { x } fun main : Int := { let _ = println(@id("hi")) in + let y = 5 in + let _ = println(if true then @id("bye") else "sigh") in +# let _ = println(if @id(y) <= 0 then @id("bye") else "sigh") in 0 -} \ No newline at end of file +} From 6525b25ae29946493d01845a0df11ea42ebd9222 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 17 Dec 2019 15:24:52 -0800 Subject: [PATCH 07/23] Fix redundant type tags getting created --- src/Simpl/Backend/Codegen.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 42983a7..2429d31 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -137,12 +137,14 @@ lookupTypeTag ty = TyBool -> "Bool" _ -> error "TODO" let llvmTy = typeToLLVM (Fix ty) - let nullptr = LLVMC.Null (LLVM.ptr RT.typeTagType) let size = LLVMC.sizeof llvmTy let tagContents = LLVMC.Struct { LLVMC.structName = Nothing , LLVMC.isPacked = False , LLVMC.memberValues = [size] } - LLVMIR.global (LLVM.mkName $ "simpl.tag." ++ name) RT.typeTagType tagContents + let lname = LLVM.mkName $ "simpl.tag." ++ Text.unpack name + oper <- LLVMIR.global lname RT.typeTagType tagContents + modify (\t -> t { tableTypeTags = Map.insert ty oper (tableTypeTags t) }) + pure oper bindVariable :: MonadState CodegenTable m => Text From 26df57f3c5475cd16fea9ee5ffac84fadb75cab3 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 17 Dec 2019 15:29:21 -0800 Subject: [PATCH 08/23] Box arguments to if and case --- src/Simpl/AstToJoinIR.hs | 50 ++++++++++++++++++++++++++----------- test-suite/polymorphism.spl | 14 +++++++++++ 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 26aa9cc..4d206da 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -150,6 +150,18 @@ rebindBoxing val ty b = do (Left ty', Boxed) -> create (Fix (TyBox ty')) J.CTag _ -> pure (val, id) +-- | If needed, rebinds the given value to match the boxing target. +withRebindBoxing :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) + => BoxedVal -- ^Boxing target + -> (J.JValue -> J.JExprF (J.AnnExpr '[ 'ExprType])) + -> J.JValue -- ^Expression value + -> m (J.AnnExpr '[ 'ExprType]) +withRebindBoxing boxVal f val = do + ty <- getJvalueType val + (val', boxConv) <- rebindBoxing val ty boxVal + ty' <- getJvalueType val' + pure . boxConv . makeJexpr ty' . f $ val' + -- * ANF Transformation -- | Perform ANF transformation on the given symbol table @@ -157,21 +169,26 @@ transformTable :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVa => m (SymbolTable (J.AnnExpr '[ 'ExprType])) transformTable = do table <- asks tcSymTab - symTabTraverseExprs (\(tvars, args, ty, expr) -> (tvars, args, ty, transformExpr expr)) table + flip symTabTraverseExprs table $ \(tvars, args, ty, expr) -> + -- Initialize boxing status first, then transform + let initVars tab = foldl (flip (uncurry insertVar)) tab args + in (tvars, args, ty, local initVars (transformExpr expr (boxedVal ty))) -- | Perform ANF transformation on the given expression transformExpr :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) => A.AnnExpr flds + -> BoxedVal -- ^ Expected boxing of the final value -> m (J.AnnExpr '[ 'ExprType]) -transformExpr expr = anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) +transformExpr expr boxVal = anfTransform expr $ withRebindBoxing boxVal J.JVal -- | Perform ANF transformation on the branch, afterwards handling control flow. transformBranch :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) => J.ControlFlow (J.AnnExpr '[ 'ExprType]) -- ^ Control flow handler + -> BoxedVal -- ^ Whether branch is expected to be boxed -> A.Branch (A.AnnExpr flds) -- ^ Branches -> m (J.JBranch (J.AnnExpr '[ 'ExprType])) -transformBranch cf (A.BrAdt adtName argNames expr) = do - jexpr <- anfTransform expr (pure . makeJexpr (astType expr) . J.JVal) +transformBranch cf boxVal (A.BrAdt adtName argNames expr) = do + jexpr <- anfTransform expr $ withRebindBoxing boxVal J.JVal pure $ J.BrAdt adtName argNames (J.Cfe jexpr cf) @@ -202,27 +219,28 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of local (insertVar name ty) (cont (J.JVar name)) A.If guard trueBr falseBr -> anfTransform guard $ \jguard -> do - lbl <- freshLabel - let transformBr br = anfTransform br $ \result -> do - rTy <- getJvalueType result - (result', boxConv) <- rebindBoxing result rTy (boxedVal ty) - pure . boxConv . makeJexpr (astType br) . J.JVal $ result' + -- Guard must be unboxed to compare for truthiness + guardTy <- getJvalueType jguard + (jguard', boxConv) <- rebindBoxing jguard guardTy Unboxed + let guardCfe = boxConv (makeJexpr guardTy (J.JVal jguard')) + -- Handle branches + let transformBr br = anfTransform br $ withRebindBoxing (boxedVal ty) J.JVal trueBr' <- transformBr trueBr falseBr' <- transformBr falseBr - name <- freshVar + lbl <- freshLabel let jmp = J.JJump lbl - let guardTy = getType (unfix guard) - let guardCfe = makeJexpr guardTy (J.JVal jguard) let cfe = J.Cfe guardCfe (J.JIf (J.Cfe trueBr' jmp) (J.Cfe falseBr' jmp)) -- TODO: Make JJoin node placement more efficient + name <- freshVar makeJexpr ty . J.JJoin lbl name cfe <$> local (insertVar name ty) (cont (J.JVar name)) A.Case branches expr -> anfTransform expr $ \jexpr -> do + -- Case value must be unboxed + jexpr' <- withRebindBoxing Unboxed J.JVal jexpr lbl <- freshLabel - let jexprTy = astType expr - jbranches <- traverse (transformBranch (J.JJump lbl)) branches - let jexprCfe = J.Cfe (makeJexpr jexprTy (J.JVal jexpr)) (J.JCase jbranches) + jbranches <- traverse (transformBranch (J.JJump lbl) (boxedVal ty)) branches + let jexprCfe = J.Cfe jexpr' (J.JCase jbranches) name <- freshVar -- TODO: Make JJoin node placement more efficient makeJexpr ty . J.JJoin lbl name jexprCfe <$> @@ -250,6 +268,7 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of varName <- freshVar valTy <- getJvalueType jexpr (jexpr', boxConv) <- rebindBoxing jexpr valTy Unboxed + -- Resulting value is unboxed, so use original type boxConv . makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr'] <$> local (insertVar varName ty) (cont (J.JVar varName)) A.Print expr -> @@ -257,6 +276,7 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of varName <- freshVar valTy <- getJvalueType jval (jval', boxConv) <- rebindBoxing jval valTy Unboxed + -- Resulting value is unboxed, so use original type boxConv . makeJexpr ty . J.JApp varName J.CPrint [jval'] <$> local (insertVar varName ty) (cont (J.JVar varName)) A.FunRef name -> do diff --git a/test-suite/polymorphism.spl b/test-suite/polymorphism.spl index 55f9db3..fa46428 100644 --- a/test-suite/polymorphism.spl +++ b/test-suite/polymorphism.spl @@ -1,5 +1,19 @@ fun id (x: a) : a := { x } +data Foo = { Bar Int | Nope } + +fun test_data (f : Foo) : Foo := { + case @id(f) of + Bar y => f + Nope => f +} + +fun test_case (x : Int) : String := { + @id(case Bar x of + Bar y => @id("hi") + Nope => "bye") +} + fun main : Int := { let _ = println(@id("hi")) in let y = 5 in From 64e7f59df67b63044f0948228c9953886314698a Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Tue, 17 Dec 2019 16:19:11 -0800 Subject: [PATCH 09/23] Handle (un)boxing correctly for unboxed types during codegen --- src/Simpl/Backend/Codegen.hs | 32 ++++++++++++++++++++++++++------ src/Simpl/Backend/Runtime.hs | 8 ++++---- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 2429d31..8d360ef 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -135,7 +135,8 @@ lookupTypeTag ty = TyNumber _ -> "Int" -- TODO: fix this TyString -> "String" TyBool -> "Bool" - _ -> error "TODO" + TyAdt n _ -> "data." <> n + x -> error ("TODO: handle tag type lookup for " ++ show x) let llvmTy = typeToLLVM (Fix ty) let size = LLVMC.sizeof llvmTy let tagContents = LLVMC.Struct { LLVMC.structName = Nothing @@ -146,6 +147,14 @@ lookupTypeTag ty = modify (\t -> t { tableTypeTags = Map.insert ty oper (tableTypeTags t) }) pure oper +-- | Whether a type is represented using a pointer +typeRepIsPtr :: TypeF Type -> Bool +typeRepIsPtr = \case + TyNumber _ -> False + TyBool -> False + TyAdt _ _ -> False + _ -> True + bindVariable :: MonadState CodegenTable m => Text -> TypeF Type @@ -406,11 +415,17 @@ callableCodegen callable args = case callable of CTag -> case args of [jval] -> do ty <- lookupValueType jval - val <- jvalueCodegen jval tag <- lookupTypeTag ty - -- TODO: Box primitive types like int, bool, etc.; should require a malloc - -- and store - bytes <- LLVMIR.bitcast val (LLVM.ptr LLVM.void) + let llvmTy = typeToLLVM (Fix ty) + let size = LLVM.ConstantOperand (LLVMC.ZExt (LLVMC.sizeof llvmTy) LLVM.i64) + -- We only need to allocate if the value is not a pointer + val <- jvalueCodegen jval + bytes <- if typeRepIsPtr ty + then LLVMIR.bitcast val (LLVM.ptr LLVM.i8) + else do b <- LLVMIR.call RT.mallocRef [(size, [])] + allocRef <- LLVMIR.bitcast b (LLVM.ptr llvmTy) + LLVMIR.store allocRef 0 val + pure b LLVMIR.call RT.taggedBoxRef [(tag, []), (bytes, [])] _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) CUntag -> case args of @@ -420,7 +435,12 @@ callableCodegen callable args = case callable of _ -> error "callableCodegen: untagging non-boxed type" val <- jvalueCodegen jval bytesPtr <- LLVMIR.call RT.taggedUnboxRef [(val, [])] - LLVMIR.bitcast bytesPtr (typeToLLVM ty) + let llvmTy = typeToLLVM ty + -- Need to unbox if the stored type is not a pointer + if typeRepIsPtr (unfix ty) + then LLVMIR.bitcast bytesPtr llvmTy + else do ptr <- LLVMIR.bitcast bytesPtr (LLVM.ptr llvmTy) + LLVMIR.load ptr 0 _ -> error $ "callableCodegen: expected 1 args to CTag, got " ++ show (length args) -- | Generates code for a [JExpr] diff --git a/src/Simpl/Backend/Runtime.hs b/src/Simpl/Backend/Runtime.hs index 343dbc3..06915fa 100644 --- a/src/Simpl/Backend/Runtime.hs +++ b/src/Simpl/Backend/Runtime.hs @@ -64,8 +64,8 @@ mallocType = mkFunType [("ptr", LLVM.i64)] (LLVM.ptr LLVM.i8) memcpyType = mkFunType [ ("dest", LLVM.ptr LLVM.i8) , ("src", LLVM.ptr LLVM.i8) , ("len", LLVM.i64) ] - LLVM.void -printfType = ([("", LLVM.ptr LLVM.i8)], LLVM.void, True) + LLVM.i8 +printfType = ([("", LLVM.ptr LLVM.i8)], LLVM.i8, True) mallocRef, memcpyRef, printfRef :: LLVM.Operand mallocRef = runtimeFunRef "simpl_malloc" mallocType @@ -111,8 +111,8 @@ taggedValueType = runtimeStruct "simpl_tagged_value" tagSizeType, taggedTagType, taggedBoxType, taggedUnboxType :: FunType tagSizeType = mkFunType [("t", LLVM.ptr typeTagType)] LLVM.i64 taggedTagType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr taggedValueType) -taggedBoxType = mkFunType [("t", LLVM.ptr typeTagType), ("d", LLVM.ptr LLVM.void)] (LLVM.ptr taggedValueType) -taggedUnboxType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr LLVM.void) +taggedBoxType = mkFunType [("t", LLVM.ptr typeTagType), ("d", LLVM.ptr LLVM.i8)] (LLVM.ptr taggedValueType) +taggedUnboxType = mkFunType [("t", LLVM.ptr taggedValueType)] (LLVM.ptr LLVM.i8) tagSizeRef, taggedTagRef, taggedBoxRef, taggedUnboxRef :: LLVM.Operand tagSizeRef = runtimeFunRef "simpl_tag_size" tagSizeType From 8f79e0fda337f859f7646038e879c542a391d8be Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 18 Dec 2019 21:14:27 -0800 Subject: [PATCH 10/23] Instantiate type variables for case/ctor in type checker --- src/Simpl/AstToJoinIR.hs | 18 ++++++++++---- src/Simpl/Typecheck.hs | 52 ++++++++++++++++++++++++++++------------ test-suite/poly-data.spl | 8 +++++++ 3 files changed, 59 insertions(+), 19 deletions(-) create mode 100644 test-suite/poly-data.spl diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 4d206da..a9fc1fa 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -20,6 +20,8 @@ module Simpl.AstToJoinIR import Control.Monad.Reader hiding (guard) import Data.Functor.Foldable (Fix(..), unfix) import Data.Functor.Identity +import Data.Foldable (fold) +import Data.Monoid (Endo(..), appEndo) import Data.Text (Text) import Data.String (fromString) import Data.Maybe (fromJust) @@ -188,7 +190,9 @@ transformBranch :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshV -> A.Branch (A.AnnExpr flds) -- ^ Branches -> m (J.JBranch (J.AnnExpr '[ 'ExprType])) transformBranch cf boxVal (A.BrAdt adtName argNames expr) = do - jexpr <- anfTransform expr $ withRebindBoxing boxVal J.JVal + (_, A.Ctor _ argTys, _) <- asks (fromJust . symTabLookupCtor adtName . tcSymTab) + let withScope ctx = foldr (uncurry insertVar) ctx (argNames `zip` argTys) + jexpr <- local withScope $ anfTransform expr $ withRebindBoxing boxVal J.JVal pure $ J.BrAdt adtName argNames (J.Cfe jexpr cf) @@ -247,9 +251,15 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of local (insertVar name ty) (cont (J.JVar name)) A.Cons ctorName args -> collectArgs args $ \argVals -> do - -- TODO: find and box polymorphic args + (_, A.Ctor _ ctorTyArgs, _) <- asks (fromJust . symTabLookupCtor ctorName . tcSymTab) + -- Box each argument as needed + pairs <- forM (argVals `zip` ctorTyArgs) $ \(jv, cty) -> do + jty <- getJvalueType jv + rebindBoxing jv jty (boxedVal cty) + let (argVals', boxConvs) = unzip pairs + let boxConv = appEndo (fold (fmap Endo boxConvs)) varName <- freshVar - makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals <$> + boxConv . makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals' <$> local (insertVar varName ty) (cont (J.JVar varName)) A.App funcName args -> collectArgs args $ \argVals -> do @@ -259,7 +269,7 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of tuples <- sequence [rebindBoxing val aTy (boxedVal faTy) | (((_, faTy), val), aTy) <- funcArgs `zip` argVals `zip` argTys] let (argVals', boxConvArgs_) = unzip tuples - let boxConvArgs = foldl (.) id boxConvArgs_ + let boxConvArgs = appEndo (fold (fmap Endo boxConvArgs_)) let ty' = if isBoxed funcRetTy then Fix (TyBox ty) else ty boxConvArgs . makeJexpr ty' . J.JApp varName (J.CFunc funcName) argVals' <$> local (insertVar varName ty') (cont (J.JVar varName)) diff --git a/src/Simpl/Typecheck.hs b/src/Simpl/Typecheck.hs index 63b2faf..64e6144 100644 --- a/src/Simpl/Typecheck.hs +++ b/src/Simpl/Typecheck.hs @@ -16,7 +16,7 @@ import Control.Monad.Except (ExceptT, MonadError, lift, runExceptT, throwError) import Control.Unification import Control.Unification.IntVar import Data.Maybe (fromMaybe) -import Data.Foldable (traverse_, asum) +import Data.Foldable (asum) import Data.Functor.Identity import Data.Functor.Foldable (Fix(..), unfix, cata) import Data.Text (Text) @@ -112,13 +112,20 @@ inferType = cata $ \ae -> case annGetExpr ae of let argTys = extractTy <$> args ctorRes <- asks (symTabLookupCtor name) case ctorRes of - Just (adtTy, Ctor _ ctorArgTys, _) -> do - let conArgs = typeToUtype <$> ctorArgTys - let numConArgs = length conArgs + Just (Fix (TyAdt adtName tparamTys), Ctor _ ctorArgTys, _) -> do + let numConArgs = length ctorArgTys when (numConArgs /= length argTys) $ throwError $ TyErrArgCount numConArgs (length argTys) ctorArgTys - traverse_ (uncurry unifyTy) (zip conArgs argTys) - pure $ annotate (Cons name args) (annGetAnn ae) (typeToUtype adtTy) + -- Instantiate type variables + let tparams = flip fmap tparamTys $ \t -> case unfix t of + TyVar n -> n + _ -> error "Symbol table ADT types should only contain variables" + substMap <- instantiateVars (Set.fromList tparams) + let conArgs = substituteUVars substMap . typeToUtype <$> ctorArgTys + argTys' <- traverse (uncurry unifyTy) (zip conArgs argTys) + let newTy = UTerm (TyAdt adtName argTys') + pure $ annotate (Cons name args) (annGetAnn ae) newTy + Just ty -> error $ "Symbol table contained ADT with invalid type: " ++ show ty Nothing -> throwError $ TyErrNoSuchCtor name Case branchMs valM -> do val <- valM @@ -126,22 +133,36 @@ inferType = cata $ \ae -> case annGetExpr ae of branches <- forM branchMs $ \case BrAdt ctorName bindings exprM -> asks (symTabLookupCtor ctorName) >>= \case - Just (dataTy, Ctor _ ctorArgs, _) -> do + Just (dataTy@(Fix (TyAdt _ tparamTys)), Ctor _ ctorArgs, _) -> do when (length bindings /= length ctorArgs) $ throwError $ TyErrArgCount (length ctorArgs) (length bindings) ctorArgs - _ <- unifyTy valTy (typeToUtype dataTy) - let updatedBinds = Map.fromList (bindings `zip` ctorArgs) + -- Instantiate type variables + let tparams = flip fmap tparamTys $ \t -> case unfix t of + TyVar n -> n + _ -> error "Symbol table ADT types should only contain variables" + substMap <- instantiateVars (Set.fromList tparams) + _ <- unifyTy valTy (substituteUVars substMap (typeToUtype dataTy)) + let substCtorArgs = substituteUVars substMap . typeToUtype <$> ctorArgs + -- Same hack as in let binding + instCtorArgs <- forM substCtorArgs $ \t -> do + t' <- utypeToType <$> forceBindings t + case t' of + Just t'' -> pure t'' + Nothing -> throwError $ TyErrAmbiguousType t + let updatedBinds = Map.fromList (bindings `zip` instCtorArgs) -- Infer result type with ctor args bound expr <- local (\t -> t { symTabVars = Map.union (symTabVars t) updatedBinds }) exprM pure $ BrAdt ctorName bindings expr + Just ty -> error $ "Symbol table contained ADT with invalid type: " ++ show ty Nothing -> throwError $ TyErrNoSuchCtor ctorName let brTys = extractTy . branchGetExpr <$> branches resTy <- mkMetaVar annotate (Case branches val) (annGetAnn ae) <$> foldM unifyTy resTy brTys Let name valM nextM -> do val <- valM + valTy <- forceBindings (extractTy val) -- TODO: Hack, fix this - case utypeToType (extractTy val) of + case utypeToType valTy of Just ty -> do next <- local (symTabInsertVar name ty) nextM pure $ annotate (Let name val next) (annGetAnn ae) (extractTy next) @@ -174,11 +195,12 @@ inferType = cata $ \ae -> case annGetExpr ae of pure $ annotate (App name params') (annGetAnn ae) resTy FunRef name -> asks (symTabLookupStaticFun name) >>= \case - Just (tvars, params, ty, _) -> - -- TODO: How to handle polymorphic function? - let paramTys = snd <$> params - funTy = typeToUtype (Fix $ TyFun paramTys ty) in - pure $ annotate (FunRef name) (annGetAnn ae) funTy + Just (tvars, params, ty, _) -> do + substMap <- instantiateVars tvars + let paramTys = substituteUVars substMap . typeToUtype . snd <$> params + let resTy = substituteUVars substMap (typeToUtype ty) + let funTy = substituteUVars substMap (UTerm (TyFun paramTys resTy)) + pure $ annotate (FunRef name) (annGetAnn ae) funTy Nothing -> throwError $ TyErrNoSuchVar name Cast exprM num -> do expr <- exprM diff --git a/test-suite/poly-data.spl b/test-suite/poly-data.spl new file mode 100644 index 0000000..8e346c3 --- /dev/null +++ b/test-suite/poly-data.spl @@ -0,0 +1,8 @@ +data Maybe a = { Just a | Nothing } + +fun main : Int := { + let foo = Just 5 in # 0 + case foo of + Just x => let _ = println("Just OK") in x + Nothing => let _ = println("ERROR: Got Nothing") in 0 +} From 57294bc9651a96a3f5e80a338b1c90f14a51bc3d Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Wed, 18 Dec 2019 23:28:41 -0800 Subject: [PATCH 11/23] Rewrite boxing logic to use continuation passing style Possibly fixes scoping issues? --- src/Simpl/AstToJoinIR.hs | 159 ++++++++++++++++++++------------------- 1 file changed, 82 insertions(+), 77 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index a9fc1fa..355bcc2 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -20,8 +20,6 @@ module Simpl.AstToJoinIR import Control.Monad.Reader hiding (guard) import Data.Functor.Foldable (Fix(..), unfix) import Data.Functor.Identity -import Data.Foldable (fold) -import Data.Monoid (Endo(..), appEndo) import Data.Text (Text) import Data.String (fromString) import Data.Maybe (fromJust) @@ -140,17 +138,18 @@ boxedVal t = if isBoxed t then Boxed else Unboxed rebindBoxing :: (HasType flds, MonadFreshVar m, MonadReader (TransformCtx flds) m) => J.JValue -- ^ Variable name - -> Type -- ^ Variable type -> BoxedVal -- ^ Whether to ensure boxed or unboxed - -> m (J.JValue, J.AnnExpr '[ 'ExprType] -> J.AnnExpr '[ 'ExprType]) -rebindBoxing val ty b = do + -> (J.JValue -> m (J.AnnExpr '[ 'ExprType])) -- ^ Continuation + -> m (J.AnnExpr '[ 'ExprType]) +rebindBoxing val b cont = do + ty <- getJvalueType val let create ty' action = do name <- case val of { J.JVar n -> pure n; _ -> freshVar } - local (insertVar name ty') (pure (J.JVar name, makeJexpr ty' . J.JApp name action [val])) + makeJexpr ty' . J.JApp name action [val] <$> local (insertVar name ty')(cont (J.JVar name)) case (boxedType ty, b) of (Right ty', Unboxed) -> create ty' J.CUntag (Left ty', Boxed) -> create (Fix (TyBox ty')) J.CTag - _ -> pure (val, id) + _ -> cont val -- | If needed, rebinds the given value to match the boxing target. withRebindBoxing :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) @@ -158,11 +157,10 @@ withRebindBoxing :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFresh -> (J.JValue -> J.JExprF (J.AnnExpr '[ 'ExprType])) -> J.JValue -- ^Expression value -> m (J.AnnExpr '[ 'ExprType]) -withRebindBoxing boxVal f val = do - ty <- getJvalueType val - (val', boxConv) <- rebindBoxing val ty boxVal - ty' <- getJvalueType val' - pure . boxConv . makeJexpr ty' . f $ val' +withRebindBoxing boxVal f val = + rebindBoxing val boxVal $ \val' -> do + ty <- getJvalueType val' + pure . makeJexpr ty . f $ val' -- * ANF Transformation @@ -212,95 +210,102 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of local (insertVar name ty) (anfTransform next cont) A.BinOp op left right -> anfTransform left $ \jleft -> - anfTransform right $ \jright -> do - jlTy <- getJvalueType jleft - jrTy <- getJvalueType jright - (jl, boxlConv) <- rebindBoxing jleft jlTy Unboxed - (jr, boxrConv) <- rebindBoxing jright jrTy Unboxed - let boxConv = boxlConv . boxrConv - name <- freshVar - boxConv . makeJexpr ty . J.JApp name (J.CBinOp op) [jl, jr] <$> - local (insertVar name ty) (cont (J.JVar name)) + anfTransform right $ \jright -> + rebindBoxing jleft Unboxed $ \jl -> + rebindBoxing jright Unboxed $ \jr -> do + name <- freshVar + makeJexpr ty . J.JApp name (J.CBinOp op) [jl, jr] <$> + local (insertVar name ty) (cont (J.JVar name)) A.If guard trueBr falseBr -> - anfTransform guard $ \jguard -> do + anfTransform guard $ \jguard -> -- Guard must be unboxed to compare for truthiness - guardTy <- getJvalueType jguard - (jguard', boxConv) <- rebindBoxing jguard guardTy Unboxed - let guardCfe = boxConv (makeJexpr guardTy (J.JVal jguard')) - -- Handle branches - let transformBr br = anfTransform br $ withRebindBoxing (boxedVal ty) J.JVal - trueBr' <- transformBr trueBr - falseBr' <- transformBr falseBr - lbl <- freshLabel - let jmp = J.JJump lbl - let cfe = J.Cfe guardCfe (J.JIf (J.Cfe trueBr' jmp) (J.Cfe falseBr' jmp)) - -- TODO: Make JJoin node placement more efficient - name <- freshVar - makeJexpr ty . J.JJoin lbl name cfe <$> - local (insertVar name ty) (cont (J.JVar name)) + rebindBoxing jguard Unboxed $ \jguard' -> do + guardTy <- getJvalueType jguard' + let guardCfe = makeJexpr guardTy (J.JVal jguard') + -- Handle branches + let transformBr br = anfTransform br $ withRebindBoxing (boxedVal ty) J.JVal + trueBr' <- transformBr trueBr + falseBr' <- transformBr falseBr + lbl <- freshLabel + let jmp = J.JJump lbl + let cfe = J.Cfe guardCfe (J.JIf (J.Cfe trueBr' jmp) (J.Cfe falseBr' jmp)) + -- TODO: Make JJoin node placement more efficient + name <- freshVar + makeJexpr ty . J.JJoin lbl name cfe <$> + local (insertVar name ty) (cont (J.JVar name)) A.Case branches expr -> - anfTransform expr $ \jexpr -> do + anfTransform expr $ \jexpr -> -- Case value must be unboxed - jexpr' <- withRebindBoxing Unboxed J.JVal jexpr - lbl <- freshLabel - jbranches <- traverse (transformBranch (J.JJump lbl) (boxedVal ty)) branches - let jexprCfe = J.Cfe jexpr' (J.JCase jbranches) - name <- freshVar - -- TODO: Make JJoin node placement more efficient - makeJexpr ty . J.JJoin lbl name jexprCfe <$> - local (insertVar name ty) (cont (J.JVar name)) + rebindBoxing jexpr Unboxed $ \jexpr' -> do + jTy <- getJvalueType jexpr' + -- Transform branches + lbl <- freshLabel + jbranches <- traverse (transformBranch (J.JJump lbl) (boxedVal ty)) branches + let jexprCfe = J.Cfe (makeJexpr jTy (J.JVal jexpr')) (J.JCase jbranches) + name <- freshVar + -- TODO: Make JJoin node placement more efficient + makeJexpr ty . J.JJoin lbl name jexprCfe <$> + local (insertVar name ty) (cont (J.JVar name)) A.Cons ctorName args -> collectArgs args $ \argVals -> do (_, A.Ctor _ ctorTyArgs, _) <- asks (fromJust . symTabLookupCtor ctorName . tcSymTab) -- Box each argument as needed - pairs <- forM (argVals `zip` ctorTyArgs) $ \(jv, cty) -> do - jty <- getJvalueType jv - rebindBoxing jv jty (boxedVal cty) - let (argVals', boxConvs) = unzip pairs - let boxConv = appEndo (fold (fmap Endo boxConvs)) - varName <- freshVar - boxConv . makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals' <$> - local (insertVar varName ty) (cont (J.JVar varName)) + collectRebinds (argVals `zip` (boxedVal <$> ctorTyArgs)) $ \argVals' -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCtor ctorName) argVals' <$> + local (insertVar varName ty) (cont (J.JVar varName)) A.App funcName args -> collectArgs args $ \argVals -> do varName <- freshVar (_, funcArgs, funcRetTy, _) <- asks (fromJust . symTabLookupStaticFun funcName . tcSymTab) - argTys <- traverse getJvalueType argVals - tuples <- sequence [rebindBoxing val aTy (boxedVal faTy) - | (((_, faTy), val), aTy) <- funcArgs `zip` argVals `zip` argTys] - let (argVals', boxConvArgs_) = unzip tuples - let boxConvArgs = appEndo (fold (fmap Endo boxConvArgs_)) - let ty' = if isBoxed funcRetTy then Fix (TyBox ty) else ty - boxConvArgs . makeJexpr ty' . J.JApp varName (J.CFunc funcName) argVals' <$> - local (insertVar varName ty') (cont (J.JVar varName)) + let valueBoxPairs = [(val, boxedVal fTy) | (val, (_, fTy)) <- argVals `zip` funcArgs] + collectRebinds valueBoxPairs $ \argVals' -> do + let ty' = if isBoxed funcRetTy then Fix (TyBox ty) else ty + makeJexpr ty' . J.JApp varName (J.CFunc funcName) argVals' <$> + local (insertVar varName ty') (cont (J.JVar varName)) A.Cast expr numTy -> - anfTransform expr $ \jexpr -> do - varName <- freshVar - valTy <- getJvalueType jexpr - (jexpr', boxConv) <- rebindBoxing jexpr valTy Unboxed + anfTransform expr $ \jexpr -> -- Resulting value is unboxed, so use original type - boxConv . makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr'] <$> - local (insertVar varName ty) (cont (J.JVar varName)) + rebindBoxing jexpr Unboxed $ \jexpr' -> do + varName <- freshVar + makeJexpr ty . J.JApp varName (J.CCast numTy) [jexpr'] <$> + local (insertVar varName ty) (cont (J.JVar varName)) A.Print expr -> - anfTransform expr $ \jval -> do - varName <- freshVar - valTy <- getJvalueType jval - (jval', boxConv) <- rebindBoxing jval valTy Unboxed + anfTransform expr $ \jval -> -- Resulting value is unboxed, so use original type - boxConv . makeJexpr ty . J.JApp varName J.CPrint [jval'] <$> - local (insertVar varName ty) (cont (J.JVar varName)) + rebindBoxing jval Unboxed $ \jval' -> do + varName <- freshVar + makeJexpr ty . J.JApp varName J.CPrint [jval'] <$> + local (insertVar varName ty) (cont (J.JVar varName)) A.FunRef name -> do varName <- freshVar makeJexpr ty . J.JApp varName (J.CFunRef name) [] <$> local (insertVar varName ty) (cont (J.JVar varName)) +-- | Utility function for collecting arguments to CPS style functions +collectConts :: (a -> (b -> m c) -> m c) + -> [a] + -> ([b] -> m c) + -> m c +collectConts f as cont = go [] as + where + go vals [] = cont (reverse vals) + go vals (x:xs) = f x $ \v -> go (v:vals) xs + -- | Normalize each expression in sequential order, and then run the -- continuation with the expression values. collectArgs :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) => [A.AnnExpr flds] -- ^ Argument expressions -> ([J.JValue] -> m (J.AnnExpr '[ 'ExprType])) -- ^ Continuation -> m (J.AnnExpr '[ 'ExprType]) -collectArgs = go [] - where - go vals [] mcont = mcont (reverse vals) - go vals (x:xs) mcont = anfTransform x $ \v -> go (v:vals) xs mcont +collectArgs = collectConts anfTransform +-- collectArgs = go [] +-- where +-- go vals [] mcont = mcont (reverse vals) +-- go vals (x:xs) mcont = anfTransform x $ \v -> go (v:vals) xs mcont + +collectRebinds :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) + => [(J.JValue, BoxedVal)] -- ^ Value-boxing pairs + -> ([J.JValue] -> m (J.AnnExpr '[ 'ExprType])) -- ^ Continuation + -> m (J.AnnExpr '[ 'ExprType]) +collectRebinds = collectConts (uncurry rebindBoxing) From 0ed77b2fccf6ffc99462fed469854d590309587a Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Thu, 19 Dec 2019 00:12:07 -0800 Subject: [PATCH 12/23] Fixing scoping errors in typecheck and codegen --- src/Simpl/AstToJoinIR.hs | 5 +++-- src/Simpl/Backend/Codegen.hs | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 355bcc2..c455283 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -205,8 +205,9 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of A.Lit lit -> cont (J.JLit lit) A.Var name -> cont (J.JVar name) A.Let name bindExpr next -> - anfTransform bindExpr $ \bindVal -> - makeJexpr ty . J.JLet name bindVal <$> + anfTransform bindExpr $ \bindVal -> do + bindTy <- getJvalueType bindVal + makeJexpr bindTy . J.JLet name bindVal <$> local (insertVar name ty) (anfTransform next cont) A.BinOp op left right -> anfTransform left $ \jleft -> diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 8d360ef..d84bfb9 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -308,7 +308,9 @@ controlFlowCodegen val valOper = \case let labelName = "case_" <> fromString (Text.unpack name) in (name, ) <$> LLVMIR.freshName labelName -- Assume the symbol table and type information is correct - dataName <- (\case { TyAdt n _ -> n; _ -> error "" }) <$> lookupValueType val + dataName <- flip fmap (lookupValueType val) $ \case + TyAdt n _ -> n + t -> error $ "controlFlowCodegen: Unexpected case value type: " ++ show t ++ " " ++ show val ctors <- gets ((\(_,_,cs) -> cs) . fromJust . Map.lookup dataName . tableAdts) let ctorNames = ctorGetName <$> ctors let usedLabelTriples = filter (\(_, (n, _)) -> n `elem` ctorNames) $ [0..] `zip` allCaseLabels @@ -321,23 +323,23 @@ controlFlowCodegen val valOper = \case let cf = branchGetControlFlow br (_, ctorLLVMName, index) <- gets (fromJust . Map.lookup ctorName . tableCtors) let Ctor _ argTys = ctors !! index - let bindingPairs = branchGetBindings br `zip` (typeToLLVM <$> argTys) + let bindingPairs = branchGetBindings br `zip` argTys LLVMIR.emitBlockStart label ctorPtr <- LLVMIR.bitcast dataPtr (LLVM.ptr (LLVM.NamedTypeReference ctorLLVMName)) let ctorPtrOffset = LLVMIR.int32 0 - bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, llvmTy)) -> do + bindings <- forM ([0..] `zip` bindingPairs) $ \(i, (n, ty)) -> do let ctorPtrIndex = LLVMIR.int32 i -- Need to bitcast the ptr type because we need a concrete type. We -- also need to load the data immediately because of how variables -- are implemented. v <- LLVMIR.gep ctorPtr [ctorPtrOffset, ctorPtrIndex] - >>= flip LLVMIR.bitcast (LLVM.ptr llvmTy) + >>= flip LLVMIR.bitcast (LLVM.ptr (typeToLLVM ty)) >>= flip LLVMIR.load 0 - ty <- lookupValueType (JVar n) - pure (n, (ty, v)) + pure (n, (unfix ty, v)) let updateTable t = t { tableVars = Map.union (tableVars t) (Map.fromList bindings) } - (exprVal, exprOper) <- jexprCodegen expr - localCodegenTable updateTable (controlFlowCodegen exprVal exprOper cf) + localCodegenTable updateTable $ do + (exprVal, exprOper) <- jexprCodegen expr + controlFlowCodegen exprVal exprOper cf LLVMIR.emitBlockStart defLabel LLVMIR.unreachable JJump lbl -> do @@ -432,7 +434,7 @@ callableCodegen callable args = case callable of [jval] -> do ty <- lookupValueType jval >>= \case TyBox t -> pure t - _ -> error "callableCodegen: untagging non-boxed type" + t -> error $ "callableCodegen: untagging non-boxed type " ++ show t val <- jvalueCodegen jval bytesPtr <- LLVMIR.call RT.taggedUnboxRef [(val, [])] let llvmTy = typeToLLVM ty From 829da43eb28cf22e718e050ffb8de1b113ebde11 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Thu, 19 Dec 2019 13:10:56 -0800 Subject: [PATCH 13/23] Fix JoinIR+codegen for case on polymorphic ADTs; fix scoping issue --- src/Simpl/AstToJoinIR.hs | 24 ++++++++++++++++++------ src/Simpl/Backend/Codegen.hs | 19 +++++-------------- src/Simpl/JoinIR/Syntax.hs | 14 ++++++++------ src/Simpl/JoinIR/Verify.hs | 2 +- src/Simpl/SymbolTable.hs | 13 ++++++------- src/Simpl/Type.hs | 16 ++++++++++++++++ src/Simpl/Typecheck.hs | 19 ++++++------------- 7 files changed, 60 insertions(+), 47 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index c455283..f78df9d 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -32,7 +32,7 @@ import Simpl.Annotation import Simpl.SymbolTable import qualified Simpl.Ast as A import qualified Simpl.JoinIR.Syntax as J -import Simpl.Type (Type, TypeF(TyBox, TyVar)) +import Simpl.Type (Type, TypeF(TyBox, TyVar, TyAdt), substituteTypeVars, typeRepIsPtr) import Simpl.Typecheck (literalType) import Simpl.Util.Supply import qualified Simpl.Util.Stream as Stream @@ -184,14 +184,26 @@ transformExpr expr boxVal = anfTransform expr $ withRebindBoxing boxVal J.JVal -- | Perform ANF transformation on the branch, afterwards handling control flow. transformBranch :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshVar m) => J.ControlFlow (J.AnnExpr '[ 'ExprType]) -- ^ Control flow handler + -> Type -- ^ Type of guard expression -> BoxedVal -- ^ Whether branch is expected to be boxed -> A.Branch (A.AnnExpr flds) -- ^ Branches -> m (J.JBranch (J.AnnExpr '[ 'ExprType])) -transformBranch cf boxVal (A.BrAdt adtName argNames expr) = do - (_, A.Ctor _ argTys, _) <- asks (fromJust . symTabLookupCtor adtName . tcSymTab) +transformBranch cf guardTy boxVal (A.BrAdt ctorName argNames expr) = do + (_, A.Ctor _ argTys, _) <- asks (fromJust . symTabLookupCtor ctorName . tcSymTab) let withScope ctx = foldr (uncurry insertVar) ctx (argNames `zip` argTys) jexpr <- local withScope $ anfTransform expr $ withRebindBoxing boxVal J.JVal - pure $ J.BrAdt adtName argNames (J.Cfe jexpr cf) + argTys' <- case unfix guardTy of + TyAdt name tvars -> do + (tvars', _) <- asks (fromJust . symTabLookupAdt name . tcSymTab) + let tvarMapping = Map.fromList (tvars' `zip` tvars) + pure $ flip fmap argTys $ \t -> case unfix t of + TyVar _ -> + let t' = substituteTypeVars tvarMapping t in + if t' == t || typeRepIsPtr (unfix t') + then t' else Fix (TyBox t') + _ -> substituteTypeVars tvarMapping t + _ -> pure argTys + pure $ J.BrAdt ctorName (argNames `zip` argTys') (J.Cfe jexpr cf) -- | Main ANF transformation logic. Given the SimPL AST, this function will @@ -208,7 +220,7 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of anfTransform bindExpr $ \bindVal -> do bindTy <- getJvalueType bindVal makeJexpr bindTy . J.JLet name bindVal <$> - local (insertVar name ty) (anfTransform next cont) + local (insertVar name bindTy) (anfTransform next cont) A.BinOp op left right -> anfTransform left $ \jleft -> anfTransform right $ \jright -> @@ -241,7 +253,7 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of jTy <- getJvalueType jexpr' -- Transform branches lbl <- freshLabel - jbranches <- traverse (transformBranch (J.JJump lbl) (boxedVal ty)) branches + jbranches <- traverse (transformBranch (J.JJump lbl) jTy (boxedVal ty)) branches let jexprCfe = J.Cfe (makeJexpr jTy (J.JVal jexpr')) (J.JCase jbranches) name <- freshVar -- TODO: Make JJoin node placement more efficient diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index d84bfb9..2afed36 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -48,7 +48,7 @@ import Simpl.Annotation hiding (AnnExpr, AnnExprF) import Simpl.Ast (BinaryOp(..), Constructor(..), Literal(..)) import Simpl.CompilerOptions import Simpl.SymbolTable -import Simpl.Type (Type(..), TypeF(..), Numeric(..)) +import Simpl.Type (Type, TypeF(..), Numeric(..), typeRepIsPtr) import Simpl.Typecheck (literalType) import Simpl.Backend.Runtime () import Simpl.JoinIR.Syntax @@ -57,7 +57,7 @@ import qualified Simpl.Backend.Runtime as RT data CodegenTable = MkCodegenTable { tableVars :: Map Text (TypeF Type, LLVM.Operand) -- ^ Pointer to variables , tableCtors :: Map Text (LLVM.Name, LLVM.Name, Int) -- ^ Data type name, ctor name, index - , tableAdts :: Map Text (LLVM.Name, Type, [Constructor]) + , tableAdts :: Map Text (LLVM.Name, [Text], [Constructor]) , tableFuns :: Map Text LLVM.Operand , tableJoinValues :: Map Text (LLVM.Name, [(LLVM.Operand, LLVM.Name)]) , tablePrintf :: LLVM.Operand @@ -147,14 +147,6 @@ lookupTypeTag ty = modify (\t -> t { tableTypeTags = Map.insert ty oper (tableTypeTags t) }) pure oper --- | Whether a type is represented using a pointer -typeRepIsPtr :: TypeF Type -> Bool -typeRepIsPtr = \case - TyNumber _ -> False - TyBool -> False - TyAdt _ _ -> False - _ -> True - bindVariable :: MonadState CodegenTable m => Text -> TypeF Type @@ -321,9 +313,8 @@ controlFlowCodegen val valOper = \case forM_ (usedLabelTriples `zip` branches) $ \((_, (ctorName, label)), br) -> do let expr = branchGetExpr br let cf = branchGetControlFlow br - (_, ctorLLVMName, index) <- gets (fromJust . Map.lookup ctorName . tableCtors) - let Ctor _ argTys = ctors !! index - let bindingPairs = branchGetBindings br `zip` argTys + (_, ctorLLVMName, _) <- gets (fromJust . Map.lookup ctorName . tableCtors) + let bindingPairs = branchGetBindings br LLVMIR.emitBlockStart label ctorPtr <- LLVMIR.bitcast dataPtr (LLVM.ptr (LLVM.NamedTypeReference ctorLLVMName)) let ctorPtrOffset = LLVMIR.int32 0 @@ -584,7 +575,7 @@ moduleCodegen srcCode symTab = mdo -- definition doesn't matter. This works because the codegen monad is lazy. modify (\t -> t { tableFuns = tableFuns t `Map.union` Map.fromList funOpers }) -- TODO: Care about type variables - funOpers <- forM (Map.toList . symTabFuns $ symTab) $ \(name, (tvars, params, ty, body)) -> + funOpers <- forM (Map.toList . symTabFuns $ symTab) $ \(name, (_, params, ty, body)) -> (name, ) <$> funToLLVM name params ty body _ <- LLVMIR.function "main" [] LLVM.i64 $ \_ -> do diff --git a/src/Simpl/JoinIR/Syntax.hs b/src/Simpl/JoinIR/Syntax.hs index 3a6a93b..e02dddd 100644 --- a/src/Simpl/JoinIR/Syntax.hs +++ b/src/Simpl/JoinIR/Syntax.hs @@ -23,7 +23,7 @@ import Data.Text.Prettyprint.Doc (Pretty, pretty, (<>), (<+>)) import qualified Data.Text.Prettyprint.Doc as PP import qualified Simpl.Annotation as Ann import Simpl.Ast (BinaryOp(..), Literal(..)) -import Simpl.Type (Numeric(..)) +import Simpl.Type (Numeric(..), Type) import Text.Show.Deriving (deriveShow1) import Data.Functor.Foldable @@ -66,14 +66,14 @@ data ControlFlow a data JBranch a - = BrAdt Name [Name] !(Cfe a) -- ^ Destructure algebraic data type + = BrAdt Name [(Name, Type)] !(Cfe a) -- ^ Destructure algebraic data type deriving (Functor, Foldable, Traversable, Show) branchGetExpr :: JBranch a -> a branchGetExpr = \case BrAdt _ _ (Cfe e _) -> e -branchGetBindings :: JBranch a -> [Text] +branchGetBindings :: JBranch a -> [(Text, Type)] branchGetBindings = \case BrAdt _ vars _ -> vars @@ -135,9 +135,11 @@ instance Pretty JValue where JLit l -> pretty l instance Pretty a => Pretty (JBranch a) where - pretty (BrAdt ctorName varNames cfe) = - PP.hang 2 $ PP.hsep (pretty <$> brPart) <> PP.softline <> pretty cfe - where brPart = [ctorName] ++ varNames ++ ["=>"] + pretty (BrAdt ctorName varPairs cfe) = + PP.hang 2 $ PP.hsep brPart <> PP.softline <> pretty cfe + where + vars = (\(n, t) -> PP.parens (pretty n <+> ":" <+> pretty t)) <$> varPairs + brPart = [pretty ctorName] ++ vars ++ [pretty ("=>" :: Text)] instance Pretty a => Pretty (ControlFlow a) where pretty = \case diff --git a/src/Simpl/JoinIR/Verify.hs b/src/Simpl/JoinIR/Verify.hs index 6b85d11..152654c 100644 --- a/src/Simpl/JoinIR/Verify.hs +++ b/src/Simpl/JoinIR/Verify.hs @@ -142,5 +142,5 @@ doVerifyBranch :: (MonadError VerifyError m, MonadReader VerifyCtx m) -> m () doVerifyBranch (BrAdt name args cfe) = do checkUnboundVar name - _ <- traverse checkUnboundVar args + _ <- traverse checkUnboundVar (fst <$> args) doVerifyCfe cfe diff --git a/src/Simpl/SymbolTable.hs b/src/Simpl/SymbolTable.hs index 3bbed45..b4ef431 100644 --- a/src/Simpl/SymbolTable.hs +++ b/src/Simpl/SymbolTable.hs @@ -5,7 +5,6 @@ {-# LANGUAGE DeriveTraversable #-} module Simpl.SymbolTable where -import Data.Functor.Foldable (Fix(Fix)) import Data.List (find) import Data.Maybe (mapMaybe, listToMaybe) import Data.Map.Strict (Map) @@ -17,7 +16,7 @@ import Simpl.Ast import Simpl.Type data SymbolTable expr = MkSymbolTable - { symTabAdts :: Map Text (Type, [Constructor]) + { symTabAdts :: Map Text ([Text], [Constructor]) -- ^ ADT definitions: name, type vars, constructors , symTabFuns :: Map Text (Set Text, [(Text, Type)], Type, expr) -- ^ Static functions: free type vars, params, return type, body , symTabVars :: Map Text Type -- ^ Variables , symTabExtern :: Map Text ([(Text, Type)], Type) -- ^ External functions @@ -29,7 +28,7 @@ buildSymbolTable (SourceFile _ decls) = let adts = Map.fromList $ mapMaybe (\case DeclAdt name tparams ctors -> - Just (name, (Fix (TyAdt name (Fix . TyVar <$> tparams)), ctors)) + Just (name, (tparams, ctors)) _ -> Nothing) decls funs = Map.fromList $ mapMaybe (\case @@ -47,7 +46,7 @@ buildSymbolTable (SourceFile _ decls) = , symTabVars = Map.empty , symTabExtern = extern } -symTabModifyAdts :: (Map Text (Type, [Constructor]) -> Map Text (Type, [Constructor])) +symTabModifyAdts :: (Map Text ([Text], [Constructor]) -> Map Text ([Text], [Constructor])) -> SymbolTable e -> SymbolTable e symTabModifyAdts f t = t { symTabAdts = f (symTabAdts t) } @@ -68,13 +67,13 @@ symTabTraverseExprs f t = do -- | Searches for the given constructor, returning the name of the ADT, the -- constructor, and the index of the constructor. -symTabLookupCtor :: Text -> SymbolTable e -> Maybe (Type, Constructor, Int) +symTabLookupCtor :: Text -> SymbolTable e -> Maybe ((Text, [Text]), Constructor, Int) symTabLookupCtor name t = listToMaybe . mapMaybe (find isTheCtor) $ getCtors where - getCtors = (\(ty, cs) -> zip3 (repeat ty) cs [0..]) <$> Map.elems (symTabAdts t) + getCtors = (\(tname, (tvars, cs)) -> zip3 (repeat (tname, tvars)) cs [0..]) <$> Map.toList (symTabAdts t) isTheCtor (_, Ctor ctorName _, _) = ctorName == name -symTabLookupAdt :: Text -> SymbolTable e -> Maybe (Type, [Constructor]) +symTabLookupAdt :: Text -> SymbolTable e -> Maybe ([Text], [Constructor]) symTabLookupAdt name = Map.lookup name . symTabAdts symTabLookupVar :: Text -> SymbolTable e -> Maybe Type diff --git a/src/Simpl/Type.hs b/src/Simpl/Type.hs index d4c76d1..534fab1 100644 --- a/src/Simpl/Type.hs +++ b/src/Simpl/Type.hs @@ -16,6 +16,8 @@ import Data.Text (Text) import Data.Text.Prettyprint.Doc import Data.Eq.Deriving (deriveEq1) import Data.Ord.Deriving (deriveOrd1) +import Data.Map (Map) +import qualified Data.Map as Map import Data.Set (Set) import qualified Data.Set as Set import Text.Show.Deriving (deriveShow1) @@ -78,6 +80,14 @@ isComplexType = \case TyFun _ _ -> True _ -> False +-- | Whether a type is represented using a pointer +typeRepIsPtr :: TypeF Type -> Bool +typeRepIsPtr = \case + TyNumber _ -> False + TyBool -> False + TyAdt _ _ -> False + _ -> True + functionTypeResult :: Type -> Type functionTypeResult (Fix ty) = case ty of TyFun _ res -> functionTypeResult res @@ -93,6 +103,12 @@ getTypeVars = cata $ \case TyAdt _ vargs -> Set.unions vargs TyBox vs -> vs +substituteTypeVars :: Map Text Type -> Type -> Type +substituteTypeVars vars = cata go + where + go = \case + ty@(TyVar n) -> Map.findWithDefault (Fix ty) n vars + ty -> Fix ty instance Pretty Type where pretty = para go diff --git a/src/Simpl/Typecheck.hs b/src/Simpl/Typecheck.hs index 64e6144..c0ca90c 100644 --- a/src/Simpl/Typecheck.hs +++ b/src/Simpl/Typecheck.hs @@ -112,20 +112,16 @@ inferType = cata $ \ae -> case annGetExpr ae of let argTys = extractTy <$> args ctorRes <- asks (symTabLookupCtor name) case ctorRes of - Just (Fix (TyAdt adtName tparamTys), Ctor _ ctorArgTys, _) -> do + Just ((adtName, tvars), Ctor _ ctorArgTys, _) -> do let numConArgs = length ctorArgTys when (numConArgs /= length argTys) $ throwError $ TyErrArgCount numConArgs (length argTys) ctorArgTys -- Instantiate type variables - let tparams = flip fmap tparamTys $ \t -> case unfix t of - TyVar n -> n - _ -> error "Symbol table ADT types should only contain variables" - substMap <- instantiateVars (Set.fromList tparams) + substMap <- instantiateVars (Set.fromList tvars) let conArgs = substituteUVars substMap . typeToUtype <$> ctorArgTys argTys' <- traverse (uncurry unifyTy) (zip conArgs argTys) let newTy = UTerm (TyAdt adtName argTys') pure $ annotate (Cons name args) (annGetAnn ae) newTy - Just ty -> error $ "Symbol table contained ADT with invalid type: " ++ show ty Nothing -> throwError $ TyErrNoSuchCtor name Case branchMs valM -> do val <- valM @@ -133,17 +129,15 @@ inferType = cata $ \ae -> case annGetExpr ae of branches <- forM branchMs $ \case BrAdt ctorName bindings exprM -> asks (symTabLookupCtor ctorName) >>= \case - Just (dataTy@(Fix (TyAdt _ tparamTys)), Ctor _ ctorArgs, _) -> do + Just ((adtName, tvars), Ctor _ ctorArgs, _) -> do when (length bindings /= length ctorArgs) $ throwError $ TyErrArgCount (length ctorArgs) (length bindings) ctorArgs -- Instantiate type variables - let tparams = flip fmap tparamTys $ \t -> case unfix t of - TyVar n -> n - _ -> error "Symbol table ADT types should only contain variables" - substMap <- instantiateVars (Set.fromList tparams) + substMap <- instantiateVars (Set.fromList tvars) + let dataTy = Fix (TyAdt adtName (Fix . TyVar <$> tvars)) _ <- unifyTy valTy (substituteUVars substMap (typeToUtype dataTy)) let substCtorArgs = substituteUVars substMap . typeToUtype <$> ctorArgs - -- Same hack as in let binding + -- TODO: Same hack as in let binding instCtorArgs <- forM substCtorArgs $ \t -> do t' <- utypeToType <$> forceBindings t case t' of @@ -153,7 +147,6 @@ inferType = cata $ \ae -> case annGetExpr ae of -- Infer result type with ctor args bound expr <- local (\t -> t { symTabVars = Map.union (symTabVars t) updatedBinds }) exprM pure $ BrAdt ctorName bindings expr - Just ty -> error $ "Symbol table contained ADT with invalid type: " ++ show ty Nothing -> throwError $ TyErrNoSuchCtor ctorName let brTys = extractTy . branchGetExpr <$> branches resTy <- mkMetaVar From 22a8be64d69528c48f482e55c9bb907d8032c05a Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Thu, 19 Dec 2019 13:14:40 -0800 Subject: [PATCH 14/23] Update parser to handle polymorphic ADTs --- src/Simpl/Parser.hs | 4 ++-- test-suite/poly-data.spl | 31 ++++++++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/Simpl/Parser.hs b/src/Simpl/Parser.hs index 534800d..ebc1511 100644 --- a/src/Simpl/Parser.hs +++ b/src/Simpl/Parser.hs @@ -184,7 +184,7 @@ typeVar = Fix . TyVar <$> identifier typeAdt :: Parser m Type typeAdt = do name <- typeIdentifier - tparams <- many typeVar + tparams <- many (parens type' <|> type') pure $ Fix (TyAdt name tparams) typeAtom :: Parser m Type @@ -228,7 +228,7 @@ declFun = lexeme $ constructor :: Parser m Ast.Constructor constructor = lexeme $ do name <- typeIdentifier - args <- many type' + args <- many (parens type' <|> type') pure $ Ast.Ctor name args declAdt :: Parser m (Decl SourcedExpr) diff --git a/test-suite/poly-data.spl b/test-suite/poly-data.spl index 8e346c3..8ea2676 100644 --- a/test-suite/poly-data.spl +++ b/test-suite/poly-data.spl @@ -1,8 +1,33 @@ data Maybe a = { Just a | Nothing } -fun main : Int := { +fun testJust : Int := { let foo = Just 5 in # 0 - case foo of - Just x => let _ = println("Just OK") in x + let asdf = case foo of + Just x => let _ = println("Got Just!") in x Nothing => let _ = println("ERROR: Got Nothing") in 0 + in asdf +} + +data List a = { Nil | Cons a (List a) } + +fun head (xs : List a) : Maybe a := { + case xs of + Nil => Nothing + Cons h t => Just h +} + +fun numberList (n : Int) : List Int := { + if n <= 0 then Nil else Cons n @numberList(n - 1) +} + +fun printList (xs : List a) : Int := { + case xs of + Nil => 0 + Cons h t => + let _ = println("Item") in + @printList(t) +} + +fun main : Int := { + @printList(@numberList(10)) } From 62484bf72dfadbbf0a90d1be731e2594e8a9ff9d Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Thu, 19 Dec 2019 13:15:07 -0800 Subject: [PATCH 15/23] Fix printing of ADT type arguments --- src/Simpl/Type.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Simpl/Type.hs b/src/Simpl/Type.hs index 534fab1..98b0a22 100644 --- a/src/Simpl/Type.hs +++ b/src/Simpl/Type.hs @@ -119,7 +119,7 @@ instance Pretty Type where go (TyNumber n) = pretty n go TyBool = "Bool" go TyString = "String" - go (TyAdt n tparams) = pretty n <> hsep (snd <$> tparams) + go (TyAdt n tparams) = hsep (pretty n : (snd <$> tparams)) go (TyFun args res) = encloseSep mempty mempty " -> " (wrapComplex <$> args ++ [res]) go (TyVar n) = pretty n From 2be73027a4d882d3fad38271c3e452c144ad6e92 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Fri, 20 Dec 2019 11:58:42 -0800 Subject: [PATCH 16/23] Fix type check regression: function arg count not checked properly --- src/Simpl/Typecheck.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Simpl/Typecheck.hs b/src/Simpl/Typecheck.hs index c0ca90c..e8b16c8 100644 --- a/src/Simpl/Typecheck.hs +++ b/src/Simpl/Typecheck.hs @@ -172,7 +172,7 @@ inferType = cata $ \ae -> case annGetExpr ae of (tvars, params, ty) <- lookupFun name (extractTy <$> argsTc) -- Check parameter count let numParams = length params - let paramCount = length params + let paramCount = length args when (numParams /= paramCount) $ throwError $ TyErrArgCount numParams paramCount params let unifyExprTy expr pTy = From e53d8e89605c3c8b3f1e440a16ce52ca7a6482dc Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Fri, 20 Dec 2019 11:59:34 -0800 Subject: [PATCH 17/23] Fix type variables not parsing correctly in function signatures --- src/Simpl/Parser.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Simpl/Parser.hs b/src/Simpl/Parser.hs index ebc1511..98541ff 100644 --- a/src/Simpl/Parser.hs +++ b/src/Simpl/Parser.hs @@ -188,7 +188,7 @@ typeAdt = do pure $ Fix (TyAdt name tparams) typeAtom :: Parser m Type -typeAtom = typeLit <|> typeAdt +typeAtom = typeLit <|> typeAdt <|> typeVar typeFun :: Parser m Type typeFun = lexeme $ do @@ -198,7 +198,7 @@ typeFun = lexeme $ do pure . Fix $ TyFun (first : init rest) (last rest) type' :: Parser m Type -type' = try typeFun <|> typeAtom <|> typeVar "type" +type' = try typeFun <|> typeAtom "type" declFunParamList :: Parser m [(Text, Type)] declFunParamList = lexeme $ option [] (parens params) From 84bc5b08fa8017ca0b8b0f737bd14e943c126bdb Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Fri, 20 Dec 2019 12:00:27 -0800 Subject: [PATCH 18/23] Fix function ref lookups in AST to JoinIR transformation --- src/Simpl/AstToJoinIR.hs | 12 +++++++++--- src/Simpl/SymbolTable.hs | 6 ++++++ src/Simpl/Typecheck.hs | 10 +++------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index f78df9d..7ae7da2 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -32,7 +32,7 @@ import Simpl.Annotation import Simpl.SymbolTable import qualified Simpl.Ast as A import qualified Simpl.JoinIR.Syntax as J -import Simpl.Type (Type, TypeF(TyBox, TyVar, TyAdt), substituteTypeVars, typeRepIsPtr) +import Simpl.Type (Type, TypeF(TyBox, TyVar, TyAdt, TyFun), substituteTypeVars, typeRepIsPtr) import Simpl.Typecheck (literalType) import Simpl.Util.Supply import qualified Simpl.Util.Stream as Stream @@ -270,8 +270,14 @@ anfTransform (Fix ae) cont = let ty = getType ae in case annGetExpr ae of A.App funcName args -> collectArgs args $ \argVals -> do varName <- freshVar - (_, funcArgs, funcRetTy, _) <- asks (fromJust . symTabLookupStaticFun funcName . tcSymTab) - let valueBoxPairs = [(val, boxedVal fTy) | (val, (_, fTy)) <- argVals `zip` funcArgs] + -- Function references are stored in variable scope, so we need to perform + -- lookup there. + funcCandidate <- asks (fmap unfix . symTabLookupVar funcName . tcSymTab) >>= \case + Just (TyFun params resTy) -> pure $ Just (params, resTy) + Just t -> error $ "Variable called as function: " ++ show funcName ++ " : " ++ show t + Nothing -> asks (fmap (\(_, p, r) -> (snd <$> p, r)) . symTabLookupFun funcName . tcSymTab) + let (funcArgs, funcRetTy) = fromJust funcCandidate + let valueBoxPairs = [(val, boxedVal fTy) | (val, fTy) <- argVals `zip` funcArgs] collectRebinds valueBoxPairs $ \argVals' -> do let ty' = if isBoxed funcRetTy then Fix (TyBox ty) else ty makeJexpr ty' . J.JApp varName (J.CFunc funcName) argVals' <$> diff --git a/src/Simpl/SymbolTable.hs b/src/Simpl/SymbolTable.hs index b4ef431..2fef48f 100644 --- a/src/Simpl/SymbolTable.hs +++ b/src/Simpl/SymbolTable.hs @@ -5,6 +5,7 @@ {-# LANGUAGE DeriveTraversable #-} module Simpl.SymbolTable where +import Control.Applicative ((<|>)) import Data.List (find) import Data.Maybe (mapMaybe, listToMaybe) import Data.Map.Strict (Map) @@ -90,3 +91,8 @@ symTabLookupStaticFun name = Map.lookup name . symTabFuns symTabLookupExternFun :: Text -> SymbolTable e -> Maybe ([(Text, Type)], Type) symTabLookupExternFun name = Map.lookup name . symTabExtern + +symTabLookupFun :: Text -> SymbolTable e -> Maybe (Set Text, [(Text, Type)], Type) +symTabLookupFun name tab = static tab <|> extern tab + where static = fmap (\(tvars, p, r, _) -> (tvars, p, r)) . symTabLookupStaticFun name + extern = fmap (\(p, r) -> (Set.empty, p, r)) . symTabLookupExternFun name diff --git a/src/Simpl/Typecheck.hs b/src/Simpl/Typecheck.hs index e8b16c8..2e26cc1 100644 --- a/src/Simpl/Typecheck.hs +++ b/src/Simpl/Typecheck.hs @@ -16,7 +16,6 @@ import Control.Monad.Except (ExceptT, MonadError, lift, runExceptT, throwError) import Control.Unification import Control.Unification.IntVar import Data.Maybe (fromMaybe) -import Data.Foldable (asum) import Data.Functor.Identity import Data.Functor.Foldable (Fix(..), unfix, cata) import Data.Text (Text) @@ -229,12 +228,13 @@ inferType = cata $ \ae -> case annGetExpr ae of rTy <- if unifyArgResult then unifyTy yTy resultTy else pure resultTy pure $ annotate (BinOp op x y) annFields rTy + -- | Looks up a function from either variable or function scope. Does not + -- instantiate type variables (should be handled on a per-construct basis). lookupFun :: Text -> [UType] -> Typecheck fields (Set Text, [Type], Type) lookupFun name argTys = asks (symTabLookupVar name) >>= \case Just ty -> case unfix ty of - -- TODO: Currently there is no way to know if the function is polymorphic TyFun params resTy -> pure (Set.empty, params, resTy) _ -> do resTy <- mkMetaVar @@ -242,11 +242,7 @@ inferType = cata $ \ae -> case annGetExpr ae of expected = TyFun argTys resTy throwError $ TyErrMismatch expected got Nothing -> - -- TODO: Instantiate type variables - let lookupStatic n = fmap (\(tvars, p, r, _) -> (tvars, p, r)) . symTabLookupStaticFun n - lookupExtern n = fmap (\(p, r) -> (Set.empty, p, r)) . symTabLookupExternFun n - result = traverse (\f -> asks (f name)) [lookupStatic, lookupExtern] in - asum <$> result >>= \case + asks (symTabLookupFun name) >>= \case Just (tvars, params, resTy) -> pure (tvars, snd <$> params, resTy) Nothing -> throwError (TyErrNoSuchVar name) From 3f8c1b6e8e2f97953a771d15f7410cefc2f39719 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Fri, 20 Dec 2019 12:24:33 -0800 Subject: [PATCH 19/23] Fix AST to Join IR transformation of case branch bindings --- src/Simpl/AstToJoinIR.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Simpl/AstToJoinIR.hs b/src/Simpl/AstToJoinIR.hs index 7ae7da2..0cd35b0 100644 --- a/src/Simpl/AstToJoinIR.hs +++ b/src/Simpl/AstToJoinIR.hs @@ -190,8 +190,6 @@ transformBranch :: (HasType flds, MonadReader (TransformCtx flds) m, MonadFreshV -> m (J.JBranch (J.AnnExpr '[ 'ExprType])) transformBranch cf guardTy boxVal (A.BrAdt ctorName argNames expr) = do (_, A.Ctor _ argTys, _) <- asks (fromJust . symTabLookupCtor ctorName . tcSymTab) - let withScope ctx = foldr (uncurry insertVar) ctx (argNames `zip` argTys) - jexpr <- local withScope $ anfTransform expr $ withRebindBoxing boxVal J.JVal argTys' <- case unfix guardTy of TyAdt name tvars -> do (tvars', _) <- asks (fromJust . symTabLookupAdt name . tcSymTab) @@ -203,6 +201,8 @@ transformBranch cf guardTy boxVal (A.BrAdt ctorName argNames expr) = do then t' else Fix (TyBox t') _ -> substituteTypeVars tvarMapping t _ -> pure argTys + let withScope ctx = foldr (uncurry insertVar) ctx (argNames `zip` argTys') + jexpr <- local withScope $ anfTransform expr $ withRebindBoxing boxVal J.JVal pure $ J.BrAdt ctorName (argNames `zip` argTys') (J.Cfe jexpr cf) From 0c37cd07de238bf1a4381ae41ff9ae09b0f0f586 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Fri, 20 Dec 2019 12:58:12 -0800 Subject: [PATCH 20/23] Fix constructor memory allocation `gep null 0` returns 0! Replace the dynamic size calculation with the a correct compile-time size calculation. --- src/Simpl/Backend/Codegen.hs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 2afed36..2b56c09 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -372,13 +372,10 @@ callableCodegen callable args = case callable of -- Tag (index = 0) tagStruct2 <- LLVMIR.insertValue tagStruct1 (LLVMIR.int32 (fromIntegral ctorIndex)) [0] -- Data pointer (index = 1) - let ctorTy = LLVM.NamedTypeReference ctorName - let nullptr = LLVM.ConstantOperand (LLVMC.Null (LLVM.ptr ctorTy)) - -- Use offsets to calculate struct size - ctorStructSize <- LLVMIR.gep nullptr [LLVMIR.int32 0] - >>= flip LLVMIR.ptrtoint LLVM.i64 -- Allocate memory for constructor. -- For now, use "leak memory" as an implementation strategy for deallocation. + let ctorTy = LLVM.NamedTypeReference ctorName + let ctorStructSize = LLVM.ConstantOperand (LLVMC.ZExt (LLVMC.sizeof ctorTy) LLVM.i64) ctorStructPtr <- LLVMIR.call RT.mallocRef [(ctorStructSize, [])] >>= flip LLVMIR.bitcast (LLVM.ptr ctorTy) values <- traverse jvalueCodegen args From 01a7da41a2f920dc9492862e2d846490f8df1a06 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Fri, 20 Dec 2019 12:59:47 -0800 Subject: [PATCH 21/23] Update polymorphism test cases --- test-suite/poly-data.spl | 34 ++++++++++++++++++++++++++++++---- test-suite/polymorphism.spl | 4 ++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/test-suite/poly-data.spl b/test-suite/poly-data.spl index 8ea2676..35bf365 100644 --- a/test-suite/poly-data.spl +++ b/test-suite/poly-data.spl @@ -16,18 +16,44 @@ fun head (xs : List a) : Maybe a := { Cons h t => Just h } +# FIXME: passing a function ref w/ unboxed argument to a function +# requiring a function ref w/ boxed argument does not work. +# Passed function ref needs to be automatically "promoted" to a +# boxed version. + +# fun filter (xs : List a, f : a -> Bool) : List a := { +# case xs of +# Nil => Nil +# Cons h t => if @f(h) then Cons h @filter(t, f) else @filter(t, f) +# } + +fun filter (xs : List Int, f : Int -> Bool) : List Int := { + case xs of + Nil => Nil + Cons h t => + let newT = @filter(t, f) in + if @f(h) then Cons h newT else newT +} + +fun lte5 (x: Int) : Bool := { + x <= 5 +} + fun numberList (n : Int) : List Int := { if n <= 0 then Nil else Cons n @numberList(n - 1) } -fun printList (xs : List a) : Int := { +fun printList (xs : List a, msg : String) : Int := { case xs of Nil => 0 Cons h t => - let _ = println("Item") in - @printList(t) + let _ = println(msg) in + @printList(t, msg) } fun main : Int := { - @printList(@numberList(10)) + let nums = @numberList(10) in + let x1 = @printList(nums, "numberList") in + let x2 = @printList(@filter(nums, <e5), "filter(numberList, <= 5)") in + x2 } diff --git a/test-suite/polymorphism.spl b/test-suite/polymorphism.spl index fa46428..85c24f7 100644 --- a/test-suite/polymorphism.spl +++ b/test-suite/polymorphism.spl @@ -14,6 +14,10 @@ fun test_case (x : Int) : String := { Nope => "bye") } +fun hi : String := { + let f = &test_case in @test_case(0) +} + fun main : Int := { let _ = println(@id("hi")) in let y = 5 in From 5b93c6ed557c99df012cd217369c18d72afe7d86 Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 29 Dec 2019 21:44:18 -0800 Subject: [PATCH 22/23] Fix boxing of doubles --- src/Simpl/Backend/Codegen.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Simpl/Backend/Codegen.hs b/src/Simpl/Backend/Codegen.hs index 2b56c09..951dfb4 100644 --- a/src/Simpl/Backend/Codegen.hs +++ b/src/Simpl/Backend/Codegen.hs @@ -132,9 +132,10 @@ lookupTypeTag ty = Just oper -> pure oper Nothing -> do let name = case ty of - TyNumber _ -> "Int" -- TODO: fix this - TyString -> "String" + TyNumber NumInt -> "Int" + TyNumber NumDouble -> "Double" TyBool -> "Bool" + TyString -> "String" TyAdt n _ -> "data." <> n x -> error ("TODO: handle tag type lookup for " ++ show x) let llvmTy = typeToLLVM (Fix ty) @@ -571,8 +572,8 @@ moduleCodegen srcCode symTab = mdo -- Insert function operands into symbol table before emitting so order of -- definition doesn't matter. This works because the codegen monad is lazy. modify (\t -> t { tableFuns = tableFuns t `Map.union` Map.fromList funOpers }) - -- TODO: Care about type variables funOpers <- forM (Map.toList . symTabFuns $ symTab) $ \(name, (_, params, ty, body)) -> + -- Ignore type variables since they're handled with boxing code (name, ) <$> funToLLVM name params ty body _ <- LLVMIR.function "main" [] LLVM.i64 $ \_ -> do From 481587750bf4b7cead488338e025266ffe4a17ec Mon Sep 17 00:00:00 2001 From: Bryan Tan Date: Sun, 29 Dec 2019 21:44:35 -0800 Subject: [PATCH 23/23] Fix substitution of ADT type variables in type checker --- src/Simpl/Typecheck.hs | 14 ++++++++------ test-suite/poly-data.spl | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/Simpl/Typecheck.hs b/src/Simpl/Typecheck.hs index 2e26cc1..494ca79 100644 --- a/src/Simpl/Typecheck.hs +++ b/src/Simpl/Typecheck.hs @@ -15,7 +15,8 @@ import Control.Monad.Reader (ReaderT, MonadReader, runReaderT, asks, local) import Control.Monad.Except (ExceptT, MonadError, lift, runExceptT, throwError) import Control.Unification import Control.Unification.IntVar -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, fromJust) +import Data.Foldable (traverse_) import Data.Functor.Identity import Data.Functor.Foldable (Fix(..), unfix, cata) import Data.Text (Text) @@ -118,8 +119,9 @@ inferType = cata $ \ae -> case annGetExpr ae of -- Instantiate type variables substMap <- instantiateVars (Set.fromList tvars) let conArgs = substituteUVars substMap . typeToUtype <$> ctorArgTys - argTys' <- traverse (uncurry unifyTy) (zip conArgs argTys) - let newTy = UTerm (TyAdt adtName argTys') + traverse_ (uncurry unifyTy) (zip conArgs argTys) + let newTvars = [fromJust $ Map.lookup v substMap | v <- tvars] + let newTy = UTerm (TyAdt adtName newTvars) pure $ annotate (Cons name args) (annGetAnn ae) newTy Nothing -> throwError $ TyErrNoSuchCtor name Case branchMs valM -> do @@ -174,14 +176,14 @@ inferType = cata $ \ae -> case annGetExpr ae of let paramCount = length args when (numParams /= paramCount) $ throwError $ TyErrArgCount numParams paramCount params - let unifyExprTy expr pTy = - annotate (annGetExpr (unfix expr)) (annGetAnn ae) <$> unifyTy (extractTy expr) pTy -- Instantiate type variables substMap <- instantiateVars tvars let instParams = substituteUVars substMap . typeToUtype <$> params let resTy = substituteUVars substMap (typeToUtype ty) -- Check parameter types + let unifyExprTy expr pTy = + annotate (annGetExpr (unfix expr)) (annGetAnn ae) <$> unifyTy (extractTy expr) pTy params' <- zipWithM unifyExprTy argsTc instParams -- Annotate with result type pure $ annotate (App name params') (annGetAnn ae) resTy @@ -274,7 +276,7 @@ typeToUtype = cata $ \case TyNumber n -> UTerm (TyNumber n) TyBool -> UTerm TyBool TyString -> UTerm TyString - TyAdt n tparams -> UTerm (TyAdt n tparams) -- TODO: Instantiate variables somewhere + TyAdt n tparams -> UTerm (TyAdt n tparams) TyFun args res -> UTerm (TyFun args res) TyVar n -> UTerm (TyVar n) TyBox _ -> error "TyBox should not be in SimPL AST" diff --git a/test-suite/poly-data.spl b/test-suite/poly-data.spl index 35bf365..4d4be16 100644 --- a/test-suite/poly-data.spl +++ b/test-suite/poly-data.spl @@ -57,3 +57,19 @@ fun main : Int := { let x2 = @printList(@filter(nums, <e5), "filter(numberList, <= 5)") in x2 } + +data AddTable a = { Add (a -> a -> a) } + +fun add (x : a, y : a, table : AddTable a) : a := { + case table of + Add f => @f(x, y) +} + +fun add_Int (x : Int, y : Int) : Int := { + x + y +} + +fun addTable_Int : AddTable Int := { + let f = &add_Int in + Add f +}