diff --git a/src/backend_x64/emit_x64_vector.cpp b/src/backend_x64/emit_x64_vector.cpp index a0e31765..fb536b3f 100644 --- a/src/backend_x64/emit_x64_vector.cpp +++ b/src/backend_x64/emit_x64_vector.cpp @@ -222,6 +222,67 @@ void EmitX64::EmitVectorAnd(EmitContext& ctx, IR::Inst* inst) { EmitVectorOperation(code, ctx, inst, &Xbyak::CodeGenerator::pand); } +void EmitX64::EmitVectorArithmeticShiftRight8(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + const u8 shift_amount = args[1].GetImmediateU8(); + + // TODO: Optimize + code.movdqa(tmp, result); + code.pslldq(tmp, 1); + code.psraw(tmp, shift_amount); + code.psraw(result, shift_amount + 8); + code.psllw(result, 8); + code.psrlw(tmp, 8); + code.por(result, tmp); + + ctx.reg_alloc.DefineValue(inst, result); +} + +void EmitX64::EmitVectorArithmeticShiftRight16(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const u8 shift_amount = args[1].GetImmediateU8(); + + code.psraw(result, shift_amount); + + ctx.reg_alloc.DefineValue(inst, result); +} + +void EmitX64::EmitVectorArithmeticShiftRight32(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const u8 shift_amount = args[1].GetImmediateU8(); + + code.psrad(result, shift_amount); + + ctx.reg_alloc.DefineValue(inst, result); +} + +void EmitX64::EmitVectorArithmeticShiftRight64(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Xmm tmp1 = ctx.reg_alloc.ScratchXmm(); + Xbyak::Xmm tmp2 = ctx.reg_alloc.ScratchXmm(); + const u8 shift_amount = std::min(args[1].GetImmediateU8(), u8(63)); + + const u64 sign_bit = 0x80000000'00000000u >> shift_amount; + + code.pxor(tmp2, tmp2); + code.psrlq(result, shift_amount); + code.movdqa(tmp1, code.MConst(sign_bit, sign_bit)); + code.pand(tmp1, result); + code.psubq(tmp2, tmp1); + code.por(result, tmp2); + + ctx.reg_alloc.DefineValue(inst, result); +} + void EmitX64::EmitVectorBroadcastLower8(EmitContext& ctx, IR::Inst* inst) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index a56c9d4f..c50e4b37 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -804,6 +804,21 @@ U128 IREmitter::VectorAnd(const U128& a, const U128& b) { return Inst(Opcode::VectorAnd, a, b); } +U128 IREmitter::VectorArithmeticShiftRight(size_t esize, const U128& a, u8 shift_amount) { + switch (esize) { + case 8: + return Inst(Opcode::VectorArithmeticShiftRight8, a, Imm8(shift_amount)); + case 16: + return Inst(Opcode::VectorArithmeticShiftRight16, a, Imm8(shift_amount)); + case 32: + return Inst(Opcode::VectorArithmeticShiftRight32, a, Imm8(shift_amount)); + case 64: + return Inst(Opcode::VectorArithmeticShiftRight64, a, Imm8(shift_amount)); + } + UNREACHABLE(); + return {}; +} + U128 IREmitter::VectorBroadcastLower(size_t esize, const UAny& a) { switch (esize) { case 8: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index e2de2c51..8de49c40 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -209,6 +209,7 @@ public: U128 VectorSetElement(size_t esize, const U128& a, size_t index, const UAny& elem); U128 VectorAdd(size_t esize, const U128& a, const U128& b); U128 VectorAnd(const U128& a, const U128& b); + U128 VectorArithmeticShiftRight(size_t esize, const U128& a, u8 shift_amount); U128 VectorBroadcast(size_t esize, const UAny& a); U128 VectorBroadcastLower(size_t esize, const UAny& a); U128 VectorEor(const U128& a, const U128& b); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index b113549f..807f315f 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -203,6 +203,10 @@ OPCODE(VectorAdd16, T::U128, T::U128, T::U128 OPCODE(VectorAdd32, T::U128, T::U128, T::U128 ) OPCODE(VectorAdd64, T::U128, T::U128, T::U128 ) OPCODE(VectorAnd, T::U128, T::U128, T::U128 ) +OPCODE(VectorArithmeticShiftRight8, T::U128, T::U128, T::U8 ) +OPCODE(VectorArithmeticShiftRight16,T::U128, T::U128, T::U8 ) +OPCODE(VectorArithmeticShiftRight32,T::U128, T::U128, T::U8 ) +OPCODE(VectorArithmeticShiftRight64,T::U128, T::U128, T::U8 ) OPCODE(VectorBroadcastLower8, T::U128, T::U8 ) OPCODE(VectorBroadcastLower16, T::U128, T::U16 ) OPCODE(VectorBroadcastLower32, T::U128, T::U32 )