diff --git a/src/backend/x64/emit_x64_floating_point.cpp b/src/backend/x64/emit_x64_floating_point.cpp index 5df30955..3fd10ced 100644 --- a/src/backend/x64/emit_x64_floating_point.cpp +++ b/src/backend/x64/emit_x64_floating_point.cpp @@ -608,54 +608,56 @@ template static void EmitFPMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { using FPT = mp::unsigned_integer_of_size; - if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if constexpr (fsize != 16) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - Xbyak::Label end, fallback; + Xbyak::Label end, fallback; - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); - const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); - const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]); - const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); - code.movaps(result, operand1); - FCODE(vfmadd231s)(result, operand2, operand3); + code.movaps(result, operand1); + FCODE(vfmadd231s)(result, operand2, operand3); - code.movaps(tmp, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask)); - code.andps(tmp, result); - FCODE(ucomis)(tmp, code.MConst(xword, fsize == 32 ? f32_smallest_normal : f64_smallest_normal)); - code.jz(fallback, code.T_NEAR); - code.L(end); + code.movaps(tmp, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask)); + code.andps(tmp, result); + FCODE(ucomis)(tmp, code.MConst(xword, fsize == 32 ? f32_smallest_normal : f64_smallest_normal)); + code.jz(fallback, code.T_NEAR); + code.L(end); - code.SwitchToFarCode(); - code.L(fallback); + code.SwitchToFarCode(); + code.L(fallback); - code.sub(rsp, 8); - ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.movq(code.ABI_PARAM1, operand1); - code.movq(code.ABI_PARAM2, operand2); - code.movq(code.ABI_PARAM3, operand3); - code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR().Value()); + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.movq(code.ABI_PARAM1, operand1); + code.movq(code.ABI_PARAM2, operand2); + code.movq(code.ABI_PARAM3, operand3); + code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR().Value()); #ifdef _WIN32 - code.sub(rsp, 16 + ABI_SHADOW_SPACE); - code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); - code.mov(qword[rsp + ABI_SHADOW_SPACE], rax); - code.CallFunction(&FP::FPMulAdd); - code.add(rsp, 16 + ABI_SHADOW_SPACE); + code.sub(rsp, 16 + ABI_SHADOW_SPACE); + code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); + code.mov(qword[rsp + ABI_SHADOW_SPACE], rax); + code.CallFunction(&FP::FPMulAdd); + code.add(rsp, 16 + ABI_SHADOW_SPACE); #else - code.lea(code.ABI_PARAM5, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); - code.CallFunction(&FP::FPMulAdd); + code.lea(code.ABI_PARAM5, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); + code.CallFunction(&FP::FPMulAdd); #endif - code.movq(result, code.ABI_RETURN); - ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.add(rsp, 8); + code.movq(result, code.ABI_RETURN); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); - code.jmp(end, code.T_NEAR); - code.SwitchToNearCode(); + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); - ctx.reg_alloc.DefineValue(inst, result); - return; + ctx.reg_alloc.DefineValue(inst, result); + return; + } } auto args = ctx.reg_alloc.GetArgumentInfo(inst); @@ -673,6 +675,10 @@ static void EmitFPMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { #endif } +void EmitX64::EmitFPMulAdd16(EmitContext& ctx, IR::Inst* inst) { + EmitFPMulAdd<16>(code, ctx, inst); +} + void EmitX64::EmitFPMulAdd32(EmitContext& ctx, IR::Inst* inst) { EmitFPMulAdd<32>(code, ctx, inst); } diff --git a/src/backend/x64/emit_x64_vector_floating_point.cpp b/src/backend/x64/emit_x64_vector_floating_point.cpp index 5a5e0629..deb5ab1f 100644 --- a/src/backend/x64/emit_x64_vector_floating_point.cpp +++ b/src/backend/x64/emit_x64_vector_floating_point.cpp @@ -908,44 +908,50 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { } }; - if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if constexpr (fsize != 16) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); - const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); - const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); - const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); - Xbyak::Label end, fallback; + Xbyak::Label end, fallback; - code.movaps(result, xmm_a); - FCODE(vfmadd231p)(result, xmm_b, xmm_c); + code.movaps(result, xmm_a); + FCODE(vfmadd231p)(result, xmm_b, xmm_c); - code.movaps(tmp, GetNegativeZeroVector(code)); - code.andnps(tmp, result); - FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector(code)); - code.vptest(tmp, tmp); - code.jnz(fallback, code.T_NEAR); - code.L(end); + code.movaps(tmp, GetNegativeZeroVector(code)); + code.andnps(tmp, result); + FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector(code)); + code.vptest(tmp, tmp); + code.jnz(fallback, code.T_NEAR); + code.L(end); - code.SwitchToFarCode(); - code.L(fallback); - code.sub(rsp, 8); - ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - EmitFourOpFallbackWithoutRegAlloc(code, ctx, result, xmm_a, xmm_b, xmm_c, fallback_fn); - ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.add(rsp, 8); - code.jmp(end, code.T_NEAR); - code.SwitchToNearCode(); + code.SwitchToFarCode(); + code.L(fallback); + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + EmitFourOpFallbackWithoutRegAlloc(code, ctx, result, xmm_a, xmm_b, xmm_c, fallback_fn); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); - ctx.reg_alloc.DefineValue(inst, result); - return; + ctx.reg_alloc.DefineValue(inst, result); + return; + } } EmitFourOpFallback(code, ctx, inst, fallback_fn); } +void EmitX64::EmitFPVectorMulAdd16(EmitContext& ctx, IR::Inst* inst) { + EmitFPVectorMulAdd<16>(code, ctx, inst); +} + void EmitX64::EmitFPVectorMulAdd32(EmitContext& ctx, IR::Inst* inst) { EmitFPVectorMulAdd<32>(code, ctx, inst); } diff --git a/src/common/fp/op/FPMulAdd.cpp b/src/common/fp/op/FPMulAdd.cpp index 5cabe374..1e994820 100644 --- a/src/common/fp/op/FPMulAdd.cpp +++ b/src/common/fp/op/FPMulAdd.cpp @@ -35,7 +35,7 @@ FPT FPMulAdd(FPT addend, FPT op1, FPT op2, FPCR fpcr, FPSR& fpsr) { if (typeA == FPType::QNaN && ((inf1 && zero2) || (zero1 && inf2))) { FPProcessException(FPExc::InvalidOp, fpcr, fpsr); - return FPInfo::DefaultNaN(); + return FPT(FPInfo::DefaultNaN()); } if (maybe_nan) { @@ -50,29 +50,30 @@ FPT FPMulAdd(FPT addend, FPT op1, FPT op2, FPCR fpcr, FPSR& fpsr) { // Raise NaN on (inf * inf) of opposite signs or (inf * zero). if ((inf1 && zero2) || (zero1 && inf2) || (infA && infP && signA != signP)) { FPProcessException(FPExc::InvalidOp, fpcr, fpsr); - return FPInfo::DefaultNaN(); + return FPT(FPInfo::DefaultNaN()); } // Handle infinities if ((infA && !signA) || (infP && !signP)) { - return FPInfo::Infinity(false); + return FPT(FPInfo::Infinity(false)); } if ((infA && signA) || (infP && signP)) { - return FPInfo::Infinity(true); + return FPT(FPInfo::Infinity(true)); } // Result is exactly zero if (zeroA && zeroP && signA == signP) { - return FPInfo::Zero(signA); + return FPT(FPInfo::Zero(signA)); } const FPUnpacked result_value = FusedMulAdd(valueA, value1, value2); if (result_value.mantissa == 0) { - return FPInfo::Zero(rounding == RoundingMode::TowardsMinusInfinity); + return FPT(FPInfo::Zero(rounding == RoundingMode::TowardsMinusInfinity)); } return FPRound(result_value, fpcr, fpsr); } +template u16 FPMulAdd(u16 addend, u16 op1, u16 op2, FPCR fpcr, FPSR& fpsr); template u32 FPMulAdd(u32 addend, u32 op1, u32 op2, FPCR fpcr, FPSR& fpsr); template u64 FPMulAdd(u64 addend, u64 op1, u64 op2, FPCR fpcr, FPSR& fpsr); diff --git a/src/frontend/A64/decoder/a64.inc b/src/frontend/A64/decoder/a64.inc index 50867ee8..36625453 100644 --- a/src/frontend/A64/decoder/a64.inc +++ b/src/frontend/A64/decoder/a64.inc @@ -538,9 +538,9 @@ INST(FCVTZU_fix_1, "FCVTZU (vector, fixed-point)", "01111 INST(SQDMULL_elt_1, "SQDMULL, SQDMULL2 (by element)", "01011111zzLMmmmm1011H0nnnnnddddd") INST(SQDMULH_elt_1, "SQDMULH (by element)", "01011111zzLMmmmm1100H0nnnnnddddd") INST(SQRDMULH_elt_1, "SQRDMULH (by element)", "01011111zzLMmmmm1101H0nnnnnddddd") -//INST(FMLA_elt_1, "FMLA (by element)", "0101111100LMmmmm0001H0nnnnnddddd") +INST(FMLA_elt_1, "FMLA (by element)", "0101111100LMmmmm0001H0nnnnnddddd") INST(FMLA_elt_2, "FMLA (by element)", "010111111zLMmmmm0001H0nnnnnddddd") -//INST(FMLS_elt_1, "FMLS (by element)", "0101111100LMmmmm0101H0nnnnnddddd") +INST(FMLS_elt_1, "FMLS (by element)", "0101111100LMmmmm0101H0nnnnnddddd") INST(FMLS_elt_2, "FMLS (by element)", "010111111zLMmmmm0101H0nnnnnddddd") //INST(FMUL_elt_1, "FMUL (by element)", "0101111100LMmmmm1001H0nnnnnddddd") INST(FMUL_elt_2, "FMUL (by element)", "010111111zLMmmmm1001H0nnnnnddddd") @@ -583,11 +583,11 @@ INST(INS_elt, "INS (element)", "01101 //INST(FCMGT_reg_3, "FCMGT (register)", "0Q101110110mmmmm001001nnnnnddddd") //INST(FACGT_3, "FACGT", "0Q101110110mmmmm001011nnnnnddddd") //INST(FMAXNM_1, "FMAXNM (vector)", "0Q001110010mmmmm000001nnnnnddddd") -//INST(FMLA_vec_1, "FMLA (vector)", "0Q001110010mmmmm000011nnnnnddddd") +INST(FMLA_vec_1, "FMLA (vector)", "0Q001110010mmmmm000011nnnnnddddd") //INST(FADD_1, "FADD (vector)", "0Q001110010mmmmm000101nnnnnddddd") //INST(FMAX_1, "FMAX (vector)", "0Q001110010mmmmm001101nnnnnddddd") //INST(FMINNM_1, "FMINNM (vector)", "0Q001110110mmmmm000001nnnnnddddd") -//INST(FMLS_vec_1, "FMLS (vector)", "0Q001110110mmmmm000011nnnnnddddd") +INST(FMLS_vec_1, "FMLS (vector)", "0Q001110110mmmmm000011nnnnnddddd") //INST(FSUB_1, "FSUB (vector)", "0Q001110110mmmmm000101nnnnnddddd") //INST(FMIN_1, "FMIN (vector)", "0Q001110110mmmmm001101nnnnnddddd") //INST(FMAXNMP_vec_1, "FMAXNMP (vector)", "0Q101110010mmmmm000001nnnnnddddd") @@ -876,9 +876,9 @@ INST(SQDMULL_elt_2, "SQDMULL, SQDMULL2 (by element)", "0Q001 INST(SQDMULH_elt_2, "SQDMULH (by element)", "0Q001111zzLMmmmm1100H0nnnnnddddd") INST(SQRDMULH_elt_2, "SQRDMULH (by element)", "0Q001111zzLMmmmm1101H0nnnnnddddd") INST(SDOT_elt, "SDOT (by element)", "0Q001111zzLMmmmm1110H0nnnnnddddd") -//INST(FMLA_elt_3, "FMLA (by element)", "0Q00111100LMmmmm0001H0nnnnnddddd") +INST(FMLA_elt_3, "FMLA (by element)", "0Q00111100LMmmmm0001H0nnnnnddddd") INST(FMLA_elt_4, "FMLA (by element)", "0Q0011111zLMmmmm0001H0nnnnnddddd") -//INST(FMLS_elt_3, "FMLS (by element)", "0Q00111100LMmmmm0101H0nnnnnddddd") +INST(FMLS_elt_3, "FMLS (by element)", "0Q00111100LMmmmm0101H0nnnnnddddd") INST(FMLS_elt_4, "FMLS (by element)", "0Q0011111zLMmmmm0101H0nnnnnddddd") //INST(FMUL_elt_3, "FMUL (by element)", "0Q00111100LMmmmm1001H0nnnnnddddd") INST(FMUL_elt_4, "FMUL (by element)", "0Q0011111zLMmmmm1001H0nnnnnddddd") diff --git a/src/frontend/A64/location_descriptor.h b/src/frontend/A64/location_descriptor.h index 4d19bd05..d2ff4de0 100644 --- a/src/frontend/A64/location_descriptor.h +++ b/src/frontend/A64/location_descriptor.h @@ -25,7 +25,7 @@ namespace Dynarmic::A64 { class LocationDescriptor { public: static constexpr u64 PC_MASK = 0x00FF'FFFF'FFFF'FFFFull; - static constexpr u32 FPCR_MASK = 0x07C0'0000; + static constexpr u32 FPCR_MASK = 0x07C8'0000; LocationDescriptor(u64 pc, FP::FPCR fpcr) : pc(pc & PC_MASK), fpcr(fpcr.Value() & FPCR_MASK) {} diff --git a/src/frontend/A64/translate/impl/floating_point_data_processing_three_register.cpp b/src/frontend/A64/translate/impl/floating_point_data_processing_three_register.cpp index 0401cf9f..db23bf10 100644 --- a/src/frontend/A64/translate/impl/floating_point_data_processing_three_register.cpp +++ b/src/frontend/A64/translate/impl/floating_point_data_processing_three_register.cpp @@ -12,56 +12,56 @@ namespace Dynarmic::A64 { bool TranslatorVisitor::FMADD_float(Imm<2> type, Vec Vm, Vec Va, Vec Vn, Vec Vd) { const auto datasize = FPGetDataSize(type); - if (!datasize || *datasize == 16) { + if (!datasize) { return UnallocatedEncoding(); } - const IR::U32U64 operanda = V_scalar(*datasize, Va); - const IR::U32U64 operand1 = V_scalar(*datasize, Vn); - const IR::U32U64 operand2 = V_scalar(*datasize, Vm); - const IR::U32U64 result = ir.FPMulAdd(operanda, operand1, operand2, true); + const IR::U16U32U64 operanda = V_scalar(*datasize, Va); + const IR::U16U32U64 operand1 = V_scalar(*datasize, Vn); + const IR::U16U32U64 operand2 = V_scalar(*datasize, Vm); + const IR::U16U32U64 result = ir.FPMulAdd(operanda, operand1, operand2, true); V_scalar(*datasize, Vd, result); return true; } bool TranslatorVisitor::FMSUB_float(Imm<2> type, Vec Vm, Vec Va, Vec Vn, Vec Vd) { const auto datasize = FPGetDataSize(type); - if (!datasize || *datasize == 16) { + if (!datasize) { return UnallocatedEncoding(); } - const IR::U32U64 operanda = V_scalar(*datasize, Va); - const IR::U32U64 operand1 = V_scalar(*datasize, Vn); - const IR::U32U64 operand2 = V_scalar(*datasize, Vm); - const IR::U32U64 result = ir.FPMulAdd(operanda, ir.FPNeg(operand1), operand2, true); + const IR::U16U32U64 operanda = V_scalar(*datasize, Va); + const IR::U16U32U64 operand1 = V_scalar(*datasize, Vn); + const IR::U16U32U64 operand2 = V_scalar(*datasize, Vm); + const IR::U16U32U64 result = ir.FPMulAdd(operanda, ir.FPNeg(operand1), operand2, true); V_scalar(*datasize, Vd, result); return true; } bool TranslatorVisitor::FNMADD_float(Imm<2> type, Vec Vm, Vec Va, Vec Vn, Vec Vd) { const auto datasize = FPGetDataSize(type); - if (!datasize || *datasize == 16) { + if (!datasize) { return UnallocatedEncoding(); } - const IR::U32U64 operanda = V_scalar(*datasize, Va); - const IR::U32U64 operand1 = V_scalar(*datasize, Vn); - const IR::U32U64 operand2 = V_scalar(*datasize, Vm); - const IR::U32U64 result = ir.FPMulAdd(ir.FPNeg(operanda), ir.FPNeg(operand1), operand2, true); + const IR::U16U32U64 operanda = V_scalar(*datasize, Va); + const IR::U16U32U64 operand1 = V_scalar(*datasize, Vn); + const IR::U16U32U64 operand2 = V_scalar(*datasize, Vm); + const IR::U16U32U64 result = ir.FPMulAdd(ir.FPNeg(operanda), ir.FPNeg(operand1), operand2, true); V_scalar(*datasize, Vd, result); return true; } bool TranslatorVisitor::FNMSUB_float(Imm<2> type, Vec Vm, Vec Va, Vec Vn, Vec Vd) { const auto datasize = FPGetDataSize(type); - if (!datasize || *datasize == 16) { + if (!datasize) { return UnallocatedEncoding(); } - const IR::U32U64 operanda = V_scalar(*datasize, Va); - const IR::U32U64 operand1 = V_scalar(*datasize, Vn); - const IR::U32U64 operand2 = V_scalar(*datasize, Vm); - const IR::U32U64 result = ir.FPMulAdd(ir.FPNeg(operanda), operand1, operand2, true); + const IR::U16U32U64 operanda = V_scalar(*datasize, Va); + const IR::U16U32U64 operand1 = V_scalar(*datasize, Vn); + const IR::U16U32U64 operand2 = V_scalar(*datasize, Vm); + const IR::U16U32U64 result = ir.FPMulAdd(ir.FPNeg(operanda), operand1, operand2, true); V_scalar(*datasize, Vd, result); return true; } diff --git a/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp b/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp index 693cf3d3..350eb2fe 100644 --- a/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp +++ b/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp @@ -36,7 +36,7 @@ bool MultiplyByElement(TranslatorVisitor& v, bool sz, Imm<1> L, Imm<1> M, Imm<4> const size_t esize = sz ? 64 : 32; const IR::U32U64 element = v.ir.VectorGetElement(esize, v.V(idxdsize, Vm), index); - const IR::U32U64 result = [&] { + const IR::U32U64 result = [&]() -> IR::U32U64 { IR::U32U64 operand1 = v.V_scalar(esize, Vn); if (extra_behavior == ExtraBehavior::None) { @@ -58,12 +58,54 @@ bool MultiplyByElement(TranslatorVisitor& v, bool sz, Imm<1> L, Imm<1> M, Imm<4> v.V_scalar(esize, Vd, result); return true; } + +bool MultiplyByElementHalfPrecision(TranslatorVisitor& v, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, + Vec Vn, Vec Vd, ExtraBehavior extra_behavior) { + const size_t esize = 16; + const size_t idxsize = H == 1 ? 128 : 64; + const size_t index = concatenate(H, L, M).ZeroExtend(); + + const auto Vm = Vmlo.ZeroExtend(); + const IR::U16 element = v.ir.VectorGetElement(esize, v.V(idxsize, Vm), index); + const IR::U16 result = [&]() -> IR::U16 { + IR::U16 operand1 = v.V_scalar(esize, Vn); + + // TODO: Currently we don't implement half-precision paths + // for regular multiplication and extended multiplication. + + if (extra_behavior == ExtraBehavior::None) { + UNIMPLEMENTED(); + } + + if (extra_behavior == ExtraBehavior::MultiplyExtended) { + UNIMPLEMENTED(); + } + + if (extra_behavior == ExtraBehavior::Subtract) { + operand1 = v.ir.FPNeg(operand1); + } + + const IR::U16 operand2 = v.V_scalar(esize, Vd); + return v.ir.FPMulAdd(operand2, operand1, element, true); + }(); + + v.V_scalar(esize, Vd, result); + return true; +} } // Anonymous namespace +bool TranslatorVisitor::FMLA_elt_1(Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { + return MultiplyByElementHalfPrecision(*this, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Accumulate); +} + bool TranslatorVisitor::FMLA_elt_2(bool sz, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { return MultiplyByElement(*this, sz, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Accumulate); } +bool TranslatorVisitor::FMLS_elt_1(Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { + return MultiplyByElementHalfPrecision(*this, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Subtract); +} + bool TranslatorVisitor::FMLS_elt_2(bool sz, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { return MultiplyByElement(*this, sz, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Subtract); } diff --git a/src/frontend/A64/translate/impl/simd_three_same.cpp b/src/frontend/A64/translate/impl/simd_three_same.cpp index a4492165..a484dc17 100644 --- a/src/frontend/A64/translate/impl/simd_three_same.cpp +++ b/src/frontend/A64/translate/impl/simd_three_same.cpp @@ -680,10 +680,24 @@ bool TranslatorVisitor::FADD_2(bool Q, bool sz, Vec Vm, Vec Vn, Vec Vd) { return true; } +bool TranslatorVisitor::FMLA_vec_1(bool Q, Vec Vm, Vec Vn, Vec Vd) { + const size_t datasize = Q ? 128 : 64; + const size_t esize = 16; + + const IR::U128 operand1 = V(datasize, Vn); + const IR::U128 operand2 = V(datasize, Vm); + const IR::U128 operand3 = V(datasize, Vd); + const IR::U128 result = ir.FPVectorMulAdd(esize, operand3, operand1, operand2); + + V(datasize, Vd, result); + return true; +} + bool TranslatorVisitor::FMLA_vec_2(bool Q, bool sz, Vec Vm, Vec Vn, Vec Vd) { if (sz && !Q) { return ReservedValue(); } + const size_t esize = sz ? 64 : 32; const size_t datasize = Q ? 128 : 64; @@ -691,6 +705,20 @@ bool TranslatorVisitor::FMLA_vec_2(bool Q, bool sz, Vec Vm, Vec Vn, Vec Vd) { const IR::U128 operand2 = V(datasize, Vm); const IR::U128 operand3 = V(datasize, Vd); const IR::U128 result = ir.FPVectorMulAdd(esize, operand3, operand1, operand2); + + V(datasize, Vd, result); + return true; +} + +bool TranslatorVisitor::FMLS_vec_1(bool Q, Vec Vm, Vec Vn, Vec Vd) { + const size_t datasize = Q ? 128 : 64; + const size_t esize = 16; + + const IR::U128 operand1 = V(datasize, Vn); + const IR::U128 operand2 = V(datasize, Vm); + const IR::U128 operand3 = V(datasize, Vd); + const IR::U128 result = ir.FPVectorMulAdd(esize, operand3, ir.FPVectorNeg(esize, operand1), operand2); + V(datasize, Vd, result); return true; } @@ -699,6 +727,7 @@ bool TranslatorVisitor::FMLS_vec_2(bool Q, bool sz, Vec Vm, Vec Vn, Vec Vd) { if (sz && !Q) { return ReservedValue(); } + const size_t esize = sz ? 64 : 32; const size_t datasize = Q ? 128 : 64; @@ -706,6 +735,7 @@ bool TranslatorVisitor::FMLS_vec_2(bool Q, bool sz, Vec Vm, Vec Vn, Vec Vd) { const IR::U128 operand2 = V(datasize, Vm); const IR::U128 operand3 = V(datasize, Vd); const IR::U128 result = ir.FPVectorMulAdd(esize, operand3, ir.FPVectorNeg(esize, operand1), operand2); + V(datasize, Vd, result); return true; } diff --git a/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp b/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp index 98bee23c..a05ad7d5 100644 --- a/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp +++ b/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp @@ -89,6 +89,39 @@ bool FPMultiplyByElement(TranslatorVisitor& v, bool Q, bool sz, Imm<1> L, Imm<1> return true; } +bool FPMultiplyByElementHalfPrecision(TranslatorVisitor& v, bool Q, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, + Vec Vn, Vec Vd, ExtraBehavior extra_behavior) { + const size_t idxdsize = H == 1 ? 128 : 64; + const size_t index = concatenate(H, L, M).ZeroExtend(); + const Vec Vm = Vmlo.ZeroExtend(); + const size_t esize = 16; + const size_t datasize = Q ? 128 : 64; + + const IR::UAny element2 = v.ir.VectorGetElement(esize, v.V(idxdsize, Vm), index); + const IR::U128 operand1 = v.V(datasize, Vn); + const IR::U128 operand2 = Q ? v.ir.VectorBroadcast(esize, element2) : v.ir.VectorBroadcastLower(esize, element2); + const IR::U128 operand3 = v.V(datasize, Vd); + + // TODO: We currently don't implement half-precision paths for + // regular multiplies and extended multiplies. + const IR::U128 result = [&]{ + switch (extra_behavior) { + case ExtraBehavior::None: + break; + case ExtraBehavior::Extended: + break; + case ExtraBehavior::Accumulate: + return v.ir.FPVectorMulAdd(esize, operand3, operand1, operand2); + case ExtraBehavior::Subtract: + return v.ir.FPVectorMulAdd(esize, operand3, v.ir.FPVectorNeg(esize, operand1), operand2); + } + UNREACHABLE(); + return IR::U128{}; + }(); + v.V(datasize, Vd, result); + return true; +} + using ExtensionFunction = IR::U32 (IREmitter::*)(const IR::UAny&); bool DotProduct(TranslatorVisitor& v, bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, @@ -276,10 +309,18 @@ bool TranslatorVisitor::FCMLA_elt(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4 return true; } +bool TranslatorVisitor::FMLA_elt_3(bool Q, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { + return FPMultiplyByElementHalfPrecision(*this, Q, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Accumulate); +} + bool TranslatorVisitor::FMLA_elt_4(bool Q, bool sz, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { return FPMultiplyByElement(*this, Q, sz, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Accumulate); } +bool TranslatorVisitor::FMLS_elt_3(bool Q, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { + return FPMultiplyByElementHalfPrecision(*this, Q, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Subtract); +} + bool TranslatorVisitor::FMLS_elt_4(bool Q, bool sz, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { return FPMultiplyByElement(*this, Q, sz, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::Subtract); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index e774f25d..53df95b0 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -1882,13 +1882,20 @@ U32U64 IREmitter::FPMul(const U32U64& a, const U32U64& b, bool fpcr_controlled) } } -U32U64 IREmitter::FPMulAdd(const U32U64& a, const U32U64& b, const U32U64& c, bool fpcr_controlled) { +U16U32U64 IREmitter::FPMulAdd(const U16U32U64& a, const U16U32U64& b, const U16U32U64& c, bool fpcr_controlled) { ASSERT(fpcr_controlled); ASSERT(a.GetType() == b.GetType()); - if (a.GetType() == Type::U32) { + + switch (a.GetType()) { + case Type::U16: + return Inst(Opcode::FPMulAdd16, a, b, c); + case Type::U32: return Inst(Opcode::FPMulAdd32, a, b, c); - } else { + case Type::U64: return Inst(Opcode::FPMulAdd64, a, b, c); + default: + UNREACHABLE(); + return U16U32U64{}; } } @@ -2181,6 +2188,8 @@ U128 IREmitter::FPVectorMul(size_t esize, const U128& a, const U128& b) { U128 IREmitter::FPVectorMulAdd(size_t esize, const U128& a, const U128& b, const U128& c) { switch (esize) { + case 16: + return Inst(Opcode::FPVectorMulAdd16, a, b, c); case 32: return Inst(Opcode::FPVectorMulAdd32, a, b, c); case 64: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index a77089fe..32fad5b7 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -302,7 +302,7 @@ public: U32U64 FPMin(const U32U64& a, const U32U64& b, bool fpcr_controlled); U32U64 FPMinNumeric(const U32U64& a, const U32U64& b, bool fpcr_controlled); U32U64 FPMul(const U32U64& a, const U32U64& b, bool fpcr_controlled); - U32U64 FPMulAdd(const U32U64& addend, const U32U64& op1, const U32U64& op2, bool fpcr_controlled); + U16U32U64 FPMulAdd(const U16U32U64& addend, const U16U32U64& op1, const U16U32U64& op2, bool fpcr_controlled); U32U64 FPMulX(const U32U64& a, const U32U64& b); U16U32U64 FPNeg(const U16U32U64& a); U32U64 FPRecipEstimate(const U32U64& a); diff --git a/src/frontend/ir/microinstruction.cpp b/src/frontend/ir/microinstruction.cpp index b34b989d..c087b514 100644 --- a/src/frontend/ir/microinstruction.cpp +++ b/src/frontend/ir/microinstruction.cpp @@ -269,6 +269,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const { case Opcode::FPMinNumeric64: case Opcode::FPMul32: case Opcode::FPMul64: + case Opcode::FPMulAdd16: case Opcode::FPMulAdd32: case Opcode::FPMulAdd64: case Opcode::FPRecipEstimate32: @@ -326,6 +327,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const { case Opcode::FPVectorGreaterEqual64: case Opcode::FPVectorMul32: case Opcode::FPVectorMul64: + case Opcode::FPVectorMulAdd16: case Opcode::FPVectorMulAdd32: case Opcode::FPVectorMulAdd64: case Opcode::FPVectorPairedAddLower32: diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 7723b03e..db3128e8 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -483,6 +483,7 @@ OPCODE(FPMinNumeric32, U32, U32, OPCODE(FPMinNumeric64, U64, U64, U64 ) OPCODE(FPMul32, U32, U32, U32 ) OPCODE(FPMul64, U64, U64, U64 ) +OPCODE(FPMulAdd16, U16, U16, U16, U16 ) OPCODE(FPMulAdd32, U32, U32, U32, U32 ) OPCODE(FPMulAdd64, U64, U64, U64, U64 ) OPCODE(FPMulX32, U32, U32, U32 ) @@ -556,6 +557,7 @@ OPCODE(FPVectorMin32, U128, U128 OPCODE(FPVectorMin64, U128, U128, U128 ) OPCODE(FPVectorMul32, U128, U128, U128 ) OPCODE(FPVectorMul64, U128, U128, U128 ) +OPCODE(FPVectorMulAdd16, U128, U128, U128, U128 ) OPCODE(FPVectorMulAdd32, U128, U128, U128, U128 ) OPCODE(FPVectorMulAdd64, U128, U128, U128, U128 ) OPCODE(FPVectorMulX32, U128, U128, U128 )