Skip to content

Commit 033e8bf

Browse files
committed
Merge remote-tracking branch 'shad_upstream/main' into rb4-new
2 parents d0f91f5 + 23bf8bf commit 033e8bf

File tree

1 file changed

+231
-10
lines changed

1 file changed

+231
-10
lines changed

src/core/cpu_patches.cpp

+231-10
Original file line numberDiff line numberDiff line change
@@ -620,9 +620,6 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera
620620
if (immediateForm) {
621621
u8 length = operands[1].imm.value.u & 0x3F;
622622
u8 index = operands[2].imm.value.u & 0x3F;
623-
if (length == 0) {
624-
length = 64;
625-
}
626623

627624
LOG_DEBUG(Core, "Patching immediate form EXTRQ, length: {}, index: {}", length, index);
628625

@@ -635,7 +632,15 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera
635632
c.push(scratch1);
636633
c.push(scratch2);
637634

638-
u64 mask = (1ULL << length) - 1;
635+
u64 mask;
636+
if (length == 0) {
637+
length = 64; // for the check below
638+
mask = 0xFFFF'FFFF'FFFF'FFFF;
639+
} else {
640+
mask = (1ULL << length) - 1;
641+
}
642+
643+
ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64.");
639644

640645
// Get lower qword from xmm register
641646
MAYBE_AVX(movq, scratch1, xmm_dst);
@@ -676,6 +681,8 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera
676681
const Xbyak::Reg64 scratch2 = rcx;
677682
const Xbyak::Reg64 mask = rdx;
678683

684+
Xbyak::Label length_zero, end;
685+
679686
c.lea(rsp, ptr[rsp - 128]);
680687
c.pushfq();
681688
c.push(scratch1);
@@ -686,9 +693,18 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera
686693
MAYBE_AVX(movq, scratch1, xmm_src);
687694
c.mov(scratch2, scratch1);
688695
c.and_(scratch2, 0x3F);
696+
c.jz(length_zero);
697+
698+
// mask = (1ULL << length) - 1
689699
c.mov(mask, 1);
690700
c.shl(mask, cl);
691701
c.dec(mask);
702+
c.jmp(end);
703+
704+
c.L(length_zero);
705+
c.mov(mask, 0xFFFF'FFFF'FFFF'FFFF);
706+
707+
c.L(end);
692708

693709
// Get the shift amount and store it in scratch2
694710
c.shr(scratch1, 8);
@@ -708,6 +724,149 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera
708724
}
709725
}
710726

