From 51d541f1a1fae66ba5dbd7816e283c8a723690db Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Sun, 11 Apr 2021 02:02:40 -0300 Subject: [PATCH] Remove forward references and add phi node patching The previous API for forward declarations broke when more than one definition was done. Forward references on instructions that are not labels were only needed for phi nodes, so it has been replaced with a deferred phi node instruction and a method to patch these after everything has been defined. --- include/sirit/sirit.h | 20 +++++++++++--------- src/instructions/flow.cpp | 10 ++++++++++ src/sirit.cpp | 27 ++++++++++++++------------- src/stream.h | 14 +++++++++++++- 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index f3925bb..d442191 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,9 @@ public: */ std::vector Assemble() const; + /// Patches deferred phi nodes calling the passed function on each phi argument + void PatchDeferredPhi(const std::function& func); + /// Adds a SPIR-V extension. void AddExtension(std::string extension_name); @@ -87,15 +91,6 @@ public: AddExecutionMode(entry_point, mode, std::span({literals...})); } - /// Generate a new id for forward declarations - [[nodiscard]] Id ForwardDeclarationId(); - - /// Returns the current generator id, useful for self-referencing phi nodes - [[nodiscard]] Id CurrentId() const noexcept; - - /// Assign a new id and return the old one, useful for defining forward declarations - Id ExchangeCurrentId(Id new_current_id); - /** * Adds an existing label to the code * @param label Label to insert into code. @@ -253,6 +248,12 @@ public: */ Id OpPhi(Id result_type, std::span operands); + /** + * The SSA phi function. This instruction will be revisited when patching phi nodes. + * @param operands An immutable span of block pairs + */ + Id DeferredOpPhi(Id result_type, std::span blocks); + /// Declare a structured loop. Id OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control, std::span literals = {}); @@ -1236,6 +1237,7 @@ private: std::unique_ptr declarations; std::unique_ptr global_variables; std::unique_ptr code; + std::vector deferred_phi_nodes; }; } // namespace Sirit diff --git a/src/instructions/flow.cpp b/src/instructions/flow.cpp index e13caa9..e462dcb 100644 --- a/src/instructions/flow.cpp +++ b/src/instructions/flow.cpp @@ -18,6 +18,16 @@ Id Module::OpPhi(Id result_type, std::span operands) { return *code << OpId{spv::Op::OpPhi, result_type} << operands << EndOp{}; } +Id Module::DeferredOpPhi(Id result_type, std::span blocks) { + deferred_phi_nodes.push_back(code->LocalAddress()); + code->Reserve(3 + blocks.size() * 2); + *code << OpId{spv::Op::OpPhi, result_type}; + for (const Id block : blocks) { + *code << u32{0} << block; + } + return *code << EndOp{}; +} + Id Module::OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control, std::span literals) { code->Reserve(4 + literals.size()); diff --git a/src/sirit.cpp b/src/sirit.cpp index 22a4570..7075f23 100644 --- a/src/sirit.cpp +++ b/src/sirit.cpp @@ -68,6 +68,20 @@ std::vector Module::Assemble() const { return words; } +void Module::PatchDeferredPhi(const std::function& func) { + for (const u32 phi_index : deferred_phi_nodes) { + const u32 first_word = code->Value(phi_index); + [[maybe_unused]] const spv::Op op = static_cast(first_word & 0xffff); + assert(op == spv::Op::OpPhi); + const u32 num_words = first_word >> 16; + const u32 num_args = (num_words - 3) / 2; + u32 cursor = phi_index + 3; + for (u32 arg = 0; arg < num_args; ++arg, cursor += 2) { + code->SetValue(cursor, func(arg).value); + } + } +} + void Module::AddExtension(std::string extension_name) { extensions.insert(std::move(extension_name)); } @@ -95,19 +109,6 @@ void Module::AddExecutionMode(Id entry_point, spv::ExecutionMode mode, *execution_modes << spv::Op::OpExecutionMode << entry_point << mode << literals << EndOp{}; } -Id Module::ForwardDeclarationId() { - return Id{++bound}; -} - -Id Module::CurrentId() const noexcept { - return Id{bound + 1}; -} - -Id Module::ExchangeCurrentId(Id new_current_id) { - const std::uint32_t old_id = std::exchange(bound, new_current_id.value - 1); - return Id{old_id + 1}; -} - Id Module::AddLabel(Id label) { assert(label.value != 0); code->Reserve(2); diff --git a/src/stream.h b/src/stream.h index ef7f2d2..ee3e7b8 100644 --- a/src/stream.h +++ b/src/stream.h @@ -40,7 +40,7 @@ struct OpId { struct EndOp {}; -constexpr size_t WordsInString(std::string_view string) { +inline size_t WordsInString(std::string_view string) { return string.size() / sizeof(u32) + 1; } @@ -76,6 +76,18 @@ public: return std::span(words.data(), insert_index); } + u32 LocalAddress() const noexcept { + return static_cast(words.size()); + } + + u32 Value(u32 index) const noexcept { + return words[index]; + } + + void SetValue(u32 index, u32 value) noexcept { + words[index] = value; + } + Stream& operator<<(spv::Op op) { op_index = insert_index; words[insert_index++] = static_cast(op);