emit_x64{_vector}_floating_point: Unsafe AVX512 implementation of Emit{RSqrt,Recip}Estimate

This implementation exists within the unsafe optimization paths and
utilize the 14-bit-precision `vrsqrt14*` and `vrcp14p*`
instructions provided by AVX512F+VL. These are _more_ accurate than
the fallback path and the current `rsqrt`-based unsafe code-path
but still falls in line with what is expected of the
`Unsafe_ReducedErrorFP` optimization flag.

Having AVX512 available will mean this function has 14 bits of precision.
Not having AVX512 available will mean these functions have 11 bits of precision.
This commit is contained in:
Wunkolo 2021-06-22 23:55:27 -07:00 committed by merry
parent ea02a7d05d
commit 1fc96fd0c2
2 changed files with 36 additions and 22 deletions

View file

@ -766,12 +766,16 @@ static void EmitFPRecipEstimate(BlockOfCode& code, EmitContext& ctx, IR::Inst* i
const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
if constexpr (fsize == 32) { if (code.HasHostFeature(HostFeature::AVX512_OrthoFloat)) {
code.rcpss(result, operand); FCODE(vrcp14s)(result, operand, operand);
} else { } else {
code.cvtsd2ss(result, operand); if constexpr (fsize == 32) {
code.rcpss(result, result); code.rcpss(result, operand);
code.cvtss2sd(result, result); } else {
code.cvtsd2ss(result, operand);
code.rcpss(result, result);
code.cvtss2sd(result, result);
}
} }
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
@ -984,20 +988,22 @@ static void EmitFPRSqrtEstimate(BlockOfCode& code, EmitContext& ctx, IR::Inst* i
const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
if constexpr (fsize == 32) { if (code.HasHostFeature(HostFeature::AVX512_OrthoFloat)) {
code.rsqrtss(result, operand); FCODE(vrsqrt14s)(result, operand, operand);
} else { } else {
code.cvtsd2ss(result, operand); if constexpr (fsize == 32) {
code.rsqrtss(result, result); code.rsqrtss(result, operand);
code.cvtss2sd(result, result); } else {
code.cvtsd2ss(result, operand);
code.rsqrtss(result, result);
code.cvtss2sd(result, result);
}
} }
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
return; return;
} }
// TODO: VRSQRT14SS implementation (AVX512F)
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]);

View file

@ -1288,12 +1288,16 @@ static void EmitRecipEstimate(BlockOfCode& code, EmitContext& ctx, IR::Inst* ins
const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
if constexpr (fsize == 32) { if (code.HasHostFeature(HostFeature::AVX512_OrthoFloat)) {
code.rcpps(result, operand); FCODE(vrcp14p)(result, operand);
} else { } else {
code.cvtpd2ps(result, operand); if constexpr (fsize == 32) {
code.rcpps(result, result); code.rcpps(result, operand);
code.cvtps2pd(result, result); } else {
code.cvtpd2ps(result, operand);
code.rcpps(result, result);
code.cvtps2pd(result, result);
}
} }
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
@ -1502,12 +1506,16 @@ static void EmitRSqrtEstimate(BlockOfCode& code, EmitContext& ctx, IR::Inst* ins
const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm operand = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
if constexpr (fsize == 32) { if (code.HasHostFeature(HostFeature::AVX512_OrthoFloat)) {
code.rsqrtps(result, operand); FCODE(vrsqrt14p)(result, operand);
} else { } else {
code.cvtpd2ps(result, operand); if constexpr (fsize == 32) {
code.rsqrtps(result, result); code.rsqrtps(result, operand);
code.cvtps2pd(result, result); } else {
code.cvtpd2ps(result, operand);
code.rsqrtps(result, result);
code.cvtps2pd(result, result);
}
} }
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);