From 006302686c6277bd473b7e2e07ae68d90aced775 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Thu, 19 Oct 2023 11:57:54 +1100 Subject: [PATCH 01/47] Update jitify2 copyright years --- LICENSE | 2 +- example_headers/class_arg_kernel.cuh | 2 +- example_headers/constant_header.cuh | 2 +- example_headers/my_header1.cuh | 2 +- example_headers/my_header2.cuh | 2 +- example_headers/my_header3.cuh | 2 +- jitify2.hpp | 2 +- jitify2_preprocess.cpp | 2 +- jitify2_test.cu | 2 +- jitify2_test_kernels.cu | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/LICENSE b/LICENSE index b678a46..a4d873b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2017-2024, NVIDIA Corporation +Copyright (c) 2017-2025, NVIDIA Corporation All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/example_headers/class_arg_kernel.cuh b/example_headers/class_arg_kernel.cuh index 19dd48a..b452ba3 100644 --- a/example_headers/class_arg_kernel.cuh +++ b/example_headers/class_arg_kernel.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/example_headers/constant_header.cuh b/example_headers/constant_header.cuh index f3f1cc9..0eaf9bf 100644 --- a/example_headers/constant_header.cuh +++ b/example_headers/constant_header.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/example_headers/my_header1.cuh b/example_headers/my_header1.cuh index 7f07df7..38027c9 100644 --- a/example_headers/my_header1.cuh +++ b/example_headers/my_header1.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/example_headers/my_header2.cuh b/example_headers/my_header2.cuh index f5a90c2..c776fae 100644 --- a/example_headers/my_header2.cuh +++ b/example_headers/my_header2.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/example_headers/my_header3.cuh b/example_headers/my_header3.cuh index 4933de5..e5f3cc7 100644 --- a/example_headers/my_header3.cuh +++ b/example_headers/my_header3.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/jitify2.hpp b/jitify2.hpp index d5d379d..25592c4 100644 --- a/jitify2.hpp +++ b/jitify2.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/jitify2_preprocess.cpp b/jitify2_preprocess.cpp index 575efe9..93ffdbe 100644 --- a/jitify2_preprocess.cpp +++ b/jitify2_preprocess.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/jitify2_test.cu b/jitify2_test.cu index f22c684..e53bc96 100644 --- a/jitify2_test.cu +++ b/jitify2_test.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions diff --git a/jitify2_test_kernels.cu b/jitify2_test_kernels.cu index 8dbcbab..d2681b3 100644 --- a/jitify2_test_kernels.cu +++ b/jitify2_test_kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions From f5cd6e1aa6dc03377656a293b3edbf55ccaba9a7 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Mon, 6 Nov 2023 14:39:47 +1100 Subject: [PATCH 02/47] Overhaul parsing and preprocessing - Replaces C++ lexing/parsing/patching code with a proper lexer implementation, which significantly improves robustness and maintainability. - Replaces minification logic with robust token-based minification. - Replaces preprocessing logic with a new approach that uses custom parsing to find include directives. This only requires invoking NVRTC (and only its preprocessor) once per preprocess, which speeds up preprocessing by 50x in some cases. - Fixes include directory handling. Relative include paths are now handled robustly, and there is no longer any ambiguity between external and built-in headers. Note that relative paths (including in -I options) now start from the current executable directory instead of the current working directory. - These changes should be almost completely backwards compatible. --- jitify2.hpp | 2303 ++++++++++++++++++++++++++++++++++++----------- jitify2_test.cu | 549 ++++++++++- 2 files changed, 2296 insertions(+), 556 deletions(-) diff --git a/jitify2.hpp b/jitify2.hpp index 25592c4..be1cb51 100644 --- a/jitify2.hpp +++ b/jitify2.hpp @@ -110,6 +110,7 @@ #include #include #include +#include #include #include #include @@ -2294,6 +2295,15 @@ inline std::string path_base(const std::string& p) { } } +inline bool path_is_absolute(const std::string& p) { +#if defined _WIN32 || defined _WIN64 + return (p.size() >= 1 && (p[0] == '\\' || p[0] == '/')) || + (p.size() >= 3 && p[1] == ':' && (p[2] == '\\' || p[2] == '/')); +#else + return p.size() >= 1 && p[0] == '/'; +#endif +} + inline std::string path_join(StringRef p1, StringRef p2) { #if defined _WIN32 || defined _WIN64 // Note that Windows supports both forward and backslash path separators. @@ -2301,7 +2311,7 @@ inline std::string path_join(StringRef p1, StringRef p2) { #else const char* sep = "/"; #endif - if (p1.size() && p2.size() && std::strchr(sep, p2[0])) { + if (p1.size() && p2.size() && path_is_absolute(p2)) { return {}; // Error, cannot join to absolute path } std::string result; @@ -3366,6 +3376,15 @@ inline void add_default_device_flag_if_not_specified(OptionsVec* options) { } } +inline void add_no_source_include_flag_if_not_specified(OptionsVec* options) { + // This prevents NVRTC's preprocessor from automatically using the current + // working directory as an include path. We need to do this because we must + // find all includes ourselves so that we can patch them etc. + if (options->find({"--no-source-include", "-no-source-include"}).empty()) { + options->emplace_back("-no-source-include"); + } +} + // Demangles nested variable names using the PTX name mangling scheme // (which mostly follows the Itanium64 ABI). E.g., _ZN1a3Foo2bcE -> a::Foo::bc. inline std::string demangle_ptx_variable_name(const char* mangled_name) { @@ -3463,13 +3482,12 @@ inline void find_lowered_global_variables(StringRef ptx, inline bool ptx_remove_unused_globals(std::string* ptx); // Defined below -// Returns false on error. // Sets *error on failure if provided. // Sets *log if provided. // Sets *ptx on success if provided. // Adds one entry to *lowered_name_map for each entry in name_expressions as // well as any global definitions found in the generated PTX. -inline bool compile_program( +inline nvrtcResult compile_program( const std::string& name, const std::string& source, const StringMap& header_sources, const OptionsVec& options, std::string* error = nullptr, std::string* log = nullptr, @@ -3478,7 +3496,7 @@ inline bool compile_program( StringMap* lowered_name_map = nullptr, bool remove_unused_globals = false) { if (!nvrtc()) { if (error) *error = nvrtc().error(); - return false; + return NVRTC_ERROR_PROGRAM_CREATION_FAILURE; } std::vector header_names_c; @@ -3509,7 +3527,7 @@ inline bool compile_program( nvrtcResult jitify_nvrtc_ret = call; \ if (jitify_nvrtc_ret != NVRTC_SUCCESS) { \ if (error) *error = nvrtc().GetErrorString()(jitify_nvrtc_ret); \ - return false; \ + return jitify_nvrtc_ret; \ } \ } while (0) @@ -3592,7 +3610,7 @@ inline bool compile_program( } #undef JITIFY_CHECK_NVRTC - return true; + return NVRTC_SUCCESS; } inline StringVec split_string(std::string str, long maxsplit = -1, @@ -3743,10 +3761,10 @@ inline CompiledProgram CompiledProgram::compile( {"-remove-unused-globals", "--remove-unused-globals"}); std::string log, ptx, cubin, nvvm; StringMap lowered_name_map; - if (!detail::compile_program(name, source, header_sources, compiler_options, - &error, &log, &ptx, &cubin, &nvvm, - name_expressions, &lowered_name_map, - should_remove_unused_globals)) { + if (detail::compile_program(name, source, header_sources, compiler_options, + &error, &log, &ptx, &cubin, &nvvm, + name_expressions, &lowered_name_map, + should_remove_unused_globals)) { std::string options_str = detail::string_join( compiler_options, " ", "Compiler options: \"", "\"\n"); std::vector header_names; @@ -3986,7 +4004,20 @@ class PreprocessedProgramData } }; -using FileCallback = std::function; +namespace parser { + +class IncludeName; + +} // namespace parser + +using parser::IncludeName; // Pull into main namespace + +using HeaderCallback = + std::function; + +// TODO: Mark with deprecated attribute. +// Deprecated, use HeaderCallback instead. +using FileCallback = HeaderCallback; class PreprocessedProgram : public detail::FallibleObjectBase\n #include INC", which Thrust does in some headers. - size_t beg = find_source_line(source, line_num); - if (beg == std::string::npos) { - if (error) *error = "EOF reached before source line was found"; - return false; - } - // TODO: This is not robust to inline comments, strings etc. - beg = source.find("include", beg); - if (beg == std::string::npos) { - if (error) *error = "Line does not contain 'include'"; - return false; - } - beg += 7; - beg = source.find_first_of("\"<", beg); - if (beg == std::string::npos) { - if (error) *error = "Did not find expected '\"' or '<' character"; - return false; - } - return source[beg] == '"'; -} - // Elides "/." and "/.." tokens from path. Returns empty string if illformed. inline std::string path_simplify(StringRef path) { #if defined _WIN32 || defined _WIN64 @@ -5735,152 +5669,696 @@ inline bool read_text_file(const std::string& fullpath, std::string* content) { return true; } -static const char* const kJitifyBuiltinHeaderPrefix = "__jitify_builtin"; -static const char* const kJitifyCallbackHeaderPrefix = "__jitify_callback"; +// Prepends the current executable dir (instead of the current working dir, +// which is the implicit default) to relative paths. This is expected to be more +// useful than the default because it allows referencing headers that are +// shipped with the application independent of the current working directory. +inline std::string expand_include_path(std::string path) { + if (path.empty()) return ""; + if (!path_is_absolute(path)) { + path = path_join(path_base(get_current_executable_path()), path); + } + // TODO: Consider also expanding "$FOO" and "${FOO}" as environment variables. + return path; +} -// Searches for the specified header and loads its contents into *source and its -// full path into *fullpath. Returns false if not found. -inline bool load_header_impl(const std::string& filename, - const StringVec& include_paths, - StringRef current_dir, bool search_current_dir, - bool search_builtin_headers, - FileCallback header_callback, std::string* source, - std::string* fullpath) { - // Try loading from header callback. - if (header_callback) { - *fullpath = path_join(kJitifyCallbackHeaderPrefix, filename); - if (header_callback(filename, source)) return true; - } - // Try loading from filesystem. - if (search_current_dir) { - *fullpath = path_join(current_dir, filename); - if (read_text_file(*fullpath, source)) return true; - } - // Search include directories. - for (const std::string& include_path : include_paths) { - *fullpath = path_join(include_path, filename); - if (read_text_file(*fullpath, source)) return true; +inline void extract_include_paths(OptionsVec* options, + StringVec* include_paths) { + const std::vector idxs = options->find({"-I"}); + for (int i = (int)idxs.size() - 1; i >= 0; --i) { + const int idx = idxs[i]; + std::string include_path = (*options)[idx].value(); + include_path = expand_include_path(std::move(include_path)); + include_paths->push_back(std::move(include_path)); + options->erase(idx); } - // Try loading from builtin headers. - if (search_builtin_headers) { - *fullpath = path_join(kJitifyBuiltinHeaderPrefix, filename); - auto iter = get_jitsafe_headers_map().find(filename); - if (iter != get_jitsafe_headers_map().end()) { - *source = iter->second; - return true; +} + +// Replaces forward and backward slashes with '|'. +inline std::string sanitize_slashes(std::string s) { + for (std::string::iterator it = s.begin(); it != s.end(); ++it) { + if (*it == '\\' || *it == '/') { + *it = '|'; } } - return false; + return s; } -enum class HeaderLoadStatus { - FAILED = 0, - ALREADY_LOADED = 1, - NEWLY_LOADED = 2, +// Note: When acting as a reference, this behaves like a raw pointer, so the +// referenced value must outlive this class. Caution is advised. +template +class ValueOrRef { + public: + using value_type = T; + using reference = T&; + using const_reference = const T&; + using pointer = T*; + + ValueOrRef() = default; + // Construct as value. Allow implicit conversions. + ValueOrRef(value_type _val) : val_(std::move(_val)) {} + // Construct as reference. + explicit ValueOrRef(pointer _ref) : ref_(_ref) {} + + // Implicit conversion to reference. + operator const_reference() const { return ref_ ? *ref_ : val_; } + operator reference() { return ref_ ? *ref_ : val_; } + + void copy_to_and_reference(T* dst) { + *dst = ref_ ? *ref_ : std::move(val_); + ref_ = dst; + } + + private: + value_type val_; + pointer ref_ = nullptr; }; -// Searches for the specified header and adds its contents to *sources and its -// simplified full path to *fullpaths (if provided). Returns 0 if not found, -1 -// if alreay found, or 1 if successfully loaded. -inline HeaderLoadStatus load_header( - const std::string& filename, const StringVec& include_paths, - StringRef current_dir, bool search_current_dir, bool search_builtin_headers, - FileCallback header_callback, StringMap* sources, StringMap* fullpaths) { - if (sources->count(filename)) { - return HeaderLoadStatus::ALREADY_LOADED; - } - std::string source, fullpath; - if (!load_header_impl(filename, include_paths, current_dir, - search_current_dir, search_builtin_headers, - header_callback, &source, &fullpath)) { - return HeaderLoadStatus::FAILED; - } - sources->emplace(filename, source); - if (fullpaths) { - // Record the full file path corresponding to this include name. - fullpaths->emplace(filename, path_simplify(fullpath)); - } - return HeaderLoadStatus::NEWLY_LOADED; -} - -// Replaces std with cuda::std so that the jit-safe libcudacxx implementations -// are used instead of the unsafe standard implementations. -inline std::string replace_std_with_cuda_std(std::string source) { - static const std::regex re_qualified_name( - R"(::cuda::std::|\bcuda::std::|::std::|\bstd::)", std::regex::optimize); - // TODO: This isn't safe because it might already be ns cuda { ns std { } }. - // static const std::regex re_namespace(R"(\bnamespace\s+std\s*\{)", - // std::regex::optimize); - source = std::regex_replace(source, re_qualified_name, "::cuda::std::"); - // source = std::regex_replace(source, re_namespace, "namespace cuda::std {"); - return source; -} - -// Helper class for basic lexing of C++ source code. -class CppLexer { - const char* current_; +using StringOrRef = ValueOrRef; + +} // namespace detail + +namespace parser { + +// This includes whitespace and comment tokens so that it forms a lossless +// representation of the original source. +class Token { + public: + enum class Type : int { + kInvalid, + kLParen, // ( + kRParen, // ) + kLBracket, // [ <: (if not followed by :: or :>) + kRBracket, // ] :> + kLBrace, // { <% + kRBrace, // } %> + kDot, // . + kDotStar, // .* + kArrow, // -> + kArrowStar, // ->* + kComma, // , + kPlus, // + + kPlusPlus, // ++ + kPlusEq, // += + kMinus, // - + kMinusMinus, // -- + kMinusEq, // -= + kStar, // * + kStarEq, // *= + kSlash, // / + kSlashEq, // /= + kPercent, // % + kPercentEq, // %= + kQuestion, // ? + kColon, // : + kColonColon, // :: + kAmp, // & + kAmpAmp, // && + kAmpEq, // &= + kBar, // | + kBarBar, // || + kBarEq, // |= + kCaret, // ^ + kCaretEq, // ^= + kTilde, // ~ + kEq, // = + kEqEq, // == + kBang, // ! + kBangEq, // != + kLt, // < + kLtLt, // << + kLtEq, // <= + kLtLtEq, // <<= + kGt, // > + kGtGt, // >> + kGtEq, // >= + kGtGtEq, // >>= + kHash, // # %: + kHashHash, // ## %:%: + kSemicolon, // ; + kEndOfDirective, // Newline at end of a preprocessor directive + kWhitespace, // Any sequence of whitespace + kNumber, // Anything beginning with a digit + kString, // "abc" (or after a #include directive) + kRawString, // R"delim(abc)delim" + kCharacter, // 'c' (possibly prefixed) + kIdentifier, // abc_def + kKeyword, // class, using, not, etc. (excludes preproc directives) + kComment, // // or /**/ comment + kEndOfFile, // The end of the file + kNumTokenTypes + }; + + // This is useful for debugging. + friend std::string to_string(Type token_type) { +#define JITIFY_DEFINE_TOKEN_CASE(type) \ + case Type::type: \ + return #type + + switch (token_type) { + JITIFY_DEFINE_TOKEN_CASE(kInvalid); + JITIFY_DEFINE_TOKEN_CASE(kLParen); + JITIFY_DEFINE_TOKEN_CASE(kRParen); + JITIFY_DEFINE_TOKEN_CASE(kLBracket); + JITIFY_DEFINE_TOKEN_CASE(kRBracket); + JITIFY_DEFINE_TOKEN_CASE(kLBrace); + JITIFY_DEFINE_TOKEN_CASE(kRBrace); + JITIFY_DEFINE_TOKEN_CASE(kDot); + JITIFY_DEFINE_TOKEN_CASE(kDotStar); + JITIFY_DEFINE_TOKEN_CASE(kArrow); + JITIFY_DEFINE_TOKEN_CASE(kArrowStar); + JITIFY_DEFINE_TOKEN_CASE(kComma); + JITIFY_DEFINE_TOKEN_CASE(kPlus); + JITIFY_DEFINE_TOKEN_CASE(kPlusPlus); + JITIFY_DEFINE_TOKEN_CASE(kPlusEq); + JITIFY_DEFINE_TOKEN_CASE(kMinus); + JITIFY_DEFINE_TOKEN_CASE(kMinusMinus); + JITIFY_DEFINE_TOKEN_CASE(kMinusEq); + JITIFY_DEFINE_TOKEN_CASE(kStar); + JITIFY_DEFINE_TOKEN_CASE(kStarEq); + JITIFY_DEFINE_TOKEN_CASE(kSlash); + JITIFY_DEFINE_TOKEN_CASE(kSlashEq); + JITIFY_DEFINE_TOKEN_CASE(kPercent); + JITIFY_DEFINE_TOKEN_CASE(kPercentEq); + JITIFY_DEFINE_TOKEN_CASE(kQuestion); + JITIFY_DEFINE_TOKEN_CASE(kColon); + JITIFY_DEFINE_TOKEN_CASE(kColonColon); + JITIFY_DEFINE_TOKEN_CASE(kAmp); + JITIFY_DEFINE_TOKEN_CASE(kAmpAmp); + JITIFY_DEFINE_TOKEN_CASE(kAmpEq); + JITIFY_DEFINE_TOKEN_CASE(kBar); + JITIFY_DEFINE_TOKEN_CASE(kBarBar); + JITIFY_DEFINE_TOKEN_CASE(kBarEq); + JITIFY_DEFINE_TOKEN_CASE(kCaret); + JITIFY_DEFINE_TOKEN_CASE(kCaretEq); + JITIFY_DEFINE_TOKEN_CASE(kTilde); + JITIFY_DEFINE_TOKEN_CASE(kEq); + JITIFY_DEFINE_TOKEN_CASE(kEqEq); + JITIFY_DEFINE_TOKEN_CASE(kBang); + JITIFY_DEFINE_TOKEN_CASE(kBangEq); + JITIFY_DEFINE_TOKEN_CASE(kLt); + JITIFY_DEFINE_TOKEN_CASE(kLtLt); + JITIFY_DEFINE_TOKEN_CASE(kLtEq); + JITIFY_DEFINE_TOKEN_CASE(kLtLtEq); + JITIFY_DEFINE_TOKEN_CASE(kGt); + JITIFY_DEFINE_TOKEN_CASE(kGtGt); + JITIFY_DEFINE_TOKEN_CASE(kGtEq); + JITIFY_DEFINE_TOKEN_CASE(kGtGtEq); + JITIFY_DEFINE_TOKEN_CASE(kHash); + JITIFY_DEFINE_TOKEN_CASE(kHashHash); + JITIFY_DEFINE_TOKEN_CASE(kSemicolon); + JITIFY_DEFINE_TOKEN_CASE(kEndOfDirective); + JITIFY_DEFINE_TOKEN_CASE(kWhitespace); + JITIFY_DEFINE_TOKEN_CASE(kNumber); + JITIFY_DEFINE_TOKEN_CASE(kString); + JITIFY_DEFINE_TOKEN_CASE(kRawString); + JITIFY_DEFINE_TOKEN_CASE(kCharacter); + JITIFY_DEFINE_TOKEN_CASE(kIdentifier); + JITIFY_DEFINE_TOKEN_CASE(kKeyword); + JITIFY_DEFINE_TOKEN_CASE(kComment); + JITIFY_DEFINE_TOKEN_CASE(kEndOfFile); + JITIFY_DEFINE_TOKEN_CASE(kNumTokenTypes); + } +#undef JITIFY_DEFINE_TOKEN_CASE + return ""; + } + + friend std::ostream& operator<<(std::ostream& stream, Type token_type) { + return stream << to_string(token_type); + } + + friend std::ostream& operator<<(std::ostream& stream, const Token& token) { + return stream << token.type() << "(" << token.token_string() << ")"; + } + + static bool TypeIsValid(Token::Type token_type) { + return token_type != Token::Type::kInvalid && + token_type != Token::Type::kEndOfFile; + } + + // Efficiently represents a set of token types. + class TypeSet { + public: + constexpr TypeSet() : data_(0) {} + + template + constexpr TypeSet(Type token_type0, TokenTypes... token_types) + : TypeSet(TypeSet(uint64_t(1) << static_cast(token_type0)) | + TypeSet(token_types...)) {} + + // Tests if token_type is in the set. + constexpr bool count(Type token_type) const { + return data_ & (uint64_t(1) << static_cast(token_type)); + } + + // Combine sets. + friend constexpr TypeSet operator|(TypeSet lhs, TypeSet rhs) { + return TypeSet(lhs.data_ | rhs.data_); + } + friend constexpr TypeSet operator&(TypeSet lhs, TypeSet rhs) { + return TypeSet(lhs.data_ & rhs.data_); + } + + private: + constexpr explicit TypeSet(uint64_t _data) : data_(_data) {} + + uint64_t data_; + static_assert(static_cast(Type::kNumTokenTypes) <= + sizeof(data_) * 8, + "Too many token types to fit in 64-bit set!"); + }; + + Token() = default; + Token(Type _type, const char* _begin, const char* _end, + std::string _token_string = {}) + : begin_(_begin), + size_(static_cast(_end - _begin)), + type_(_type), + token_string_(std::move(_token_string)) {} + Token(Type _type, std::string _token_string) + : type_(_type), token_string_(std::move(_token_string)) {} + + const char* begin() const { return begin_; } + const char* end() const { return begin_ + size_; } + Type type() const { return type_; } + + explicit operator bool() const { return TypeIsValid(type_); } + + friend bool operator==(const Token& lhs, const Token& rhs) { + return lhs.type_ == rhs.type_ && lhs.begin_ == rhs.begin_ && + lhs.size_ == rhs.size_ && lhs.token_string_ == rhs.token_string_; + } + friend bool operator!=(const Token& lhs, const Token& rhs) { + return !(lhs == rhs); + } + + template + bool matches(TokenTypes... token_types) const { + return TypeSet(token_types...).count(type_); + } + + bool matches_identifier(const char* name) const { + return type_ == Token::Type::kIdentifier && token_string() == name; + } + + // Returns the number of newlines in the token's original source string. + // Note that any token can have escaped newlines in it. + int num_newlines() const { + const std::string source = source_string(); + return (int)std::count(source.begin(), source.end(), '\n'); + } + + // Returns the number of newlines (excluding escaped newlines) in the token + // string. + int num_unescaped_newlines() const { + const std::string token = token_string(); + return (int)std::count(token.begin(), token.end(), '\n'); + } + + std::string source_string() const { + return begin_ ? std::string(begin_, size_) : token_string_; + } + + std::string token_string() const { + return token_string_.empty() ? source_string() : token_string_; + } + + private: + // Note: begin_ and end_ point to locations in the original source. In the + // simple case, the string between them is exactly the token string, and + // token_string_ is empty. If the source contains escaped newlines or if + // tokens have been concatenated, begin_ and end_ reference the original + // source string (e.g., "foo ## b\\\nar") and token_string_ is set to the + // logical token string (e.g., "foobar"). + const char* begin_ = nullptr; + uint32_t size_ = 0; + Type type_ = Type::kInvalid; + std::string token_string_; +}; - bool isspace(char c) const { - return std::isspace(static_cast(c)); +// Converts token to a kKeyword token if it matches a language keyword, +// otherwise returns it unchanged. The cxx_standard_year argument is e.g., 11, +// 14, 17, or 20, or -1 for the latest standard including technical +// specifications. +inline bool is_keyword(const std::string& token_string, + int cxx_standard_year = -1) { + static const std::unordered_set keywords = { + "and", "and_eq", "asm", "auto", + "bitand", "bitor", "bool", "break", + "case", "catch", "char", "class", + "compl", "const", "const_cast", "continue", + "default", "delete", "do", "double", + "dynamic_cast", "else", "enum", "explicit", + "export", "extern", "false", "float", + "for", "friend", "goto", "if", + "inline", "int", "long", "mutable", + "namespace", "new", "not", "not_eq", + "operator", "or", "or_eq", "private", + "protected", "public", "register", "reinterpret_cast", + "return", "short", "signed", "sizeof", + "static", "static_cast", "struct", "switch", + "template", "this", "throw", "true", + "try", "typedef", "typeid", "typename", + "union", "unsigned", "using", "virtual", + "void", "volatile", "wchar_t", "while", + "xor", "xor_eq", "__restrict__", "__constant__", + "__device__", "__global__", "__host__", + }; + static const std::unordered_set cxx11_keywords = { + "alignas", "alignof", "char16_t", "char32_t", "constexpr", + "decltype", "noexcept", "nullptr", "static_assert", "thread_local", + }; + static const std::unordered_set cxx20_keywords = { + "char8_t", "concept", "consteval", "constinit", + "co_await", "co_return", "co_yield", "requires", + }; + static const std::unordered_set ts_keywords = { + "atomic_cancel", "atomic_commit", "atomic_noexcept", + "reflexpr", "synchronized", + }; + if (cxx_standard_year == -1) { + cxx_standard_year = 99; } + if (keywords.count(token_string)) return true; + if (cxx_standard_year < 11) return false; + if (cxx11_keywords.count(token_string)) return true; + if (cxx_standard_year < 20) return false; + if (cxx20_keywords.count(token_string)) return true; + if (cxx_standard_year != 99) return false; + return ts_keywords.count(token_string); +} + +class CppLexer { + class Iterator { + public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = Token; + using pointer = const Token*; + using reference = const Token&; + + Iterator() : lexer_(nullptr) {} + explicit Iterator(CppLexer* _lexer) + : lexer_(_lexer), current_(lexer_->next()) {} + + reference operator*() const { return current_; } + pointer operator->() { return ¤t_; } + + Iterator& operator++() { + current_ = lexer_->next(); + return *this; + } + + Iterator operator++(int) { + Iterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool operator==(const Iterator& lhs, const Iterator& rhs) { + return (lhs.lexer_ == rhs.lexer_ && lhs.current_ == rhs.current_) || + (lhs.current_.type() == Token::Type::kEndOfFile && + rhs.lexer_ == nullptr) || + (lhs.lexer_ == nullptr && + rhs.current_.type() == Token::Type::kEndOfFile); + } + friend bool operator!=(const Iterator& lhs, const Iterator& rhs) { + return !(lhs == rhs); + } + + private: + CppLexer* lexer_; + value_type current_; + }; public: - CppLexer(const char* str) : current_(str) {} - const char* current() const { return current_; } - char advance() { return *current_++; } - void skip(int n) { current_ += n; } - char peek(int i = 0) const { return *(current_ + i); } - bool match(char c) { return peek() == c ? advance() : false; } - bool match(const char* s) { + using iterator = Iterator; + + template + static Container tokenize(const char* source, int _cxx_standard_year = -1) { + CppLexer lexer(source, _cxx_standard_year); + Container result; + for (const Token& token : lexer) { + result.push_back(token); + } + return result; + } + + CppLexer(const char* source, int _cxx_standard_year = -1) + : current_(source), cxx_standard_year_(_cxx_standard_year) {} + + iterator begin() { return iterator(this); } + iterator end() { return iterator(); } + + Token next() { + using Tt = Token::Type; + token_start_ = current_; + char c = advance(); + // clang-format off + switch (c) { + case '\0': return token(Tt::kEndOfFile); + // This just handles the very first character being an escaped newline, + // because all other escaped newlines are skipped over. + case '\\': return token(match('\n') ? Tt::kWhitespace : Tt::kInvalid); + case '\n': return in_directive_ + ? (in_directive_ = false, in_include_directive_ = false, + token(Tt::kEndOfDirective)) + : whitespace(); + case '(': return token(Tt::kLParen); + case ')': return token(Tt::kRParen); + case '[': return token(Tt::kLBracket); + case ']': return token(Tt::kRBracket); + case '<': + return in_include_directive_ + ? angle_include() + : token(((peek_match(":") && !peek_match("::")) || + peek_match(":::") || peek_match("::>")) + ? (match(':'), Tt::kLBracket) + : match('%') + ? Tt::kLBrace + : match('<') + ? match('=') ? Tt::kLtLtEq + : Tt::kLtLt + : match('=') ? Tt::kLtEq : Tt::kLt); + case '>': // Note: This does not distinguish template close vs. bitshift + return token(match('>') ? match('=') ? Tt::kGtGtEq : Tt::kGtGt + : match('=') ? Tt::kGtEq : Tt::kGt); + case ':': return token(match('>') + ? Tt::kRBracket + : match(':') ? Tt::kColonColon : Tt::kColon); + case '{': return token(Tt::kLBrace); + case '}': return token(Tt::kRBrace); + case '%': + return token(match('>') + ? Tt::kRBrace + : match(':') + // TODO: Probably need to do the in_directive_ + // etc. logic here too. + ? match("%:") ? Tt::kHashHash : Tt::kHash + : match('=') ? Tt::kPercentEq : Tt::kPercent); + case '.': return token(match('*') ? Tt::kDotStar : Tt::kDot); + case '-': return token(match('>') ? match('*') ? Tt::kArrowStar + : Tt::kArrow + : match('-') ? Tt::kMinusMinus + : match('=') ? Tt::kMinusEq + : Tt::kMinus); + + case ',': return token(Tt::kComma); + case '+': return token(match('+') ? Tt::kPlusPlus + : match('=') ? Tt::kPlusEq + : Tt::kPlus); + case '*': return token(match('=') ? Tt::kStarEq : Tt::kStar); + case '/': return match('/') + ? line_comment() + : match('*') ? block_comment() + : token(match('=') ? Tt::kSlashEq + : Tt::kSlash); + // Note: Trigraphs not supported. + case '?': return token(Tt::kQuestion); + case '&': return token(match('&') ? Tt::kAmpAmp + : match('=') ? Tt::kAmpEq : Tt::kAmp); + case '|': return token(match('|') ? Tt::kBarBar + : match('=') ? Tt::kBarEq : Tt::kBar); + case '^': return token(match('=') ? Tt::kCaretEq : Tt::kCaret); + case '~': return token(Tt::kTilde); + case '=': return token(match('=') ? Tt::kEqEq : Tt::kEq); + case '!': return token(match('=') ? Tt::kBangEq : Tt::kBang); + case '#': + return token(match('#') ? Tt::kHashHash + : (is_start_of_directive_ = !in_directive_, + in_directive_ = true, + Tt::kHash)); + case '\'': return character(); + case '"': return in_include_directive_ ? quote_include() : string(); + case 'u': match('8'); + // fall-through + [[gnu::fallthrough]]; // Not sure why gcc complains here without this + case 'L': + // fall-through + case 'U': + return match('\'') + ? character() + : match('"') ? string() + : match("R\"") ? raw_string() : identifier(); + case 'R': return match('"') ? raw_string() : identifier(); + case ';': return token(Tt::kSemicolon); + default: + if (is_space(c)) return in_directive_ ? whitespace_except_newlines() + : whitespace(); + if (is_digit(c)) return number(); + if (is_alpha(c)) return identifier(); + } + // clang-format on + return token(Tt::kInvalid); + } + + private: + bool is_space_except_newline(char c) const { + return c == ' ' || c == '\f' || c == '\r' || c == '\t' || c == '\v'; + } + bool is_space(char c) const { + // Note: std::isspace is locale-dependent. + return is_space_except_newline(c) || c == '\n'; + } + bool is_digit(char c) const { return c >= '0' && c <= '9'; } + bool is_alpha(char c) const { + // Note: std::isalpha is locale-dependent. + // Also, implementations may accept additional alphabet characters (e.g., + // MSVC accepts '$', and clang accepts things like Greek alphabet unicode + // chars). + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_' || + c == '$'; + } + bool is_alnum(char c) const { return is_alpha(c) || is_digit(c); } + + bool contains_escaped_newlines(const char* begin, const char* end) const { + for (const char* ptr = begin; ptr != end; ++ptr) { + if (ptr[0] == '\\' && ptr[1] == '\n') return true; + } + return false; + } + std::string without_escaped_newlines(const char* begin, + const char* end) const { + std::string result; + result.reserve(end - begin); + for (const char* ptr = begin; ptr != end; ++ptr) { + if (ptr[0] == '\\' && ptr[1] == '\n') { + ++ptr; + } else { + result.push_back(*ptr); + } + } + return result; + } + + const char* skip_escaped_newlines(const char* ptr) const { + while (*ptr == '\\' && *(ptr + 1) == '\n') ptr += 2; + return ptr; + } + const char* reverse_skip_escaped_newlines(const char* ptr) const { + while (*ptr == '\n' && *(ptr - 1) == '\\') ptr -= 2; + return ptr; + } + const char* advance_by(const char* ptr, int n) const { + if (n == 0) return ptr; + bool reverse = n < 0; + n = reverse ? -n : n; + // Skip over escaped newlines (which can appear anywhere, even in the middle + // of tokens). + for (int i = 0; i < n; ++i) { + ptr = reverse ? reverse_skip_escaped_newlines(ptr - 1) + : skip_escaped_newlines(ptr + 1); + } + return ptr; + } + + char advance() { + char ret = *current_; + // Only advance if this isn't the end of the string. + if (ret) current_ = advance_by(current_, 1); + return ret; + } + char peek(int i = 0) const { return *advance_by(current_, i); } + int peek_match(const char* s) { int i; for (i = 0; s[i]; ++i) { - if (!peek(i) || peek(i) != s[i]) return false; + if (!peek(i) || peek(i) != s[i]) return 0; } - current_ += i; + return i; + } + bool match(char c) { return peek() == c && advance(); } + bool match(const char* s) { + int n = peek_match(s); + if (!n) return false; + current_ = advance_by(current_, n); return true; } - bool match_whitespace() { - // Includes line continuations. - return (isspace(peek()) || (peek() == '\\' && peek(1) == '\n')) ? advance() - : false; + bool match_literal_suffix() { + if (!is_alpha(peek())) return false; + advance(); + while (is_alnum(peek())) advance(); + return true; } - const char* whitespace() { - while (match_whitespace()) { + Token escapable_char_delimited_span(char delim, Token::Type token_type, + bool include_suffix = true, + bool enable_escapes = true) { + bool in_escape = false; + // Note: We stop if we reach an unescaped newline because it's a syntax + // error and we don't want to run on into the next line. + while (peek() && ((peek() != delim && peek() != '\n') || in_escape)) { + in_escape = enable_escapes && !in_escape && peek() == '\\'; + advance(); } - // while (isspace(peek()) || (peek() == '\\' && peek(1) == '\n')) advance(); - return current_; - } - const char* escapable_char_delimited_span(char delim) { - while (peek() && (peek() != delim || peek(-1) == '\\')) advance(); - if (peek() == delim) { - skip(1); - } else { - // Error, unexpected end of string. - } - return current_; - } - // Excludes the ending newline char. - const char* line() { return escapable_char_delimited_span('\n') - 1; } - // These all include the ending delimiter chars. - const char* string_literal() { return escapable_char_delimited_span('"'); } - const char* char_literal() { return escapable_char_delimited_span('\''); } - const char* delimited_span(const char* delim, int delim_size) { - auto peek_equals_delimiter = [&] { - for (int i = 0; i < delim_size; ++i) { - if (peek(i) != delim[i]) return false; + // Include the ending delimiter. + if (match(delim) && include_suffix) { + // Include literal suffix. + match_literal_suffix(); + } + return token(token_type); + } + + // Constructs a token to represent the current part of the source. + Token token(Token::Type type) const { + std::string token_string; + // If the source string contains escaped newlines, we remove them to + // construct a clean token string. + if (type == Token::Type::kRawString) { + // Special processing for raw strings because we must preserve escaped + // newlines inside them. + // TODO: Check how escaped newlines inside the delimiters should be + // handled. + const char* first_quotes = token_start_; + while (first_quotes != current_ && first_quotes[0] != '"') ++first_quotes; + if (contains_escaped_newlines(token_start_, first_quotes)) { + token_string = without_escaped_newlines(token_start_, first_quotes) + + std::string(first_quotes, current_); } - return true; - }; - while (peek() && !peek_equals_delimiter()) advance(); - if (peek() == delim[0]) { - skip(delim_size); } else { - // Error, unexpected end of string. + if (contains_escaped_newlines(token_start_, current_)) { + token_string = without_escaped_newlines(token_start_, current_); + } } - return current_; + return Token(type, token_start_, current_, std::move(token_string)); } - const char* block_comment() { return delimited_span("*/", 2); } - const char* raw_string_literal() { + + Token whitespace() { + while (is_space(peek())) { + advance(); + } + return token(Token::Type::kWhitespace); + } + Token whitespace_except_newlines() { + while (is_space_except_newline(peek())) advance(); + return token(Token::Type::kWhitespace); + } + Token number() { + while (is_alnum(peek())) advance(); + return token(Token::Type::kNumber); + } + Token string() { + return escapable_char_delimited_span('"', Token::Type::kString); + } + Token raw_string() { const char* delim_beg = current_; while (peek() && peek() != '(') advance(); std::string delim; @@ -5888,236 +6366,899 @@ class CppLexer { delim += ')'; delim.append(delim_beg, current_); delim += '"'; - return delimited_span(delim.c_str(), (int)delim.size()); + while (peek() && !match(delim.c_str())) advance(); + match_literal_suffix(); + return token(Token::Type::kRawString); + } + Token quote_include() { + // Note: Strings in #include directives treat backslashes literally, not as + // escapes. + return escapable_char_delimited_span('"', Token::Type::kString, false, + false); + } + Token angle_include() { + // Note: Strings in #include directives treat backslashes literally, not as + // escapes. + return escapable_char_delimited_span('>', Token::Type::kString, false, + false); + } + Token character() { + return escapable_char_delimited_span('\'', Token::Type::kCharacter); + } + Token identifier() { + while (is_alnum(peek())) advance(); + Token result = token(Token::Type::kIdentifier); + if (!is_start_of_directive_ && + is_keyword(result.token_string(), cxx_standard_year_)) { + result = Token(Token::Type::kKeyword, result.begin(), result.end(), + result.token_string()); + } + if (in_directive_) { + if (is_start_of_directive_ && result.token_string() == "include") { + in_include_directive_ = true; + } + is_start_of_directive_ = false; + } + return result; + } + Token line_comment() { + // Excludes the newline. + while (peek() && peek() != '\n') advance(); + return token(Token::Type::kComment); + } + Token block_comment() { + while (peek() && !match("*/")) advance(); + return token(Token::Type::kComment); } + + const char* current_; + int cxx_standard_year_; + const char* token_start_; + bool in_directive_ = false; + bool is_start_of_directive_ = false; + bool in_include_directive_ = false; }; -inline bool find_pragma_once(const std::string& source, size_t* begin_ptr, - size_t* end_ptr) { - // Match string literals, comments (/), and preprocessor directives (#). - const char* match_chars = "\"'R/#"; - size_t pos = 0; - while ((pos = source.find_first_of(match_chars, pos)) != std::string::npos) { - const char* beg = source.c_str() + pos; - CppLexer lexer(beg); - bool hit = false; - const char* end = [&] { - // clang-format off - switch (lexer.advance()) { - case '"': return lexer.string_literal(); - case '\'': return lexer.char_literal(); - case 'R': return lexer.match('"') ? lexer.raw_string_literal() : - lexer.current(); - case '/': return lexer.match('/') ? lexer.line() : - lexer.match('*') ? lexer.block_comment() : - lexer.current(); - case '#': return (hit = lexer.match("pragma") && - lexer.match_whitespace() && - (lexer.whitespace(), lexer.match("once"))), - lexer.current(); - default: return lexer.current(); // Should never be reached - } - // clang-format on - }(); - if (hit) { - *begin_ptr = pos; - *end_ptr = end - source.c_str(); - return true; +// Pastes two tokens together as per the ## macro operator. +// Returns a kInvalid token if the concatenation does not form a valid token. +inline Token concatenate(const Token& lhs, const Token& rhs, + int cxx_standard_year) { + using Tt = Token::Type; + std::string combined_token_string = lhs.token_string() + rhs.token_string(); + Token::Type type = [&] { + auto match = [&](Token::Type x, Token::Type y) -> bool { + return lhs.type() == x && rhs.type() == y; + }; + if (match(Tt::kLt, Tt::kColon)) return Tt::kLBracket; + if (match(Tt::kColon, Tt::kGt)) return Tt::kRBracket; + if (match(Tt::kLt, Tt::kPercent)) return Tt::kLBrace; + if (match(Tt::kPercent, Tt::kGt)) return Tt::kRBrace; + if (match(Tt::kDot, Tt::kStar)) return Tt::kDotStar; + if (match(Tt::kMinus, Tt::kGt)) return Tt::kArrow; + if (match(Tt::kArrow, Tt::kStar)) return Tt::kArrowStar; + if (match(Tt::kPlus, Tt::kPlus)) return Tt::kPlusPlus; + if (match(Tt::kPlus, Tt::kEq)) return Tt::kPlusEq; + if (match(Tt::kMinus, Tt::kMinus)) return Tt::kMinusMinus; + if (match(Tt::kMinus, Tt::kEq)) return Tt::kMinusEq; + if (match(Tt::kStar, Tt::kEq)) return Tt::kStarEq; + if (match(Tt::kSlash, Tt::kEq)) return Tt::kSlashEq; + if (match(Tt::kPercent, Tt::kEq)) return Tt::kPercentEq; + if (match(Tt::kColon, Tt::kColon)) return Tt::kColonColon; + if (match(Tt::kAmp, Tt::kAmp)) return Tt::kAmpAmp; + if (match(Tt::kAmp, Tt::kEq)) return Tt::kAmpEq; + if (match(Tt::kBar, Tt::kBar)) return Tt::kBarBar; + if (match(Tt::kBar, Tt::kEq)) return Tt::kBarEq; + if (match(Tt::kCaret, Tt::kEq)) return Tt::kCaretEq; + if (match(Tt::kEq, Tt::kEq)) return Tt::kEqEq; + if (match(Tt::kBang, Tt::kEq)) return Tt::kBangEq; + if (match(Tt::kLt, Tt::kLt)) return Tt::kLtLt; + if (match(Tt::kLt, Tt::kEq)) return Tt::kLtEq; + if (match(Tt::kLt, Tt::kLtEq)) return Tt::kLtLtEq; + if (match(Tt::kLtLt, Tt::kEq)) return Tt::kLtLtEq; + if (match(Tt::kGt, Tt::kGt)) return Tt::kGtGt; + if (match(Tt::kGt, Tt::kEq)) return Tt::kGtEq; + if (match(Tt::kGt, Tt::kGtEq)) return Tt::kGtGtEq; + if (match(Tt::kGtGt, Tt::kEq)) return Tt::kGtGtEq; + if (match(Tt::kHash, Tt::kHash)) return Tt::kHashHash; + if (match(Tt::kPercent, Tt::kColon)) return Tt::kHash; + // E.g., 123 ## 456. + if (match(Tt::kNumber, Tt::kNumber)) return Tt::kNumber; + // E.g., 123 ## ull. + if (match(Tt::kNumber, Tt::kIdentifier)) return Tt::kNumber; + // E.g., abc ## 123, class ## 123 + if (lhs.matches(Tt::kIdentifier, Tt::kKeyword) && + rhs.type() == Tt::kNumber) { + return Tt::kIdentifier; } - pos += end - beg; - } - return false; + // E.g., u8 ## 'c'. + if (match(Tt::kIdentifier, Tt::kCharacter)) return Tt::kCharacter; + // E.g., u8 ## "abc" (but not include ## ). + // TODO: Consider using a separate kAngleString instead (it would simplify + // this but slightly complicate parsing of include directives). + if (match(Tt::kIdentifier, Tt::kString) && + lhs.token_string() != "include") { + return Tt::kString; + } + // E.g., u8 ## R"(abc)". + if (match(Tt::kIdentifier, Tt::kRawString)) return Tt::kRawString; + // E.g., 'c' ## _foo. + if (match(Tt::kCharacter, Tt::kIdentifier)) return Tt::kCharacter; + // E.g., "foo" ## s. + if (match(Tt::kString, Tt::kIdentifier)) return Tt::kString; + // E.g., R"(foo)" ## s. + if (match(Tt::kRawString, Tt::kIdentifier)) return Tt::kRawString; + // E.g., abc ## def -> ident, cl ## ass -> kw, not ## using -> ident. + if (lhs.matches(Tt::kIdentifier, Tt::kKeyword) && + rhs.matches(Tt::kIdentifier, Tt::kKeyword)) { + return is_keyword(combined_token_string, cxx_standard_year) + ? Tt::kKeyword + : Tt::kIdentifier; + } + return Tt::kInvalid; + }(); + return Token(type, lhs.begin(), rhs.end(), std::move(combined_token_string)); } -inline std::string remove_cpp_comments_and_line_continuations( - const std::string& source) { - std::string result; - result.reserve(source.size()); - size_t old_pos = 0, pos; - // Match string literals, comments (forward slashes), and line continuations - // (backslashes). - const char* match_chars = "\"'R/\\"; - while ((pos = source.find_first_of(match_chars, old_pos)) != - std::string::npos) { - result.append(source, old_pos, pos - old_pos); - const char* beg = source.c_str() + pos; - CppLexer lexer(beg); - const char* end = [&] { - // clang-format off - switch (lexer.advance()) { - case '"': return lexer.string_literal(); - case '\'': return lexer.char_literal(); - case 'R': return lexer.match('"') ? lexer.raw_string_literal() : - lexer.current(); - case '/': return lexer.match('/') ? lexer.line() : - lexer.match('*') ? lexer.block_comment() : - lexer.current(); - // Match line continuation (escaped newline). - // TODO: Line continuations inside string literals will not be matched - // here. Would need to use a separate pass that only matches them and - // raw strings. - case '\\': return lexer.match('\n'), lexer.current(); - default: return lexer.current(); // Should never be reached +template +class TokenHistoryBuffer { + public: + using value_type = Token; + using reference = Token&; + using const_reference = const Token&; + + constexpr int size() const { return Size; } + + void push(const Token& value) { + if (++head_ == size()) { + head_ = 0; + } + data_[head_] = value; + } + + // Requires i to be in the range (-size(), 0], where + // i=0 corresponds to the most recent value. + const_reference operator[](int i) const { + assert(-size() < i && i <= 0); + int idx = head_ + i; + if (idx < 0) { + idx += size(); + } + return data_[idx]; + } + + bool match(std::initializer_list token_types) const { + assert((int)token_types.size() <= size()); + int i = 0; + for (Token::Type token_type : token_types) { + Token::Type historic_type = + (*this)[-(int)token_types.size() + 1 + i++].type(); + if (historic_type != token_type) { + return false; } - // clang-format on - }(); - old_pos = end - source.c_str(); - if (end - beg == 1 || *beg == '"' || *beg == '\'' || *beg == 'R') { - // Keep single characters ('/') and string literals. - result.append(beg, end); - } else { - // Elide comments and line continuations. } + return true; } - result.append(source, old_pos, std::string::npos); - return result; + + // Removes the most recent entry. + void pop() { + data_[head_--] = value_type(); + if (head_ < 0) head_ += size(); + } + + private: + std::array data_ = {}; + int head_ = -1; +}; + +// This filters out whitespace and comments from an iterator over Tokens, and +// provides several convenience methods to assist parsing. +template +class CppParserIterator { + public: + using token_iterator = TokenIterator; + using iterator_category = typename std::conditional< + std::is_same< + typename std::iterator_traits::iterator_category, + std::input_iterator_tag>::value, + std::input_iterator_tag, std::forward_iterator_tag>::type; + using difference_type = + typename std::iterator_traits::difference_type; + using value_type = typename std::iterator_traits::value_type; + using reference = typename std::iterator_traits::reference; + using pointer = typename std::iterator_traits::pointer; + + explicit CppParserIterator(token_iterator token_iter, token_iterator _end) + : previous_tokens_(), current_(token_iter), end_(_end) { + skip_whitespace_and_comments(); + } + + token_iterator base() const { return current_; } + + // Construct a corresponding end iterator for use with iterator-based + // algorithms. + CppParserIterator end() const { return CppParserIterator(end_, end_); } + + explicit operator bool() const { return current_ != end_; } + + reference operator*() const { return *current_; } + token_iterator operator->() const { return current_; } + + // Advances to the next non-whitespace and non-comment token. + CppParserIterator& operator++() { + previous_tokens_.push(*current_); + ++current_; + skip_whitespace_and_comments(); + return *this; + } + + CppParserIterator operator++(int) { + CppParserIterator tmp(*this); + ++(*this); + return tmp; + } + + // Requires idx to be in the range (-size(), 0], where + // idx=0 corresponds to the most recent value (before current). + const value_type& previous_token(int idx = 0) const { + return previous_tokens_[idx]; + } + + bool match(Token::Type token_type) { + if (current_->type() != token_type) return false; + ++(*this); + return true; + } + + template + bool match(TokenTypes... token_types) { + if (!current_->matches(token_types...)) return false; + ++(*this); + return true; + } + + bool match_identifier(const char* name) { + if (current_->type() != Token::Type::kIdentifier || + current_->token_string() != name) { + return false; + } + ++(*this); + return true; + } + + // Advances to the first token with the given type. + CppParserIterator& advance_to(Token::Type token_type) { + while (*this && (*this)->type() != token_type) ++(*this); + return *this; + } + + // Erases tokens from *token_container in the range [first_to_erase, *this] + // inclusive, and sets *this to point to the next parser token. + template + CppParserIterator& erase_back_to(Container* token_container, + CppParserIterator first_to_erase) { + for (token_iterator it = first_to_erase.base(); it != current_; ++it) { + previous_tokens_.pop(); + } + current_ = token_container->erase(first_to_erase.base(), ++current_); + skip_whitespace_and_comments(); + return *this; + } + + int line_number() const { return line_num_; } + + bool has_whitespace_before() const { return whitespace_before_; } + + private: + void skip_whitespace_and_comments() { + line_num_ += previous_tokens_[0].num_newlines(); + whitespace_before_ = false; + while (current_ != end_ && + current_->matches(Token::Type::kWhitespace, Token::Type::kComment)) { + line_num_ += current_->num_newlines(); + ++current_; + whitespace_before_ = true; + } + using Tt = Token::Type; + // Handle #line preprocessor directives. + if (previous_tokens_.match( + {Tt::kHash, Tt::kIdentifier, Tt::kNumber, Tt::kEndOfDirective}) && + previous_tokens_[-2].matches_identifier("line")) { + // TODO: Should check this for invalid values (non-integer or negative + // integer; strangely, zero is allowed). + line_num_ = std::atoi(previous_tokens_[-1].token_string().c_str()); + } else if (previous_tokens_.match({Tt::kHash, Tt::kIdentifier, Tt::kNumber, + Tt::kString, Tt::kEndOfDirective}) && + previous_tokens_[-3].matches_identifier("line")) { + line_num_ = std::atoi(previous_tokens_[-2].token_string().c_str()); + // TODO: The string token should be used as the new filename. + } + } + + TokenHistoryBuffer<5> previous_tokens_; + token_iterator current_; + token_iterator end_; + int line_num_ = 1; + bool whitespace_before_ = false; +}; + +template +inline CppParserIterator make_cpp_parser_iterator( + TokenIterator iter, TokenIterator end) { + return CppParserIterator(iter, end); } -// This removes most but not all whitespace. Remaining whitespace is tricky to -// handle safely+efficiently. -inline std::string remove_cpp_whitespace(const std::string& source) { - std::string result; - result.reserve(source.size()); - size_t old_pos = 0, pos; - // Match string literals, preprocessor directives, whitespace, and chars that - // can safely have whitespace after them removed. - bool inside_directive = false; - const char* match_chars = "\"'R# \f\n\r\t\v.,;!|~^()[]{}"; - while ((pos = source.find_first_of(match_chars, old_pos)) != - std::string::npos) { - result.append(source, old_pos, pos - old_pos); - const char* beg = source.c_str() + pos; - CppLexer lexer(beg); - bool end_of_directive = false; - bool is_whitespace = false; - const char* end = [&] { - // clang-format off - char c = lexer.advance(); - switch (c) { - case '"': return lexer.string_literal(); - case '\'': return lexer.char_literal(); - case 'R': return lexer.match('"') ? lexer.raw_string_literal() : - lexer.current(); - case '#': return inside_directive = true, lexer.current(); - default: return is_whitespace = true, lexer.whitespace(); +struct SourceLocation { + SourceLocation() = default; + SourceLocation(std::string _filename, int _line = 0) + : filename_(std::move(_filename)), line_(_line) {} + + const std::string& file_name() const noexcept { return filename_; } + int line() const noexcept { return line_; } + + friend std::string to_string(const SourceLocation& location) { + return location.file_name() + ":" + std::to_string(location.line()); + } + + private: + std::string filename_; + int line_ = 0; +}; + +static const char* const kJitifyDirPrefix = "__jitify_rel_inc:"; +static const char* const kJitifyNamePrefix = ":__jitify_name:"; + +class IncludeName { + public: + IncludeName() = default; + /* Construct as a <> include (unless _include_name is a patched name, in which + * case it is parsed into a "" include. + */ + explicit IncludeName(std::string _include_name, SourceLocation _location = {}) + : include_name_(std::move(_include_name)), + location_(std::move(_location)) { + const size_t prefix_len = std::strlen(kJitifyDirPrefix); + if (include_name_.substr(0, prefix_len) == kJitifyDirPrefix) { + // Parse patched name. + const size_t dir_end = include_name_.find(kJitifyNamePrefix, prefix_len); + assert(dir_end != std::string::npos); + current_dir_ = include_name_.substr(prefix_len, dir_end - prefix_len); + include_name_ = + include_name_.substr(dir_end + std::strlen(kJitifyNamePrefix)); + } + } + /* Construct as a "" include. + */ + IncludeName(std::string _include_name, std::string _current_dir, + SourceLocation _location = {}) + : include_name_(std::move(_include_name)), + current_dir_(std::move(_current_dir)), + location_(std::move(_location)) { + // Absolute paths should always be treated like <> includes. + if (jitify2::detail::path_is_absolute(include_name_)) { + current_dir_.clear(); + } + } + /*! Returns the filename of the include directive (the part inside "" or <>). + */ + const std::string& name() const { return include_name_; } + /*! For "" includes, returns the current directory in which the include + * directive was present. For <> includes, returns empty string. + */ + const std::string& current_dir() const { return current_dir_; } + /*! Returns whether this is a "" include (as opposed to a <> include).*/ + bool is_quote_include() const { return !current_dir_.empty(); } + /*! Returns the full path to the header assuming it exists in its current + * directory. Must only be called for "" includes, never <> includes. + */ + std::string local_full_path() const { + assert(is_quote_include()); + return is_quote_include() ? current_dir() + "/" + name() : ""; + } + /*! Returns the full path to the header assuming it exists in the given + * include directory. May be called for either "" or <> includes. + */ + std::string nonlocal_full_path(const std::string& include_dir) const { + return include_dir + "/" + include_name_; + } + // For quote-includes, this returns a modified name that encodes the current + // dir too. + std::string patched_name() const { + if (!is_quote_include()) return name(); + return kJitifyDirPrefix + current_dir() + kJitifyNamePrefix + name(); + } + + friend bool operator==(const IncludeName& lhs, const IncludeName& rhs) { + return lhs.name() == rhs.name() && lhs.current_dir() == rhs.current_dir(); + } + friend bool operator!=(const IncludeName& lhs, const IncludeName& rhs) { + return !(lhs == rhs); + } + + size_t hash() const { + using jitify2::detail::string_concat; + const std::string hash_str = + is_quote_include() + ? string_concat('"', include_name_, '"', current_dir_) + : string_concat('<', include_name_, '>'); + return std::hash()(hash_str); + } + struct Hash { + size_t operator()(const IncludeName& x) const { return x.hash(); } + }; + + // Implicit conversion to string to maintain backwards compatibility with + // FileCallback. + operator const std::string &() const { return name(); } + + friend std::string to_string(const IncludeName& incname) { + using jitify2::detail::string_concat; + return incname.is_quote_include() ? string_concat('"', incname.name(), '"') + : string_concat('<', incname.name(), '>'); + } + + const SourceLocation& location() const { return location_; } + + private: + std::string include_name_; + std::string current_dir_; // Empty for <> includes, non-empty for "" includes + // Informational only. + SourceLocation location_; +}; + +// Visitor must be callable with signature: +// (IncludeName, CppParserIterator) -> void. +template +inline ErrorMsg visit_all_include_directives(TokenIterator begin, + TokenIterator end, + const std::string& full_path, + Visitor visitor) { + auto error_msg = [&](int line_number, const std::string& msg) { + return ErrorMsg(full_path + ":" + std::to_string(line_number) + + ": error: " + msg); + }; + using Tt = Token::Type; + for (auto iter = make_cpp_parser_iterator(begin, end); iter; ++iter) { + if (iter.match(Tt::kHash)) { + if (!iter.match(Tt::kIdentifier)) { + return error_msg( + iter.line_number(), + "invalid preprocessing directive #" + iter->source_string()); } - // clang-format on - }(); - if (inside_directive && is_whitespace && std::find(beg, end, '\n') != end) { - inside_directive = false; - end_of_directive = true; - } - old_pos = end - source.c_str(); - if ((end - beg == 1 && !std::isspace((unsigned char)*beg)) || *beg == '"' || - *beg == '\'' || *beg == 'R' || *beg == '#') { - // Keep single characters ('R'), string literals, and preprocessor - // directives. - result.append(beg, end); - } else { - // Elide or replace whitespace. - bool before_directive = !inside_directive && *end == '#'; - if (!std::isspace((unsigned char)*beg)) { - // Remove whitespace after symbol. - result += *beg; - if (end_of_directive || before_directive) { - result += '\n'; - } - } else { - if (end_of_directive) { - result += '\n'; - } else { - // A newline may already be present from a preprocessor directive. - bool after_newline = result.empty() || result.back() == '\n'; - if (!after_newline || before_directive) { - // Replace whitespace. - result += before_directive ? '\n' : ' '; + if (iter.previous_token().token_string() == "include") { + auto prev_iter = iter; + // Note: It is possible to have macro substitutions here instead of a + // string literal, but it is very rare, and some popular tools are + // known to not support it (e.g., scons). Of course, Thrust does it! + if (!iter.match(Tt::kString)) { + // WAR for Thrust using macro substitutions in an #include directive. + if (iter->matches_identifier("__THRUST_HOST_SYSTEM_TAG_HEADER")) { + *iter = Token(Tt::kString, iter->begin(), iter->end(), + ""); + ++iter; + } else if (iter->matches_identifier( + "__THRUST_DEVICE_SYSTEM_TAG_HEADER")) { + *iter = Token(Tt::kString, iter->begin(), iter->end(), + ""); + ++iter; + + } else { + return error_msg( + iter.line_number(), + "#include expects \"FILENAME\" or , got " + + iter->source_string()); } } + + std::string include_name = iter.previous_token().token_string(); + const bool is_quote_include = include_name[0] == '"'; + // Remove quotes/angles. + include_name = include_name.substr(1, include_name.size() - 2); + const std::string current_dir = jitify2::detail::path_base(full_path); + SourceLocation location(full_path, iter.line_number()); + IncludeName include = + is_quote_include + ? IncludeName(std::move(include_name), current_dir, + std::move(location)) + : IncludeName(std::move(include_name), std::move(location)); + visitor(std::move(include), prev_iter); } + iter.advance_to(Tt::kEndOfDirective); + if (!iter) break; } } - result.append(source, old_pos, std::string::npos); - return result; -} - -// WAR for #pragma once not working when there are multiple inclusions of the -// same header from different paths. -inline std::string replace_pragma_once_with_ifndef(const std::string& source) { + return {}; +} + +template +inline Iterator insert_directive_impl(TokenSequence* tokens, Iterator where, + const Token (&directive_tokens)[N]) { + using Tt = Token::Type; + // TODO: Find a safer way to do this. + constexpr int kMaxNewTokens = 1 + 1 + (2 * N - 1) + 1; + Token new_tokens[kMaxNewTokens]; + int j = 0; + Iterator before_where = where; + --before_where; + if (where != tokens->begin() && before_where->type() != Tt::kEndOfDirective && + (before_where->type() != Tt::kWhitespace || + before_where->num_unescaped_newlines() == 0)) { + // Must add newline before new directive. + new_tokens[j++] = Token(Tt::kWhitespace, "\n"); + } + new_tokens[j++] = Token(Tt::kHash, "#"); + for (int i = 0; i < N; ++i) { + if (i > 0) { + new_tokens[j++] = Token(Tt::kWhitespace, " "); + } + new_tokens[j++] = directive_tokens[i]; + } + new_tokens[j++] = Token(Tt::kEndOfDirective, "\n"); + assert(j <= kMaxNewTokens); + return tokens->insert(where, new_tokens, new_tokens + j); +} + +template +inline Iterator insert_directive(TokenSequence* tokens, Iterator where, + const std::string& name, + const DirectiveTokens&... directive_tokens) { + Token directive_tokens_array[] = {Token(Token::Type::kIdentifier, name), + directive_tokens...}; + return insert_directive_impl(tokens, where, directive_tokens_array); +} + +// Note: List seems to be up to 4x faster than deque. +using TokenSequence = std::list; + +// Returns true if a pragma once directive was found. +inline bool replace_pragma_once_with_ifndef(const std::string& unique_source_id, + TokenSequence* tokens) { + using Tt = Token::Type; + // Find and remove all "#pragma once" directives. + bool found = false; + for (auto iter = make_cpp_parser_iterator(tokens->begin(), tokens->end()); + iter;) { + auto start_iter = iter; + if (iter.match(Tt::kHash)) { + if (iter.match_identifier("pragma") && iter.match_identifier("once")) { + iter.advance_to(Tt::kEndOfDirective); + if (!iter) break; + // Note: The ++ here advances to the next _base_ token (because we don't + // want to jump over subsequent comment or whitespace tokens). + iter.erase_back_to(tokens, start_iter); + found = true; + // Note: There can be more than one #pragma once. + continue; + } else { + iter.advance_to(Tt::kEndOfDirective); + if (!iter) break; + } + } + ++iter; + } constexpr const char* const kJitifyIncludeGuardPrefix = "JITIFY_INCLUDE_GUARD_"; - if (startswith(source, std::string("#ifndef ") + kJitifyIncludeGuardPrefix)) { - return source; // Already been processed - } - size_t begin, end; - if (!find_pragma_once(source, &begin, &end)) return source; - // Replace #pragma once with hash-based include guard around source. - std::string include_guard_name = - string_concat(kJitifyIncludeGuardPrefix, sha256(source), "\n"); - // Note: We use `#line 1` to fix the line numbering after adding additional - // code at the beginning of the file. - std::string prefix = string_concat("#ifndef ", include_guard_name, "#define ", - include_guard_name, "#line 1\n"); - std::string suffix = "\n#endif // " + include_guard_name; - std::string result; - result.reserve(prefix.size() + source.size() + suffix.size()); - result += prefix; - result.append(source, 0, begin); - result.append(source, end, std::string::npos); - result += suffix; - return result; + if (found) { + using jitify2::detail::sha256; + using jitify2::detail::string_concat; + // Insert a hash-based include guard around the source. + std::string include_guard_name = + string_concat(kJitifyIncludeGuardPrefix, sha256(unique_source_id)); + Token guard_identifier(Tt::kIdentifier, include_guard_name); + // Note: Reverse order due to insertion at the beginning. + insert_directive(tokens, tokens->begin(), "define", guard_identifier); + insert_directive(tokens, tokens->begin(), "ifndef", guard_identifier); + insert_directive(tokens, tokens->end(), "endif", + Token(Tt::kComment, "// " + include_guard_name)); + } + return found; +} + +// Changes usages of "std::" to "cuda::std::". +// TODO: This isn't completely robust because we don't apply macro +// substitutions. +template +inline void replace_std_with_cuda_std(TokenSequence* tokens) { + using Tt = Token::Type; + for (auto iter = make_cpp_parser_iterator(tokens->begin(), tokens->end()); + iter;) { + if (((iter.previous_token().type() != Tt::kIdentifier && + iter.match(Tt::kColonColon)) || + iter.previous_token().type() != Tt::kColonColon)) { + auto before_std_iter = iter; + if (iter.match_identifier("std") && iter.match(Tt::kColonColon)) { + tokens->insert(before_std_iter.base(), Token(Tt::kIdentifier, "cuda")); + tokens->insert(before_std_iter.base(), Token(Tt::kColonColon, "::")); + } else if (iter.previous_token().type() != Tt::kColonColon) { + ++iter; + } + } else { + ++iter; + } + } } -inline std::string patch_cuda_source(std::string source, bool use_cuda_std, - bool replace_pragma_once) { - if (use_cuda_std) { - source = detail::replace_std_with_cuda_std(std::move(source)); +inline bool must_separate_tokens(const Token& lhs, const Token& rhs, + int cxx_standard_year) { + using Tt = Token::Type; + // Check if concatenating them would form a new token. + return concatenate(lhs, rhs, cxx_standard_year) || + // These are parsed greedily, so lhs/rhs would become reversed. + // E.g., a+++b == a++ +b. + // Note: It's very important to get these right, because otherwise it + // will silently introduce bugs in the minified source. + (lhs.matches(Tt::kPlus) && rhs.matches(Tt::kPlusPlus)) || + (lhs.matches(Tt::kMinus) && rhs.matches(Tt::kMinusMinus)) || + (lhs.matches(Tt::kColon) && rhs.matches(Tt::kColonColon)) || + (lhs.matches(Tt::kGt) && rhs.matches(Tt::kGtGt)); +} + +template +inline void minify_cuda_source(TokenIterator begin, TokenIterator end, + int cxx_standard_year, + std::string* minified_source) { + using Tt = Token::Type; + minified_source->clear(); + bool in_directive = false; + for (auto iter = make_cpp_parser_iterator(begin, end); iter; ++iter) { + if (iter.previous_token() && + must_separate_tokens(iter.previous_token(), *iter, cxx_standard_year)) { + minified_source->push_back(' '); + // TODO: The below condition should really check that the hash is the + // start of a directive (and not another hash inside a directive), but + // there's not an easy way to do it here. Using a new kStartOfIdentifier + // type is a possibility, but it complicates other things. + } else if (!iter->matches(Tt::kEndOfDirective) && + iter.has_whitespace_before() && + iter.previous_token().matches(Tt::kIdentifier) && + iter.previous_token(-1).matches_identifier("define") && + iter.previous_token(-2).matches(Tt::kHash)) { + // Must separate macro name and definition with whitespace. + // E.g., `FOO-123` is OK, but `#define FOO-123` is not. + // E.g., `FOO(bar)` is OK, but `#define FOO(bar)` is different to + // `#define FOO (bar)`. + minified_source->push_back(' '); + } + if (!in_directive && iter->type() == Tt::kHash) { + in_directive = true; + if (iter.previous_token() && + !iter.previous_token().matches(Tt::kEndOfDirective)) { + // Must start directives on a new line. + minified_source->push_back('\n'); + } + } else if (in_directive && iter->type() == Tt::kEndOfDirective) { + in_directive = false; + } + // Note: Using token_string() means that escaped newlines are elided. + minified_source->append(iter->token_string()); } - if (replace_pragma_once) { - source = detail::replace_pragma_once_with_ifndef(std::move(source)); +} + +enum class ProcessFlags : unsigned { + kNone = 0, + kReplacePragmaOnce = 1 << 0, + kReplaceStd = 1 << 1, + kMinify = 1 << 2, + kAddUsedHeaderWarning = 1 << 3, +}; +inline ProcessFlags operator|(ProcessFlags lhs, ProcessFlags rhs) { + using T = typename std::underlying_type::type; + return static_cast(static_cast(lhs) | static_cast(rhs)); +} +inline ProcessFlags& operator|=(ProcessFlags& lhs, ProcessFlags rhs) { + lhs = lhs | rhs; + return lhs; +} +inline bool operator&(ProcessFlags lhs, ProcessFlags rhs) { + using T = typename std::underlying_type::type; + return static_cast(lhs) & static_cast(rhs); +} + +// Note: The returned includes are _all_ the includes in the source, even if +// they end up not being reachable due to #if[def] directives. +// Note: It is OK if source and *processed_source are the same underlying memory +// (i.e., in-place operation is OK). +template +inline ErrorMsg process_cuda_source(const std::string& source, + const std::string& full_path, + ProcessFlags flags, int cxx_standard_year, + std::string* processed_source, + IncludeVisitor include_visitor) { + using Tt = Token::Type; + auto tokens = CppLexer::tokenize(source.c_str()); + using TokenIterator = TokenSequence::iterator; + ErrorMsg err = visit_all_include_directives( + tokens.begin(), tokens.end(), full_path, + [&](IncludeName include, CppParserIterator iter) { + if (include.is_quote_include()) { + // Change `#include "name"` to `#include `, where + // patched_name encodes the current dir as well as the name. + *iter = Token(Tt::kString, "<" + include.patched_name() + ">"); + } + include_visitor(std::move(include)); + }); + if (err) return err; + // Insert "#line 1" at the beginning of the file so that line numbering is + // not messed up by subsequent line insertions at the beginning. + // Note: Reverse order due to insertion at the beginning. + insert_directive(&tokens, tokens.begin(), "line", Token(Tt::kNumber, "1")); + if (flags & ProcessFlags::kAddUsedHeaderWarning) { + // Insert a guarded #warning that we can use to see if this header was + // actually included during compilation. + insert_directive(&tokens, tokens.begin(), "endif"); + insert_directive(&tokens, tokens.begin(), "warning", + Token(Tt::kIdentifier, "JITIFY_USED_HEADER"), + Token(Tt::kString, "\"" + full_path + "\"")); + insert_directive(&tokens, tokens.begin(), "ifdef", + Token(Tt::kIdentifier, "JITIFY_USED_HEADER_WARNINGS")); + } + if (flags & ProcessFlags::kReplacePragmaOnce) { + // Note: Must use source itself as unique idenfitier because multiple + // filenames may refer to the same file (via copy/symlink/hardlink). + replace_pragma_once_with_ifndef(source, &tokens); + } + if (flags & ProcessFlags::kReplaceStd) { + replace_std_with_cuda_std(&tokens); + } + if (flags & ProcessFlags::kMinify) { + // Reconstruct minified source. + minify_cuda_source(tokens.begin(), tokens.end(), cxx_standard_year, + processed_source); + } else { + processed_source->clear(); + // Reconstruct source. + for (const Token& token : tokens) { + processed_source->append(token.source_string()); + } } - // HACK This is a WAR for some CUB sources including a header they shouldn't. - size_t pos = source.find("#include \"../util_device.cuh\""); - if (pos != std::string::npos) { - source[pos] = '/'; // Comment out the line - source[pos + 1] = '/'; + return {}; +} + +} // namespace parser + +namespace detail { + +static const char* const kJitifyBuiltinHeaderPrefix = "__jitify_builtin"; +static const char* const kJitifyCallbackHeaderPrefix = "__jitify_callback"; + +enum class HeaderLoadStatus { + kFailed = 0, + kAlreadyLoaded = 1, + kNewlyLoaded = 2, +}; + +// Note: StringMapT is to allow the caller to use StringOrRef instead of +// std::string in the map. +template +HeaderLoadStatus load_header(const parser::IncludeName& include, + HeaderCallback header_callback, + const std::vector& include_paths, + bool use_builtin_headers, std::string* full_path, + StringMapT* fullpath_to_source) { + auto already_loaded = [&](const std::string& fp) { + return fullpath_to_source->count(fp); + }; + auto newly_loaded = [&](std::string source) { + fullpath_to_source->emplace(*full_path, std::move(source)); + return HeaderLoadStatus::kNewlyLoaded; + }; + std::string source; + // Try loading via callback. + *full_path = include.nonlocal_full_path(kJitifyCallbackHeaderPrefix); + *full_path = path_simplify(*full_path); + if (already_loaded(*full_path)) return HeaderLoadStatus::kAlreadyLoaded; + if (header_callback and header_callback(include, &source)) { + return newly_loaded(std::move(source)); + } + // Try loading from current directory. + if (include.is_quote_include()) { + *full_path = include.local_full_path(); + *full_path = path_simplify(*full_path); + if (already_loaded(*full_path)) return HeaderLoadStatus::kAlreadyLoaded; + if (read_text_file(*full_path, &source)) { + return newly_loaded(std::move(source)); + } } - // HACK This is a WAR for Thrust (pre-CUDA-11) using "#define A #pragma B". - pos = source.find("#pragma nv_exec_check_disable"); - if (pos != std::string::npos) { - source[pos] = '/'; // Comment out the (rest of the) line - source[pos + 1] = '/'; + // Try loading from include directories. + for (const std::string& include_path : include_paths) { + *full_path = include.nonlocal_full_path(include_path); + *full_path = path_simplify(*full_path); + if (already_loaded(*full_path)) return HeaderLoadStatus::kAlreadyLoaded; + if (read_text_file(*full_path, &source)) { + return newly_loaded(std::move(source)); + } } - // HACK This is a WAR for Thrust using - pos = source.find("__has_cpp_attribute(gnu::warn_unused_result)"); - if (pos != std::string::npos) { - source[pos + 23] = '_'; // Replace "::" with "__". - source[pos + 24] = '_'; + // Try loading from builtin headers. + if (use_builtin_headers) { + *full_path = include.nonlocal_full_path(kJitifyBuiltinHeaderPrefix); + *full_path = path_simplify(*full_path); + if (already_loaded(*full_path)) return HeaderLoadStatus::kAlreadyLoaded; + auto iter = get_jitsafe_headers_map().find(include.name()); + if (iter != get_jitsafe_headers_map().end()) { + source = iter->second; + return newly_loaded(std::move(source)); + } } - return source; + return HeaderLoadStatus::kFailed; } -// Removes comments and most whitespace from C++ source code. -inline std::string minify_cpp_source(const std::string& source) { - return remove_cpp_whitespace( - remove_cpp_comments_and_line_continuations(source)); +inline bool remove_stop_compilation_error(std::string* compile_log) { + size_t pos = compile_log->find("__JITIFY_STOP_COMPILATION"); + if (pos == std::string::npos) return false; + pos = compile_log->find_last_of('\n', pos); + if (pos == std::string::npos) { + pos = 0; + } + compile_log->resize(pos); + return true; } -inline void extract_include_paths(OptionsVec* options, - StringVec* include_paths) { - const std::vector idxs = options->find({"-I"}); - for (int i = (int)idxs.size() - 1; i >= 0; --i) { - const int idx = idxs[i]; - include_paths->push_back((*options)[idx].value()); - options->erase(idx); +// Finds used header warnings, removes them from the compile log, and adds their +// fullpaths to *used_headers. +inline bool extract_used_header_warnings( + std::string* compile_log, std::unordered_set* used_headers) { + // Remove line containing JITIFY_USED_HEADER and the next two lines. + // If the line after the first one of these contains -diag-suppress, + // remove that line and the one after it. + static const char* const kJitifyUsedHeader = "JITIFY_USED_HEADER"; + int num_found = 0; + size_t pos; + while ((pos = compile_log->find(kJitifyUsedHeader)) != std::string::npos) { + ++num_found; + size_t start = pos + std::strlen(kJitifyUsedHeader) + 2; + size_t end = compile_log->find_first_of('"', start); + assert(end != std::string::npos); + std::string header_fullpath = compile_log->substr(start, end - start); + used_headers->emplace(std::move(header_fullpath)); + start = compile_log->find_last_of('\n', pos); + if (start == std::string::npos) { + start = (size_t)-1; + } + ++start; + // Each full warning message is 4 lines. + for (int i = 0; i < 4; ++i) { + size_t new_end = compile_log->find_first_of('\n', end + 1); + if (new_end == std::string::npos) break; // End of log + end = new_end; + } + ++end; + std::string tail = compile_log->substr(end); + compile_log->resize(start); + *compile_log += tail; + } + const bool found_any = num_found > 0; + if (found_any) { + if (compile_log->find("#warning directive") == std::string::npos) { + // There are no other warnings, remove message about -diag-suppress. + pos = compile_log->find("-diag-suppress"); + if (pos == std::string::npos) return true; + size_t start = compile_log->find_last_of('\n', pos); + if (start == std::string::npos) { + start = (size_t)-1; + } + ++start; + size_t end = + compile_log->find_first_of('\n', pos + std::strlen("-diag-suppress")); + assert(end != std::string::npos); + end = compile_log->find_first_of('\n', end + 1); + std::string tail; + if (end != std::string::npos) { + ++end; + tail = compile_log->substr(end); + } + compile_log->resize(start); + *compile_log += tail; + } } + return found_any; } } // namespace detail inline PreprocessedProgram PreprocessedProgram::preprocess( - std::string name, std::string source, StringMap header_sources, - OptionsVec compiler_options, OptionsVec linker_options, - FileCallback header_callback) { + std::string program_name, std::string program_source, + StringMap header_sources, OptionsVec compiler_options, + OptionsVec linker_options, HeaderCallback header_callback) { // Add pre-include built-in JIT-safe headers. bool use_system_headers_war = !compiler_options.pop( {"-no-system-headers-workaround", "--no-system-headers-workaround"}); @@ -6141,7 +7282,8 @@ inline PreprocessedProgram PreprocessedProgram::preprocess( detail::get_jitsafe_headers_map().at("jitify_preinclude.h")); compiler_options.push_back(Option("-include", "jitify_preinclude.h")); } - detail::add_std_flag_if_not_specified(&compiler_options, 11); + const int cxx_standard_year = + detail::add_std_flag_if_not_specified(&compiler_options, 11); detail::add_default_device_flag_if_not_specified(&compiler_options); bool minify = compiler_options.pop({"-m", "--minify"}); // TODO: This flag is experimental, because the implementation does not @@ -6160,37 +7302,148 @@ inline PreprocessedProgram PreprocessedProgram::preprocess( bool should_remove_unused_globals = compiler_options.pop( {"-remove-unused-globals", "--remove-unused-globals"}); - // Patch all given sources. - source = detail::patch_cuda_source(source, use_cuda_std, replace_pragma_once); - for (auto& name_source : header_sources) { - const std::string& header_name = name_source.first; - std::string& header_source = name_source.second; - bool is_jitify_preinclude = header_name == "jitify_preinclude.h"; - bool is_cuda_std_header = - detail::get_workaround_system_headers().count(header_name); - header_source = detail::patch_cuda_source( - header_source, - use_cuda_std && !is_jitify_preinclude && !is_cuda_std_header, - replace_pragma_once); - } + using parser::IncludeName; + using parser::ProcessFlags; + std::unordered_map + include_to_fullpath; + std::unordered_map fullpath_to_source; + std::queue include_queue; + ProcessFlags process_flags = ProcessFlags::kNone; + if (replace_pragma_once) process_flags |= ProcessFlags::kReplacePragmaOnce; + if (minify) process_flags |= ProcessFlags::kMinify; + const ProcessFlags replace_std_flag_if_enabled = + use_cuda_std ? ProcessFlags::kReplaceStd : ProcessFlags::kNone; + + auto process_cuda_source_fn = + [&](std::string* source_ptr, const std::string& fullpath, + ProcessFlags extra_flags = ProcessFlags::kNone) { + return parser::process_cuda_source( + source_ptr->c_str(), fullpath, process_flags | extra_flags, + cxx_standard_year, source_ptr, [&](IncludeName include) { + if (include_to_fullpath.count(include)) { + return; + } + include_queue.push(std::move(include)); + }); + }; - if (minify) { - source = detail::minify_cpp_source(source); - for (auto& name_source : header_sources) { - std::string* header_source = &name_source.second; - *header_source = detail::minify_cpp_source(*header_source); - } + const std::string current_dir = + detail::path_base(detail::get_current_executable_path()); + const std::string program_fullpath = + detail::path_join(current_dir, detail::sanitize_slashes(program_name)); + ErrorMsg err = process_cuda_source_fn(&program_source, program_fullpath, + replace_std_flag_if_enabled); + if (err) return Error(err); + static const char* const early_stop_code = R"( +#ifdef JITIFY_PREPROCESS_ONLY +#include <__JITIFY_STOP_COMPILATION> +#endif +)"; + program_source += early_stop_code; + + // Put the given header_sources into the include_to_fullpath and + // fullpath_to_source maps. + for (auto& header_source : header_sources) { + const std::string& name = header_source.first; + std::string* source_ptr = &header_source.second; + std::string fullpath = detail::path_is_absolute(name) + ? name + : detail::path_join(current_dir, name); + fullpath = detail::path_simplify(fullpath); + err = process_cuda_source_fn( + source_ptr, fullpath, + replace_std_flag_if_enabled | ProcessFlags::kAddUsedHeaderWarning); + if (err) return Error(err); + // Note: The names (keys) in header_sources will be matched: + // a) directly, for `#include ` directives, and + // b) as if they are filenames (relative to the current exe dir if not + // absolute), for `#include "name"` directives. This will NOT fall back + // to direct matching like <> includes. + // This allows path-based matching. + fullpath_to_source.emplace(fullpath, detail::StringOrRef(source_ptr)); + // This allows direct matching for <> includes. + include_to_fullpath.emplace(IncludeName(name), std::move(fullpath)); } - // Temporarily add the program source to header_sources for easier processing. - header_sources.emplace(name, source); - StringVec include_paths; detail::extract_include_paths(&compiler_options, &include_paths); - std::string include_paths_msg = - detail::string_join(include_paths, "\n", "Include paths:\n", "\n"); + + // Recursively load and process all includes, putting them into the + // include_to_fullpath and fullpath_to_source maps. + std::string header_log; + while (!include_queue.empty()) { + const IncludeName include_name = std::move(include_queue.front()); + include_queue.pop(); + std::string header_fullpath; + using detail::HeaderLoadStatus; + const HeaderLoadStatus status = detail::load_header( + include_name, header_callback, include_paths, use_builtin_headers, + &header_fullpath, &fullpath_to_source); + // Note: We ignore missing headers here because they may not be needed; if + // they are needed, the error will be caught when we invoke the compiler. + if (status == HeaderLoadStatus::kFailed) continue; + header_log += detail::string_concat("Found #include ", include_name, + " from ", include_name.location(), + " at:\n ", header_fullpath, "\n"); + if (status == HeaderLoadStatus::kNewlyLoaded) { + std::string& header_source = fullpath_to_source.at(header_fullpath); + if (detail::endswith(header_fullpath, "cub/util_device.cuh")) { + // WAR for CUB header that is full of host-only code. + header_source = ""; + } else { + ProcessFlags extra_flags = ProcessFlags::kAddUsedHeaderWarning; + const bool is_jitify_preinclude = + include_name.name() == "jitify_preinclude.h"; + const bool is_builtin_header = + header_fullpath.find(detail::kJitifyBuiltinHeaderPrefix) == 0; + const bool is_cuda_std_header = + // TODO: More robust way to detect this? + header_fullpath.find("cuda/std/") != std::string::npos || + header_fullpath.find("cuda\\std\\") != std::string::npos; + if (!is_jitify_preinclude && !is_builtin_header && + !is_cuda_std_header) { + extra_flags |= replace_std_flag_if_enabled; + } + err = process_cuda_source_fn(&header_source, header_fullpath, + extra_flags); + if (!err.empty()) return Error(err); + } + } + include_to_fullpath.emplace(include_name, header_fullpath); + } + + // Put all includes from the maps into header_sources. + for (const auto& include_fullpath : include_to_fullpath) { + const IncludeName include_name = include_fullpath.first; + const std::string& fullpath = include_fullpath.second; + assert(fullpath_to_source.count(fullpath)); + detail::StringOrRef* source_ptr = &fullpath_to_source.at(fullpath); + // Note: This will not replace existing headers that were passed in, giving + // them the priority. This also makes our use of StringOrRef safe, because + // the ones that are references are the ones that are already in + // header_sources. + // Note: We insert an empty string first and then assign to it. + auto iter_inserted = + header_sources.emplace(include_name.patched_name(), std::string()); + auto iter = iter_inserted.first; + std::string* out_source_ptr = &iter->second; + const bool inserted = iter_inserted.second; + if (inserted) { + // This is a cheap string move the first time this source_ptr is used. + // Subsequent times (i.e., if the same header source is mapped to multiple + // include names), it copies the string. + // TODO: In theory we could use StringOrRef in header_sources too to avoid + // needing copies of the same header sources, and I think it would be safe + // as long as we didn't erase any elements from it, but it's a bit risky, + // and would be exposed in the public interface. + source_ptr->copy_to_and_reference(out_source_ptr); + } + } if (!nvrtc()) return Error(nvrtc().error()); + if (nvrtc().get_version() >= 11060) { + detail::add_no_source_include_flag_if_not_specified(&compiler_options); + } // Parse architecture flags for special handling. If specified here, the arch // must be explicit (no auto-detection), and it will not be passed through to // the compile phase. @@ -6248,96 +7501,45 @@ inline PreprocessedProgram PreprocessedProgram::preprocess( // default arch) when none was specified by the user. arch_flags.insert({0, false}); } + // We temporarily enable warnings so that we can parse the ones we added. + const bool disable_warnings = + compiler_options.pop({"--disable-warnings", "-w"}); // Maps header include names to their full file paths. StringMap header_fullpaths; - std::string compile_log, header_log; - // Repeat preprocessing for each specified architecture. + std::string compile_log; + std::unordered_set used_header_fullpaths; + // Repeat preprocessing for each specified architecture, collecting in + // used_header_fullpaths. for (const ArchFlag& arch_flag : arch_flags) { if (arch_flag.cc) { // Temporarily add this arch flag. compiler_options.push_back(static_cast