diff --git a/src/backend_x64/emit_x64_vector_floating_point.cpp b/src/backend_x64/emit_x64_vector_floating_point.cpp index 62629557..7a1c1d03 100644 --- a/src/backend_x64/emit_x64_vector_floating_point.cpp +++ b/src/backend_x64/emit_x64_vector_floating_point.cpp @@ -24,6 +24,7 @@ #include "common/mp/list.h" #include "common/mp/lut.h" #include "common/mp/to_tuple.h" +#include "common/mp/vlift.h" #include "common/mp/vllift.h" #include "frontend/ir/basic_block.h" #include "frontend/ir/microinstruction.h" @@ -728,6 +729,55 @@ void EmitX64::EmitFPVectorRecipStepFused64(EmitContext& ctx, IR::Inst* inst) { EmitRecipStepFused(code, ctx, inst); } +template +void EmitFPVectorRoundInt(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { + using FPT = mp::unsigned_integer_of_size; + + const auto rounding = static_cast(inst->GetArg(1).GetU8()); + const bool exact = inst->GetArg(2).GetU1(); + + using rounding_list = mp::list< + std::integral_constant, + std::integral_constant, + std::integral_constant, + std::integral_constant, + std::integral_constant + >; + using exact_list = mp::list, mp::vlift>; + + using key_type = std::tuple; + using value_type = void(*)(VectorArray&, const VectorArray&, FP::FPCR, FP::FPSR&); + + static const auto lut = mp::GenerateLookupTableFromList( + [](auto arg) { + return std::pair{ + mp::to_tuple, + static_cast( + [](VectorArray& output, const VectorArray& input, FP::FPCR fpcr, FP::FPSR& fpsr) { + constexpr FP::RoundingMode rounding_mode = std::get<0>(mp::to_tuple); + constexpr bool exact = std::get<1>(mp::to_tuple); + + for (size_t i = 0; i < output.size(); ++i) { + output[i] = static_cast(FP::FPRoundInt(input[i], fpcr, rounding_mode, exact, fpsr)); + } + } + ) + }; + }, + mp::cartesian_product{} + ); + + EmitTwoOpFallback(code, ctx, inst, lut.at(std::make_tuple(rounding, exact))); +} + +void EmitX64::EmitFPVectorRoundInt32(EmitContext& ctx, IR::Inst* inst) { + EmitFPVectorRoundInt<32>(code, ctx, inst); +} + +void EmitX64::EmitFPVectorRoundInt64(EmitContext& ctx, IR::Inst* inst) { + EmitFPVectorRoundInt<64>(code, ctx, inst); +} + template static void EmitRSqrtEstimate(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { EmitTwoOpFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& operand, FP::FPCR fpcr, FP::FPSR& fpsr) { diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 98aad4c3..399af66d 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -1932,6 +1932,17 @@ U128 IREmitter::FPVectorRecipStepFused(size_t esize, const U128& a, const U128& return {}; } +U128 IREmitter::FPVectorRoundInt(size_t esize, const U128& operand, FP::RoundingMode rounding, bool exact) { + switch (esize) { + case 32: + return Inst(Opcode::FPVectorRoundInt32, operand, Imm8(static_cast(rounding)), Imm1(exact)); + case 64: + return Inst(Opcode::FPVectorRoundInt64, operand, Imm8(static_cast(rounding)), Imm1(exact)); + } + UNREACHABLE(); + return {}; +} + U128 IREmitter::FPVectorRSqrtEstimate(size_t esize, const U128& a) { switch (esize) { case 32: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index 47c005ee..fa7a907e 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -320,6 +320,7 @@ public: U128 FPVectorPairedAddLower(size_t esize, const U128& a, const U128& b); U128 FPVectorRecipEstimate(size_t esize, const U128& a); U128 FPVectorRecipStepFused(size_t esize, const U128& a, const U128& b); + U128 FPVectorRoundInt(size_t esize, const U128& operand, FP::RoundingMode rounding, bool exact); U128 FPVectorRSqrtEstimate(size_t esize, const U128& a); U128 FPVectorRSqrtStepFused(size_t esize, const U128& a, const U128& b); U128 FPVectorS32ToSingle(const U128& a); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 5890433e..105da9dc 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -491,6 +491,8 @@ OPCODE(FPVectorRecipEstimate32, T::U128, T::U128 OPCODE(FPVectorRecipEstimate64, T::U128, T::U128 ) OPCODE(FPVectorRecipStepFused32, T::U128, T::U128, T::U128 ) OPCODE(FPVectorRecipStepFused64, T::U128, T::U128, T::U128 ) +OPCODE(FPVectorRoundInt32, T::U128, T::U128, T::U8, T::U1 ) +OPCODE(FPVectorRoundInt64, T::U128, T::U128, T::U8, T::U1 ) OPCODE(FPVectorRSqrtEstimate32, T::U128, T::U128 ) OPCODE(FPVectorRSqrtEstimate64, T::U128, T::U128 ) OPCODE(FPVectorRSqrtStepFused32, T::U128, T::U128, T::U128 )