diff --git a/src/backend_x64/emit_x64_vector.cpp b/src/backend_x64/emit_x64_vector.cpp index 77d4207a..fc4d1ccf 100644 --- a/src/backend_x64/emit_x64_vector.cpp +++ b/src/backend_x64/emit_x64_vector.cpp @@ -938,6 +938,61 @@ void EmitX64::EmitVectorPopulationCount(EmitContext& ctx, IR::Inst* inst) { }); } +void EmitX64::EmitVectorSignExtend8(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { + const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]); + code.pmovsxbw(a, a); + ctx.reg_alloc.DefineValue(inst, a); + } else { + const Xbyak::Xmm a = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + code.pxor(result, result); + code.punpcklbw(result, a); + code.psraw(result, 8); + ctx.reg_alloc.DefineValue(inst, result); + } +} + +void EmitX64::EmitVectorSignExtend16(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { + const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]); + ctx.reg_alloc.DefineValue(inst, a); + code.pmovsxwd(a, a); + } else { + const Xbyak::Xmm a = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + code.pxor(result, result); + code.punpcklwd(result, a); + code.psrad(result, 16); + ctx.reg_alloc.DefineValue(inst, result); + } +} + +void EmitX64::EmitVectorSignExtend32(EmitContext& ctx, IR::Inst* inst) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]); + code.pmovsxdq(a, a); + ctx.reg_alloc.DefineValue(inst, a); + return; + } + + EmitOneArgumentFallback(code, ctx, inst, [](std::array& result, const std::array& a){ + for (size_t i = 0; i < 2; ++i) { + result[i] = Common::SignExtend<32, u64>(a[i]); + } + }); +} + +void EmitX64::EmitVectorSignExtend64(EmitContext& ctx, IR::Inst* inst) { + EmitOneArgumentFallback(code, ctx, inst, [](std::array& result, const std::array& a){ + result[1] = (a[0] >> 63) ? ~u64(0) : 0; + result[0] = a[0]; + }); +} + void EmitX64::EmitVectorSub8(EmitContext& ctx, IR::Inst* inst) { EmitVectorOperation(code, ctx, inst, &Xbyak::CodeGenerator::psubb); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 1c79797c..960669b3 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -981,6 +981,21 @@ U128 IREmitter::VectorPopulationCount(const U128& a) { return Inst(Opcode::VectorPopulationCount, a); } +U128 IREmitter::VectorSignExtend(size_t original_esize, const U128& a) { + switch (original_esize) { + case 8: + return Inst(Opcode::VectorSignExtend8, a); + case 16: + return Inst(Opcode::VectorSignExtend16, a); + case 32: + return Inst(Opcode::VectorSignExtend32, a); + case 64: + return Inst(Opcode::VectorSignExtend64, a); + } + UNREACHABLE(); + return {}; +} + U128 IREmitter::VectorSub(size_t esize, const U128& a, const U128& b) { switch (esize) { case 8: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index c25d891f..29fcbdf7 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -224,6 +224,7 @@ public: U128 VectorPairedAdd(size_t esize, const U128& a, const U128& b); U128 VectorPairedAddLower(size_t esize, const U128& a, const U128& b); U128 VectorPopulationCount(const U128& a); + U128 VectorSignExtend(size_t original_esize, const U128& a); U128 VectorSub(size_t esize, const U128& a, const U128& b); U128 VectorZeroExtend(size_t original_esize, const U128& a); U128 VectorZeroUpper(const U128& a); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 85bd88b7..2e469304 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -249,6 +249,10 @@ OPCODE(VectorPairedAdd16, T::U128, T::U128, T::U128 OPCODE(VectorPairedAdd32, T::U128, T::U128, T::U128 ) OPCODE(VectorPairedAdd64, T::U128, T::U128, T::U128 ) OPCODE(VectorPopulationCount, T::U128, T::U128 ) +OPCODE(VectorSignExtend8, T::U128, T::U128 ) +OPCODE(VectorSignExtend16, T::U128, T::U128 ) +OPCODE(VectorSignExtend32, T::U128, T::U128 ) +OPCODE(VectorSignExtend64, T::U128, T::U128 ) OPCODE(VectorSub8, T::U128, T::U128, T::U128 ) OPCODE(VectorSub16, T::U128, T::U128, T::U128 ) OPCODE(VectorSub32, T::U128, T::U128, T::U128 )