IR: Introduce VectorReduceAdd{8,16,32,64} opcode

Adds all elements of vector and puts the result into the lowest element.
Accelerates the `addv` instruction into a vectorized implementation
rather than a serial one.
This commit is contained in:
Wunkolo 2021-09-10 18:58:17 -07:00 committed by merry
parent 69b831d7d2
commit 5e7d2afe0f
5 changed files with 113 additions and 18 deletions

View file

@ -2990,6 +2990,98 @@ void EmitX64::EmitVectorReverseBits(EmitContext& ctx, IR::Inst* inst) {
ctx.reg_alloc.DefineValue(inst, data);
}
void EmitX64::EmitVectorReduceAdd8(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm temp = xmm0;
// Add upper elements to lower elements
code.pshufd(temp, data, 0b01'00'11'10);
code.paddb(data, temp);
// Add adjacent 8-bit values into 64-bit lanes
code.pxor(temp, temp);
code.psadbw(data, temp);
// Zero-extend lower 8-bits
code.pslldq(data, 15);
code.psrldq(data, 15);
ctx.reg_alloc.DefineValue(inst, data);
}
void EmitX64::EmitVectorReduceAdd16(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm temp = xmm0;
if (code.HasHostFeature(HostFeature::SSSE3)) {
code.pxor(temp, temp);
code.phaddw(data, xmm0);
code.phaddw(data, xmm0);
code.phaddw(data, xmm0);
} else {
// Add upper elements to lower elements
code.pshufd(temp, data, 0b00'01'10'11);
code.paddw(data, temp);
// Add pairs of 16-bit values into 32-bit lanes
code.movdqa(temp, code.MConst(xword, 0x0001000100010001, 0x0001000100010001));
code.pmaddwd(data, temp);
// Sum adjacent 32-bit lanes
code.pshufd(temp, data, 0b10'11'00'01);
code.paddd(data, temp);
// Zero-extend lower 16-bits
code.pslldq(data, 14);
code.psrldq(data, 14);
}
ctx.reg_alloc.DefineValue(inst, data);
}
void EmitX64::EmitVectorReduceAdd32(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm temp = xmm0;
// Add upper elements to lower elements(reversed)
code.pshufd(temp, data, 0b00'01'10'11);
code.paddd(data, temp);
// Sum adjacent 32-bit lanes
if (code.HasHostFeature(HostFeature::SSSE3)) {
code.phaddd(data, data);
} else {
code.pshufd(temp, data, 0b10'11'00'01);
code.paddd(data, temp);
}
// shift upper-most result into lower-most lane
code.psrldq(data, 12);
ctx.reg_alloc.DefineValue(inst, data);
}
void EmitX64::EmitVectorReduceAdd64(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm temp = xmm0;
// Add upper elements to lower elements
code.pshufd(temp, data, 0b01'00'11'10);
code.paddq(data, temp);
// Zero-extend lower 64-bits
code.movq(data, data);
ctx.reg_alloc.DefineValue(inst, data);
}
static void EmitVectorRoundingHalvingAddSigned(size_t esize, EmitContext& ctx, IR::Inst* inst, BlockOfCode& code) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);

View file

@ -171,27 +171,10 @@ bool TranslatorVisitor::ADDV(bool Q, Imm<2> size, Vec Vn, Vec Vd) {
const size_t esize = 8 << size.ZeroExtend();
const size_t datasize = Q ? 128 : 64;
const size_t elements = datasize / esize;
const IR::U128 operand = V(datasize, Vn);
const auto get_element = [&](IR::U128 vec, size_t element) {
return ir.ZeroExtendToWord(ir.VectorGetElement(esize, vec, element));
};
IR::U32 sum = get_element(operand, 0);
for (size_t i = 1; i < elements; i++) {
sum = ir.Add(sum, get_element(operand, i));
}
if (size == 0b00) {
V(datasize, Vd, ir.ZeroExtendToQuad(ir.LeastSignificantByte(sum)));
} else if (size == 0b01) {
V(datasize, Vd, ir.ZeroExtendToQuad(ir.LeastSignificantHalf(sum)));
} else {
V(datasize, Vd, ir.ZeroExtendToQuad(sum));
}
V(128, Vd, ir.VectorReduceAdd(esize, operand));
return true;
}

View file

@ -1526,6 +1526,21 @@ U128 IREmitter::VectorReverseBits(const U128& a) {
return Inst<U128>(Opcode::VectorReverseBits, a);
}
U128 IREmitter::VectorReduceAdd(size_t esize, const U128& a) {
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorReduceAdd8, a);
case 16:
return Inst<U128>(Opcode::VectorReduceAdd16, a);
case 32:
return Inst<U128>(Opcode::VectorReduceAdd32, a);
case 64:
return Inst<U128>(Opcode::VectorReduceAdd64, a);
}
UNREACHABLE();
}
U128 IREmitter::VectorRotateLeft(size_t esize, const U128& a, u8 amount) {
ASSERT(amount < esize);

View file

@ -294,6 +294,7 @@ public:
U128 VectorPolynomialMultiplyLong(size_t esize, const U128& a, const U128& b);
U128 VectorPopulationCount(const U128& a);
U128 VectorReverseBits(const U128& a);
U128 VectorReduceAdd(size_t esize, const U128& a);
U128 VectorRotateLeft(size_t esize, const U128& a, u8 amount);
U128 VectorRotateRight(size_t esize, const U128& a, u8 amount);
U128 VectorRoundingHalvingAddSigned(size_t esize, const U128& a, const U128& b);

View file

@ -431,6 +431,10 @@ OPCODE(VectorPolynomialMultiplyLong8, U128, U128
OPCODE(VectorPolynomialMultiplyLong64, U128, U128, U128 )
OPCODE(VectorPopulationCount, U128, U128 )
OPCODE(VectorReverseBits, U128, U128 )
OPCODE(VectorReduceAdd8, U128, U128 )
OPCODE(VectorReduceAdd16, U128, U128 )
OPCODE(VectorReduceAdd32, U128, U128 )
OPCODE(VectorReduceAdd64, U128, U128 )
OPCODE(VectorRoundingHalvingAddS8, U128, U128, U128 )
OPCODE(VectorRoundingHalvingAddS16, U128, U128, U128 )
OPCODE(VectorRoundingHalvingAddS32, U128, U128, U128 )