diff --git a/src/frontend/A64/decoder/a64.inc b/src/frontend/A64/decoder/a64.inc index 337f20a3..12efc42c 100644 --- a/src/frontend/A64/decoder/a64.inc +++ b/src/frontend/A64/decoder/a64.inc @@ -864,7 +864,7 @@ INST(SHA512SU1, "SHA512SU1", "11001 INST(RAX1, "RAX1", "11001110011mmmmm100011nnnnnddddd") INST(SM3PARTW1, "SM3PARTW1", "11001110011mmmmm110000nnnnnddddd") INST(SM3PARTW2, "SM3PARTW2", "11001110011mmmmm110001nnnnnddddd") -//INST(SM4EKEY, "SM4EKEY", "11001110011mmmmm110010nnnnnddddd") +INST(SM4EKEY, "SM4EKEY", "11001110011mmmmm110010nnnnnddddd") INST(XAR, "XAR", "11001110100mmmmmiiiiiinnnnnddddd") // Data Processing - FP and SIMD - Cryptographic four register diff --git a/src/frontend/A64/translate/impl/simd_sha512.cpp b/src/frontend/A64/translate/impl/simd_sha512.cpp index 85040232..30c48752 100644 --- a/src/frontend/A64/translate/impl/simd_sha512.cpp +++ b/src/frontend/A64/translate/impl/simd_sha512.cpp @@ -99,6 +99,55 @@ IR::U128 SHA512Hash(IREmitter& ir, Vec Vm, Vec Vn, Vec Vd, SHA512HashPart part) return ir.VectorSetElement(64, low_result, 1, Vtmp); } + +enum class SM4RotationType { + SM4E, + SM4EKEY +}; + +IR::U32 SM4Rotation(IREmitter& ir, IR::U32 intval, IR::U32 round_result_low_word, SM4RotationType type) { + if (type == SM4RotationType::SM4E) { + const IR::U32 tmp1 = ir.RotateRight(intval, ir.Imm8(30)); + const IR::U32 tmp2 = ir.RotateRight(intval, ir.Imm8(22)); + const IR::U32 tmp3 = ir.RotateRight(intval, ir.Imm8(14)); + const IR::U32 tmp4 = ir.RotateRight(intval, ir.Imm8(8)); + const IR::U32 tmp5 = ir.Eor(intval, ir.Eor(tmp1, ir.Eor(tmp2, ir.Eor(tmp3, tmp4)))); + + return ir.Eor(tmp5, round_result_low_word); + } + + const IR::U32 tmp1 = ir.RotateRight(intval, ir.Imm8(19)); + const IR::U32 tmp2 = ir.RotateRight(intval, ir.Imm8(9)); + return ir.Eor(round_result_low_word, ir.Eor(intval, ir.Eor(tmp1, tmp2))); +} + +IR::U128 SM4Hash(IREmitter& ir, Vec Vn, Vec Vd, SM4RotationType type) { + const IR::U128 n = ir.GetQ(Vn); + IR::U128 roundresult = ir.GetQ(Vd); + + for (size_t i = 0; i < 4; i++) { + const IR::U32 round_key = ir.VectorGetElement(32, n, i); + + const IR::U32 upper_round = ir.VectorGetElement(32, roundresult, 3); + const IR::U32 before_upper_round = ir.VectorGetElement(32, roundresult, 2); + const IR::U32 after_lower_round = ir.VectorGetElement(32, roundresult, 1); + + IR::U128 intval_vec = ir.ZeroExtendToQuad(ir.Eor(upper_round, ir.Eor(before_upper_round, ir.Eor(after_lower_round, round_key)))); + + for (size_t j = 0; j < 4; j++) { + const IR::U8 byte_element = ir.VectorGetElement(8, intval_vec, j); + intval_vec = ir.VectorSetElement(8, intval_vec, j, ir.SM4AccessSubstitutionBox(byte_element)); + } + + const IR::U32 intval_low_word = ir.VectorGetElement(32, intval_vec, 0); + const IR::U32 round_result_low_word = ir.VectorGetElement(32, roundresult, 0); + const IR::U32 intval = SM4Rotation(ir, intval_low_word, round_result_low_word, type); + roundresult = ir.VectorShuffleWords(roundresult, 0b00111001); + roundresult = ir.VectorSetElement(32, roundresult, 3, intval); + } + + return roundresult; +} } // Anonymous namespace bool TranslatorVisitor::SHA512SU0(Vec Vn, Vec Vd) { @@ -240,39 +289,12 @@ bool TranslatorVisitor::SM3PARTW2(Vec Vm, Vec Vn, Vec Vd) { } bool TranslatorVisitor::SM4E(Vec Vn, Vec Vd) { - const IR::U128 n = ir.GetQ(Vn); - IR::U128 roundresult = ir.GetQ(Vd); + ir.SetQ(Vd, SM4Hash(ir, Vn, Vd, SM4RotationType::SM4E)); + return true; +} - for (size_t i = 0; i < 4; i++) { - const IR::U32 round_key = ir.VectorGetElement(32, n, i); - - const IR::U32 upper_round = ir.VectorGetElement(32, roundresult, 3); - const IR::U32 before_upper_round = ir.VectorGetElement(32, roundresult, 2); - const IR::U32 after_lower_round = ir.VectorGetElement(32, roundresult, 1); - - IR::U128 intval_vec = ir.ZeroExtendToQuad(ir.Eor(upper_round, ir.Eor(before_upper_round, ir.Eor(after_lower_round, round_key)))); - - for (size_t i = 0; i < 4; i++) { - const IR::U8 byte_element = ir.VectorGetElement(8, intval_vec, i); - intval_vec = ir.VectorSetElement(8, intval_vec, i, ir.SM4AccessSubstitutionBox(byte_element)); - } - - const IR::U32 intval = [&] { - const IR::U32 low_word = ir.VectorGetElement(32, intval_vec, 0); - const IR::U32 tmp1 = ir.RotateRight(low_word, ir.Imm8(30)); - const IR::U32 tmp2 = ir.RotateRight(low_word, ir.Imm8(22)); - const IR::U32 tmp3 = ir.RotateRight(low_word, ir.Imm8(14)); - const IR::U32 tmp4 = ir.RotateRight(low_word, ir.Imm8(8)); - - const IR::U32 tmp5 = ir.Eor(low_word, ir.Eor(tmp1, ir.Eor(tmp2, ir.Eor(tmp3, tmp4)))); - return ir.Eor(tmp5, ir.VectorGetElement(32, roundresult, 0)); - }(); - - roundresult = ir.VectorShuffleWords(roundresult, 0b00111001); - roundresult = ir.VectorSetElement(32, roundresult, 3, intval); - } - - ir.SetQ(Vd, roundresult); +bool TranslatorVisitor::SM4EKEY(Vec Vm, Vec Vn, Vec Vd) { + ir.SetQ(Vd, SM4Hash(ir, Vm, Vn, SM4RotationType::SM4EKEY)); return true; }