diff --git a/src/backend/x64/emit_x64_vector_floating_point.cpp b/src/backend/x64/emit_x64_vector_floating_point.cpp index 5a5e0629..deb5ab1f 100644 --- a/src/backend/x64/emit_x64_vector_floating_point.cpp +++ b/src/backend/x64/emit_x64_vector_floating_point.cpp @@ -908,44 +908,50 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { } }; - if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if constexpr (fsize != 16) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); - const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); - const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); - const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); - Xbyak::Label end, fallback; + Xbyak::Label end, fallback; - code.movaps(result, xmm_a); - FCODE(vfmadd231p)(result, xmm_b, xmm_c); + code.movaps(result, xmm_a); + FCODE(vfmadd231p)(result, xmm_b, xmm_c); - code.movaps(tmp, GetNegativeZeroVector(code)); - code.andnps(tmp, result); - FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector(code)); - code.vptest(tmp, tmp); - code.jnz(fallback, code.T_NEAR); - code.L(end); + code.movaps(tmp, GetNegativeZeroVector(code)); + code.andnps(tmp, result); + FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector(code)); + code.vptest(tmp, tmp); + code.jnz(fallback, code.T_NEAR); + code.L(end); - code.SwitchToFarCode(); - code.L(fallback); - code.sub(rsp, 8); - ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - EmitFourOpFallbackWithoutRegAlloc(code, ctx, result, xmm_a, xmm_b, xmm_c, fallback_fn); - ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.add(rsp, 8); - code.jmp(end, code.T_NEAR); - code.SwitchToNearCode(); + code.SwitchToFarCode(); + code.L(fallback); + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + EmitFourOpFallbackWithoutRegAlloc(code, ctx, result, xmm_a, xmm_b, xmm_c, fallback_fn); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); - ctx.reg_alloc.DefineValue(inst, result); - return; + ctx.reg_alloc.DefineValue(inst, result); + return; + } } EmitFourOpFallback(code, ctx, inst, fallback_fn); } +void EmitX64::EmitFPVectorMulAdd16(EmitContext& ctx, IR::Inst* inst) { + EmitFPVectorMulAdd<16>(code, ctx, inst); +} + void EmitX64::EmitFPVectorMulAdd32(EmitContext& ctx, IR::Inst* inst) { EmitFPVectorMulAdd<32>(code, ctx, inst); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 6d333b91..2280a549 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -2173,6 +2173,8 @@ U128 IREmitter::FPVectorMul(size_t esize, const U128& a, const U128& b) { U128 IREmitter::FPVectorMulAdd(size_t esize, const U128& a, const U128& b, const U128& c) { switch (esize) { + case 16: + return Inst(Opcode::FPVectorMulAdd16, a, b, c); case 32: return Inst(Opcode::FPVectorMulAdd32, a, b, c); case 64: diff --git a/src/frontend/ir/microinstruction.cpp b/src/frontend/ir/microinstruction.cpp index d746c9e0..8f1bb3c2 100644 --- a/src/frontend/ir/microinstruction.cpp +++ b/src/frontend/ir/microinstruction.cpp @@ -327,6 +327,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const { case Opcode::FPVectorGreaterEqual64: case Opcode::FPVectorMul32: case Opcode::FPVectorMul64: + case Opcode::FPVectorMulAdd16: case Opcode::FPVectorMulAdd32: case Opcode::FPVectorMulAdd64: case Opcode::FPVectorPairedAddLower32: diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 3eaa41ea..bfc95a20 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -553,6 +553,7 @@ OPCODE(FPVectorMin32, U128, U128 OPCODE(FPVectorMin64, U128, U128, U128 ) OPCODE(FPVectorMul32, U128, U128, U128 ) OPCODE(FPVectorMul64, U128, U128, U128 ) +OPCODE(FPVectorMulAdd16, U128, U128, U128, U128 ) OPCODE(FPVectorMulAdd32, U128, U128, U128, U128 ) OPCODE(FPVectorMulAdd64, U128, U128, U128, U128 ) OPCODE(FPVectorMulX32, U128, U128, U128 )