emit_x64_floating_point: Simplify indexers

This commit is contained in:
MerryMage 2018-07-25 12:05:41 +01:00
parent 25b28bb234
commit 8f72be0a02

View file

@ -5,12 +5,14 @@
*/ */
#include <array> #include <array>
#include <tuple>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include "backend_x64/abi.h" #include "backend_x64/abi.h"
#include "backend_x64/block_of_code.h" #include "backend_x64/block_of_code.h"
#include "backend_x64/emit_x64.h" #include "backend_x64/emit_x64.h"
#include "common/assert.h"
#include "common/bit_util.h" #include "common/bit_util.h"
#include "common/fp/fpcr.h" #include "common/fp/fpcr.h"
#include "common/fp/info.h" #include "common/fp/info.h"
@ -39,12 +41,12 @@ static T ChooseOnFsize(T f32, T f64) {
#define FCODE(NAME) (code.*ChooseOnFsize<fsize>(&Xbyak::CodeGenerator::NAME##s, &Xbyak::CodeGenerator::NAME##d)) #define FCODE(NAME) (code.*ChooseOnFsize<fsize>(&Xbyak::CodeGenerator::NAME##s, &Xbyak::CodeGenerator::NAME##d))
template<typename T, auto IndexFunction, size_t... argi> template<typename T, template<typename> class Indexer, size_t... argi>
static auto GetRuntimeNaNFunction(std::index_sequence<argi...>) { static auto GetRuntimeNaNFunction(std::index_sequence<argi...>) {
auto result = [](std::array<VectorArray<T>, sizeof...(argi) + 1>& values) { auto result = [](std::array<VectorArray<T>, sizeof...(argi) + 1>& values) {
VectorArray<T>& result = values[0]; VectorArray<T>& result = values[0];
for (size_t elementi = 0; elementi < result.size(); ++elementi) { for (size_t elementi = 0; elementi < result.size(); ++elementi) {
const auto current_values = IndexFunction(elementi, values[argi + 1]...); const auto current_values = Indexer<T>{}(elementi, values[argi + 1]...);
if (auto r = FP::ProcessNaNs(std::get<argi>(current_values)...)) { if (auto r = FP::ProcessNaNs(std::get<argi>(current_values)...)) {
result[elementi] = *r; result[elementi] = *r;
} else if (FP::IsNaN(result[elementi])) { } else if (FP::IsNaN(result[elementi])) {
@ -55,7 +57,7 @@ static auto GetRuntimeNaNFunction(std::index_sequence<argi...>) {
return static_cast<mp::equivalent_function_type_t<decltype(result)>*>(result); return static_cast<mp::equivalent_function_type_t<decltype(result)>*>(result);
} }
template <size_t fsize, size_t nargs, auto IndexFunction> template<size_t fsize, size_t nargs, template<typename> class Indexer>
static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array<Xbyak::Xmm, nargs + 1> xmms, const Xbyak::Xmm& nan_mask) { static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array<Xbyak::Xmm, nargs + 1> xmms, const Xbyak::Xmm& nan_mask) {
static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64");
@ -90,7 +92,7 @@ static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array<Xbyak::Xm
code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]); code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]);
using T = mp::unsigned_integer_of_size<fsize>; using T = mp::unsigned_integer_of_size<fsize>;
code.CallFunction(GetRuntimeNaNFunction<T, IndexFunction>(std::make_index_sequence<nargs>{})); code.CallFunction(GetRuntimeNaNFunction<T, Indexer>(std::make_index_sequence<nargs>{}));
code.movaps(result, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]); code.movaps(result, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]);
code.add(rsp, stack_space + ABI_SHADOW_SPACE); code.add(rsp, stack_space + ABI_SHADOW_SPACE);
@ -100,41 +102,59 @@ static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array<Xbyak::Xm
code.SwitchToNearCode(); code.SwitchToNearCode();
} }
static std::tuple<u32, u32> DefaultIndexFunction32(size_t i, const VectorArray<u32>& a, const VectorArray<u32>& b) { template<typename T>
struct DefaultIndexer {
std::tuple<T, T> operator()(size_t i, const VectorArray<T>& a, const VectorArray<T>& b) {
return std::make_tuple(a[i], b[i]); return std::make_tuple(a[i], b[i]);
} }
static std::tuple<u64, u64> DefaultIndexFunction64(size_t i, const VectorArray<u64>& a, const VectorArray<u64>& b) { std::tuple<T, T, T> operator()(size_t i, const VectorArray<T>& a, const VectorArray<T>& b, const VectorArray<T>& c) {
return std::make_tuple(a[i], b[i]); return std::make_tuple(a[i], b[i], c[i]);
} }
};
static std::tuple<u32, u32> PairedIndexFunction32(size_t i, const VectorArray<u32>& a, const VectorArray<u32>& b) { template<typename T>
if (i < 2) { struct PairedIndexer {
std::tuple<T, T> operator()(size_t i, const VectorArray<T>& a, const VectorArray<T>& b) {
constexpr size_t halfway = std::tuple_size_v<VectorArray<T>> / 2;
const size_t which_array = i / halfway;
i %= halfway;
switch (which_array) {
case 0:
return std::make_tuple(a[2 * i], a[2 * i + 1]); return std::make_tuple(a[2 * i], a[2 * i + 1]);
case 1:
return std::make_tuple(b[2 * i], b[2 * i + 1]);
} }
return std::make_tuple(b[2 * (i - 2)], b[2 * (i - 2) + 1]); UNREACHABLE();
return {};
} }
};
static std::tuple<u64, u64> PairedIndexFunction64(size_t i, const VectorArray<u64>& a, const VectorArray<u64>& b) { template<typename T>
return i == 0 ? std::make_tuple(a[0], a[1]) : std::make_tuple(b[0], b[1]); struct PairedLowerIndexer {
} std::tuple<T, T> operator()(size_t i, const VectorArray<T>& a, const VectorArray<T>& b) {
constexpr size_t array_size = std::tuple_size_v<VectorArray<T>>;
static std::tuple<u32, u32> PairedLowerIndexFunction32(size_t i, const VectorArray<u32>& a, const VectorArray<u32>& b) { if constexpr (array_size == 4) {
switch (i) { switch (i) {
case 0: case 0:
return std::make_tuple(a[0], a[1]); return std::make_tuple(a[0], a[1]);
case 1: case 1:
return std::make_tuple(b[0], b[1]); return std::make_tuple(b[0], b[1]);
default: default:
return std::make_tuple(u32(0), u32(0)); return std::make_tuple(0, 0);
} }
} else if constexpr (array_size == 2) {
if (i == 0) {
return std::make_tuple(a[0], b[0]);
} }
return std::make_tuple(0, 0);
}
UNREACHABLE();
return {};
}
};
static std::tuple<u64, u64> PairedLowerIndexFunction64(size_t i, const VectorArray<u64>& a, const VectorArray<u64>& b) { template<size_t fsize, template<typename> class Indexer, typename Function>
return i == 0 ? std::make_tuple(a[0], b[0]) : std::make_tuple(u64(0), u64(0));
}
template <size_t fsize, auto IndexFunction, typename Function>
static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) { static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) {
static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64");
@ -182,7 +202,7 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::
} }
FCODE(cmpunordp)(nan_mask, result); FCODE(cmpunordp)(nan_mask, result);
HandleNaNs<fsize, 2, IndexFunction>(code, ctx, {result, xmm_a, xmm_b}, nan_mask); HandleNaNs<fsize, 2, Indexer>(code, ctx, {result, xmm_a, xmm_b}, nan_mask);
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
} }
@ -290,19 +310,19 @@ void EmitX64::EmitFPVectorAbs64(EmitContext& ctx, IR::Inst* inst) {
} }
void EmitX64::EmitFPVectorAdd32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorAdd32(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::addps); EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::addps);
} }
void EmitX64::EmitFPVectorAdd64(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorAdd64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::addpd); EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::addpd);
} }
void EmitX64::EmitFPVectorDiv32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorDiv32(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::divps); EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::divps);
} }
void EmitX64::EmitFPVectorDiv64(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorDiv64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::divpd); EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::divpd);
} }
void EmitX64::EmitFPVectorEqual32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorEqual32(EmitContext& ctx, IR::Inst* inst) {
@ -366,23 +386,23 @@ void EmitX64::EmitFPVectorGreaterEqual64(EmitContext& ctx, IR::Inst* inst) {
} }
void EmitX64::EmitFPVectorMul32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorMul32(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::mulps); EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulps);
} }
void EmitX64::EmitFPVectorMul64(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorMul64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd); EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd);
} }
void EmitX64::EmitFPVectorPairedAdd32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorPairedAdd32(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<32, PairedIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::haddps); EmitThreeOpVectorOperation<32, PairedIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::haddps);
} }
void EmitX64::EmitFPVectorPairedAdd64(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorPairedAdd64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, PairedIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::haddpd); EmitThreeOpVectorOperation<64, PairedIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::haddpd);
} }
void EmitX64::EmitFPVectorPairedAddLower32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorPairedAddLower32(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<32, PairedLowerIndexFunction32>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) { EmitThreeOpVectorOperation<32, PairedLowerIndexer>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) {
const Xbyak::Xmm zero = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm zero = ctx.reg_alloc.ScratchXmm();
code.xorps(zero, zero); code.xorps(zero, zero);
code.punpcklqdq(result, xmm_b); code.punpcklqdq(result, xmm_b);
@ -391,7 +411,7 @@ void EmitX64::EmitFPVectorPairedAddLower32(EmitContext& ctx, IR::Inst* inst) {
} }
void EmitX64::EmitFPVectorPairedAddLower64(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorPairedAddLower64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, PairedLowerIndexFunction64>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) { EmitThreeOpVectorOperation<64, PairedLowerIndexer>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) {
const Xbyak::Xmm zero = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm zero = ctx.reg_alloc.ScratchXmm();
code.xorps(zero, zero); code.xorps(zero, zero);
code.punpcklqdq(result, xmm_b); code.punpcklqdq(result, xmm_b);
@ -484,11 +504,11 @@ void EmitX64::EmitFPVectorS64ToDouble(EmitContext& ctx, IR::Inst* inst) {
} }
void EmitX64::EmitFPVectorSub32(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorSub32(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::subps); EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::subps);
} }
void EmitX64::EmitFPVectorSub64(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorSub64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::subpd); EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::subpd);
} }
void EmitX64::EmitFPVectorU32ToSingle(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitFPVectorU32ToSingle(EmitContext& ctx, IR::Inst* inst) {