diff --git a/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp b/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp index a45e614c..8350e485 100644 --- a/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp +++ b/src/frontend/A64/translate/impl/simd_shift_by_immediate.cpp @@ -15,15 +15,25 @@ enum class ShiftExtraBehavior { Round }; -static void SignedShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, - ShiftExtraBehavior behavior) { +enum class Signedness { + Signed, + Unsigned +}; + +static void ShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, + ShiftExtraBehavior behavior, Signedness signedness) { const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); const size_t datasize = Q ? 128 : 64; const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); const IR::U128 operand = v.V(datasize, Vn); - IR::U128 result = v.ir.VectorArithmeticShiftRight(esize, operand, shift_amount); + IR::U128 result = [&] { + if (signedness == Signedness::Signed) { + return v.ir.VectorArithmeticShiftRight(esize, operand, shift_amount); + } + return v.ir.VectorLogicalShiftRight(esize, operand, shift_amount); + }(); if (behavior == ShiftExtraBehavior::Accumulate) { const IR::U128 accumulator = v.V(datasize, Vd); @@ -33,20 +43,8 @@ static void SignedShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> i v.V(datasize, Vd, result); } -bool TranslatorVisitor::SSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { - if (immh == 0b0000) { - return DecodeError(); - } - if (immh.Bit<3>() && !Q) { - return ReservedValue(); - } - - SignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); - return true; -} - -static void SignedRoundingShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, - ShiftExtraBehavior behavior) { +static void RoundingShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, + ShiftExtraBehavior behavior, Signedness signedness) { const size_t datasize = Q ? 128 : 64; const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); const u8 shift_amount = static_cast((esize * 2) - concatenate(immh, immb).ZeroExtend()); @@ -56,7 +54,12 @@ static void SignedRoundingShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, const IR::U128 round_const = v.ir.VectorBroadcast(esize, v.I(esize, round_value)); const IR::U128 round_correction = v.ir.VectorEqual(esize, v.ir.VectorAnd(operand, round_const), round_const); - const IR::U128 result = v.ir.VectorArithmeticShiftRight(esize, operand, shift_amount); + const IR::U128 result = [&] { + if (signedness == Signedness::Signed) { + return v.ir.VectorArithmeticShiftRight(esize, operand, shift_amount); + } + return v.ir.VectorLogicalShiftRight(esize, operand, shift_amount); + }(); IR::U128 corrected_result = v.ir.VectorSub(esize, result, round_correction); if (behavior == ShiftExtraBehavior::Accumulate) { @@ -67,63 +70,6 @@ static void SignedRoundingShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, v.V(datasize, Vd, corrected_result); } -bool TranslatorVisitor::SRSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { - if (immh == 0b0000) { - return DecodeError(); - } - - if (!Q && immh.Bit<3>()) { - return ReservedValue(); - } - - SignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); - return true; -} - -bool TranslatorVisitor::SRSRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { - if (immh == 0b0000) { - return DecodeError(); - } - - if (!Q && immh.Bit<3>()) { - return ReservedValue(); - } - - SignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate); - return true; -} - -bool TranslatorVisitor::SSRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { - if (immh == 0b0000) { - return DecodeError(); - } - if (immh.Bit<3>() && !Q) { - return ReservedValue(); - } - - SignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate); - return true; -} - -bool TranslatorVisitor::SHL_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { - if (immh == 0b0000) { - return DecodeError(); - } - if (immh.Bit<3>() && !Q) { - return ReservedValue(); - } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = Q ? 128 : 64; - - const u8 shift_amount = concatenate(immh, immb).ZeroExtend() - static_cast(esize); - - const IR::U128 operand = V(datasize, Vn); - const IR::U128 result = ir.VectorLogicalShiftLeft(esize, operand, shift_amount); - - V(datasize, Vd, result); - return true; -} - static void ShiftRightNarrowing(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, ShiftExtraBehavior behavior) { const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); @@ -146,6 +92,95 @@ static void ShiftRightNarrowing(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3 v.Vpart(64, Vd, part, result); } +static void ShiftLeftLong(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, + Signedness signedness) { + const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); + const size_t datasize = 64; + const size_t part = Q ? 1 : 0; + + const u8 shift_amount = concatenate(immh, immb).ZeroExtend() - static_cast(esize); + + const IR::U128 operand = v.Vpart(datasize, Vn, part); + const IR::U128 expanded_operand = [&] { + if (signedness == Signedness::Signed) { + return v.ir.VectorSignExtend(esize, operand); + } + return v.ir.VectorZeroExtend(esize, operand); + }(); + const IR::U128 result = v.ir.VectorLogicalShiftLeft(2 * esize, expanded_operand, shift_amount); + + v.V(2 * datasize, Vd, result); +} + +bool TranslatorVisitor::SSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { + if (immh == 0b0000) { + return DecodeError(); + } + if (immh.Bit<3>() && !Q) { + return ReservedValue(); + } + + ShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None, Signedness::Signed); + return true; +} + +bool TranslatorVisitor::SRSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { + if (immh == 0b0000) { + return DecodeError(); + } + + if (!Q && immh.Bit<3>()) { + return ReservedValue(); + } + + RoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None, Signedness::Signed); + return true; +} + +bool TranslatorVisitor::SRSRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { + if (immh == 0b0000) { + return DecodeError(); + } + + if (!Q && immh.Bit<3>()) { + return ReservedValue(); + } + + RoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate, Signedness::Signed); + return true; +} + +bool TranslatorVisitor::SSRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { + if (immh == 0b0000) { + return DecodeError(); + } + if (immh.Bit<3>() && !Q) { + return ReservedValue(); + } + + ShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate, Signedness::Signed); + return true; +} + +bool TranslatorVisitor::SHL_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { + if (immh == 0b0000) { + return DecodeError(); + } + if (immh.Bit<3>() && !Q) { + return ReservedValue(); + } + const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); + const size_t datasize = Q ? 128 : 64; + + const u8 shift_amount = concatenate(immh, immb).ZeroExtend() - static_cast(esize); + + const IR::U128 operand = V(datasize, Vn); + const IR::U128 result = ir.VectorLogicalShiftLeft(esize, operand, shift_amount); + + V(datasize, Vd, result); + return true; +} + bool TranslatorVisitor::SHRN(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { if (immh == 0b0000) { return DecodeError(); @@ -179,42 +214,11 @@ bool TranslatorVisitor::SSHLL(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) if (immh.Bit<3>()) { return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = 64; - const size_t part = Q ? 1 : 0; - const u8 shift_amount = concatenate(immh, immb).ZeroExtend() - static_cast(esize); - - const IR::U128 operand = Vpart(datasize, Vn, part); - const IR::U128 expanded_operand = ir.VectorSignExtend(esize, operand); - const IR::U128 result = ir.VectorLogicalShiftLeft(2 * esize, expanded_operand, shift_amount); - - V(2 * datasize, Vd, result); + ShiftLeftLong(*this, Q, immh, immb, Vn, Vd, Signedness::Signed); return true; } -static void UnsignedRoundingShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, - ShiftExtraBehavior behavior) { - const size_t datasize = Q ? 128 : 64; - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const u8 shift_amount = static_cast((esize * 2) - concatenate(immh, immb).ZeroExtend()); - const u64 round_value = 1ULL << (shift_amount - 1); - - const IR::U128 operand = v.V(datasize, Vn); - const IR::U128 round_const = v.ir.VectorBroadcast(esize, v.I(esize, round_value)); - const IR::U128 round_correction = v.ir.VectorEqual(esize, v.ir.VectorAnd(operand, round_const), round_const); - - const IR::U128 result = v.ir.VectorLogicalShiftRight(esize, operand, shift_amount); - IR::U128 corrected_result = v.ir.VectorSub(esize, result, round_correction); - - if (behavior == ShiftExtraBehavior::Accumulate) { - const IR::U128 accumulator = v.V(datasize, Vd); - corrected_result = v.ir.VectorAdd(esize, accumulator, corrected_result); - } - - v.V(datasize, Vd, corrected_result); -} - bool TranslatorVisitor::URSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { if (immh == 0b0000) { return DecodeError(); @@ -224,7 +228,7 @@ bool TranslatorVisitor::URSHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd return ReservedValue(); } - UnsignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); + RoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None, Signedness::Unsigned); return true; } @@ -237,28 +241,10 @@ bool TranslatorVisitor::URSRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd return ReservedValue(); } - UnsignedRoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate); + RoundingShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate, Signedness::Unsigned); return true; } -static void UnsignedShiftRight(TranslatorVisitor& v, bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd, - ShiftExtraBehavior behavior) { - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = Q ? 128 : 64; - - const u8 shift_amount = static_cast(2 * esize) - concatenate(immh, immb).ZeroExtend(); - - const IR::U128 operand = v.V(datasize, Vn); - IR::U128 result = v.ir.VectorLogicalShiftRight(esize, operand, shift_amount); - - if (behavior == ShiftExtraBehavior::Accumulate) { - const IR::U128 accumulator = v.V(datasize, Vd); - result = v.ir.VectorAdd(esize, accumulator, result); - } - - v.V(datasize, Vd, result); -} - bool TranslatorVisitor::USHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) { if (immh == 0b0000) { return DecodeError(); @@ -267,7 +253,7 @@ bool TranslatorVisitor::USHR_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) return ReservedValue(); } - UnsignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None); + ShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::None, Signedness::Unsigned); return true; } @@ -279,7 +265,7 @@ bool TranslatorVisitor::USRA_2(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) return ReservedValue(); } - UnsignedShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate); + ShiftRight(*this, Q, immh, immb, Vn, Vd, ShiftExtraBehavior::Accumulate, Signedness::Unsigned); return true; } @@ -290,17 +276,8 @@ bool TranslatorVisitor::USHLL(bool Q, Imm<4> immh, Imm<3> immb, Vec Vn, Vec Vd) if (immh.Bit<3>()) { return ReservedValue(); } - const size_t esize = 8 << Common::HighestSetBit(immh.ZeroExtend()); - const size_t datasize = 64; - const size_t part = Q ? 1 : 0; - const u8 shift_amount = concatenate(immh, immb).ZeroExtend() - static_cast(esize); - - const IR::U128 operand = Vpart(datasize, Vn, part); - const IR::U128 expanded_operand = ir.VectorZeroExtend(esize, operand); - const IR::U128 result = ir.VectorLogicalShiftLeft(2 * esize, expanded_operand, shift_amount); - - V(2 * datasize, Vd, result); + ShiftLeftLong(*this, Q, immh, immb, Vn, Vd, Signedness::Unsigned); return true; }