IR: Implement FPVectorMulAdd

This commit is contained in:
MerryMage 2018-07-25 13:19:48 +01:00
parent 3218bb9890
commit 771a4fc20b
4 changed files with 207 additions and 31 deletions

View file

@ -41,24 +41,38 @@ static T ChooseOnFsize(T f32, T f64) {
#define FCODE(NAME) (code.*ChooseOnFsize<fsize>(&Xbyak::CodeGenerator::NAME##s, &Xbyak::CodeGenerator::NAME##d)) #define FCODE(NAME) (code.*ChooseOnFsize<fsize>(&Xbyak::CodeGenerator::NAME##s, &Xbyak::CodeGenerator::NAME##d))
template<typename T, template<typename> class Indexer, size_t... argi> template<size_t fsize, template<typename> class Indexer, size_t narg>
static auto GetRuntimeNaNFunction(std::index_sequence<argi...>) { struct NaNHandler {
auto result = [](std::array<VectorArray<T>, sizeof...(argi) + 1>& values) { private:
VectorArray<T>& result = values[0]; template<size_t... argi>
static auto GetDefaultImpl(std::index_sequence<argi...>) {
using FPT = mp::unsigned_integer_of_size<fsize>;
auto result = [](std::array<VectorArray<FPT>, sizeof...(argi) + 1>& values) {
VectorArray<FPT>& result = values[0];
for (size_t elementi = 0; elementi < result.size(); ++elementi) { for (size_t elementi = 0; elementi < result.size(); ++elementi) {
const auto current_values = Indexer<T>{}(elementi, values[argi + 1]...); const auto current_values = Indexer<FPT>{}(elementi, values[argi + 1]...);
if (auto r = FP::ProcessNaNs(std::get<argi>(current_values)...)) { if (auto r = FP::ProcessNaNs(std::get<argi>(current_values)...)) {
result[elementi] = *r; result[elementi] = *r;
} else if (FP::IsNaN(result[elementi])) { } else if (FP::IsNaN(result[elementi])) {
result[elementi] = FP::FPInfo<T>::DefaultNaN(); result[elementi] = FP::FPInfo<FPT>::DefaultNaN();
} }
} }
}; };
return static_cast<mp::equivalent_function_type_t<decltype(result)>*>(result); return static_cast<mp::equivalent_function_type_t<decltype(result)>*>(result);
} }
template<size_t fsize, size_t nargs, template<typename> class Indexer> public:
static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array<Xbyak::Xmm, nargs + 1> xmms, const Xbyak::Xmm& nan_mask) { static auto GetDefault() {
return GetDefaultImpl(std::make_index_sequence<narg - 1>{});
}
using function_type = mp::return_type_t<decltype(GetDefault)>;
};
template<size_t fsize, size_t nargs, typename NaNHandler>
static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array<Xbyak::Xmm, nargs + 1> xmms, const Xbyak::Xmm& nan_mask, NaNHandler nan_handler) {
static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64");
if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) {
@ -91,8 +105,7 @@ static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array<Xbyak::Xm
} }
code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]); code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]);
using T = mp::unsigned_integer_of_size<fsize>; code.CallFunction(nan_handler);
code.CallFunction(GetRuntimeNaNFunction<T, Indexer>(std::make_index_sequence<nargs>{}));
code.movaps(result, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]); code.movaps(result, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]);
code.add(rsp, stack_space + ABI_SHADOW_SPACE); code.add(rsp, stack_space + ABI_SHADOW_SPACE);
@ -155,13 +168,13 @@ struct PairedLowerIndexer {
}; };
template<size_t fsize, template<typename> class Indexer, typename Function> template<size_t fsize, template<typename> class Indexer, typename Function>
static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) { static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn, typename NaNHandler<fsize, Indexer, 3>::function_type nan_handler = NaNHandler<fsize, Indexer, 3>::GetDefault()) {
static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64");
if (!ctx.AccurateNaN() || ctx.FPSCR_DN()) { if (!ctx.AccurateNaN() || ctx.FPSCR_DN()) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
Xbyak::Xmm xmm_a = ctx.reg_alloc.UseScratchXmm(args[0]); const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseScratchXmm(args[0]);
Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
if constexpr (std::is_member_function_pointer_v<Function>) { if constexpr (std::is_member_function_pointer_v<Function>) {
(code.*fn)(xmm_a, xmm_b); (code.*fn)(xmm_a, xmm_b);
@ -170,8 +183,8 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::
} }
if (ctx.FPSCR_DN()) { if (ctx.FPSCR_DN()) {
Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm();
Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
code.pcmpeqw(tmp, tmp); code.pcmpeqw(tmp, tmp);
code.movaps(nan_mask, xmm_a); code.movaps(nan_mask, xmm_a);
FCODE(cmpordp)(nan_mask, nan_mask); FCODE(cmpordp)(nan_mask, nan_mask);
@ -187,10 +200,10 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]);
Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm();
code.movaps(nan_mask, xmm_b); code.movaps(nan_mask, xmm_b);
code.movaps(result, xmm_a); code.movaps(result, xmm_a);
@ -202,7 +215,63 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::
} }
FCODE(cmpunordp)(nan_mask, result); FCODE(cmpunordp)(nan_mask, result);
HandleNaNs<fsize, 2, Indexer>(code, ctx, {result, xmm_a, xmm_b}, nan_mask); HandleNaNs<fsize, 2>(code, ctx, {result, xmm_a, xmm_b}, nan_mask, nan_handler);
ctx.reg_alloc.DefineValue(inst, result);
}
template<size_t fsize, template<typename> class Indexer, typename Function>
static void EmitFourOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn, typename NaNHandler<fsize, Indexer, 4>::function_type nan_handler = NaNHandler<fsize, Indexer, 4>::GetDefault()) {
static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64");
if (!ctx.AccurateNaN() || ctx.FPSCR_DN()) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]);
if constexpr (std::is_member_function_pointer_v<Function>) {
(code.*fn)(xmm_a, xmm_b, xmm_c);
} else {
fn(xmm_a, xmm_b, xmm_c);
}
if (ctx.FPSCR_DN()) {
const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm();
const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
code.pcmpeqw(tmp, tmp);
code.movaps(nan_mask, xmm_a);
FCODE(cmpordp)(nan_mask, nan_mask);
code.andps(xmm_a, nan_mask);
code.xorps(nan_mask, tmp);
code.andps(nan_mask, fsize == 32 ? code.MConst(xword, 0x7fc0'0000'7fc0'0000, 0x7fc0'0000'7fc0'0000) : code.MConst(xword, 0x7ff8'0000'0000'0000, 0x7ff8'0000'0000'0000));
code.orps(xmm_a, nan_mask);
}
ctx.reg_alloc.DefineValue(inst, xmm_a);
return;
}
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]);
const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm();
code.movaps(nan_mask, xmm_b);
code.movaps(result, xmm_a);
FCODE(cmpunordp)(nan_mask, xmm_a);
FCODE(cmpunordp)(nan_mask, xmm_c);
if constexpr (std::is_member_function_pointer_v<Function>) {
(code.*fn)(result, xmm_b, xmm_c);
} else {
fn(result, xmm_b, xmm_c);
}
FCODE(cmpunordp)(nan_mask, result);
HandleNaNs<fsize, 3>(code, ctx, {result, xmm_a, xmm_b, xmm_c}, nan_mask, nan_handler);
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
} }
@ -250,7 +319,7 @@ inline void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* i
code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 3 * 16]); code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 3 * 16]);
code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR()); code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR());
code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
code.mov(qword[rsp + ABI_SHADOW_SPACE + 0 * 16], rax); code.mov(qword[rsp + ABI_SHADOW_SPACE + 0], rax);
#else #else
constexpr u32 stack_space = 3 * 16; constexpr u32 stack_space = 3 * 16;
code.sub(rsp, stack_space + ABI_SHADOW_SPACE); code.sub(rsp, stack_space + ABI_SHADOW_SPACE);
@ -276,6 +345,54 @@ inline void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* i
ctx.reg_alloc.DefineValue(inst, xmm0); ctx.reg_alloc.DefineValue(inst, xmm0);
} }
template<typename Lambda>
inline void EmitFourOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lambda lambda) {
const auto fn = static_cast<mp::equivalent_function_type_t<Lambda>*>(lambda);
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm arg1 = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm arg2 = ctx.reg_alloc.UseXmm(args[1]);
const Xbyak::Xmm arg3 = ctx.reg_alloc.UseXmm(args[2]);
ctx.reg_alloc.EndOfAllocScope();
ctx.reg_alloc.HostCall(nullptr);
#ifdef _WIN32
constexpr u32 stack_space = 5 * 16;
code.sub(rsp, stack_space + ABI_SHADOW_SPACE);
code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 1 * 16]);
code.lea(code.ABI_PARAM2, ptr[rsp + ABI_SHADOW_SPACE + 2 * 16]);
code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 3 * 16]);
code.lea(code.ABI_PARAM4, ptr[rsp + ABI_SHADOW_SPACE + 4 * 16]);
code.mov(qword[rsp + ABI_SHADOW_SPACE + 0], ctx.FPCR());
code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
code.mov(qword[rsp + ABI_SHADOW_SPACE + 8], rax);
#else
constexpr u32 stack_space = 4 * 16;
code.sub(rsp, stack_space + ABI_SHADOW_SPACE);
code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]);
code.lea(code.ABI_PARAM2, ptr[rsp + ABI_SHADOW_SPACE + 1 * 16]);
code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 2 * 16]);
code.lea(code.ABI_PARAM4, ptr[rsp + ABI_SHADOW_SPACE + 3 * 16]);
code.mov(code.ABI_PARAM5.cvt32(), ctx.FPCR());
code.lea(code.ABI_PARAM6, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
#endif
code.movaps(xword[code.ABI_PARAM2], arg1);
code.movaps(xword[code.ABI_PARAM3], arg2);
code.movaps(xword[code.ABI_PARAM4], arg3);
code.CallFunction(fn);
#ifdef _WIN32
code.movaps(xmm0, xword[rsp + ABI_SHADOW_SPACE + 1 * 16]);
#else
code.movaps(xmm0, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]);
#endif
code.add(rsp, stack_space + ABI_SHADOW_SPACE);
ctx.reg_alloc.DefineValue(inst, xmm0);
}
void EmitX64::EmitFPVectorAbs16(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorAbs16(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
@ -393,6 +510,51 @@ void EmitX64::EmitFPVectorMul64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd); EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd);
} }
template<size_t fsize>
void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
using FPT = mp::unsigned_integer_of_size<fsize>;
if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA)) {
auto x64_instruction = fsize == 32 ? &Xbyak::CodeGenerator::vfmadd231ps : &Xbyak::CodeGenerator::vfmadd231pd;
EmitFourOpVectorOperation<fsize, DefaultIndexer>(code, ctx, inst, x64_instruction,
static_cast<void(*)(std::array<VectorArray<FPT>, 4>& values)>(
[](std::array<VectorArray<FPT>, 4>& values) {
VectorArray<FPT>& result = values[0];
const VectorArray<FPT>& a = values[1];
const VectorArray<FPT>& b = values[2];
const VectorArray<FPT>& c = values[3];
for (size_t i = 0; i < result.size(); i++) {
if (FP::IsQNaN(a[i]) && ((FP::IsInf(b[i]) && FP::IsZero(c[i])) || (FP::IsZero(b[i]) && FP::IsInf(c[i])))) {
result[i] = FP::FPInfo<FPT>::DefaultNaN();
} else if (auto r = FP::ProcessNaNs(a[i], b[i], c[i])) {
result[i] = *r;
} else if (FP::IsNaN(result[i])) {
result[i] = FP::FPInfo<FPT>::DefaultNaN();
}
}
}
)
);
return;
}
EmitFourOpFallback(code, ctx, inst,
[](VectorArray<FPT>& result, const VectorArray<FPT>& addend, const VectorArray<FPT>& op1, const VectorArray<FPT>& op2, FP::FPCR fpcr, FP::FPSR& fpsr) {
for (size_t i = 0; i < result.size(); i++) {
result[i] = FP::FPMulAdd<FPT>(addend[i], op1[i], op2[i], fpcr, fpsr);
}
}
);
}
void EmitX64::EmitFPVectorMulAdd32(EmitContext& ctx, IR::Inst* inst) {
EmitFPVectorMulAdd<32>(code, ctx, inst);
}
void EmitX64::EmitFPVectorMulAdd64(EmitContext& ctx, IR::Inst* inst) {
EmitFPVectorMulAdd<64>(code, ctx, inst);
}
void EmitX64::EmitFPVectorPairedAdd32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorPairedAdd32(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<32, PairedIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::haddps); EmitThreeOpVectorOperation<32, PairedIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::haddps);
} }

