IR: Introduce VectorReduceAdd{8,16,32,64}
opcode
Adds all elements of vector and puts the result into the lowest element. Accelerates the `addv` instruction into a vectorized implementation rather than a serial one.
This commit is contained in:
parent
69b831d7d2
commit
5e7d2afe0f
5 changed files with 113 additions and 18 deletions
|
@ -2990,6 +2990,98 @@ void EmitX64::EmitVectorReverseBits(EmitContext& ctx, IR::Inst* inst) {
|
|||
ctx.reg_alloc.DefineValue(inst, data);
|
||||
}
|
||||
|
||||
void EmitX64::EmitVectorReduceAdd8(EmitContext& ctx, IR::Inst* inst) {
|
||||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
|
||||
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
|
||||
const Xbyak::Xmm temp = xmm0;
|
||||
|
||||
// Add upper elements to lower elements
|
||||
code.pshufd(temp, data, 0b01'00'11'10);
|
||||
code.paddb(data, temp);
|
||||
|
||||
// Add adjacent 8-bit values into 64-bit lanes
|
||||
code.pxor(temp, temp);
|
||||
code.psadbw(data, temp);
|
||||
|
||||
// Zero-extend lower 8-bits
|
||||
code.pslldq(data, 15);
|
||||
code.psrldq(data, 15);
|
||||
|
||||
ctx.reg_alloc.DefineValue(inst, data);
|
||||
}
|
||||
|
||||
void EmitX64::EmitVectorReduceAdd16(EmitContext& ctx, IR::Inst* inst) {
|
||||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
|
||||
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
|
||||
const Xbyak::Xmm temp = xmm0;
|
||||
|
||||
if (code.HasHostFeature(HostFeature::SSSE3)) {
|
||||
code.pxor(temp, temp);
|
||||
code.phaddw(data, xmm0);
|
||||
code.phaddw(data, xmm0);
|
||||
code.phaddw(data, xmm0);
|
||||
} else {
|
||||
// Add upper elements to lower elements
|
||||
code.pshufd(temp, data, 0b00'01'10'11);
|
||||
code.paddw(data, temp);
|
||||
|
||||
// Add pairs of 16-bit values into 32-bit lanes
|
||||
code.movdqa(temp, code.MConst(xword, 0x0001000100010001, 0x0001000100010001));
|
||||
code.pmaddwd(data, temp);
|
||||
|
||||
// Sum adjacent 32-bit lanes
|
||||
code.pshufd(temp, data, 0b10'11'00'01);
|
||||
code.paddd(data, temp);
|
||||
// Zero-extend lower 16-bits
|
||||
code.pslldq(data, 14);
|
||||
code.psrldq(data, 14);
|
||||
}
|
||||
|
||||
ctx.reg_alloc.DefineValue(inst, data);
|
||||
}
|
||||
|
||||
void EmitX64::EmitVectorReduceAdd32(EmitContext& ctx, IR::Inst* inst) {
|
||||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
|
||||
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
|
||||
const Xbyak::Xmm temp = xmm0;
|
||||
|
||||
// Add upper elements to lower elements(reversed)
|
||||
code.pshufd(temp, data, 0b00'01'10'11);
|
||||
code.paddd(data, temp);
|
||||
|
||||
// Sum adjacent 32-bit lanes
|
||||
if (code.HasHostFeature(HostFeature::SSSE3)) {
|
||||
code.phaddd(data, data);
|
||||
} else {
|
||||
code.pshufd(temp, data, 0b10'11'00'01);
|
||||
code.paddd(data, temp);
|
||||
}
|
||||
|
||||
// shift upper-most result into lower-most lane
|
||||
code.psrldq(data, 12);
|
||||
|
||||
ctx.reg_alloc.DefineValue(inst, data);
|
||||
}
|
||||
|
||||
void EmitX64::EmitVectorReduceAdd64(EmitContext& ctx, IR::Inst* inst) {
|
||||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
|
||||
const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]);
|
||||
const Xbyak::Xmm temp = xmm0;
|
||||
|
||||
// Add upper elements to lower elements
|
||||
code.pshufd(temp, data, 0b01'00'11'10);
|
||||
code.paddq(data, temp);
|
||||
|
||||
// Zero-extend lower 64-bits
|
||||
code.movq(data, data);
|
||||
|
||||
ctx.reg_alloc.DefineValue(inst, data);
|
||||
}
|
||||
|
||||
static void EmitVectorRoundingHalvingAddSigned(size_t esize, EmitContext& ctx, IR::Inst* inst, BlockOfCode& code) {
|
||||
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
|
||||
|
||||
|
|
|
@ -171,27 +171,10 @@ bool TranslatorVisitor::ADDV(bool Q, Imm<2> size, Vec Vn, Vec Vd) {
|
|||
|
||||
const size_t esize = 8 << size.ZeroExtend();
|
||||
const size_t datasize = Q ? 128 : 64;
|
||||
const size_t elements = datasize / esize;
|
||||
|
||||
const IR::U128 operand = V(datasize, Vn);
|
||||
|
||||
const auto get_element = [&](IR::U128 vec, size_t element) {
|
||||
return ir.ZeroExtendToWord(ir.VectorGetElement(esize, vec, element));
|
||||
};
|
||||
|
||||
IR::U32 sum = get_element(operand, 0);
|
||||
for (size_t i = 1; i < elements; i++) {
|
||||
sum = ir.Add(sum, get_element(operand, i));
|
||||
}
|
||||
|
||||
if (size == 0b00) {
|
||||
V(datasize, Vd, ir.ZeroExtendToQuad(ir.LeastSignificantByte(sum)));
|
||||
} else if (size == 0b01) {
|
||||
V(datasize, Vd, ir.ZeroExtendToQuad(ir.LeastSignificantHalf(sum)));
|
||||
} else {
|
||||
V(datasize, Vd, ir.ZeroExtendToQuad(sum));
|
||||
}
|
||||
|
||||
V(128, Vd, ir.VectorReduceAdd(esize, operand));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -1526,6 +1526,21 @@ U128 IREmitter::VectorReverseBits(const U128& a) {
|
|||
return Inst<U128>(Opcode::VectorReverseBits, a);
|
||||
}
|
||||
|
||||
U128 IREmitter::VectorReduceAdd(size_t esize, const U128& a) {
|
||||
switch (esize) {
|
||||
case 8:
|
||||
return Inst<U128>(Opcode::VectorReduceAdd8, a);
|
||||
case 16:
|
||||
return Inst<U128>(Opcode::VectorReduceAdd16, a);
|
||||
case 32:
|
||||
return Inst<U128>(Opcode::VectorReduceAdd32, a);
|
||||
case 64:
|
||||
return Inst<U128>(Opcode::VectorReduceAdd64, a);
|
||||
}
|
||||
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
U128 IREmitter::VectorRotateLeft(size_t esize, const U128& a, u8 amount) {
|
||||
ASSERT(amount < esize);
|
||||
|
||||
|
|
|
@ -294,6 +294,7 @@ public:
|
|||
U128 VectorPolynomialMultiplyLong(size_t esize, const U128& a, const U128& b);
|
||||
U128 VectorPopulationCount(const U128& a);
|
||||
U128 VectorReverseBits(const U128& a);
|
||||
U128 VectorReduceAdd(size_t esize, const U128& a);
|
||||
U128 VectorRotateLeft(size_t esize, const U128& a, u8 amount);
|
||||
U128 VectorRotateRight(size_t esize, const U128& a, u8 amount);
|
||||
U128 VectorRoundingHalvingAddSigned(size_t esize, const U128& a, const U128& b);
|
||||
|
|
|
@ -431,6 +431,10 @@ OPCODE(VectorPolynomialMultiplyLong8, U128, U128
|
|||
OPCODE(VectorPolynomialMultiplyLong64, U128, U128, U128 )
|
||||
OPCODE(VectorPopulationCount, U128, U128 )
|
||||
OPCODE(VectorReverseBits, U128, U128 )
|
||||
OPCODE(VectorReduceAdd8, U128, U128 )
|
||||
OPCODE(VectorReduceAdd16, U128, U128 )
|
||||
OPCODE(VectorReduceAdd32, U128, U128 )
|
||||
OPCODE(VectorReduceAdd64, U128, U128 )
|
||||
OPCODE(VectorRoundingHalvingAddS8, U128, U128, U128 )
|
||||
OPCODE(VectorRoundingHalvingAddS16, U128, U128, U128 )
|
||||
OPCODE(VectorRoundingHalvingAddS32, U128, U128, U128 )
|
||||
|
|
Loading…
Reference in a new issue