From 4573511fe3269cdcec0fa905d4290a4ea2eae232 Mon Sep 17 00:00:00 2001 From: MerryMage Date: Mon, 20 Apr 2020 20:05:32 +0100 Subject: [PATCH] constant_propagation_pass: Prepare for IR matchers --- src/CMakeLists.txt | 1 + src/frontend/ir/ir_emitter.cpp | 8 -- src/frontend/ir/ir_emitter.h | 9 +- src/frontend/ir/value.h | 2 + src/ir_opt/constant_propagation_pass.cpp | 128 ++++++++++++----------- src/ir_opt/identity_removal_pass.cpp | 2 +- src/ir_opt/ir_matcher.h | 126 ++++++++++++++++++++++ 7 files changed, 203 insertions(+), 73 deletions(-) create mode 100644 src/ir_opt/ir_matcher.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 68dfb6cd..c39f2480 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -96,6 +96,7 @@ add_library(dynarmic ir_opt/constant_propagation_pass.cpp ir_opt/dead_code_elimination_pass.cpp ir_opt/identity_removal_pass.cpp + ir_opt/ir_matcher.h ir_opt/passes.h ir_opt/verification_pass.cpp ) diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 70121fa0..e2f0c25d 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -2607,12 +2607,4 @@ void IREmitter::SetTerm(const Terminal& terminal) { block.SetTerminal(terminal); } -void IREmitter::SetInsertionPoint(IR::Inst* new_insertion_point) { - insertion_point = IR::Block::iterator{*new_insertion_point}; -} - -void IREmitter::SetInsertionPoint(IR::Block::iterator new_insertion_point) { - insertion_point = new_insertion_point; -} - } // namespace Dynarmic::IR diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index 02768920..ce5207e7 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -369,8 +369,13 @@ public: void SetTerm(const Terminal& terminal); - void SetInsertionPoint(IR::Inst* new_insertion_point); - void SetInsertionPoint(IR::Block::iterator new_insertion_point); + void SetInsertionPoint(IR::Inst* new_insertion_point) { + insertion_point = IR::Block::iterator{*new_insertion_point}; + } + + void SetInsertionPoint(IR::Block::iterator new_insertion_point) { + insertion_point = new_insertion_point; + } protected: IR::Block::iterator insertion_point; diff --git a/src/frontend/ir/value.h b/src/frontend/ir/value.h index 906428e4..78e6352d 100644 --- a/src/frontend/ir/value.h +++ b/src/frontend/ir/value.h @@ -156,6 +156,8 @@ public: explicit TypedValue(const Value& value) : Value(value) { ASSERT((value.GetType() & type_) != Type::Void); } + + explicit TypedValue(Inst* inst) : TypedValue(Value(inst)) {} }; using U1 = TypedValue; diff --git a/src/ir_opt/constant_propagation_pass.cpp b/src/ir_opt/constant_propagation_pass.cpp index 8a68a06a..9203171a 100644 --- a/src/ir_opt/constant_propagation_pass.cpp +++ b/src/ir_opt/constant_propagation_pass.cpp @@ -4,13 +4,22 @@ * General Public License version 2 or any later version. */ +#include + +#include "common/assert.h" #include "common/bit_util.h" #include "common/common_types.h" #include "frontend/ir/basic_block.h" +#include "frontend/ir/ir_emitter.h" #include "frontend/ir/opcodes.h" +#include "ir_opt/ir_matcher.h" #include "ir_opt/passes.h" namespace Dynarmic::Optimization { + +using namespace IRMatcher; +using Op = Dynarmic::IR::Opcode; + namespace { // Tiny helper to avoid the need to store based off the opcode @@ -89,17 +98,17 @@ void FoldAND(IR::Inst& inst, bool is_32_bit) { // // 1. imm -> swap(imm) // -void FoldByteReverse(IR::Inst& inst, IR::Opcode op) { +void FoldByteReverse(IR::Inst& inst, Op op) { const auto operand = inst.GetArg(0); if (!operand.IsImmediate()) { return; } - if (op == IR::Opcode::ByteReverseWord) { + if (op == Op::ByteReverseWord) { const u32 result = Common::Swap32(static_cast(operand.GetImmediateAsU64())); inst.ReplaceUsesWith(IR::Value{result}); - } else if (op == IR::Opcode::ByteReverseHalf) { + } else if (op == Op::ByteReverseHalf) { const u16 result = Common::Swap16(static_cast(operand.GetImmediateAsU64())); inst.ReplaceUsesWith(IR::Value{result}); } else { @@ -188,7 +197,7 @@ void FoldMostSignificantBit(IR::Inst& inst) { } void FoldMostSignificantWord(IR::Inst& inst) { - IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(IR::Opcode::GetCarryFromOp); + IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp); if (!inst.AreAllArgsImmediates()) { return; @@ -239,21 +248,16 @@ void FoldNOT(IR::Inst& inst, bool is_32_bit) { // 3. 0 | y -> y // void FoldOR(IR::Inst& inst, bool is_32_bit) { - const auto lhs = inst.GetArg(0); - const auto rhs = inst.GetArg(1); - - if (lhs.IsImmediate() && rhs.IsImmediate()) { - const u64 result = lhs.GetImmediateAsU64() | rhs.GetImmediateAsU64(); - ReplaceUsesWith(inst, is_32_bit, result); - } else if (lhs.IsZero()) { - inst.ReplaceUsesWith(rhs); - } else if (rhs.IsZero()) { - inst.ReplaceUsesWith(lhs); + if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a | b; })) { + const auto rhs = inst.GetArg(1); + if (rhs.IsZero()) { + inst.ReplaceUsesWith(inst.GetArg(0)); + } } } void FoldShifts(IR::Inst& inst) { - IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(IR::Opcode::GetCarryFromOp); + IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp); // The 32-bit variants can contain 3 arguments, while the // 64-bit variants only contain 2. @@ -314,80 +318,80 @@ void ConstantPropagation(IR::Block& block) { const auto opcode = inst.GetOpcode(); switch (opcode) { - case IR::Opcode::LeastSignificantWord: + case Op::LeastSignificantWord: FoldLeastSignificantWord(inst); break; - case IR::Opcode::MostSignificantWord: + case Op::MostSignificantWord: FoldMostSignificantWord(inst); break; - case IR::Opcode::LeastSignificantHalf: + case Op::LeastSignificantHalf: FoldLeastSignificantHalf(inst); break; - case IR::Opcode::LeastSignificantByte: + case Op::LeastSignificantByte: FoldLeastSignificantByte(inst); break; - case IR::Opcode::MostSignificantBit: + case Op::MostSignificantBit: FoldMostSignificantBit(inst); break; - case IR::Opcode::LogicalShiftLeft32: - case IR::Opcode::LogicalShiftLeft64: - case IR::Opcode::LogicalShiftRight32: - case IR::Opcode::LogicalShiftRight64: - case IR::Opcode::ArithmeticShiftRight32: - case IR::Opcode::ArithmeticShiftRight64: - case IR::Opcode::RotateRight32: - case IR::Opcode::RotateRight64: + case Op::LogicalShiftLeft32: + case Op::LogicalShiftLeft64: + case Op::LogicalShiftRight32: + case Op::LogicalShiftRight64: + case Op::ArithmeticShiftRight32: + case Op::ArithmeticShiftRight64: + case Op::RotateRight32: + case Op::RotateRight64: FoldShifts(inst); break; - case IR::Opcode::Mul32: - case IR::Opcode::Mul64: - FoldMultiply(inst, opcode == IR::Opcode::Mul32); + case Op::Mul32: + case Op::Mul64: + FoldMultiply(inst, opcode == Op::Mul32); break; - case IR::Opcode::SignedDiv32: - case IR::Opcode::SignedDiv64: - FoldDivide(inst, opcode == IR::Opcode::SignedDiv32, true); + case Op::SignedDiv32: + case Op::SignedDiv64: + FoldDivide(inst, opcode == Op::SignedDiv32, true); break; - case IR::Opcode::UnsignedDiv32: - case IR::Opcode::UnsignedDiv64: - FoldDivide(inst, opcode == IR::Opcode::UnsignedDiv32, false); + case Op::UnsignedDiv32: + case Op::UnsignedDiv64: + FoldDivide(inst, opcode == Op::UnsignedDiv32, false); break; - case IR::Opcode::And32: - case IR::Opcode::And64: - FoldAND(inst, opcode == IR::Opcode::And32); + case Op::And32: + case Op::And64: + FoldAND(inst, opcode == Op::And32); break; - case IR::Opcode::Eor32: - case IR::Opcode::Eor64: - FoldEOR(inst, opcode == IR::Opcode::Eor32); + case Op::Eor32: + case Op::Eor64: + FoldEOR(inst, opcode == Op::Eor32); break; - case IR::Opcode::Or32: - case IR::Opcode::Or64: - FoldOR(inst, opcode == IR::Opcode::Or32); + case Op::Or32: + case Op::Or64: + FoldOR(inst, opcode == Op::Or32); break; - case IR::Opcode::Not32: - case IR::Opcode::Not64: - FoldNOT(inst, opcode == IR::Opcode::Not32); + case Op::Not32: + case Op::Not64: + FoldNOT(inst, opcode == Op::Not32); break; - case IR::Opcode::SignExtendByteToWord: - case IR::Opcode::SignExtendHalfToWord: + case Op::SignExtendByteToWord: + case Op::SignExtendHalfToWord: FoldSignExtendXToWord(inst); break; - case IR::Opcode::SignExtendByteToLong: - case IR::Opcode::SignExtendHalfToLong: - case IR::Opcode::SignExtendWordToLong: + case Op::SignExtendByteToLong: + case Op::SignExtendHalfToLong: + case Op::SignExtendWordToLong: FoldSignExtendXToLong(inst); break; - case IR::Opcode::ZeroExtendByteToWord: - case IR::Opcode::ZeroExtendHalfToWord: + case Op::ZeroExtendByteToWord: + case Op::ZeroExtendHalfToWord: FoldZeroExtendXToWord(inst); break; - case IR::Opcode::ZeroExtendByteToLong: - case IR::Opcode::ZeroExtendHalfToLong: - case IR::Opcode::ZeroExtendWordToLong: + case Op::ZeroExtendByteToLong: + case Op::ZeroExtendHalfToLong: + case Op::ZeroExtendWordToLong: FoldZeroExtendXToLong(inst); break; - case IR::Opcode::ByteReverseWord: - case IR::Opcode::ByteReverseHalf: - case IR::Opcode::ByteReverseDual: + case Op::ByteReverseWord: + case Op::ByteReverseHalf: + case Op::ByteReverseDual: FoldByteReverse(inst, opcode); break; default: diff --git a/src/ir_opt/identity_removal_pass.cpp b/src/ir_opt/identity_removal_pass.cpp index 1d744b48..a52889f0 100644 --- a/src/ir_opt/identity_removal_pass.cpp +++ b/src/ir_opt/identity_removal_pass.cpp @@ -30,7 +30,7 @@ void IdentityRemovalPass(IR::Block& block) { } } - if (inst.GetOpcode() == IR::Opcode::Identity) { + if (inst.GetOpcode() == IR::Opcode::Identity || inst.GetOpcode() == IR::Opcode::Void) { iter = block.Instructions().erase(inst); to_invalidate.push_back(&inst); } else { diff --git a/src/ir_opt/ir_matcher.h b/src/ir_opt/ir_matcher.h new file mode 100644 index 00000000..6069a8e9 --- /dev/null +++ b/src/ir_opt/ir_matcher.h @@ -0,0 +1,126 @@ +/* This file is part of the dynarmic project. + * Copyright (c) 2020 MerryMage + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/ir/microinstruction.h" +#include "frontend/ir/opcodes.h" +#include "frontend/ir/value.h" + +namespace Dynarmic::Optimization::IRMatcher { + +struct CaptureValue { + using ReturnType = std::tuple; + + static std::optional Match(IR::Value value) { + return std::tuple(value); + } +}; + +struct CaptureInst { + using ReturnType = std::tuple; + + static std::optional Match(IR::Value value) { + if (value.IsImmediate()) + return std::nullopt; + return std::tuple(value.GetInstRecursive()); + } +}; + +struct CaptureUImm { + using ReturnType = std::tuple; + + static std::optional Match(IR::Value value) { + return std::tuple(value.GetImmediateAsU64()); + } +}; + +struct CaptureSImm { + using ReturnType = std::tuple; + + static std::optional Match(IR::Value value) { + return std::tuple(value.GetImmediateAsS64()); + } +}; + +template +struct UImm { + using ReturnType = std::tuple<>; + + static std::optional> Match(IR::Value value) { + if (value.GetImmediateAsU64() == Value) + return std::tuple(); + return std::nullopt; + } +}; + +template +struct SImm { + using ReturnType = std::tuple<>; + + static std::optional> Match(IR::Value value) { + if (value.GetImmediateAsS64() == Value) + return std::tuple(); + return std::nullopt; + } +}; + +template +struct Inst { +public: + using ReturnType = mp::concat, typename Args::ReturnType...>; + + static std::optional Match(const IR::Inst& inst) { + if (inst.GetOpcode() != Opcode) + return std::nullopt; + if (inst.HasAssociatedPseudoOperation()) + return std::nullopt; + return MatchArgs<0>(inst); + } + + static std::optional Match(IR::Value value) { + if (value.IsImmediate()) + return std::nullopt; + return Match(*value.GetInstRecursive()); + } + +private: + template + static auto MatchArgs(const IR::Inst& inst) -> std::optional>, std::tuple<>>>> { + if constexpr (I >= sizeof...(Args)) { + return std::tuple(); + } else { + using Arg = mp::get>; + + if (const auto arg = Arg::Match(inst.GetArg(I))) { + if (const auto rest = MatchArgs(inst)) { + return std::tuple_cat(*arg, *rest); + } + } + + return std::nullopt; + } + } +}; + +inline bool IsSameInst(std::tuple t) { + return std::get<0>(t) == std::get<1>(t); +} + +inline bool IsSameInst(std::tuple t) { + return std::get<0>(t) == std::get<1>(t) && std::get<0>(t) == std::get<2>(t); +} + +} // namespace Dynarmic::Optimization::IRMatcher