727+
static void GenerateINSERTQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenerator& c) {
728+
bool immediateForm = operands[2].type == ZYDIS_OPERAND_TYPE_IMMEDIATE &&
729+
operands[3].type == ZYDIS_OPERAND_TYPE_IMMEDIATE;
730+
731+
ASSERT_MSG(operands[0].type == ZYDIS_OPERAND_TYPE_REGISTER &&
732+
operands[1].type == ZYDIS_OPERAND_TYPE_REGISTER,
733+
"operands 0 and 1 must be registers.");
734+
735+
const auto dst = ZydisToXbyakRegisterOperand(operands[0]);
736+
const auto src = ZydisToXbyakRegisterOperand(operands[1]);
737+
738+
ASSERT_MSG(dst.isXMM() && src.isXMM(), "operands 0 and 1 must be xmm registers.");
739+
740+
Xbyak::Xmm xmm_dst = *reinterpret_cast<const Xbyak::Xmm*>(&dst);
741+
Xbyak::Xmm xmm_src = *reinterpret_cast<const Xbyak::Xmm*>(&src);
742+
743+
if (immediateForm) {
744+
u8 length = operands[2].imm.value.u & 0x3F;
745+
u8 index = operands[3].imm.value.u & 0x3F;
746+
747+
const Xbyak::Reg64 scratch1 = rax;
748+
const Xbyak::Reg64 scratch2 = rcx;
749+
const Xbyak::Reg64 mask = rdx;
750+
751+
// Set rsp to before red zone and save scratch registers
752+
c.lea(rsp, ptr[rsp - 128]);
753+
c.pushfq();
754+
c.push(scratch1);
755+
c.push(scratch2);
756+
c.push(mask);
757+
758+
u64 mask_value;
759+
if (length == 0) {
760+
length = 64; // for the check below
761+
mask_value = 0xFFFF'FFFF'FFFF'FFFF;
762+
} else {
763+
mask_value = (1ULL << length) - 1;
764+
}
765+
766+
ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64.");
767+
768+
MAYBE_AVX(movq, scratch1, xmm_src);
769+
MAYBE_AVX(movq, scratch2, xmm_dst);
770+
c.mov(mask, mask_value);
771+
772+
// src &= mask
773+
c.and_(scratch1, mask);
774+
775+
// src <<= index
776+
c.shl(scratch1, index);
777+
778+
// dst &= ~(mask << index)
779+
mask_value = ~(mask_value << index);
780+
c.mov(mask, mask_value);
781+
c.and_(scratch2, mask);
782+
783+
// dst |= src
784+
c.or_(scratch2, scratch1);
785+
786+
// Insert scratch2 into low 64 bits of dst, upper 64 bits are unaffected
787+
Cpu cpu;
788+
if (cpu.has(Cpu::tAVX)) {
789+
c.vpinsrq(xmm_dst, xmm_dst, scratch2, 0);
790+
} else {
791+
c.pinsrq(xmm_dst, scratch2, 0);
792+
}
793+
794+
c.pop(mask);
795+
c.pop(scratch2);
796+
c.pop(scratch1);
797+
c.popfq();
798+
c.lea(rsp, ptr[rsp + 128]);
799+
} else {
800+
ASSERT_MSG(operands[2].type == ZYDIS_OPERAND_TYPE_UNUSED &&
801+
operands[3].type == ZYDIS_OPERAND_TYPE_UNUSED,
802+
"operands 2 and 3 must be unused for register form.");
803+
804+
const Xbyak::Reg64 scratch1 = rax;
805+
const Xbyak::Reg64 scratch2 = rcx;
806+
const Xbyak::Reg64 index = rdx;
807+
const Xbyak::Reg64 mask = rbx;
808+
809+
Xbyak::Label length_zero, end;
810+
811+
c.lea(rsp, ptr[rsp - 128]);
812+
c.pushfq();
813+
c.push(scratch1);
814+
c.push(scratch2);
815+
c.push(index);
816+
c.push(mask);
817+
818+
// Get upper 64 bits of src and copy it to mask and index
819+
MAYBE_AVX(pextrq, index, xmm_src, 1);
820+
c.mov(mask, index);
821+
822+
// When length is 0, set it to 64
823+
c.and_(mask, 0x3F); // mask now holds the length
824+
c.jz(length_zero); // Check if length is 0 and set mask to all 1s if it is
825+
826+
// Create a mask out of the length
827+
c.mov(cl, mask.cvt8());
828+
c.mov(mask, 1);
829+
c.shl(mask, cl);
830+
c.dec(mask);
831+
c.jmp(end);
832+
833+
c.L(length_zero);
834+
c.mov(mask, 0xFFFF'FFFF'FFFF'FFFF);
835+
836+
c.L(end);
837+
// Get index to insert at
838+
c.shr(index, 8);
839+
c.and_(index, 0x3F);
840+
841+
// src &= mask
842+
MAYBE_AVX(movq, scratch1, xmm_src);
843+
c.and_(scratch1, mask);
844+
845+
// mask = ~(mask << index)
846+
c.mov(cl, index.cvt8());
847+
c.shl(mask, cl);
848+
c.not_(mask);
849+
850+
// src <<= index
851+
c.shl(scratch1, cl);
852+
853+
// dst = (dst & mask) | src
854+
MAYBE_AVX(movq, scratch2, xmm_dst);
855+
c.and_(scratch2, mask);
856+
c.or_(scratch2, scratch1);
857+
858+
// Upper 64 bits are undefined in insertq
859+
MAYBE_AVX(movq, xmm_dst, scratch2);
860+
861+
c.pop(mask);
862+
c.pop(index);
863+
c.pop(scratch2);
864+
c.pop(scratch1);
865+
c.popfq();
866+
c.lea(rsp, ptr[rsp + 128]);
867+
}
868+
}
869+
711870
using PatchFilter = bool (*)(const ZydisDecodedOperand*);
712871
using InstructionGenerator = void (*)(const ZydisDecodedOperand*, Xbyak::CodeGenerator&);
713872
struct PatchInfo {
@@ -730,6 +889,7 @@ static const std::unordered_map<ZydisMnemonic, PatchInfo> Patches = {
730889
#endif
731890

732891
{ZYDIS_MNEMONIC_EXTRQ, {FilterNoSSE4a, GenerateEXTRQ, true}},
892+
{ZYDIS_MNEMONIC_INSERTQ, {FilterNoSSE4a, GenerateINSERTQ, true}},
733893

734894
#ifdef __APPLE__
735895
// Patches for instruction sets not supported by Rosetta 2.
@@ -859,8 +1019,8 @@ static bool TryExecuteIllegalInstruction(void* ctx, void* code_address) {
8591019
bool immediateForm = operands[1].type == ZYDIS_OPERAND_TYPE_IMMEDIATE &&
8601020
operands[2].type == ZYDIS_OPERAND_TYPE_IMMEDIATE;
8611021
if (immediateForm) {
862-
LOG_ERROR(Core, "EXTRQ immediate form should have been patched at code address: {}",
863-
fmt::ptr(code_address));
1022+
LOG_CRITICAL(Core, "EXTRQ immediate form should have been patched at code address: {}",
1023+
fmt::ptr(code_address));
8641024
return false;
8651025
} else {
8661026
ASSERT_MSG(operands[0].type == ZYDIS_OPERAND_TYPE_REGISTER &&
@@ -883,12 +1043,19 @@ static bool TryExecuteIllegalInstruction(void* ctx, void* code_address) {
8831043
u64 lowQWordDst;
8841044
memcpy(&lowQWordDst, dst, sizeof(lowQWordDst));
8851045

886-
u64 mask = lowQWordSrc & 0x3F;
887-
mask = (1ULL << mask) - 1;
1046+
u64 length = lowQWordSrc & 0x3F;
1047+
u64 mask;
1048+
if (length == 0) {
1049+
length = 64; // for the check below
1050+
mask = 0xFFFF'FFFF'FFFF'FFFF;
1051+
} else {
1052+
mask = (1ULL << length) - 1;
1053+
}
8881054

889-
u64 shift = (lowQWordSrc >> 8) & 0x3F;
1055+
u64 index = (lowQWordSrc >> 8) & 0x3F;
1056+
ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64.");
8901057

891-
lowQWordDst >>= shift;
1058+
lowQWordDst >>= index;
8921059
lowQWordDst &= mask;
8931060

8941061
memcpy(dst, &lowQWordDst, sizeof(lowQWordDst));
@@ -899,6 +1066,60 @@ static bool TryExecuteIllegalInstruction(void* ctx, void* code_address) {
8991066
}
9001067
break;
9011068
}
1069+
case ZYDIS_MNEMONIC_INSERTQ: {
1070+
bool immediateForm = operands[2].type == ZYDIS_OPERAND_TYPE_IMMEDIATE &&
1071+
operands[3].type == ZYDIS_OPERAND_TYPE_IMMEDIATE;
1072+
if (immediateForm) {
1073+
LOG_CRITICAL(Core,
1074+
"INSERTQ immediate form should have been patched at code address: {}",
1075+
fmt::ptr(code_address));
1076+
return false;
1077+
} else {
1078+
ASSERT_MSG(operands[2].type == ZYDIS_OPERAND_TYPE_UNUSED &&
1079+
operands[3].type == ZYDIS_OPERAND_TYPE_UNUSED,
1080+
"operands 2 and 3 must be unused for register form.");
1081+
1082+
ASSERT_MSG(operands[0].type == ZYDIS_OPERAND_TYPE_REGISTER &&
1083+
operands[1].type == ZYDIS_OPERAND_TYPE_REGISTER,
1084+
"operands 0 and 1 must be registers.");
1085+
1086+
const auto dstIndex = operands[0].reg.value - ZYDIS_REGISTER_XMM0;
1087+
const auto srcIndex = operands[1].reg.value - ZYDIS_REGISTER_XMM0;
1088+
1089+
const auto dst = Common::GetXmmPointer(ctx, dstIndex);
1090+
const auto src = Common::GetXmmPointer(ctx, srcIndex);
1091+
1092+
u64 lowQWordSrc, highQWordSrc;
1093+
memcpy(&lowQWordSrc, src, sizeof(lowQWordSrc));
1094+
memcpy(&highQWordSrc, (u8*)src + 8, sizeof(highQWordSrc));
1095+
1096+
u64 lowQWordDst;
1097+
memcpy(&lowQWordDst, dst, sizeof(lowQWordDst));
1098+
1099+
u64 length = highQWordSrc & 0x3F;
1100+
u64 mask;
1101+
if (length == 0) {
1102+
length = 64; // for the check below
1103+
mask = 0xFFFF'FFFF'FFFF'FFFF;
1104+
} else {
1105+
mask = (1ULL << length) - 1;
1106+
}
1107+
1108+
u64 index = (highQWordSrc >> 8) & 0x3F;
1109+
ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64.");
1110+
1111+
lowQWordSrc &= mask;
1112+
lowQWordDst &= ~(mask << index);
1113+
lowQWordDst |= lowQWordSrc << index;
1114+
1115+
memcpy(dst, &lowQWordDst, sizeof(lowQWordDst));
1116+
1117+
Common::IncrementRip(ctx, instruction.length);
1118+
1119+
return true;
1120+
}
1121+
break;
1122+
}
9021123
default: {
9031124
LOG_ERROR(Core, "Unhandled illegal instruction at code address {}: {}",
9041125
fmt::ptr(code_address), ZydisMnemonicGetString(instruction.mnemonic));

0 commit comments

Comments
 (0)