diff --git a/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp b/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp index 6c81fe56..9f38b5f7 100644 --- a/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp +++ b/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp @@ -9,6 +9,30 @@ namespace Dynarmic::A64 { +enum class ShiftExtraBehavior { + None, + Accumulate, + Round +}; + +static void SignedShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, + ShiftExtraBehavior behavior) { + const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); + const size_t datasize = Q ? 128 : 64; + + const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); + + const IR::U128 operand = v.V(datasize, Vn); + IR::U128 result = v.ir.VectorArithmeticShiftRight(esize, operand, shift_amount); + + if (behavior == ShiftExtraBehavior::Accumulate) { + const IR::U128 accumulator = v.V(datasize, Vd); + result = v.ir.VectorAdd(esize, result, accumulator); + } + + v.V(datasize, Vd, result); +} + bool TranslatorVisitor::SSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { if (immh == 0b0000) { return DecodeError(); @@ -16,15 +40,8 @@ bool TranslatorVisitor::SSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) if (immh.Bit<3>() && !Q) { return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = Q ? 128 : 64; - const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); - - const IR::U128 operand = V(datasize, Vn); - const IR::U128 result = ir.VectorArithmeticShiftRight(esize, operand, shift_amount); - - V(datasize, Vd, result); + SignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); return true; } @@ -35,17 +52,8 @@ bool TranslatorVisitor::SSRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) if (immh.Bit<3>() && !Q) { return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = Q ? 128 : 64; - const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); - - const IR::U128 operand = V(datasize, Vn); - const IR::U128 operand2 = V(datasize, Vd); - const IR::U128 shifted_operand = ir.VectorArithmeticShiftRight(esize, operand, shift_amount); - const IR::U128 result = ir.VectorAdd(esize, shifted_operand, operand2); - - V(datasize, Vd, result); + SignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate); return true; } @@ -68,6 +76,28 @@ bool TranslatorVisitor::SHL_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) return true; } +static void ShiftRightNarrowing(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, + ShiftExtraBehavior behavior) { + const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); + const size_t source_esize = 2 * esize; + const size_t part = Q ? 1 : 0; + + const u8 shift_amount = static_cast(source_esize - concatenate(immh, immb).ZeroExtend()); + + IR::U128 operand = v.ir.GetQ(Vn); + + if (behavior == ShiftExtraBehavior::Round) { + const u64 round_const = 1ULL << (shift_amount - 1); + const IR::U128 round_operand = v.ir.VectorBroadcast(source_esize, v.I(source_esize, round_const)); + operand = v.ir.VectorAdd(source_esize, operand, round_operand); + } + + const IR::U128 result = v.ir.VectorNarrow(source_esize, + v.ir.VectorLogicalShiftRight(source_esize, operand, shift_amount)); + + v.Vpart(64, Vd, part, result); +} + bool TranslatorVisitor::SHRN(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { if (immh == 0b0000) { return DecodeError(); @@ -77,17 +107,7 @@ bool TranslatorVisitor::SHRN(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t source_esize = 2 * esize; - const size_t part = Q ? 1 : 0; - - const u8 shift_amount = static_cast(source_esize - concatenate(immh, immb).ZeroExtend()); - - const IR::U128 operand = ir.GetQ(Vn); - const IR::U128 result = ir.VectorNarrow(source_esize, - ir.VectorLogicalShiftRight(source_esize, operand, shift_amount)); - - Vpart(64, Vd, part, result); + ShiftRightNarrowing(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); return true; } @@ -100,21 +120,7 @@ bool TranslatorVisitor::RSHRN(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t source_esize = 2 * esize; - const size_t part = Q ? 1 : 0; - - const u8 shift_amount = static_cast(source_esize - concatenate(immh, immb).ZeroExtend()); - const u64 round_const = 1ULL << (shift_amount - 1); - - const IR::U128 operand = ir.GetQ(Vn); - const IR::U128 round_operand = ir.VectorBroadcast(source_esize, I(source_esize, round_const)); - const IR::U128 rounded_value = ir.VectorAdd(source_esize, operand, round_operand); - - const IR::U128 result = ir.VectorNarrow(source_esize, - ir.VectorLogicalShiftRight(source_esize, rounded_value, shift_amount)); - - Vpart(64, Vd, part, result); + ShiftRightNarrowing(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Round); return true; } @@ -139,13 +145,8 @@ bool TranslatorVisitor::SSHLL(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) return true; } -enum class UnsignedRoundingShiftExtraBehavior { - None, - Accumulate -}; - static void UnsignedRoundingShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, - UnsignedRoundingShiftExtraBehavior behavior) { + ShiftExtraBehavior behavior) { const size_t datasize = Q ? 128 : 64; const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); const u8 shift_amount = static_cast((esize * 2) - concatenate(immh, immb).ZeroExtend()); @@ -158,7 +159,7 @@ static void UnsignedRoundingShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh const IR::U128 result = v.ir.VectorLogicalShiftRight(esize, operand, shift_amount); IR::U128 corrected_result = v.ir.VectorSub(esize, result, round_correction); - if (behavior == UnsignedRoundingShiftExtraBehavior::Accumulate) { + if (behavior == ShiftExtraBehavior::Accumulate) { const IR::U128 accumulator = v.V(datasize, Vd); corrected_result = v.ir.VectorAdd(esize, accumulator, corrected_result); } @@ -175,7 +176,7 @@ bool TranslatorVisitor::URSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd return ReservedValue(); } - UnsignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, UnsignedRoundingShiftExtraBehavior::None); + UnsignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); return true; } @@ -188,10 +189,28 @@ bool TranslatorVisitor::URSRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd return ReservedValue(); } - UnsignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, UnsignedRoundingShiftExtraBehavior::Accumulate); + UnsignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate); return true; } +static void UnsignedShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, + ShiftExtraBehavior behavior) { + const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); + const size_t datasize = Q ? 128 : 64; + + const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); + + const IR::U128 operand = v.V(datasize, Vn); + IR::U128 result = v.ir.VectorLogicalShiftRight(esize, operand, shift_amount); + + if (behavior == ShiftExtraBehavior::Accumulate) { + const IR::U128 accumulator = v.V(datasize, Vd); + result = v.ir.VectorAdd(esize, accumulator, result); + } + + v.V(datasize, Vd, result); +} + bool TranslatorVisitor::USHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { if (immh == 0b0000) { return DecodeError(); @@ -199,15 +218,8 @@ bool TranslatorVisitor::USHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) if (immh.Bit<3>() && !Q) { return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = Q ? 128 : 64; - const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); - - const IR::U128 operand = V(datasize, Vn); - const IR::U128 result = ir.VectorLogicalShiftRight(esize, operand, shift_amount); - - V(datasize, Vd, result); + UnsignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); return true; } @@ -218,17 +230,8 @@ bool TranslatorVisitor::USRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) if (immh.Bit<3>() && !Q) { return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = Q ? 128 : 64; - const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); - - const IR::U128 operand = V(datasize, Vn); - const IR::U128 operand2 = V(datasize, Vd); - const IR::U128 shifted_operand = ir.VectorLogicalShiftRight(esize, operand, shift_amount); - const IR::U128 result = ir.VectorAdd(esize, shifted_operand, operand2); - - V(datasize, Vd, result); + UnsignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate); return true; }