diff --git a/src/backend/x64/emit_x64_floating_point.cpp b/src/backend/x64/emit_x64_floating_point.cpp index 329b492c..80716794 100644 --- a/src/backend/x64/emit_x64_floating_point.cpp +++ b/src/backend/x64/emit_x64_floating_point.cpp @@ -873,8 +873,58 @@ void EmitX64::EmitFPRSqrtEstimate64(EmitContext& ctx, IR::Inst* inst) { EmitFPRSqrtEstimate(code, ctx, inst); } -template +template static void EmitFPRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { + using FPT = mp::unsigned_integer_of_size; + + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + Xbyak::Label end, fallback; + + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + + code.vmovaps(result, code.MConst(xword, FP::FPValue())); + FCODE(vfnmadd231s)(result, operand1, operand2); + + // Detect if the intermediate result is infinity or NaN or nearly an infinity. + // Why do we need to care about infinities? This is because x86 doesn't allow us + // to fuse the divide-by-two with the rest of the FMA operation. Therefore the + // intermediate value may overflow and we would like to handle this case. + const Xbyak::Reg32 tmp = ctx.reg_alloc.ScratchGpr().cvt32(); + code.vpextrw(tmp, result, fsize == 32 ? 1 : 3); + code.and_(tmp.cvt16(), fsize == 32 ? 0x7f80 : 0x7ff0); + code.cmp(tmp.cvt16(), fsize == 32 ? 0x7f00 : 0x7fe0); + ctx.reg_alloc.Release(tmp); + + code.jae(fallback, code.T_NEAR); + + FCODE(vmuls)(result, result, code.MConst(xword, FP::FPValue())); + code.L(end); + + code.SwitchToFarCode(); + code.L(fallback); + + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.movq(code.ABI_PARAM1, operand1); + code.movq(code.ABI_PARAM2, operand2); + code.mov(code.ABI_PARAM3.cvt32(), ctx.FPCR()); + code.lea(code.ABI_PARAM4, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); + code.CallFunction(&FP::FPRSqrtStepFused); + code.movq(result, code.ABI_RETURN); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); + + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + auto args = ctx.reg_alloc.GetArgumentInfo(inst); ctx.reg_alloc.HostCall(inst, args[0], args[1]); code.mov(code.ABI_PARAM3.cvt32(), ctx.FPCR()); @@ -883,11 +933,11 @@ static void EmitFPRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* } void EmitX64::EmitFPRSqrtStepFused32(EmitContext& ctx, IR::Inst* inst) { - EmitFPRSqrtStepFused(code, ctx, inst); + EmitFPRSqrtStepFused<32>(code, ctx, inst); } void EmitX64::EmitFPRSqrtStepFused64(EmitContext& ctx, IR::Inst* inst) { - EmitFPRSqrtStepFused(code, ctx, inst); + EmitFPRSqrtStepFused<64>(code, ctx, inst); } void EmitX64::EmitFPSqrt32(EmitContext& ctx, IR::Inst* inst) { diff --git a/tests/A64/a64.cpp b/tests/A64/a64.cpp index 3dc876b0..0d24fe55 100644 --- a/tests/A64/a64.cpp +++ b/tests/A64/a64.cpp @@ -472,3 +472,25 @@ TEST_CASE("A64: FNEG failed to zero upper", "[a64]") { REQUIRE(jit.GetVector(28) == Vector{0x79ee7a03980db670, 0}); REQUIRE(FP::FPSR{jit.GetFpsr()}.QC() == false); } + +TEST_CASE("A64: FRSQRTS", "[a64]") { + A64TestEnv env; + Dynarmic::A64::Jit jit{Dynarmic::A64::UserConfig{&env}}; + + env.code_mem.emplace_back(0x5eb8fcad); // FRSQRTS S13, S5, S24 + env.code_mem.emplace_back(0x14000000); // B . + + // These particular values result in an intermediate value during + // the calculation that is close to infinity. We want to verify + // that this special case is handled appropriately. + + jit.SetPC(0); + jit.SetVector(5, {0xfc6a0206, 0}); + jit.SetVector(24, {0xfc6a0206, 0}); + jit.SetFpcr(0x00400000); + + env.ticks_left = 2; + jit.Run(); + + REQUIRE(jit.GetVector(13) == Vector{0xff7fffff, 0}); +} diff --git a/tests/fp/FPValue.cpp b/tests/fp/FPValue.cpp index e5c4ee52..fd9a2e2a 100644 --- a/tests/fp/FPValue.cpp +++ b/tests/fp/FPValue.cpp @@ -13,3 +13,4 @@ static_assert(FPValue() == 0x3fc00000); static_assert(FPValue() == 0x4b4264e4); static_assert(FPValue() == 0x3ec80000); static_assert(FPValue() == 0xbf800000); +static_assert(FPValue() == 0x3f000000);