View file

@ -1696,6 +1696,17 @@ U128 IREmitter::FPVectorMul(size_t esize, const U128& a, const U128& b) {
return {}; return {};
} }
U128 IREmitter::FPVectorMulAdd(size_t esize, const U128& a, const U128& b, const U128& c) {
switch (esize) {
case 32:
return Inst<U128>(Opcode::FPVectorMulAdd32, a, b, c);
case 64:
return Inst<U128>(Opcode::FPVectorMulAdd64, a, b, c);
}
UNREACHABLE();
return {};
}
U128 IREmitter::FPVectorPairedAdd(size_t esize, const U128& a, const U128& b) { U128 IREmitter::FPVectorPairedAdd(size_t esize, const U128& a, const U128& b) {
switch (esize) { switch (esize) {
case 32: case 32:

View file

@ -267,7 +267,7 @@ public:
U32U64 FPMin(const U32U64& a, const U32U64& b, bool fpscr_controlled); U32U64 FPMin(const U32U64& a, const U32U64& b, bool fpscr_controlled);
U32U64 FPMinNumeric(const U32U64& a, const U32U64& b, bool fpscr_controlled); U32U64 FPMinNumeric(const U32U64& a, const U32U64& b, bool fpscr_controlled);
U32U64 FPMul(const U32U64& a, const U32U64& b, bool fpscr_controlled); U32U64 FPMul(const U32U64& a, const U32U64& b, bool fpscr_controlled);
U32U64 FPMulAdd(const U32U64& a, const U32U64& b, const U32U64& c, bool fpscr_controlled); U32U64 FPMulAdd(const U32U64& addend, const U32U64& op1, const U32U64& op2, bool fpscr_controlled);
U32U64 FPNeg(const U32U64& a); U32U64 FPNeg(const U32U64& a);
U32U64 FPRoundInt(const U32U64& a, FP::RoundingMode rounding, bool exact); U32U64 FPRoundInt(const U32U64& a, FP::RoundingMode rounding, bool exact);
U32U64 FPRSqrtEstimate(const U32U64& a); U32U64 FPRSqrtEstimate(const U32U64& a);
@ -300,6 +300,7 @@ public:
U128 FPVectorGreater(size_t esize, const U128& a, const U128& b); U128 FPVectorGreater(size_t esize, const U128& a, const U128& b);
U128 FPVectorGreaterEqual(size_t esize, const U128& a, const U128& b); U128 FPVectorGreaterEqual(size_t esize, const U128& a, const U128& b);
U128 FPVectorMul(size_t esize, const U128& a, const U128& b); U128 FPVectorMul(size_t esize, const U128& a, const U128& b);
U128 FPVectorMulAdd(size_t esize, const U128& addend, const U128& op1, const U128& op2);
U128 FPVectorPairedAdd(size_t esize, const U128& a, const U128& b); U128 FPVectorPairedAdd(size_t esize, const U128& a, const U128& b);
U128 FPVectorPairedAddLower(size_t esize, const U128& a, const U128& b); U128 FPVectorPairedAddLower(size_t esize, const U128& a, const U128& b);
U128 FPVectorRSqrtEstimate(size_t esize, const U128& a); U128 FPVectorRSqrtEstimate(size_t esize, const U128& a);

View file

@ -441,6 +441,8 @@ OPCODE(FPVectorGreaterEqual32, T::U128, T::U128,
OPCODE(FPVectorGreaterEqual64, T::U128, T::U128, T::U128 ) OPCODE(FPVectorGreaterEqual64, T::U128, T::U128, T::U128 )
OPCODE(FPVectorMul32, T::U128, T::U128, T::U128 ) OPCODE(FPVectorMul32, T::U128, T::U128, T::U128 )
OPCODE(FPVectorMul64, T::U128, T::U128, T::U128 ) OPCODE(FPVectorMul64, T::U128, T::U128, T::U128 )
OPCODE(FPVectorMulAdd32, T::U128, T::U128, T::U128, T::U128 )
OPCODE(FPVectorMulAdd64, T::U128, T::U128, T::U128, T::U128 )
OPCODE(FPVectorPairedAddLower32, T::U128, T::U128, T::U128 ) OPCODE(FPVectorPairedAddLower32, T::U128, T::U128, T::U128 )
OPCODE(FPVectorPairedAddLower64, T::U128, T::U128, T::U128 ) OPCODE(FPVectorPairedAddLower64, T::U128, T::U128, T::U128 )
OPCODE(FPVectorPairedAdd32, T::U128, T::U128, T::U128 ) OPCODE(FPVectorPairedAdd32, T::U128, T::U128, T::U128 )