constant_propagation_pass: Fold shifts

This commit is contained in:
MerryMage 2020-04-21 23:36:55 +01:00
parent 7242388577
commit df1a0eecaf

View file

@ -8,6 +8,7 @@
#include "common/assert.h" #include "common/assert.h"
#include "common/bit_util.h" #include "common/bit_util.h"
#include "common/safe_ops.h"
#include "common/common_types.h" #include "common/common_types.h"
#include "frontend/ir/basic_block.h" #include "frontend/ir/basic_block.h"
#include "frontend/ir/ir_emitter.h" #include "frontend/ir/ir_emitter.h"
@ -254,7 +255,7 @@ void FoldOR(IR::Inst& inst, bool is_32_bit) {
} }
} }
void FoldShifts(IR::Inst& inst) { bool FoldShifts(IR::Inst& inst) {
IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp); IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp);
// The 32-bit variants can contain 3 arguments, while the // The 32-bit variants can contain 3 arguments, while the
@ -264,14 +265,19 @@ void FoldShifts(IR::Inst& inst) {
} }
const auto shift_amount = inst.GetArg(1); const auto shift_amount = inst.GetArg(1);
if (!shift_amount.IsZero()) { if (shift_amount.IsZero()) {
return;
}
if (carry_inst) { if (carry_inst) {
carry_inst->ReplaceUsesWith(inst.GetArg(2)); carry_inst->ReplaceUsesWith(inst.GetArg(2));
} }
inst.ReplaceUsesWith(inst.GetArg(0)); inst.ReplaceUsesWith(inst.GetArg(0));
return false;
}
if (!inst.AreAllArgsImmediates() || carry_inst) {
return false;
}
return true;
} }
void FoldSignExtendXToWord(IR::Inst& inst) { void FoldSignExtendXToWord(IR::Inst& inst) {
@ -332,14 +338,84 @@ void ConstantPropagation(IR::Block& block) {
FoldMostSignificantBit(inst); FoldMostSignificantBit(inst);
break; break;
case Op::LogicalShiftLeft32: case Op::LogicalShiftLeft32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Safe::LogicalShiftLeft<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftLeft64: case Op::LogicalShiftLeft64:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Safe::LogicalShiftLeft<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftRight32: case Op::LogicalShiftRight32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Safe::LogicalShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftRight64: case Op::LogicalShiftRight64:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Safe::LogicalShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::ArithmeticShiftRight32: case Op::ArithmeticShiftRight32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Safe::ArithmeticShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::ArithmeticShiftRight64: case Op::ArithmeticShiftRight64:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Safe::ArithmeticShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::RotateRight32: case Op::RotateRight32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::RotateRight64: case Op::RotateRight64:
FoldShifts(inst); if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftLeftMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() << (inst.GetArg(1).GetU32() & 0x1f));
}
break;
case Op::LogicalShiftLeftMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() << (inst.GetArg(1).GetU64() & 0x3f));
}
break;
case Op::LogicalShiftRightMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() >> (inst.GetArg(1).GetU32() & 0x1f));
}
break;
case Op::LogicalShiftRightMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() >> (inst.GetArg(1).GetU64() & 0x3f));
}
break;
case Op::ArithmeticShiftRightMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, static_cast<s32>(inst.GetArg(0).GetU32()) >> (inst.GetArg(1).GetU32() & 0x1f));
}
break;
case Op::ArithmeticShiftRightMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, static_cast<s64>(inst.GetArg(0).GetU64()) >> (inst.GetArg(1).GetU64() & 0x3f));
}
break;
case Op::RotateRightMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU32()));
}
break;
case Op::RotateRightMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU64()));
}
break; break;
case Op::Mul32: case Op::Mul32:
case Op::Mul64: case Op::Mul64: