IR: Implement SHA256MessageSchedule{0,1}

This commit is contained in:
merry 2022-03-20 09:32:59 +00:00
parent f0a4bf1f6a
commit 98cff8dd0d
5 changed files with 49 additions and 58 deletions

View file

@ -48,4 +48,34 @@ void EmitX64::EmitSHA256Hash(EmitContext& ctx, IR::Inst* inst) {
ctx.reg_alloc.DefineValue(inst, y); ctx.reg_alloc.DefineValue(inst, y);
} }
void EmitX64::EmitSHA256MessageSchedule0(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
ASSERT(code.HasHostFeature(HostFeature::SHA));
const Xbyak::Xmm x = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm y = ctx.reg_alloc.UseXmm(args[1]);
code.sha256msg1(x, y);
ctx.reg_alloc.DefineValue(inst, x);
}
void EmitX64::EmitSHA256MessageSchedule1(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
ASSERT(code.HasHostFeature(HostFeature::SHA));
const Xbyak::Xmm x = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm y = ctx.reg_alloc.UseXmm(args[1]);
const Xbyak::Xmm z = ctx.reg_alloc.UseXmm(args[2]);
code.movaps(xmm0, z);
code.palignr(xmm0, y, 4);
code.paddd(x, xmm0);
code.sha256msg2(x, z);
ctx.reg_alloc.DefineValue(inst, x);
}
} // namespace Dynarmic::Backend::X64 } // namespace Dynarmic::Backend::X64

View file

