From f9ccf91b941bfd240fa90fec0023789f41914ed9 Mon Sep 17 00:00:00 2001 From: MerryMage Date: Sat, 2 Jan 2021 17:19:08 +0000 Subject: [PATCH] Add Unsafe_InaccurateNaN optimization to all fma instructions --- src/backend/x64/emit_x64_floating_point.cpp | 32 ++++++++-- .../x64/emit_x64_vector_floating_point.cpp | 63 ++++++++++++++----- 2 files changed, 77 insertions(+), 18 deletions(-) diff --git a/src/backend/x64/emit_x64_floating_point.cpp b/src/backend/x64/emit_x64_floating_point.cpp index 9bc8bf6d..c31aed14 100644 --- a/src/backend/x64/emit_x64_floating_point.cpp +++ b/src/backend/x64/emit_x64_floating_point.cpp @@ -590,9 +590,20 @@ static void EmitFPMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { using FPT = mp::unsigned_integer_of_size; if constexpr (fsize != 16) { - if (code.HasFMA()) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if (code.HasFMA() && ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN)) { + const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]); + + FCODE(vfmadd231s)(result, operand2, operand3); + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + + if (code.HasFMA()) { Xbyak::Label end, fallback; const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); @@ -641,8 +652,6 @@ static void EmitFPMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { } if (ctx.HasOptimization(OptimizationFlag::Unsafe_UnfuseFMA)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseScratchXmm(args[0]); const Xbyak::Xmm operand2 = ctx.reg_alloc.UseScratchXmm(args[1]); const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]); @@ -1014,6 +1023,21 @@ static void EmitFPRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* using FPT = mp::unsigned_integer_of_size; if constexpr (fsize != 16) { + if (code.HasFMA() && code.HasAVX() && ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + + code.vmovaps(result, code.MConst(xword, FP::FPValue())); + FCODE(vfnmadd231s)(result, operand1, operand2); + FCODE(vmuls)(result, result, code.MConst(xword, FP::FPValue())); + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + if (code.HasFMA() && code.HasAVX()) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); diff --git a/src/backend/x64/emit_x64_vector_floating_point.cpp b/src/backend/x64/emit_x64_vector_floating_point.cpp index 20b1c808..07d4cd4e 100644 --- a/src/backend/x64/emit_x64_vector_floating_point.cpp +++ b/src/backend/x64/emit_x64_vector_floating_point.cpp @@ -985,11 +985,23 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { }; if constexpr (fsize != 16) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const bool fpcr_controlled = args[3].GetImmediateU1(); + if (code.HasFMA() && code.HasAVX()) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); - const bool fpcr_controlled = args[3].GetImmediateU1(); + MaybeStandardFPSCRValue(code, ctx, fpcr_controlled, [&]{ + FCODE(vfmadd231p)(result, xmm_b, xmm_c); + }); + ctx.reg_alloc.DefineValue(inst, result); + return; + } + + if (code.HasFMA() && code.HasAVX()) { const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); @@ -1025,8 +1037,6 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { } if (ctx.HasOptimization(OptimizationFlag::Unsafe_UnfuseFMA)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseScratchXmm(args[0]); const Xbyak::Xmm operand2 = ctx.reg_alloc.UseScratchXmm(args[1]); const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]); @@ -1233,10 +1243,24 @@ static void EmitRecipStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* in }; if constexpr (fsize != 16) { - if (code.HasFMA() && code.HasAVX()) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const bool fpcr_controlled = args[2].GetImmediateU1(); + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const bool fpcr_controlled = args[2].GetImmediateU1(); + if (code.HasFMA() && code.HasAVX() && ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN)) { + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + + MaybeStandardFPSCRValue(code, ctx, fpcr_controlled, [&]{ + code.movaps(result, GetVectorOf(code)); + FCODE(vfnmadd231p)(result, operand1, operand2); + }); + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + + if (code.HasFMA() && code.HasAVX()) { const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); @@ -1269,8 +1293,6 @@ static void EmitRecipStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* in } if (ctx.HasOptimization(OptimizationFlag::Unsafe_UnfuseFMA)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseScratchXmm(args[0]); const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); @@ -1428,10 +1450,25 @@ static void EmitRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* in }; if constexpr (fsize != 16) { - if (code.HasFMA() && code.HasAVX()) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const bool fpcr_controlled = args[2].GetImmediateU1(); + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const bool fpcr_controlled = args[2].GetImmediateU1(); + if (code.HasFMA() && code.HasAVX() && ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN)) { + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + + MaybeStandardFPSCRValue(code, ctx, fpcr_controlled, [&]{ + code.vmovaps(result, GetVectorOf(code)); + FCODE(vfnmadd231p)(result, operand1, operand2); + FCODE(vmulp)(result, result, GetVectorOf(code)); + }); + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + + if (code.HasFMA() && code.HasAVX()) { const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); @@ -1470,8 +1507,6 @@ static void EmitRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* in } if (ctx.HasOptimization(OptimizationFlag::Unsafe_UnfuseFMA)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseScratchXmm(args[0]); const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();