diff --git a/src/backend_x64/emit_x64_vector.cpp b/src/backend_x64/emit_x64_vector.cpp index 08f89f83..f5fe6503 100644 --- a/src/backend_x64/emit_x64_vector.cpp +++ b/src/backend_x64/emit_x64_vector.cpp @@ -101,6 +101,104 @@ void EmitX64::EmitVectorGetElement64(EmitContext& ctx, IR::Inst* inst) { ctx.reg_alloc.DefineValue(inst, dest); } +void EmitX64::EmitVectorSetElement8(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + ASSERT(args[1].IsImmediate()); + u8 index = args[1].GetImmediateU8(); + + if (code->DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { + Xbyak::Xmm source_vector = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Reg8 source_elem = ctx.reg_alloc.UseGpr(args[2]).cvt8(); + + code->pinsrb(source_vector, source_elem.cvt32(), index); + + ctx.reg_alloc.DefineValue(inst, source_vector); + } else { + Xbyak::Xmm source_vector = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Reg32 source_elem = ctx.reg_alloc.UseScratchGpr(args[2]).cvt32(); + Xbyak::Reg32 tmp = ctx.reg_alloc.ScratchGpr().cvt32(); + + code->pextrw(tmp, source_vector, index / 2); + if (index % 2 == 0) { + code->and_(tmp, 0xFF00); + code->and_(source_elem, 0x00FF); + code->or_(tmp, source_elem); + } else { + code->and_(tmp, 0x00FF); + code->shl(source_elem, 8); + code->or_(tmp, source_elem); + } + code->pinsrw(source_vector, tmp, index / 2); + + ctx.reg_alloc.DefineValue(inst, source_vector); + } +} + +void EmitX64::EmitVectorSetElement16(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + ASSERT(args[1].IsImmediate()); + u8 index = args[1].GetImmediateU8(); + + Xbyak::Xmm source_vector = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Reg16 source_elem = ctx.reg_alloc.UseGpr(args[2]).cvt16(); + + code->pinsrw(source_vector, source_elem.cvt32(), index); + + ctx.reg_alloc.DefineValue(inst, source_vector); +} + +void EmitX64::EmitVectorSetElement32(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + ASSERT(args[1].IsImmediate()); + u8 index = args[1].GetImmediateU8(); + + if (code->DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { + Xbyak::Xmm source_vector = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Reg32 source_elem = ctx.reg_alloc.UseGpr(args[2]).cvt32(); + + code->pinsrd(source_vector, source_elem, index); + + ctx.reg_alloc.DefineValue(inst, source_vector); + } else { + Xbyak::Xmm source_vector = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Reg32 source_elem = ctx.reg_alloc.UseScratchGpr(args[2]).cvt32(); + + code->pinsrw(source_vector, source_elem, index * 2); + code->shr(source_elem, 16); + code->pinsrw(source_vector, source_elem, index * 2 + 1); + + ctx.reg_alloc.DefineValue(inst, source_vector); + } +} + +void EmitX64::EmitVectorSetElement64(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + ASSERT(args[1].IsImmediate()); + u8 index = args[1].GetImmediateU8(); + + if (code->DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { + Xbyak::Xmm source_vector = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Reg64 source_elem = ctx.reg_alloc.UseGpr(args[2]); + + code->pinsrq(source_vector, source_elem, index); + + ctx.reg_alloc.DefineValue(inst, source_vector); + } else { + Xbyak::Xmm source_vector = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Reg64 source_elem = ctx.reg_alloc.UseScratchGpr(args[2]); + + code->pinsrw(source_vector, source_elem.cvt32(), index * 4); + code->shr(source_elem, 16); + code->pinsrw(source_vector, source_elem.cvt32(), index * 4 + 1); + code->shr(source_elem, 16); + code->pinsrw(source_vector, source_elem.cvt32(), index * 4 + 2); + code->shr(source_elem, 16); + code->pinsrw(source_vector, source_elem.cvt32(), index * 4 + 3); + + ctx.reg_alloc.DefineValue(inst, source_vector); + } +} + void EmitX64::EmitVectorAdd8(EmitContext& ctx, IR::Inst* inst) { EmitVectorOperation(code, ctx, inst, &Xbyak::CodeGenerator::paddb); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 34c2f27e..be072f12 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -740,6 +740,23 @@ UAny IREmitter::VectorGetElement(size_t esize, const U128& a, size_t index) { } } +U128 IREmitter::VectorSetElement(size_t esize, const U128& a, size_t index, const IR::UAny& elem) { + ASSERT_MSG(esize * index < 128, "Invalid index"); + switch (esize) { + case 8: + return Inst(Opcode::VectorSetElement8, a, Imm8(static_cast(index)), elem); + case 16: + return Inst(Opcode::VectorSetElement16, a, Imm8(static_cast(index)), elem); + case 32: + return Inst(Opcode::VectorSetElement32, a, Imm8(static_cast(index)), elem); + case 64: + return Inst(Opcode::VectorSetElement64, a, Imm8(static_cast(index)), elem); + default: + ASSERT_MSG(false, "Unreachable"); + return {}; + } +} + U128 IREmitter::VectorAdd8(const U128& a, const U128& b) { return Inst(Opcode::VectorAdd8, a, b); } diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index 94cae668..4fd58483 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -199,6 +199,7 @@ public: U128 AESMixColumns(const U128& a); UAny VectorGetElement(size_t esize, const U128& a, size_t index); + U128 VectorSetElement(size_t esize, const U128& a, size_t index, const UAny& elem); U128 VectorAdd8(const U128& a, const U128& b); U128 VectorAdd16(const U128& a, const U128& b); U128 VectorAdd32(const U128& a, const U128& b); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 8bce8f76..bd105669 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -187,6 +187,10 @@ OPCODE(VectorGetElement8, T::U8, T::U128, T::U8 OPCODE(VectorGetElement16, T::U16, T::U128, T::U8 ) OPCODE(VectorGetElement32, T::U32, T::U128, T::U8 ) OPCODE(VectorGetElement64, T::U64, T::U128, T::U8 ) +OPCODE(VectorSetElement8, T::U128, T::U128, T::U8, T::U8 ) +OPCODE(VectorSetElement16, T::U128, T::U128, T::U8, T::U16 ) +OPCODE(VectorSetElement32, T::U128, T::U128, T::U8, T::U32 ) +OPCODE(VectorSetElement64, T::U128, T::U128, T::U8, T::U64 ) OPCODE(VectorAdd8, T::U128, T::U128, T::U128 ) OPCODE(VectorAdd16, T::U128, T::U128, T::U128 ) OPCODE(VectorAdd32, T::U128, T::U128, T::U128 )