diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index f3925bb..d442191 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,9 @@ public: */ std::vector Assemble() const; + /// Patches deferred phi nodes calling the passed function on each phi argument + void PatchDeferredPhi(const std::function& func); + /// Adds a SPIR-V extension. void AddExtension(std::string extension_name); @@ -87,15 +91,6 @@ public: AddExecutionMode(entry_point, mode, std::span({literals...})); } - /// Generate a new id for forward declarations - [[nodiscard]] Id ForwardDeclarationId(); - - /// Returns the current generator id, useful for self-referencing phi nodes - [[nodiscard]] Id CurrentId() const noexcept; - - /// Assign a new id and return the old one, useful for defining forward declarations - Id ExchangeCurrentId(Id new_current_id); - /** * Adds an existing label to the code * @param label Label to insert into code. @@ -253,6 +248,12 @@ public: */ Id OpPhi(Id result_type, std::span operands); + /** + * The SSA phi function. This instruction will be revisited when patching phi nodes. + * @param operands An immutable span of block pairs + */ + Id DeferredOpPhi(Id result_type, std::span blocks); + /// Declare a structured loop. Id OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control, std::span literals = {}); @@ -1236,6 +1237,7 @@ private: std::unique_ptr declarations; std::unique_ptr global_variables; std::unique_ptr code; + std::vector deferred_phi_nodes; }; } // namespace Sirit diff --git a/src/instructions/flow.cpp b/src/instructions/flow.cpp index e13caa9..e462dcb 100644 --- a/src/instructions/flow.cpp +++ b/src/instructions/flow.cpp @@ -18,6 +18,16 @@ Id Module::OpPhi(Id result_type, std::span operands) { return *code << OpId{spv::Op::OpPhi, result_type} << operands << EndOp{}; } +Id Module::DeferredOpPhi(Id result_type, std::span blocks) { + deferred_phi_nodes.push_back(code->LocalAddress()); + code->Reserve(3 + blocks.size() * 2); + *code << OpId{spv::Op::OpPhi, result_type}; + for (const Id block : blocks) { + *code << u32{0} << block; + } + return *code << EndOp{}; +} + Id Module::OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control, std::span literals) { code->Reserve(4 + literals.size()); diff --git a/src/sirit.cpp b/src/sirit.cpp index 22a4570..7075f23 100644 --- a/src/sirit.cpp +++ b/src/sirit.cpp @@ -68,6 +68,20 @@ std::vector Module::Assemble() const { return words; } +void Module::PatchDeferredPhi(const std::function& func) { + for (const u32 phi_index : deferred_phi_nodes) { + const u32 first_word = code->Value(phi_index); + [[maybe_unused]] const spv::Op op = static_cast(first_word & 0xffff); + assert(op == spv::Op::OpPhi); + const u32 num_words = first_word >> 16; + const u32 num_args = (num_words - 3) / 2; + u32 cursor = phi_index + 3; + for (u32 arg = 0; arg < num_args; ++arg, cursor += 2) { + code->SetValue(cursor, func(arg).value); + } + } +} + void Module::AddExtension(std::string extension_name) { extensions.insert(std::move(extension_name)); } @@ -95,19 +109,6 @@ void Module::AddExecutionMode(Id entry_point, spv::ExecutionMode mode, *execution_modes << spv::Op::OpExecutionMode << entry_point << mode << literals << EndOp{}; } -Id Module::ForwardDeclarationId() { - return Id{++bound}; -} - -Id Module::CurrentId() const noexcept { - return Id{bound + 1}; -} - -Id Module::ExchangeCurrentId(Id new_current_id) { - const std::uint32_t old_id = std::exchange(bound, new_current_id.value - 1); - return Id{old_id + 1}; -} - Id Module::AddLabel(Id label) { assert(label.value != 0); code->Reserve(2); diff --git a/src/stream.h b/src/stream.h index ef7f2d2..ee3e7b8 100644 --- a/src/stream.h +++ b/src/stream.h @@ -40,7 +40,7 @@ struct OpId { struct EndOp {}; -constexpr size_t WordsInString(std::string_view string) { +inline size_t WordsInString(std::string_view string) { return string.size() / sizeof(u32) + 1; } @@ -76,6 +76,18 @@ public: return std::span(words.data(), insert_index); } + u32 LocalAddress() const noexcept { + return static_cast(words.size()); + } + + u32 Value(u32 index) const noexcept { + return words[index]; + } + + void SetValue(u32 index, u32 value) noexcept { + words[index] = value; + } + Stream& operator<<(spv::Op op) { op_index = insert_index; words[insert_index++] = static_cast(op);