@ -114,72 +114,21 @@ bool TranslatorVisitor::SHA1H(Vec Vn, Vec Vd) {
} }
bool TranslatorVisitor::SHA256SU0(Vec Vn, Vec Vd) { bool TranslatorVisitor::SHA256SU0(Vec Vn, Vec Vd) {
const IR::U128 d = ir.GetQ(Vd); const IR::U128 x = ir.GetQ(Vd);
const IR::U128 n = ir.GetQ(Vn); const IR::U128 y = ir.GetQ(Vn);
const IR::U128 t = [&] { const IR::U128 result = ir.SHA256MessageSchedule0(x, y);
// Shuffle the upper three elements down: [3, 2, 1, 0] -> [0, 3, 2, 1]
const IR::U128 shuffled = ir.VectorShuffleWords(d, 0b00111001);
return ir.VectorSetElement(32, shuffled, 3, ir.VectorGetElement(32, n, 0));
}();
IR::U128 result = ir.ZeroVector();
for (size_t i = 0; i < 4; i++) {
const IR::U32 modified_element = [&] {
const IR::U32 element = ir.VectorGetElement(32, t, i);
const IR::U32 tmp1 = ir.RotateRight(element, ir.Imm8(7));
const IR::U32 tmp2 = ir.RotateRight(element, ir.Imm8(18));
const IR::U32 tmp3 = ir.LogicalShiftRight(element, ir.Imm8(3));
return ir.Eor(tmp1, ir.Eor(tmp2, tmp3));
}();
const IR::U32 d_element = ir.VectorGetElement(32, d, i);
result = ir.VectorSetElement(32, result, i, ir.Add(modified_element, d_element));
}
ir.SetQ(Vd, result); ir.SetQ(Vd, result);
return true; return true;
} }
bool TranslatorVisitor::SHA256SU1(Vec Vm, Vec Vn, Vec Vd) { bool TranslatorVisitor::SHA256SU1(Vec Vm, Vec Vn, Vec Vd) {
const IR::U128 d = ir.GetQ(Vd); const IR::U128 x = ir.GetQ(Vd);
const IR::U128 m = ir.GetQ(Vm); const IR::U128 y = ir.GetQ(Vn);
const IR::U128 n = ir.GetQ(Vn); const IR::U128 z = ir.GetQ(Vm);
const IR::U128 T0 = [&] { const IR::U128 result = ir.SHA256MessageSchedule1(x, y, z);
const IR::U32 low_m = ir.VectorGetElement(32, m, 0);
const IR::U128 shuffled_n = ir.VectorShuffleWords(n, 0b00111001);
return ir.VectorSetElement(32, shuffled_n, 3, low_m);
}();
const IR::U128 lower_half = [&] {
const IR::U128 T = ir.VectorShuffleWords(m, 0b01001110);
const IR::U128 tmp1 = ir.VectorRotateRight(32, T, 17);
const IR::U128 tmp2 = ir.VectorRotateRight(32, T, 19);
const IR::U128 tmp3 = ir.VectorLogicalShiftRight(32, T, 10);
const IR::U128 tmp4 = ir.VectorEor(tmp1, ir.VectorEor(tmp2, tmp3));
const IR::U128 tmp5 = ir.VectorAdd(32, tmp4, ir.VectorAdd(32, d, T0));
return ir.VectorZeroUpper(tmp5);
}();
const IR::U64 upper_half = [&] {
const IR::U128 tmp1 = ir.VectorRotateRight(32, lower_half, 17);
const IR::U128 tmp2 = ir.VectorRotateRight(32, lower_half, 19);
const IR::U128 tmp3 = ir.VectorLogicalShiftRight(32, lower_half, 10);
const IR::U128 tmp4 = ir.VectorEor(tmp1, ir.VectorEor(tmp2, tmp3));
// Shuffle the top two 32-bit elements downwards [3, 2, 1, 0] -> [1, 0, 3, 2]
const IR::U128 shuffled_d = ir.VectorShuffleWords(d, 0b01001110);
const IR::U128 shuffled_T0 = ir.VectorShuffleWords(T0, 0b01001110);
const IR::U128 tmp5 = ir.VectorAdd(32, tmp4, ir.VectorAdd(32, shuffled_d, shuffled_T0));
return ir.VectorGetElement(64, tmp5, 0);
}();
const IR::U128 result = ir.VectorSetElement(64, lower_half, 1, upper_half);
ir.SetQ(Vd, result); ir.SetQ(Vd, result);
return true; return true;

View file

@ -907,6 +907,14 @@ U128 IREmitter::SHA256Hash(const U128& x, const U128& y, const U128& w, bool par
return Inst<U128>(Opcode::SHA256Hash, x, y, w, Imm1(part1)); return Inst<U128>(Opcode::SHA256Hash, x, y, w, Imm1(part1));
} }
U128 IREmitter::SHA256MessageSchedule0(const U128& x, const U128& y) {
return Inst<U128>(Opcode::SHA256MessageSchedule0, x, y);
}
U128 IREmitter::SHA256MessageSchedule1(const U128& x, const U128& y, const U128& z) {
return Inst<U128>(Opcode::SHA256MessageSchedule1, x, y, z);
}
UAny IREmitter::VectorGetElement(size_t esize, const U128& a, size_t index) { UAny IREmitter::VectorGetElement(size_t esize, const U128& a, size_t index) {
ASSERT_MSG(esize * index < 128, "Invalid index"); ASSERT_MSG(esize * index < 128, "Invalid index");
switch (esize) { switch (esize) {

View file

@ -237,6 +237,8 @@ public:
U8 SM4AccessSubstitutionBox(const U8& a); U8 SM4AccessSubstitutionBox(const U8& a);
U128 SHA256Hash(const U128& x, const U128& y, const U128& w, bool part1); U128 SHA256Hash(const U128& x, const U128& y, const U128& w, bool part1);
U128 SHA256MessageSchedule0(const U128& x, const U128& y);
U128 SHA256MessageSchedule1(const U128& x, const U128& y, const U128& z);
UAny VectorGetElement(size_t esize, const U128& a, size_t index); UAny VectorGetElement(size_t esize, const U128& a, size_t index);
U128 VectorSetElement(size_t esize, const U128& a, size_t index, const UAny& elem); U128 VectorSetElement(size_t esize, const U128& a, size_t index, const UAny& elem);

View file

@ -274,6 +274,8 @@ OPCODE(SM4AccessSubstitutionBox, U8, U8
// SHA instructions // SHA instructions
OPCODE(SHA256Hash, U128, U128, U128, U128, U1 ) OPCODE(SHA256Hash, U128, U128, U128, U128, U1 )
OPCODE(SHA256MessageSchedule0, U128, U128, U128 )
OPCODE(SHA256MessageSchedule1, U128, U128, U128, U128 )
// Vector instructions // Vector instructions
OPCODE(VectorGetElement8, U8, U128, U8 ) OPCODE(VectorGetElement8, U8, U128, U8 )