diff --git a/src/backend/x64/emit_x64_saturation.cpp b/src/backend/x64/emit_x64_saturation.cpp index 2f5a2879..91ff6ff4 100644 --- a/src/backend/x64/emit_x64_saturation.cpp +++ b/src/backend/x64/emit_x64_saturation.cpp @@ -134,6 +134,66 @@ void EmitX64::EmitSignedSaturatedAdd64(EmitContext& ctx, IR::Inst* inst) { EmitSignedSaturatedOp(code, ctx, inst); } +void EmitX64::EmitSignedSaturatedDoublingMultiplyReturnHigh16(EmitContext& ctx, IR::Inst* inst) { + auto overflow_inst = inst->GetAssociatedPseudoOperation(IR::Opcode::GetOverflowFromOp); + + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + const Xbyak::Reg32 x = ctx.reg_alloc.UseScratchGpr(args[0]).cvt32(); + const Xbyak::Reg32 y = ctx.reg_alloc.UseScratchGpr(args[1]).cvt32(); + const Xbyak::Reg32 tmp = ctx.reg_alloc.ScratchGpr().cvt32(); + + code.movsx(x, x.cvt16()); + code.movsx(y, y.cvt16()); + + code.imul(x, y); + code.lea(y, ptr[x.cvt64() + x.cvt64()]); + code.mov(tmp, x); + code.shr(tmp, 15); + code.xor_(y, x); + code.mov(y, 0x7FFF); + code.cmovns(y, tmp); + + if (overflow_inst) { + code.sets(tmp.cvt8()); + + ctx.reg_alloc.DefineValue(overflow_inst, tmp); + ctx.EraseInstruction(overflow_inst); + } + + ctx.reg_alloc.DefineValue(inst, y); +} + +void EmitX64::EmitSignedSaturatedDoublingMultiplyReturnHigh32(EmitContext& ctx, IR::Inst* inst) { + auto overflow_inst = inst->GetAssociatedPseudoOperation(IR::Opcode::GetOverflowFromOp); + + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + const Xbyak::Reg64 x = ctx.reg_alloc.UseScratchGpr(args[0]); + const Xbyak::Reg64 y = ctx.reg_alloc.UseScratchGpr(args[1]); + const Xbyak::Reg64 tmp = ctx.reg_alloc.ScratchGpr(); + + code.movsxd(x, x.cvt32()); + code.movsxd(y, y.cvt32()); + + code.imul(x, y); + code.lea(y, ptr[x + x]); + code.mov(tmp, x); + code.shr(tmp, 31); + code.xor_(y, x); + code.mov(y.cvt32(), 0x7FFFFFFF); + code.cmovns(y.cvt32(), tmp.cvt32()); + + if (overflow_inst) { + code.sets(tmp.cvt8()); + + ctx.reg_alloc.DefineValue(overflow_inst, tmp); + ctx.EraseInstruction(overflow_inst); + } + + ctx.reg_alloc.DefineValue(inst, y); +} + void EmitX64::EmitSignedSaturatedSub8(EmitContext& ctx, IR::Inst* inst) { EmitSignedSaturatedOp(code, ctx, inst); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 62bf0685..53025eac 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -501,6 +501,24 @@ ResultAndOverflow IREmitter::SignedSaturatedAdd(const UAny& a, const UAny& return {result, overflow}; } +ResultAndOverflow IREmitter::SignedSaturatedDoublingMultiplyReturnHigh(const UAny& a, const UAny& b) { + ASSERT(a.GetType() == b.GetType()); + const auto result = [&]() -> IR::UAny { + switch (a.GetType()) { + case IR::Type::U16: + return Inst(Opcode::SignedSaturatedDoublingMultiplyReturnHigh16, a, b); + case IR::Type::U32: + return Inst(Opcode::SignedSaturatedDoublingMultiplyReturnHigh32, a, b); + default: + UNREACHABLE(); + return IR::UAny{}; + } + }(); + + const auto overflow = Inst(Opcode::GetOverflowFromOp, result); + return {result, overflow}; +} + ResultAndOverflow IREmitter::SignedSaturatedSub(const UAny& a, const UAny& b) { ASSERT(a.GetType() == b.GetType()); const auto result = [&]() -> IR::UAny { diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index d500f749..6fa579aa 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -143,6 +143,7 @@ public: U32U64 MinUnsigned(const U32U64& a, const U32U64& b); ResultAndOverflow SignedSaturatedAdd(const UAny& a, const UAny& b); + ResultAndOverflow SignedSaturatedDoublingMultiplyReturnHigh(const UAny& a, const UAny& b); ResultAndOverflow SignedSaturatedSub(const UAny& a, const UAny& b); ResultAndOverflow SignedSaturation(const U32& a, size_t bit_size_to_saturate_to); ResultAndOverflow UnsignedSaturatedAdd(const UAny& a, const UAny& b); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 8f01b7d5..9532f38a 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -157,6 +157,8 @@ OPCODE(SignedSaturatedAdd8, U8, U8, OPCODE(SignedSaturatedAdd16, U16, U16, U16 ) OPCODE(SignedSaturatedAdd32, U32, U32, U32 ) OPCODE(SignedSaturatedAdd64, U64, U64, U64 ) +OPCODE(SignedSaturatedDoublingMultiplyReturnHigh16, U16, U16, U16 ) +OPCODE(SignedSaturatedDoublingMultiplyReturnHigh32, U32, U32, U32 ) OPCODE(SignedSaturatedSub8, U8, U8, U8 ) OPCODE(SignedSaturatedSub16, U16, U16, U16 ) OPCODE(SignedSaturatedSub32, U32, U32, U32 )