Merge pull request #476 from lioncash/frint

A64: Handle half-precision variants of FRINT* instructions
This commit is contained in:
Merry 2019-04-14 10:57:09 +01:00 committed by MerryMage
commit c6e6ec0e69
10 changed files with 115 additions and 42 deletions

View file

@ -843,7 +843,7 @@ static void EmitFPRound(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, siz
const bool exact = inst->GetArg(2).GetU1();
const auto round_imm = ConvertRoundingModeToX64Immediate(rounding_mode);
if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41) && round_imm && !exact) {
if (fsize != 16 && code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41) && round_imm && !exact) {
if (fsize == 64) {
FPTwoOp<64>(code, ctx, inst, [&](Xbyak::Xmm result) {
code.roundsd(result, result, *round_imm);
@ -857,7 +857,9 @@ static void EmitFPRound(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, siz
return;
}
using fsize_list = mp::list<mp::vlift<size_t(32)>, mp::vlift<size_t(64)>>;
using fsize_list = mp::list<mp::vlift<size_t(16)>,
mp::vlift<size_t(32)>,
mp::vlift<size_t(64)>>;
using rounding_list = mp::list<
std::integral_constant<FP::RoundingMode, FP::RoundingMode::ToNearest_TieEven>,
std::integral_constant<FP::RoundingMode, FP::RoundingMode::TowardsPlusInfinity>,
@ -897,6 +899,10 @@ static void EmitFPRound(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, siz
code.CallFunction(lut.at(std::make_tuple(fsize, rounding_mode, exact)));
}
void EmitX64::EmitFPRoundInt16(EmitContext& ctx, IR::Inst* inst) {
EmitFPRound(code, ctx, inst, 16);
}
void EmitX64::EmitFPRoundInt32(EmitContext& ctx, IR::Inst* inst) {
EmitFPRound(code, ctx, inst, 32);
}

View file

@ -1160,6 +1160,7 @@ void EmitFPVectorRoundInt(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
const auto rounding = static_cast<FP::RoundingMode>(inst->GetArg(1).GetU8());
const bool exact = inst->GetArg(2).GetU1();
if constexpr (fsize != 16) {
if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41) && rounding != FP::RoundingMode::ToNearest_TieAwayFromZero && !exact) {
const u8 round_imm = [&]() -> u8 {
switch (rounding) {
@ -1183,6 +1184,7 @@ void EmitFPVectorRoundInt(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
return;
}
}
using rounding_list = mp::list<
std::integral_constant<FP::RoundingMode, FP::RoundingMode::ToNearest_TieEven>,
@ -1218,6 +1220,10 @@ void EmitFPVectorRoundInt(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
EmitTwoOpFallback(code, ctx, inst, lut.at(std::make_tuple(rounding, exact)));
}
void EmitX64::EmitFPVectorRoundInt16(EmitContext& ctx, IR::Inst* inst) {
EmitFPVectorRoundInt<16>(code, ctx, inst);
}
void EmitX64::EmitFPVectorRoundInt32(EmitContext& ctx, IR::Inst* inst) {
EmitFPVectorRoundInt<32>(code, ctx, inst);
}

View file

@ -31,11 +31,11 @@ u64 FPRoundInt(FPT op, FPCR fpcr, RoundingMode rounding, bool exact, FPSR& fpsr)
}
if (type == FPType::Infinity) {
return FPInfo<FPT>::Infinity(sign);
return FPT(FPInfo<FPT>::Infinity(sign));
}
if (type == FPType::Zero) {
return FPInfo<FPT>::Zero(sign);
return FPT(FPInfo<FPT>::Zero(sign));
}
// Reshift decimal point back to bit zero.
@ -79,7 +79,7 @@ u64 FPRoundInt(FPT op, FPCR fpcr, RoundingMode rounding, bool exact, FPSR& fpsr)
const u64 abs_int_result = new_sign ? Safe::Negate<u64>(int_result) : static_cast<u64>(int_result);
const FPT result = int_result == 0
? FPInfo<FPT>::Zero(sign)
? FPT(FPInfo<FPT>::Zero(sign))
: FPRound<FPT>(FPUnpacked{new_sign, normalized_point_position, abs_int_result}, fpcr, RoundingMode::TowardsZero, fpsr);
if (error != ResidualError::Zero && exact) {
@ -89,6 +89,7 @@ u64 FPRoundInt(FPT op, FPCR fpcr, RoundingMode rounding, bool exact, FPSR& fpsr)
return result;
}
template u64 FPRoundInt<u16>(u16 op, FPCR fpcr, RoundingMode rounding, bool exact, FPSR& fpsr);
template u64 FPRoundInt<u32>(u32 op, FPCR fpcr, RoundingMode rounding, bool exact, FPSR& fpsr);
template u64 FPRoundInt<u64>(u64 op, FPCR fpcr, RoundingMode rounding, bool exact, FPSR& fpsr);

View file

@ -621,9 +621,9 @@ INST(XTN, "XTN, XTN2", "0Q001
INST(SQXTN_2, "SQXTN, SQXTN2", "0Q001110zz100001010010nnnnnddddd")
INST(FCVTN, "FCVTN, FCVTN2", "0Q0011100z100001011010nnnnnddddd")
INST(FCVTL, "FCVTL, FCVTL2", "0Q0011100z100001011110nnnnnddddd")
//INST(FRINTN_1, "FRINTN (vector)", "0Q00111001111001100010nnnnnddddd")
INST(FRINTN_1, "FRINTN (vector)", "0Q00111001111001100010nnnnnddddd")
INST(FRINTN_2, "FRINTN (vector)", "0Q0011100z100001100010nnnnnddddd")
//INST(FRINTM_1, "FRINTM (vector)", "0Q00111001111001100110nnnnnddddd")
INST(FRINTM_1, "FRINTM (vector)", "0Q00111001111001100110nnnnnddddd")
INST(FRINTM_2, "FRINTM (vector)", "0Q0011100z100001100110nnnnnddddd")
//INST(FCVTNS_3, "FCVTNS (vector)", "0Q00111001111001101010nnnnnddddd")
INST(FCVTNS_4, "FCVTNS (vector)", "0Q0011100z100001101010nnnnnddddd")
@ -641,9 +641,9 @@ INST(FCMEQ_zero_4, "FCMEQ (zero)", "0Q001
INST(FCMLT_4, "FCMLT (zero)", "0Q0011101z100000111010nnnnnddddd")
INST(FABS_1, "FABS (vector)", "0Q00111011111000111110nnnnnddddd")
INST(FABS_2, "FABS (vector)", "0Q0011101z100000111110nnnnnddddd")
//INST(FRINTP_1, "FRINTP (vector)", "0Q00111011111001100010nnnnnddddd")
INST(FRINTP_1, "FRINTP (vector)", "0Q00111011111001100010nnnnnddddd")
INST(FRINTP_2, "FRINTP (vector)", "0Q0011101z100001100010nnnnnddddd")
//INST(FRINTZ_1, "FRINTZ (vector)", "0Q00111011111001100110nnnnnddddd")
INST(FRINTZ_1, "FRINTZ (vector)", "0Q00111011111001100110nnnnnddddd")
INST(FRINTZ_2, "FRINTZ (vector)", "0Q0011101z100001100110nnnnnddddd")
//INST(FCVTPS_3, "FCVTPS (vector)", "0Q00111011111001101010nnnnnddddd")
INST(FCVTPS_4, "FCVTPS (vector)", "0Q0011101z100001101010nnnnnddddd")
@ -665,9 +665,9 @@ INST(SQXTUN_2, "SQXTUN, SQXTUN2", "0Q101
INST(SHLL, "SHLL, SHLL2", "0Q101110zz100001001110nnnnnddddd")
INST(UQXTN_2, "UQXTN, UQXTN2", "0Q101110zz100001010010nnnnnddddd")
INST(FCVTXN_2, "FCVTXN, FCVTXN2", "0Q1011100z100001011010nnnnnddddd")
//INST(FRINTA_1, "FRINTA (vector)", "0Q10111001111001100010nnnnnddddd")
INST(FRINTA_1, "FRINTA (vector)", "0Q10111001111001100010nnnnnddddd")
INST(FRINTA_2, "FRINTA (vector)", "0Q1011100z100001100010nnnnnddddd")
//INST(FRINTX_1, "FRINTX (vector)", "0Q10111001111001100110nnnnnddddd")
INST(FRINTX_1, "FRINTX (vector)", "0Q10111001111001100110nnnnnddddd")
INST(FRINTX_2, "FRINTX (vector)", "0Q1011100z100001100110nnnnnddddd")
//INST(FCVTNU_3, "FCVTNU (vector)", "0Q10111001111001101010nnnnnddddd")
INST(FCVTNU_4, "FCVTNU (vector)", "0Q1011100z100001101010nnnnnddddd")
@ -681,7 +681,7 @@ INST(NOT, "NOT", "0Q101
INST(RBIT_asimd, "RBIT (vector)", "0Q10111001100000010110nnnnnddddd")
INST(FNEG_1, "FNEG (vector)", "0Q10111011111000111110nnnnnddddd")
INST(FNEG_2, "FNEG (vector)", "0Q1011101z100000111110nnnnnddddd")
//INST(FRINTI_1, "FRINTI (vector)", "0Q10111011111001100110nnnnnddddd")
INST(FRINTI_1, "FRINTI (vector)", "0Q10111011111001100110nnnnnddddd")
INST(FRINTI_2, "FRINTI (vector)", "0Q1011101z100001100110nnnnnddddd")
//INST(FCMGE_zero_3, "FCMGE (zero)", "0Q10111011111000110010nnnnnddddd")
INST(FCMGE_zero_4, "FCMGE (zero)", "0Q1011101z100000110010nnnnnddddd")

View file

@ -149,12 +149,12 @@ bool TranslatorVisitor::FCVT_float(Imm<2> type, Imm<2> opc, Vec Vn, Vec Vd) {
static bool FloatingPointRoundToIntegral(TranslatorVisitor& v, Imm<2> type, Vec Vn, Vec Vd,
FP::RoundingMode rounding_mode, bool exact) {
const auto datasize = FPGetDataSize(type);
if (!datasize || *datasize == 16) {
if (!datasize) {
return v.UnallocatedEncoding();
}
const IR::U32U64 operand = v.V_scalar(*datasize, Vn);
const IR::U32U64 result = v.ir.FPRoundInt(operand, rounding_mode, exact);
const IR::U16U32U64 operand = v.V_scalar(*datasize, Vn);
const IR::U16U32U64 result = v.ir.FPRoundInt(operand, rounding_mode, exact);
v.V_scalar(*datasize, Vd, result);
return true;
}

View file

@ -138,6 +138,17 @@ bool FloatRoundToIntegral(TranslatorVisitor& v, bool Q, bool sz, Vec Vn, Vec Vd,
return true;
}
bool FloatRoundToIntegralHalfPrecision(TranslatorVisitor& v, bool Q, Vec Vn, Vec Vd, FP::RoundingMode rounding_mode, bool exact) {
const size_t datasize = Q ? 128 : 64;
const size_t esize = 16;
const IR::U128 operand = v.V(datasize, Vn);
const IR::U128 result = v.ir.FPVectorRoundInt(esize, operand, rounding_mode, exact);
v.V(datasize, Vd, result);
return true;
}
bool SaturatedNarrow(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vn, Vec Vd, IR::U128 (IR::IREmitter::*fn)(size_t, const IR::U128&)) {
if (size == 0b11) {
return v.ReservedValue();
@ -451,30 +462,58 @@ bool TranslatorVisitor::FCVTZU_int_4(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatConvertToInteger(*this, Q, sz, Vn, Vd, Signedness::Unsigned, FP::RoundingMode::TowardsZero);
}
bool TranslatorVisitor::FRINTN_1(bool Q, Vec Vn, Vec Vd) {
return FloatRoundToIntegralHalfPrecision(*this, Q, Vn, Vd, FP::RoundingMode::ToNearest_TieEven, false);
}
bool TranslatorVisitor::FRINTN_2(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatRoundToIntegral(*this, Q, sz, Vn, Vd, FP::RoundingMode::ToNearest_TieEven, false);
}
bool TranslatorVisitor::FRINTM_1(bool Q, Vec Vn, Vec Vd) {
return FloatRoundToIntegralHalfPrecision(*this, Q, Vn, Vd, FP::RoundingMode::TowardsMinusInfinity, false);
}
bool TranslatorVisitor::FRINTM_2(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatRoundToIntegral(*this, Q, sz, Vn, Vd, FP::RoundingMode::TowardsMinusInfinity, false);
}
bool TranslatorVisitor::FRINTP_1(bool Q, Vec Vn, Vec Vd) {
return FloatRoundToIntegralHalfPrecision(*this, Q, Vn, Vd, FP::RoundingMode::TowardsPlusInfinity, false);
}
bool TranslatorVisitor::FRINTP_2(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatRoundToIntegral(*this, Q, sz, Vn, Vd, FP::RoundingMode::TowardsPlusInfinity, false);
}
bool TranslatorVisitor::FRINTZ_1(bool Q, Vec Vn, Vec Vd) {
return FloatRoundToIntegralHalfPrecision(*this, Q, Vn, Vd, FP::RoundingMode::TowardsZero, false);
}
bool TranslatorVisitor::FRINTZ_2(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatRoundToIntegral(*this, Q, sz, Vn, Vd, FP::RoundingMode::TowardsZero, false);
}
bool TranslatorVisitor::FRINTA_1(bool Q, Vec Vn, Vec Vd) {
return FloatRoundToIntegralHalfPrecision(*this, Q, Vn, Vd, FP::RoundingMode::ToNearest_TieAwayFromZero, false);
}
bool TranslatorVisitor::FRINTA_2(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatRoundToIntegral(*this, Q, sz, Vn, Vd, FP::RoundingMode::ToNearest_TieAwayFromZero, false);
}
bool TranslatorVisitor::FRINTX_1(bool Q, Vec Vn, Vec Vd) {
return FloatRoundToIntegralHalfPrecision(*this, Q, Vn, Vd, ir.current_location->FPCR().RMode(), true);
}
bool TranslatorVisitor::FRINTX_2(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatRoundToIntegral(*this, Q, sz, Vn, Vd, ir.current_location->FPCR().RMode(), true);
}
bool TranslatorVisitor::FRINTI_1(bool Q, Vec Vn, Vec Vd) {
return FloatRoundToIntegralHalfPrecision(*this, Q, Vn, Vd, ir.current_location->FPCR().RMode(), false);
}
bool TranslatorVisitor::FRINTI_2(bool Q, bool sz, Vec Vn, Vec Vd) {
return FloatRoundToIntegral(*this, Q, sz, Vn, Vd,ir.current_location->FPCR().RMode(), false);
}

View file

@ -1950,11 +1950,21 @@ U32U64 IREmitter::FPRecipStepFused(const U32U64& a, const U32U64& b) {
return Inst<U64>(Opcode::FPRecipStepFused64, a, b);
}
U32U64 IREmitter::FPRoundInt(const U32U64& a, FP::RoundingMode rounding, bool exact) {
if (a.GetType() == Type::U32) {
return Inst<U32>(Opcode::FPRoundInt32, a, static_cast<u8>(rounding), Imm1(exact));
U16U32U64 IREmitter::FPRoundInt(const U16U32U64& a, FP::RoundingMode rounding, bool exact) {
const u8 rounding_value = static_cast<u8>(rounding);
const IR::U1 exact_imm = Imm1(exact);
switch (a.GetType()) {
case Type::U16:
return Inst<U16>(Opcode::FPRoundInt16, a, rounding_value, exact_imm);
case Type::U32:
return Inst<U32>(Opcode::FPRoundInt32, a, rounding_value, exact_imm);
case Type::U64:
return Inst<U64>(Opcode::FPRoundInt64, a, rounding_value, exact_imm);
default:
UNREACHABLE();
return U16U32U64{};
}
return Inst<U64>(Opcode::FPRoundInt64, a, static_cast<u8>(rounding), Imm1(exact));
}
U32U64 IREmitter::FPRSqrtEstimate(const U32U64& a) {
@ -2268,11 +2278,16 @@ U128 IREmitter::FPVectorRecipStepFused(size_t esize, const U128& a, const U128&
}
U128 IREmitter::FPVectorRoundInt(size_t esize, const U128& operand, FP::RoundingMode rounding, bool exact) {
const IR::U8 rounding_imm = Imm8(static_cast<u8>(rounding));
const IR::U1 exact_imm = Imm1(exact);
switch (esize) {
case 16:
return Inst<U128>(Opcode::FPVectorRoundInt16, operand, rounding_imm, exact_imm);
case 32:
return Inst<U128>(Opcode::FPVectorRoundInt32, operand, Imm8(static_cast<u8>(rounding)), Imm1(exact));
return Inst<U128>(Opcode::FPVectorRoundInt32, operand, rounding_imm, exact_imm);
case 64:
return Inst<U128>(Opcode::FPVectorRoundInt64, operand, Imm8(static_cast<u8>(rounding)), Imm1(exact));
return Inst<U128>(Opcode::FPVectorRoundInt64, operand, rounding_imm, exact_imm);
}
UNREACHABLE();
return {};

View file

@ -308,7 +308,7 @@ public:
U32U64 FPRecipEstimate(const U32U64& a);
U16U32U64 FPRecipExponent(const U16U32U64& a);
U32U64 FPRecipStepFused(const U32U64& a, const U32U64& b);
U32U64 FPRoundInt(const U32U64& a, FP::RoundingMode rounding, bool exact);
U16U32U64 FPRoundInt(const U16U32U64& a, FP::RoundingMode rounding, bool exact);
U32U64 FPRSqrtEstimate(const U32U64& a);
U32U64 FPRSqrtStepFused(const U32U64& a, const U32U64& b);
U32U64 FPSqrt(const U32U64& a);

View file

@ -279,6 +279,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const {
case Opcode::FPRecipExponent64:
case Opcode::FPRecipStepFused32:
case Opcode::FPRecipStepFused64:
case Opcode::FPRoundInt16:
case Opcode::FPRoundInt32:
case Opcode::FPRoundInt64:
case Opcode::FPRSqrtEstimate32:
@ -338,6 +339,9 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const {
case Opcode::FPVectorRecipEstimate64:
case Opcode::FPVectorRecipStepFused32:
case Opcode::FPVectorRecipStepFused64:
case Opcode::FPVectorRoundInt16:
case Opcode::FPVectorRoundInt32:
case Opcode::FPVectorRoundInt64:
case Opcode::FPVectorRSqrtEstimate32:
case Opcode::FPVectorRSqrtEstimate64:
case Opcode::FPVectorRSqrtStepFused32:

View file

@ -498,6 +498,7 @@ OPCODE(FPRecipExponent32, U32, U32
OPCODE(FPRecipExponent64, U64, U64 )
OPCODE(FPRecipStepFused32, U32, U32, U32 )
OPCODE(FPRecipStepFused64, U64, U64, U64 )
OPCODE(FPRoundInt16, U16, U16, U8, U1 )
OPCODE(FPRoundInt32, U32, U32, U8, U1 )
OPCODE(FPRoundInt64, U64, U64, U8, U1 )
OPCODE(FPRSqrtEstimate32, U32, U32 )
@ -573,6 +574,7 @@ OPCODE(FPVectorRecipEstimate32, U128, U128
OPCODE(FPVectorRecipEstimate64, U128, U128 )
OPCODE(FPVectorRecipStepFused32, U128, U128, U128 )
OPCODE(FPVectorRecipStepFused64, U128, U128, U128 )
OPCODE(FPVectorRoundInt16, U128, U128, U8, U1 )
OPCODE(FPVectorRoundInt32, U128, U128, U8, U1 )
OPCODE(FPVectorRoundInt64, U128, U128, U8, U1 )
OPCODE(FPVectorRSqrtEstimate32, U128, U128 )