diff --git a/src/frontend/translate/translate_arm/multiply.cpp b/src/frontend/translate/translate_arm/multiply.cpp index 33f07025..3e56b6c8 100644 --- a/src/frontend/translate/translate_arm/multiply.cpp +++ b/src/frontend/translate/translate_arm/multiply.cpp @@ -147,25 +147,94 @@ bool ArmTranslatorVisitor::arm_UMULL(Cond cond, bool S, Reg dHi, Reg dLo, Reg m, // Multiply (Halfword) instructions bool ArmTranslatorVisitor::arm_SMLALxy(Cond cond, Reg dHi, Reg dLo, Reg m, bool M, bool N, Reg n) { - return InterpretThisInstruction(); + if (dLo == Reg::PC || dHi == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (dLo == dHi) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n16 = N ? ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result + : ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m16 = M ? ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result + : ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto product = ir.SignExtendWordToLong(ir.Mul(n16, m16)); + auto addend = ir.Pack2x32To1x64(ir.GetRegister(dLo), ir.GetRegister(dHi)); + auto result = ir.Add64(product, addend); + ir.SetRegister(dLo, ir.LeastSignificantWord(result)); + ir.SetRegister(dHi, ir.MostSignificantWord(result).result); + } + return true; } bool ArmTranslatorVisitor::arm_SMLAxy(Cond cond, Reg d, Reg a, Reg m, bool M, bool N, Reg n) { - return InterpretThisInstruction(); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC || a == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n16 = N ? ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result + : ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m16 = M ? ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result + : ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto product = ir.Mul(n16, m16); + auto result_overflow = ir.AddWithCarry(product, ir.GetRegister(a), ir.Imm1(0)); + ir.SetRegister(d, result_overflow.result); + ir.OrQFlag(result_overflow.overflow); + } + return true; } bool ArmTranslatorVisitor::arm_SMULxy(Cond cond, Reg d, Reg m, bool M, bool N, Reg n) { - return InterpretThisInstruction(); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n16 = N ? ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result + : ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m16 = M ? ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result + : ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto result = ir.Mul(n16, m16); + ir.SetRegister(d, result); + } + return true; } // Multiply (word by halfword) instructions bool ArmTranslatorVisitor::arm_SMLAWy(Cond cond, Reg d, Reg a, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC || a == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.SignExtendWordToLong(ir.GetRegister(n)); + auto m32 = ir.GetRegister(m); + if (M) + m32 = ir.LogicalShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + auto m16 = ir.LeastSignificantHalf(m32); + m16 = ir.SignExtendWordToLong(ir.SignExtendHalfToWord(m16)); + auto product = ir.LeastSignificantWord(ir.LogicalShiftRight64(ir.Mul64(n32, m16), ir.Imm8(16))); + auto result_overflow = ir.AddWithCarry(product, ir.GetRegister(a), ir.Imm1(0)); + ir.SetRegister(d, result_overflow.result); + ir.OrQFlag(result_overflow.overflow); + } + return true; } bool ArmTranslatorVisitor::arm_SMULWy(Cond cond, Reg d, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.SignExtendWordToLong(ir.GetRegister(n)); + auto m32 = ir.GetRegister(m); + if (M) + m32 = ir.LogicalShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + auto m16 = ir.LeastSignificantHalf(m32); + m16 = ir.SignExtendWordToLong(ir.SignExtendHalfToWord(m16)); + auto result = ir.LogicalShiftRight64(ir.Mul64(n32, m16), ir.Imm8(16)); + ir.SetRegister(d, ir.LeastSignificantWord(result)); + } + return true; } @@ -223,27 +292,142 @@ bool ArmTranslatorVisitor::arm_SMMUL(Cond cond, Reg d, Reg m, bool R, Reg n) { // Multiply (Dual) instructions bool ArmTranslatorVisitor::arm_SMLAD(Cond cond, Reg d, Reg a, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (a == Reg::PC) + return arm_SMUAD(cond, d, m, M, n); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto n_hi = ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result; + auto m_hi = ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + if (M) + std::swap(m_lo, m_hi); + auto product_lo = ir.Mul(n_lo, m_lo); + auto product_hi = ir.Mul(n_hi, m_hi); + auto addend = ir.GetRegister(a); + auto result_overflow = ir.AddWithCarry(product_lo, product_hi, ir.Imm1(0)); + ir.OrQFlag(result_overflow.overflow); + result_overflow = ir.AddWithCarry(result_overflow.result, addend, ir.Imm1(0)); + ir.SetRegister(d, result_overflow.result); + ir.OrQFlag(result_overflow.overflow); + } + return true; } bool ArmTranslatorVisitor::arm_SMLALD(Cond cond, Reg dHi, Reg dLo, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (dLo == Reg::PC || dHi == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (dLo == dHi) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto n_hi = ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result; + auto m_hi = ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + if (M) + std::swap(m_lo, m_hi); + auto product_lo = ir.SignExtendWordToLong(ir.Mul(n_lo, m_lo)); + auto product_hi = ir.SignExtendWordToLong(ir.Mul(n_hi, m_hi)); + auto addend = ir.Pack2x32To1x64(ir.GetRegister(dLo), ir.GetRegister(dHi)); + auto result = ir.Add64(ir.Add64(product_lo, product_hi), addend); + ir.SetRegister(dLo, ir.LeastSignificantWord(result)); + ir.SetRegister(dHi, ir.MostSignificantWord(result).result); + } + return true; } bool ArmTranslatorVisitor::arm_SMLSD(Cond cond, Reg d, Reg a, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (a == Reg::PC) + return arm_SMUSD(cond, d, m, M, n); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto n_hi = ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result; + auto m_hi = ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + if (M) + std::swap(m_lo, m_hi); + auto product_lo = ir.Mul(n_lo, m_lo); + auto product_hi = ir.Mul(n_hi, m_hi); + auto addend = ir.GetRegister(a); + auto result_overflow = ir.AddWithCarry(ir.Sub(product_lo, product_hi), addend, ir.Imm1(0)); + ir.SetRegister(d, result_overflow.result); + ir.OrQFlag(result_overflow.overflow); + } + return true; } bool ArmTranslatorVisitor::arm_SMLSLD(Cond cond, Reg dHi, Reg dLo, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (dLo == Reg::PC || dHi == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (dLo == dHi) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto n_hi = ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result; + auto m_hi = ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + if (M) + std::swap(m_lo, m_hi); + auto product_lo = ir.SignExtendWordToLong(ir.Mul(n_lo, m_lo)); + auto product_hi = ir.SignExtendWordToLong(ir.Mul(n_hi, m_hi)); + auto addend = ir.Pack2x32To1x64(ir.GetRegister(dLo), ir.GetRegister(dHi)); + auto result = ir.Add64(ir.Sub64(product_lo, product_hi), addend); + ir.SetRegister(dLo, ir.LeastSignificantWord(result)); + ir.SetRegister(dHi, ir.MostSignificantWord(result).result); + } + return true; } bool ArmTranslatorVisitor::arm_SMUAD(Cond cond, Reg d, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto n_hi = ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result; + auto m_hi = ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + if (M) + std::swap(m_lo, m_hi); + auto product_lo = ir.Mul(n_lo, m_lo); + auto product_hi = ir.Mul(n_hi, m_hi); + auto result_overflow = ir.AddWithCarry(product_lo, product_hi, ir.Imm1(0)); + ir.SetRegister(d, result_overflow.result); + ir.OrQFlag(result_overflow.overflow); + } + return true; } bool ArmTranslatorVisitor::arm_SMUSD(Cond cond, Reg d, Reg m, bool M, Reg n) { - return InterpretThisInstruction(); + if (d == Reg::PC || n == Reg::PC || m == Reg::PC) + return UnpredictableInstruction(); + if (ConditionPassed(cond)) { + auto n32 = ir.GetRegister(n); + auto m32 = ir.GetRegister(m); + auto n_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(n32)); + auto m_lo = ir.SignExtendHalfToWord(ir.LeastSignificantHalf(m32)); + auto n_hi = ir.ArithmeticShiftRight(n32, ir.Imm8(16), ir.Imm1(0)).result; + auto m_hi = ir.ArithmeticShiftRight(m32, ir.Imm8(16), ir.Imm1(0)).result; + if (M) + std::swap(m_lo, m_hi); + auto product_lo = ir.Mul(n_lo, m_lo); + auto product_hi = ir.Mul(n_hi, m_hi); + auto result = ir.Sub(product_lo, product_hi); + ir.SetRegister(d, result); + } + return true; } } // namespace Arm diff --git a/tests/arm/fuzz_arm.cpp b/tests/arm/fuzz_arm.cpp index 61db571f..a54e60d2 100644 --- a/tests/arm/fuzz_arm.cpp +++ b/tests/arm/fuzz_arm.cpp @@ -792,7 +792,7 @@ TEST_CASE("Fuzz ARM multiply instructions", "[JitX64]") { Bits<12, 15>(inst) != Bits<16, 19>(inst); }; - const std::array instructions = {{ + const std::array instructions = {{ InstructionGenerator("cccc0000001Sddddaaaammmm1001nnnn", validate_d_a_m_n), // MLA InstructionGenerator("cccc0000000Sdddd0000mmmm1001nnnn", validate_d_m_n), // MUL @@ -802,27 +802,26 @@ TEST_CASE("Fuzz ARM multiply instructions", "[JitX64]") { InstructionGenerator("cccc0000101Sddddaaaammmm1001nnnn", validate_h_l_m_n), // UMLAL InstructionGenerator("cccc0000100Sddddaaaammmm1001nnnn", validate_h_l_m_n), // UMULL - //InstructionGenerator("cccc00010100ddddaaaammmm1xy0nnnn", validate_d_a_m_n), // SMLALxy - //InstructionGenerator("cccc00010000ddddaaaammmm1xy0nnnn", validate_d_a_m_n), // SMLAxy - //InstructionGenerator("cccc00010110dddd0000mmmm1xy0nnnn", validate_d_m_n), // SMULxy + InstructionGenerator("cccc00010100ddddaaaammmm1xy0nnnn", validate_h_l_m_n), // SMLALxy + InstructionGenerator("cccc00010000ddddaaaammmm1xy0nnnn", validate_d_a_m_n), // SMLAxy + InstructionGenerator("cccc00010110dddd0000mmmm1xy0nnnn", validate_d_m_n), // SMULxy - //InstructionGenerator("cccc00010010ddddaaaammmm1y00nnnn", validate_d_a_m_n), // SMLAWy - //InstructionGenerator("cccc00010010dddd0000mmmm1y10nnnn", validate_d_m_n), // SMULWy + InstructionGenerator("cccc00010010ddddaaaammmm1y00nnnn", validate_d_a_m_n), // SMLAWy + InstructionGenerator("cccc00010010dddd0000mmmm1y10nnnn", validate_d_m_n), // SMULWy InstructionGenerator("cccc01110101dddd1111mmmm00R1nnnn", validate_d_m_n), // SMMUL InstructionGenerator("cccc01110101ddddaaaammmm00R1nnnn", validate_d_a_m_n), // SMMLA InstructionGenerator("cccc01110101ddddaaaammmm11R1nnnn", validate_d_a_m_n), // SMMLS - - //InstructionGenerator("cccc01110000ddddaaaammmm00M1nnnn", validate_d_a_m_n), // SMLAD - //InstructionGenerator("cccc01110100ddddaaaammmm00M1nnnn", validate_d_a_m_n), // SMLALD - //InstructionGenerator("cccc01110000ddddaaaammmm01M1nnnn", validate_d_a_m_n), // SMLSD - //InstructionGenerator("cccc01110100ddddaaaammmm01M1nnnn", validate_d_a_m_n), // SMLSLD - //InstructionGenerator("cccc01110000dddd1111mmmm00M1nnnn", validate_d_m_n), // SMUAD - //InstructionGenerator("cccc01110000dddd1111mmmm01M1nnnn", validate_d_m_n), // SMUSD + InstructionGenerator("cccc01110000ddddaaaammmm00M1nnnn", validate_d_a_m_n), // SMLAD + InstructionGenerator("cccc01110100ddddaaaammmm00M1nnnn", validate_h_l_m_n), // SMLALD + InstructionGenerator("cccc01110000ddddaaaammmm01M1nnnn", validate_d_a_m_n), // SMLSD + InstructionGenerator("cccc01110100ddddaaaammmm01M1nnnn", validate_h_l_m_n), // SMLSLD + InstructionGenerator("cccc01110000dddd1111mmmm00M1nnnn", validate_d_m_n), // SMUAD + InstructionGenerator("cccc01110000dddd1111mmmm01M1nnnn", validate_d_m_n), // SMUSD }}; SECTION("Multiply") { - FuzzJitArm(2, 2, 10000, [&]() -> u32 { + FuzzJitArm(1, 1, 10000, [&]() -> u32 { return instructions[RandInt(0, instructions.size() - 1)].Generate(); }); } @@ -852,6 +851,30 @@ TEST_CASE("Fuzz ARM parallel instructions", "[JitX64]") { } } +TEST_CASE( "SMUAD", "[JitX64]" ) { + Dynarmic::Jit jit{GetUserCallbacks()}; + code_mem.fill({}); + code_mem[0] = 0xE700F211; // smuad r0, r1, r2 + + jit.Regs() = { + 0, // Rd + 0x80008000, // Rn + 0x80008000, // Rm + 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }; + jit.Cpsr() = 0x000001d0; // User-mode + + jit.Run(6); + + REQUIRE(jit.Regs()[0] == 0x80000000); + REQUIRE(jit.Regs()[1] == 0x80008000); + REQUIRE(jit.Regs()[2] == 0x80008000); + REQUIRE(jit.Cpsr() == 0x080001d0); +} + TEST_CASE("VFP: VPUSH, VPOP", "[JitX64][vfp]") { const auto is_valid = [](u32 instr) -> bool { auto regs = (instr & 0x100) ? (Bits<0, 7>(instr) >> 1) : Bits<0, 7>(instr);