diff --git a/src/backend/x64/emit_x64_vector_floating_point.cpp b/src/backend/x64/emit_x64_vector_floating_point.cpp index 7023ce35..5d31418b 100644 --- a/src/backend/x64/emit_x64_vector_floating_point.cpp +++ b/src/backend/x64/emit_x64_vector_floating_point.cpp @@ -1273,51 +1273,57 @@ static void EmitRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* in } }; - if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if constexpr (fsize != 16) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); - const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); - const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); - const Xbyak::Xmm mask = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm mask = ctx.reg_alloc.ScratchXmm(); - Xbyak::Label end, fallback; + Xbyak::Label end, fallback; - code.vmovaps(result, GetVectorOf(code)); - FCODE(vfnmadd231p)(result, operand1, operand2); + code.vmovaps(result, GetVectorOf(code)); + FCODE(vfnmadd231p)(result, operand1, operand2); - // An explanation for this is given in EmitFPRSqrtStepFused. - code.vmovaps(mask, GetVectorOf(code)); - FCODE(vandp)(tmp, result, mask); - if constexpr (fsize == 32) { - code.vpcmpeqd(tmp, tmp, mask); - } else { - code.vpcmpeqq(tmp, tmp, mask); + // An explanation for this is given in EmitFPRSqrtStepFused. + code.vmovaps(mask, GetVectorOf(code)); + FCODE(vandp)(tmp, result, mask); + if constexpr (fsize == 32) { + code.vpcmpeqd(tmp, tmp, mask); + } else { + code.vpcmpeqq(tmp, tmp, mask); + } + code.ptest(tmp, tmp); + code.jnz(fallback, code.T_NEAR); + + FCODE(vmulp)(result, result, GetVectorOf(code)); + code.L(end); + + code.SwitchToFarCode(); + code.L(fallback); + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + EmitThreeOpFallbackWithoutRegAlloc(code, ctx, result, operand1, operand2, fallback_fn); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); + + ctx.reg_alloc.DefineValue(inst, result); + return; } - code.ptest(tmp, tmp); - code.jnz(fallback, code.T_NEAR); - - FCODE(vmulp)(result, result, GetVectorOf(code)); - code.L(end); - - code.SwitchToFarCode(); - code.L(fallback); - code.sub(rsp, 8); - ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - EmitThreeOpFallbackWithoutRegAlloc(code, ctx, result, operand1, operand2, fallback_fn); - ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.add(rsp, 8); - code.jmp(end, code.T_NEAR); - code.SwitchToNearCode(); - - ctx.reg_alloc.DefineValue(inst, result); - return; } EmitThreeOpFallback(code, ctx, inst, fallback_fn); } +void EmitX64::EmitFPVectorRSqrtStepFused16(EmitContext& ctx, IR::Inst* inst) { + EmitRSqrtStepFused<16>(code, ctx, inst); +} + void EmitX64::EmitFPVectorRSqrtStepFused32(EmitContext& ctx, IR::Inst* inst) { EmitRSqrtStepFused<32>(code, ctx, inst); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 4381b3df..6f9b8715 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -2344,6 +2344,8 @@ U128 IREmitter::FPVectorRSqrtEstimate(size_t esize, const U128& a) { U128 IREmitter::FPVectorRSqrtStepFused(size_t esize, const U128& a, const U128& b) { switch (esize) { + case 16: + return Inst(Opcode::FPVectorRSqrtStepFused16, a, b); case 32: return Inst(Opcode::FPVectorRSqrtStepFused32, a, b); case 64: diff --git a/src/frontend/ir/microinstruction.cpp b/src/frontend/ir/microinstruction.cpp index 212f0169..a6b6bee7 100644 --- a/src/frontend/ir/microinstruction.cpp +++ b/src/frontend/ir/microinstruction.cpp @@ -351,6 +351,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const { case Opcode::FPVectorRSqrtEstimate16: case Opcode::FPVectorRSqrtEstimate32: case Opcode::FPVectorRSqrtEstimate64: + case Opcode::FPVectorRSqrtStepFused16: case Opcode::FPVectorRSqrtStepFused32: case Opcode::FPVectorRSqrtStepFused64: case Opcode::FPVectorSqrt32: diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index ea43f4d0..fbe6c303 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -586,6 +586,7 @@ OPCODE(FPVectorRoundInt64, U128, U128 OPCODE(FPVectorRSqrtEstimate16, U128, U128 ) OPCODE(FPVectorRSqrtEstimate32, U128, U128 ) OPCODE(FPVectorRSqrtEstimate64, U128, U128 ) +OPCODE(FPVectorRSqrtStepFused16, U128, U128, U128 ) OPCODE(FPVectorRSqrtStepFused32, U128, U128, U128 ) OPCODE(FPVectorRSqrtStepFused64, U128, U128, U128 ) OPCODE(FPVectorSqrt32, U128, U128 )