emit_x64_floating_point: FPMulAdd: Inline NaN handling
This commit is contained in:
parent
3bf86d0755
commit
92a47c8db2
1 changed files with 153 additions and 61 deletions
|
@ -78,43 +78,48 @@ constexpr u64 f64_max_s64_lim = 0x43e0000000000000u; // 2^63 as a double (actua
|
|||
}
|
||||
|
||||
template<size_t fsize>
|
||||
void DenormalsAreZero(BlockOfCode& code, EmitContext& ctx, std::initializer_list<Xbyak::Xmm> to_daz) {
|
||||
if (ctx.FPCR().FZ()) {
|
||||
if (code.HasHostFeature(HostFeature::AVX512_OrthoFloat)) {
|
||||
constexpr u32 denormal_to_zero = FixupLUT(
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src);
|
||||
constexpr u64 denormal_to_zero64 = mcl::bit::replicate_element<fsize, u64>(denormal_to_zero);
|
||||
void ForceDenormalsToZero(BlockOfCode& code, std::initializer_list<Xbyak::Xmm> to_daz) {
|
||||
if (code.HasHostFeature(HostFeature::AVX512_OrthoFloat)) {
|
||||
constexpr u32 denormal_to_zero = FixupLUT(
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src,
|
||||
FpFixup::Norm_Src);
|
||||
constexpr u64 denormal_to_zero64 = mcl::bit::replicate_element<fsize, u64>(denormal_to_zero);
|
||||
|
||||
const Xbyak::Xmm tmp = xmm16;
|
||||
FCODE(vmovap)(tmp, code.MConst(xword, u64(denormal_to_zero64), u64(denormal_to_zero64)));
|
||||
|
||||
for (const Xbyak::Xmm& xmm : to_daz) {
|
||||
FCODE(vfixupimms)(xmm, xmm, tmp, u8(0));
|
||||
}
|
||||
return;
|
||||
}
|
||||
const Xbyak::Xmm tmp = xmm16;
|
||||
FCODE(vmovap)(tmp, code.MConst(xword, u64(denormal_to_zero64), u64(denormal_to_zero64)));
|
||||
|
||||
for (const Xbyak::Xmm& xmm : to_daz) {
|
||||
code.movaps(xmm0, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask));
|
||||
code.andps(xmm0, xmm);
|
||||
if constexpr (fsize == 32) {
|
||||
code.pcmpgtd(xmm0, code.MConst(xword, f32_smallest_normal - 1));
|
||||
} else if (code.HasHostFeature(HostFeature::SSE42)) {
|
||||
code.pcmpgtq(xmm0, code.MConst(xword, f64_smallest_normal - 1));
|
||||
} else {
|
||||
code.pcmpgtd(xmm0, code.MConst(xword, f64_smallest_normal - 1));
|
||||
code.pshufd(xmm0, xmm0, 0b11100101);
|
||||
}
|
||||
code.orps(xmm0, code.MConst(xword, fsize == 32 ? f32_negative_zero : f64_negative_zero));
|
||||
code.andps(xmm, xmm0);
|
||||
FCODE(vfixupimms)(xmm, xmm, tmp, u8(0));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (const Xbyak::Xmm& xmm : to_daz) {
|
||||
code.movaps(xmm0, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask));
|
||||
code.andps(xmm0, xmm);
|
||||
if constexpr (fsize == 32) {
|
||||
code.pcmpgtd(xmm0, code.MConst(xword, f32_smallest_normal - 1));
|
||||
} else if (code.HasHostFeature(HostFeature::SSE42)) {
|
||||
code.pcmpgtq(xmm0, code.MConst(xword, f64_smallest_normal - 1));
|
||||
} else {
|
||||
code.pcmpgtd(xmm0, code.MConst(xword, f64_smallest_normal - 1));
|
||||
code.pshufd(xmm0, xmm0, 0b11100101);
|
||||
}
|
||||
code.orps(xmm0, code.MConst(xword, fsize == 32 ? f32_negative_zero : f64_negative_zero));
|
||||
code.andps(xmm, xmm0);
|
||||
}
|
||||
}
|
||||
|
||||
template<size_t fsize>
|
||||
void DenormalsAreZero(BlockOfCode& code, EmitContext& ctx, std::initializer_list<Xbyak::Xmm> to_daz) {
|
||||
if (ctx.FPCR().FZ()) {
|
||||
ForceDenormalsToZero<fsize>(code, to_daz);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -627,59 +632,146 @@ static void EmitFPMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
|
|||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
|
||||
if constexpr (fsize != 16) {
|
||||
if (code.HasHostFeature(HostFeature::FMA) && ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN)) {
|
||||
const bool needs_rounding_correction = ctx.FPCR().FZ();
|
||||
const bool needs_nan_correction = !ctx.FPCR().DN();
|
||||
|
||||
if (code.HasHostFeature(HostFeature::FMA) && !needs_rounding_correction && !needs_nan_correction) {
|
||||
const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]);
|
||||
const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]);
|
||||
const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]);
|
||||
|
||||
FCODE(vfmadd231s)(result, operand2, operand3);
|
||||
if (ctx.FPCR().DN()) {
|
||||
ForceToDefaultNaN<fsize>(code, result);
|
||||
}
|
||||
|
||||
ctx.reg_alloc.DefineValue(inst, result);
|
||||
return;
|
||||
}
|
||||
|
||||
if (code.HasHostFeature(HostFeature::FMA)) {
|
||||
SharedLabel end = GenSharedLabel(), fallback = GenSharedLabel();
|
||||
if (code.HasHostFeature(HostFeature::FMA | HostFeature::AVX)) {
|
||||
SharedLabel fallback = GenSharedLabel(), end = GenSharedLabel();
|
||||
|
||||
const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]);
|
||||
const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]);
|
||||
const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]);
|
||||
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
|
||||
const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
|
||||
|
||||
code.movaps(result, operand1);
|
||||
FCODE(vfmadd231s)(result, operand2, operand3);
|
||||
|
||||
code.movaps(tmp, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask));
|
||||
code.andps(tmp, result);
|
||||
FCODE(ucomis)(tmp, code.MConst(xword, fsize == 32 ? f32_smallest_normal : f64_smallest_normal));
|
||||
code.jz(*fallback, code.T_NEAR);
|
||||
if (needs_rounding_correction && needs_nan_correction) {
|
||||
code.vandps(xmm0, result, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask));
|
||||
FCODE(ucomis)(xmm0, code.MConst(xword, fsize == 32 ? f32_smallest_normal : f64_smallest_normal));
|
||||
code.jz(*fallback, code.T_NEAR);
|
||||
} else if (needs_rounding_correction) {
|
||||
code.vandps(xmm0, result, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask));
|
||||
code.vxorps(xmm0, xmm0, code.MConst(xword, fsize == 32 ? f32_smallest_normal : f64_smallest_normal));
|
||||
code.ptest(xmm0, xmm0);
|
||||
code.jz(*fallback, code.T_NEAR);
|
||||
} else if (needs_nan_correction) {
|
||||
FCODE(ucomis)(result, result);
|
||||
code.jp(*fallback, code.T_NEAR);
|
||||
} else {
|
||||
UNREACHABLE();
|
||||
}
|
||||
if (ctx.FPCR().DN()) {
|
||||
ForceToDefaultNaN<fsize>(code, result);
|
||||
}
|
||||
code.L(*end);
|
||||
|
||||
ctx.deferred_emits.emplace_back([=, &code, &ctx] {
|
||||
code.L(*fallback);
|
||||
|
||||
code.sub(rsp, 8);
|
||||
ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
|
||||
code.movq(code.ABI_PARAM1, operand1);
|
||||
code.movq(code.ABI_PARAM2, operand2);
|
||||
code.movq(code.ABI_PARAM3, operand3);
|
||||
code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR().Value());
|
||||
#ifdef _WIN32
|
||||
code.sub(rsp, 16 + ABI_SHADOW_SPACE);
|
||||
code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
|
||||
code.mov(qword[rsp + ABI_SHADOW_SPACE], rax);
|
||||
code.CallFunction(&FP::FPMulAdd<FPT>);
|
||||
code.add(rsp, 16 + ABI_SHADOW_SPACE);
|
||||
#else
|
||||
code.lea(code.ABI_PARAM5, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
|
||||
code.CallFunction(&FP::FPMulAdd<FPT>);
|
||||
#endif
|
||||
code.movq(result, code.ABI_RETURN);
|
||||
ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
|
||||
code.add(rsp, 8);
|
||||
Xbyak::Label nan;
|
||||
|
||||
code.jmp(*end, code.T_NEAR);
|
||||
if (needs_rounding_correction && needs_nan_correction) {
|
||||
code.jp(nan, code.T_NEAR);
|
||||
}
|
||||
|
||||
if (needs_rounding_correction) {
|
||||
// x64 rounds before flushing to zero
|
||||
// AArch64 rounds after flushing to zero
|
||||
// This difference of behaviour is noticable if something would round to a smallest normalized number
|
||||
|
||||
code.sub(rsp, 8);
|
||||
ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
|
||||
code.movq(code.ABI_PARAM1, operand1);
|
||||
code.movq(code.ABI_PARAM2, operand2);
|
||||
code.movq(code.ABI_PARAM3, operand3);
|
||||
code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR().Value());
|
||||
#ifdef _WIN32
|
||||
code.sub(rsp, 16 + ABI_SHADOW_SPACE);
|
||||
code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
|
||||
code.mov(qword[rsp + ABI_SHADOW_SPACE], rax);
|
||||
code.CallFunction(&FP::FPMulAdd<FPT>);
|
||||
code.add(rsp, 16 + ABI_SHADOW_SPACE);
|
||||
#else
|
||||
code.lea(code.ABI_PARAM5, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
|
||||
code.CallFunction(&FP::FPMulAdd<FPT>);
|
||||
#endif
|
||||
code.movq(result, code.ABI_RETURN);
|
||||
ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
|
||||
code.add(rsp, 8);
|
||||
code.jmp(*end);
|
||||
}
|
||||
|
||||
if (needs_nan_correction) {
|
||||
code.L(nan);
|
||||
|
||||
// AArch64 preferentially returns the first SNaN over the first QNaN
|
||||
// For x64 vfmadd231ss, x64 returns the first of {op2, op3, op1} that is a NaN, irregardless of signalling state
|
||||
|
||||
Xbyak::Label has_nan, indeterminate, op1_snan, op1_done, op2_done, op3_done;
|
||||
|
||||
code.vmovaps(xmm0, code.MConst(xword, FP::FPInfo<FPT>::mantissa_msb));
|
||||
|
||||
FCODE(ucomis)(operand2, operand3);
|
||||
code.jp(has_nan);
|
||||
FCODE(ucomis)(operand1, operand1);
|
||||
code.jnp(indeterminate);
|
||||
|
||||
// AArch64 specifically emits a default NaN for the case when the addend is a QNaN and the two other arguments are {inf, zero}
|
||||
code.ptest(operand1, xmm0);
|
||||
code.jz(op1_snan);
|
||||
FCODE(vmuls)(xmm0, operand2, operand3); // check if {op2, op3} are {inf, zero}/{zero, inf}
|
||||
FCODE(ucomis)(xmm0, xmm0);
|
||||
code.jnp(*end);
|
||||
|
||||
code.L(indeterminate);
|
||||
code.vmovaps(result, code.MConst(xword, FP::FPInfo<FPT>::DefaultNaN()));
|
||||
code.jmp(*end);
|
||||
|
||||
code.L(has_nan);
|
||||
|
||||
FCODE(ucomis)(operand1, operand1);
|
||||
code.jnp(op1_done);
|
||||
code.movaps(result, operand1); // this is done because of NaN behavior of vfmadd231s (priority of op2, op3, op1)
|
||||
code.ptest(operand1, xmm0);
|
||||
code.jnz(op1_done);
|
||||
code.L(op1_snan);
|
||||
code.vorps(result, operand1, xmm0);
|
||||
code.jmp(*end);
|
||||
code.L(op1_done);
|
||||
|
||||
FCODE(ucomis)(operand2, operand2);
|
||||
code.jnp(op2_done);
|
||||
code.ptest(operand2, xmm0);
|
||||
code.jnz(op2_done);
|
||||
code.vorps(result, operand2, xmm0);
|
||||
code.jmp(*end);
|
||||
code.L(op2_done);
|
||||
|
||||
FCODE(ucomis)(operand3, operand3);
|
||||
code.jnp(op3_done);
|
||||
code.ptest(operand3, xmm0);
|
||||
code.jnz(op3_done);
|
||||
code.vorps(result, operand3, xmm0);
|
||||
code.jmp(*end);
|
||||
code.L(op3_done);
|
||||
|
||||
code.jmp(*end);
|
||||
}
|
||||
});
|
||||
|
||||
ctx.reg_alloc.DefineValue(inst, result);
|
||||
|
|
Loading…
Add table
Reference in a new issue