emit_x64_vector_floating_point: FPVectorMulAdd: Minimize full fallback
This commit is contained in:
parent
ceea80dd59
commit
adac93f12e
1 changed files with 59 additions and 10 deletions
|
@ -381,7 +381,7 @@ void EmitTwoOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* ins
|
|||
ctx.reg_alloc.DefineValue(inst, result);
|
||||
}
|
||||
|
||||
enum CheckInputNaN {
|
||||
enum class CheckInputNaN {
|
||||
Yes,
|
||||
No,
|
||||
};
|
||||
|
@ -540,7 +540,12 @@ void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, La
|
|||
ctx.reg_alloc.DefineValue(inst, result);
|
||||
}
|
||||
|
||||
template<typename Lambda>
|
||||
enum class LoadPreviousResult {
|
||||
Yes,
|
||||
No,
|
||||
};
|
||||
|
||||
template<LoadPreviousResult load_previous_result = LoadPreviousResult::No, typename Lambda>
|
||||
void EmitFourOpFallbackWithoutRegAlloc(BlockOfCode& code, EmitContext& ctx, Xbyak::Xmm result, Xbyak::Xmm arg1, Xbyak::Xmm arg2, Xbyak::Xmm arg3, Lambda lambda, bool fpcr_controlled) {
|
||||
const auto fn = static_cast<mcl::equivalent_function_type<Lambda>*>(lambda);
|
||||
|
||||
|
@ -565,6 +570,9 @@ void EmitFourOpFallbackWithoutRegAlloc(BlockOfCode& code, EmitContext& ctx, Xbya
|
|||
code.lea(code.ABI_PARAM6, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
|
||||
#endif
|
||||
|
||||
if constexpr (load_previous_result == LoadPreviousResult::Yes) {
|
||||
code.movaps(xword[code.ABI_PARAM1], result);
|
||||
}
|
||||
code.movaps(xword[code.ABI_PARAM2], arg1);
|
||||
code.movaps(xword[code.ABI_PARAM3], arg2);
|
||||
code.movaps(xword[code.ABI_PARAM4], arg3);
|
||||
|
@ -1290,6 +1298,31 @@ void EmitX64::EmitFPVectorMul64(EmitContext& ctx, IR::Inst* inst) {
|
|||
EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd);
|
||||
}
|
||||
|
||||
template<typename FPT, bool needs_rounding_correction, bool needs_nan_correction>
|
||||
static void EmitFPVectorMulAddFallback(VectorArray<FPT>& result, const VectorArray<FPT>& addend, const VectorArray<FPT>& op1, const VectorArray<FPT>& op2, FP::FPCR fpcr, [[maybe_unused]] FP::FPSR& fpsr) {
|
||||
for (size_t i = 0; i < result.size(); i++) {
|
||||
if constexpr (needs_rounding_correction) {
|
||||
constexpr FPT non_sign_mask = FP::FPInfo<FPT>::exponent_mask | FP::FPInfo<FPT>::mantissa_mask;
|
||||
constexpr FPT smallest_normal_number = FP::FPValue<FPT, false, FP::FPInfo<FPT>::exponent_min, 1>();
|
||||
if ((result[i] & non_sign_mask) == smallest_normal_number) {
|
||||
result[i] = FP::FPMulAdd<FPT>(addend[i], op1[i], op2[i], fpcr, fpsr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if constexpr (needs_nan_correction) {
|
||||
if (FP::IsNaN(result[i])) {
|
||||
if (FP::IsQNaN(addend[i]) && ((FP::IsZero(op1[i], fpcr) && FP::IsInf(op2[i])) || (FP::IsInf(op1[i]) && FP::IsZero(op2[i], fpcr)))) {
|
||||
result[i] = FP::FPInfo<FPT>::DefaultNaN();
|
||||
} else if (auto r = FP::ProcessNaNs(addend[i], op1[i], op2[i])) {
|
||||
result[i] = *r;
|
||||
} else {
|
||||
result[i] = FP::FPInfo<FPT>::DefaultNaN();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<size_t fsize>
|
||||
void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
|
||||
using FPT = mcl::unsigned_integer_of_size<fsize>;
|
||||
|
@ -1301,9 +1334,12 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
|
|||
};
|
||||
|
||||
if constexpr (fsize != 16) {
|
||||
if (code.HasHostFeature(HostFeature::FMA | HostFeature::AVX) && ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN)) {
|
||||
const bool fpcr_controlled = inst->GetArg(3).GetU1();
|
||||
const bool needs_rounding_correction = ctx.FPCR(fpcr_controlled).FZ();
|
||||
const bool needs_nan_correction = !(ctx.FPCR(fpcr_controlled).DN() || ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN));
|
||||
|
||||
if (code.HasHostFeature(HostFeature::FMA) && !needs_rounding_correction && !needs_nan_correction) {
|
||||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
const bool fpcr_controlled = args[3].GetImmediateU1();
|
||||
|
||||
const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]);
|
||||
const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
|
||||
|
@ -1311,6 +1347,7 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
|
|||
|
||||
MaybeStandardFPSCRValue(code, ctx, fpcr_controlled, [&] {
|
||||
FCODE(vfmadd231p)(result, xmm_b, xmm_c);
|
||||
ForceToDefaultNaN<fsize>(code, ctx.FPCR(fpcr_controlled), result);
|
||||
});
|
||||
|
||||
ctx.reg_alloc.DefineValue(inst, result);
|
||||
|
@ -1319,12 +1356,11 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
|
|||
|
||||
if (code.HasHostFeature(HostFeature::FMA | HostFeature::AVX)) {
|
||||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
const bool fpcr_controlled = args[3].GetImmediateU1();
|
||||
|
||||
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
|
||||
const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]);
|
||||
const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
|
||||
const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]);
|
||||
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
|
||||
const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
|
||||
|
||||
SharedLabel end = GenSharedLabel(), fallback = GenSharedLabel();
|
||||
|
@ -1333,19 +1369,32 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
|
|||
code.movaps(result, xmm_a);
|
||||
FCODE(vfmadd231p)(result, xmm_b, xmm_c);
|
||||
|
||||
code.movaps(tmp, GetNegativeZeroVector<fsize>(code));
|
||||
code.andnps(tmp, result);
|
||||
FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector<fsize>(code));
|
||||
if (needs_rounding_correction && needs_nan_correction) {
|
||||
code.vandps(tmp, result, GetNonSignMaskVector<fsize>(code));
|
||||
FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector<fsize>(code));
|
||||
} else if (needs_rounding_correction) {
|
||||
code.vandps(tmp, result, GetNonSignMaskVector<fsize>(code));
|
||||
ICODE(vpcmpeq)(tmp, tmp, GetSmallestNormalVector<fsize>(code));
|
||||
} else if (needs_nan_correction) {
|
||||
FCODE(vcmpunordp)(tmp, result, result);
|
||||
}
|
||||
code.vptest(tmp, tmp);
|
||||
code.jnz(*fallback, code.T_NEAR);
|
||||
code.L(*end);
|
||||
ForceToDefaultNaN<fsize>(code, ctx.FPCR(fpcr_controlled), result);
|
||||
});
|
||||
|
||||
ctx.deferred_emits.emplace_back([=, &code, &ctx] {
|
||||
code.L(*fallback);
|
||||
code.sub(rsp, 8);
|
||||
ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
|
||||
EmitFourOpFallbackWithoutRegAlloc(code, ctx, result, xmm_a, xmm_b, xmm_c, fallback_fn, fpcr_controlled);
|
||||
if (needs_rounding_correction && needs_nan_correction) {
|
||||
EmitFourOpFallbackWithoutRegAlloc<LoadPreviousResult::Yes>(code, ctx, result, xmm_a, xmm_b, xmm_c, EmitFPVectorMulAddFallback<FPT, true, true>, fpcr_controlled);
|
||||
} else if (needs_rounding_correction) {
|
||||
EmitFourOpFallbackWithoutRegAlloc<LoadPreviousResult::Yes>(code, ctx, result, xmm_a, xmm_b, xmm_c, EmitFPVectorMulAddFallback<FPT, true, false>, fpcr_controlled);
|
||||
} else if (needs_nan_correction) {
|
||||
EmitFourOpFallbackWithoutRegAlloc<LoadPreviousResult::Yes>(code, ctx, result, xmm_a, xmm_b, xmm_c, EmitFPVectorMulAddFallback<FPT, false, true>, fpcr_controlled);
|
||||
}
|
||||
ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
|
||||
code.add(rsp, 8);
|
||||
code.jmp(*end, code.T_NEAR);
|
||||
|
|
Loading…
Add table
Reference in a new issue