diff --git a/src/ir_opt/constant_propagation_pass.cpp b/src/ir_opt/constant_propagation_pass.cpp index 17c59fa4..cf7592cc 100644 --- a/src/ir_opt/constant_propagation_pass.cpp +++ b/src/ir_opt/constant_propagation_pass.cpp @@ -8,6 +8,7 @@ #include "common/assert.h" #include "common/bit_util.h" +#include "common/safe_ops.h" #include "common/common_types.h" #include "frontend/ir/basic_block.h" #include "frontend/ir/ir_emitter.h" @@ -254,7 +255,7 @@ void FoldOR(IR::Inst& inst, bool is_32_bit) { } } -void FoldShifts(IR::Inst& inst) { +bool FoldShifts(IR::Inst& inst) { IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp); // The 32-bit variants can contain 3 arguments, while the @@ -264,14 +265,19 @@ void FoldShifts(IR::Inst& inst) { } const auto shift_amount = inst.GetArg(1); - if (!shift_amount.IsZero()) { - return; + if (shift_amount.IsZero()) { + if (carry_inst) { + carry_inst->ReplaceUsesWith(inst.GetArg(2)); + } + inst.ReplaceUsesWith(inst.GetArg(0)); + return false; } - if (carry_inst) { - carry_inst->ReplaceUsesWith(inst.GetArg(2)); + if (!inst.AreAllArgsImmediates() || carry_inst) { + return false; } - inst.ReplaceUsesWith(inst.GetArg(0)); + + return true; } void FoldSignExtendXToWord(IR::Inst& inst) { @@ -332,14 +338,84 @@ void ConstantPropagation(IR::Block& block) { FoldMostSignificantBit(inst); break; case Op::LogicalShiftLeft32: + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, true, Safe::LogicalShiftLeft(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8())); + } + break; case Op::LogicalShiftLeft64: + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, false, Safe::LogicalShiftLeft(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8())); + } + break; case Op::LogicalShiftRight32: + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, true, Safe::LogicalShiftRight(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8())); + } + break; case Op::LogicalShiftRight64: + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, false, Safe::LogicalShiftRight(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8())); + } + break; case Op::ArithmeticShiftRight32: + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, true, Safe::ArithmeticShiftRight(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8())); + } + break; case Op::ArithmeticShiftRight64: + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, false, Safe::ArithmeticShiftRight(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8())); + } + break; case Op::RotateRight32: + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, true, Common::RotateRight(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8())); + } + break; case Op::RotateRight64: - FoldShifts(inst); + if (FoldShifts(inst)) { + ReplaceUsesWith(inst, false, Common::RotateRight(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8())); + } + break; + case Op::LogicalShiftLeftMasked32: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() << (inst.GetArg(1).GetU32() & 0x1f)); + } + break; + case Op::LogicalShiftLeftMasked64: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() << (inst.GetArg(1).GetU64() & 0x3f)); + } + break; + case Op::LogicalShiftRightMasked32: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() >> (inst.GetArg(1).GetU32() & 0x1f)); + } + break; + case Op::LogicalShiftRightMasked64: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() >> (inst.GetArg(1).GetU64() & 0x3f)); + } + break; + case Op::ArithmeticShiftRightMasked32: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, true, static_cast(inst.GetArg(0).GetU32()) >> (inst.GetArg(1).GetU32() & 0x1f)); + } + break; + case Op::ArithmeticShiftRightMasked64: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, false, static_cast(inst.GetArg(0).GetU64()) >> (inst.GetArg(1).GetU64() & 0x3f)); + } + break; + case Op::RotateRightMasked32: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, true, Common::RotateRight(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU32())); + } + break; + case Op::RotateRightMasked64: + if (inst.AreAllArgsImmediates()) { + ReplaceUsesWith(inst, false, Common::RotateRight(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU64())); + } break; case Op::Mul32: case Op::Mul64: