IR: Implement VectorMultiply

This commit is contained in:
MerryMage 2018-02-11 10:18:29 +00:00
parent 90a053a5e4
commit b6de612e01
4 changed files with 92 additions and 0 deletions

View file

@ -4,6 +4,7 @@
* General Public License version 2 or any later version. * General Public License version 2 or any later version.
*/ */
#include "backend_x64/abi.h"
#include "backend_x64/block_of_code.h" #include "backend_x64/block_of_code.h"
#include "backend_x64/emit_x64.h" #include "backend_x64/emit_x64.h"
#include "common/assert.h" #include "common/assert.h"
@ -28,6 +29,31 @@ static void EmitVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* i
ctx.reg_alloc.DefineValue(inst, xmm_a); ctx.reg_alloc.DefineValue(inst, xmm_a);
} }
template <typename Lambda>
static void EmitTwoArgumentFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lambda lambda) {
const auto fn = +lambda; // Force decay of lambda to function pointer
constexpr u32 stack_space = 3 * 16;
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm arg1 = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm arg2 = ctx.reg_alloc.UseXmm(args[1]);
ctx.reg_alloc.EndOfAllocScope();
ctx.reg_alloc.HostCall(nullptr);
code.sub(rsp, stack_space + ABI_SHADOW_SPACE);
code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]);
code.lea(code.ABI_PARAM2, ptr[rsp + ABI_SHADOW_SPACE + 1 * 16]);
code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 2 * 16]);
code.movaps(xword[code.ABI_PARAM2], arg1);
code.movaps(xword[code.ABI_PARAM3], arg2);
code.CallFunction(+fn);
code.movaps(xmm0, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]);
code.add(rsp, stack_space + ABI_SHADOW_SPACE);
ctx.reg_alloc.DefineValue(inst, xmm0);
}
void EmitX64::EmitVectorGetElement8(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitVectorGetElement8(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
ASSERT(args[1].IsImmediate()); ASSERT(args[1].IsImmediate());
@ -575,6 +601,52 @@ void EmitX64::EmitVectorLogicalShiftRight64(EmitContext& ctx, IR::Inst* inst) {
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
} }
void EmitX64::EmitVectorMultiply8(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
Xbyak::Xmm b = ctx.reg_alloc.UseScratchXmm(args[1]);
Xbyak::Xmm tmp_a = ctx.reg_alloc.ScratchXmm();
Xbyak::Xmm tmp_b = ctx.reg_alloc.ScratchXmm();
// TODO: Optimize
code.movdqa(tmp_a, a);
code.movdqa(tmp_b, b);
code.pmullw(a, b);
code.psrlw(tmp_a, 8);
code.psrlw(tmp_b, 8);
code.pmullw(tmp_a, tmp_b);
code.pand(a, code.MConst(0x00FF00FF00FF00FF, 0x00FF00FF00FF00FF));
code.psllw(tmp_a, 8);
code.por(a, tmp_a);
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorMultiply16(EmitContext& ctx, IR::Inst* inst) {
EmitVectorOperation(code, ctx, inst, &Xbyak::CodeGenerator::pmullw);
}
void EmitX64::EmitVectorMultiply32(EmitContext& ctx, IR::Inst* inst) {
if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) {
EmitVectorOperation(code, ctx, inst, &Xbyak::CodeGenerator::pmulld);
return;
}
EmitTwoArgumentFallback(code, ctx, inst, [](std::array<u32, 4>& result, const std::array<u32, 4>& a, const std::array<u32, 4>& b){
for (size_t i = 0; i < 4; ++i) {
result[i] = a[i] * b[i];
}
});
}
void EmitX64::EmitVectorMultiply64(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](std::array<u64, 2>& result, const std::array<u64, 2>& a, const std::array<u64, 2>& b){
for (size_t i = 0; i < 2; ++i) {
result[i] = a[i] * b[i];
}
});
}
void EmitX64::EmitVectorNarrow16(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitVectorNarrow16(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]); Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);

View file

@ -913,6 +913,21 @@ U128 IREmitter::VectorLogicalShiftRight(size_t esize, const U128& a, u8 shift_am
return {}; return {};
} }
U128 IREmitter::VectorMultiply(size_t esize, const U128& a, const U128& b) {
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorMultiply8, a, b);
case 16:
return Inst<U128>(Opcode::VectorMultiply16, a, b);
case 32:
return Inst<U128>(Opcode::VectorMultiply32, a, b);
case 64:
return Inst<U128>(Opcode::VectorMultiply64, a, b);
}
UNREACHABLE();
return {};
}
U128 IREmitter::VectorNarrow(size_t original_esize, const U128& a) { U128 IREmitter::VectorNarrow(size_t original_esize, const U128& a) {
switch (original_esize) { switch (original_esize) {
case 16: case 16:

View file

@ -217,6 +217,7 @@ public:
U128 VectorInterleaveLower(size_t esize, const U128& a, const U128& b); U128 VectorInterleaveLower(size_t esize, const U128& a, const U128& b);
U128 VectorLogicalShiftLeft(size_t esize, const U128& a, u8 shift_amount); U128 VectorLogicalShiftLeft(size_t esize, const U128& a, u8 shift_amount);
U128 VectorLogicalShiftRight(size_t esize, const U128& a, u8 shift_amount); U128 VectorLogicalShiftRight(size_t esize, const U128& a, u8 shift_amount);
U128 VectorMultiply(size_t esize, const U128& a, const U128& b);
U128 VectorNarrow(size_t original_esize, const U128& a); U128 VectorNarrow(size_t original_esize, const U128& a);
U128 VectorNot(const U128& a); U128 VectorNot(const U128& a);
U128 VectorOr(const U128& a, const U128& b); U128 VectorOr(const U128& a, const U128& b);

View file

@ -232,6 +232,10 @@ OPCODE(VectorLogicalShiftRight8, T::U128, T::U128, T::U8
OPCODE(VectorLogicalShiftRight16, T::U128, T::U128, T::U8 ) OPCODE(VectorLogicalShiftRight16, T::U128, T::U128, T::U8 )
OPCODE(VectorLogicalShiftRight32, T::U128, T::U128, T::U8 ) OPCODE(VectorLogicalShiftRight32, T::U128, T::U128, T::U8 )
OPCODE(VectorLogicalShiftRight64, T::U128, T::U128, T::U8 ) OPCODE(VectorLogicalShiftRight64, T::U128, T::U128, T::U8 )
OPCODE(VectorMultiply8, T::U128, T::U128, T::U128 )
OPCODE(VectorMultiply16, T::U128, T::U128, T::U128 )
OPCODE(VectorMultiply32, T::U128, T::U128, T::U128 )
OPCODE(VectorMultiply64, T::U128, T::U128, T::U128 )
OPCODE(VectorNarrow16, T::U128, T::U128 ) OPCODE(VectorNarrow16, T::U128, T::U128 )
OPCODE(VectorNarrow32, T::U128, T::U128 ) OPCODE(VectorNarrow32, T::U128, T::U128 )
OPCODE(VectorNarrow64, T::U128, T::U128 ) OPCODE(VectorNarrow64, T::U128, T::U128 )