Skip to content

Commit dddf9ec

Browse files
committed
RSX: Use AVX-512-ICL in vertex shader hashing and comparisons RPCS3#16780
1 parent 8d51f6d commit dddf9ec

File tree

1 file changed

+191
-20
lines changed

1 file changed

+191
-20
lines changed

rpcs3/Emu/RSX/Program/ProgramStateCache.cpp

Lines changed: 191 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "stdafx.h"
22
#include "ProgramStateCache.h"
33
#include "Emu/system_config.h"
4+
#include "util/sysinfo.hpp"
45

56
#include <stack>
67

@@ -21,31 +22,123 @@
2122
#endif
2223
#endif
2324

25+
#ifdef ARCH_ARM64
26+
#define AVX512_ICL_FUNC
27+
#endif
28+
29+
#ifdef _MSC_VER
30+
#define AVX512_ICL_FUNC
31+
#else
32+
#define AVX512_ICL_FUNC __attribute__((__target__("avx512f,avx512bw,avx512dq,avx512cd,avx512vl,avx512bitalg,avx512ifma,avx512vbmi,avx512vbmi2,avx512vnni,avx512vpopcntdq")))
33+
#endif
34+
35+
2436
using namespace program_hash_util;
2537

26-
usz vertex_program_utils::get_vertex_program_ucode_hash(const RSXVertexProgram &program)
38+
AVX512_ICL_FUNC usz vertex_program_utils::get_vertex_program_ucode_hash(const RSXVertexProgram &program)
2739
{
28-
// Checksum as hash with rotated data
29-
const void* instbuffer = program.data.data();
30-
u32 instIndex = 0;
31-
usz acc0 = 0;
32-
usz acc1 = 0;
33-
34-
do
40+
#ifdef ARCH_X64
41+
if (utils::has_avx512_icl())
3542
{
36-
if (program.instruction_mask[instIndex])
43+
// Load all elements of the instruction_mask bitset
44+
const __m512i* instMask512 = reinterpret_cast<const __m512i*>(&program.instruction_mask);
45+
const __m128i* instMask128 = reinterpret_cast<const __m128i*>(&program.instruction_mask);
46+
47+
const __m512i lowerMask = _mm512_loadu_si512(instMask512);
48+
const __m128i upper128 = _mm_loadu_si128(instMask128 + 4);
49+
const __m512i upperMask = _mm512_zextsi128_si512(upper128);
50+
51+
__m512i maskIndex = _mm512_setzero_si512();
52+
const __m512i negativeOnes = _mm512_set1_epi64(-1);
53+
54+
// Special masks to test against bitset
55+
const __m512i testMask0 = _mm512_set_epi64(
56+
0x0808080808080808,
57+
0x0808080808080808,
58+
0x0404040404040404,
59+
0x0404040404040404,
60+
0x0202020202020202,
61+
0x0202020202020202,
62+
0x0101010101010101,
63+
0x0101010101010101);
64+
65+
const __m512i testMask1 = _mm512_set_epi64(
66+
0x8080808080808080,
67+
0x8080808080808080,
68+
0x4040404040404040,
69+
0x4040404040404040,
70+
0x2020202020202020,
71+
0x2020202020202020,
72+
0x1010101010101010,
73+
0x1010101010101010);
74+
75+
const __m512i* instBuffer = reinterpret_cast<const __m512i*>(program.data.data());
76+
__m512i acc0 = _mm512_setzero_si512();
77+
__m512i acc1 = _mm512_setzero_si512();
78+
79+
__m512i rotMask0 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0);
80+
__m512i rotMask1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8);
81+
__m512i rotMaskAdd = _mm512_set_epi64(16, 16, 16, 16, 16, 16, 16, 16);
82+
83+
u32 instIndex = 0;
84+
85+
// If there is remainder, add an extra (masked) iteration
86+
u32 extraIteration = (program.data.size() % 32 != 0) ? 1 : 0;
87+
u32 length = (program.data.size() / 32) + extraIteration;
88+
89+
// The instruction mask will prevent us from reading out of bounds, we do not need a seperate masked loop
90+
// for the remainder, or a scalar loop.
91+
while (instIndex < (length))
3792
{
38-
const auto inst = v128::loadu(instbuffer, instIndex);
39-
usz tmp0 = std::rotr(inst._u64[0], instIndex * 2);
40-
acc0 += tmp0;
41-
usz tmp1 = std::rotr(inst._u64[1], (instIndex * 2) + 1);
42-
acc1 += tmp1;
93+
const __m512i masks = _mm512_permutex2var_epi8(lowerMask, maskIndex, upperMask);
94+
const __mmask8 result0 = _mm512_test_epi64_mask(masks, testMask0);
95+
const __mmask8 result1 = _mm512_test_epi64_mask(masks, testMask1);
96+
const __m512i load0 = _mm512_maskz_loadu_epi64(result0, (instBuffer + instIndex * 2));
97+
const __m512i load1 = _mm512_maskz_loadu_epi64(result1, (instBuffer + (instIndex * 2)+ 1));
98+
99+
const __m512i rotated0 = _mm512_rorv_epi64(load0, rotMask0);
100+
const __m512i rotated1 = _mm512_rorv_epi64(load1, rotMask1);
101+
102+
acc0 = _mm512_add_epi64(acc0, rotated0);
103+
acc1 = _mm512_add_epi64(acc1, rotated1);
104+
105+
rotMask0 = _mm512_add_epi64(rotMask0, rotMaskAdd);
106+
rotMask1 = _mm512_add_epi64(rotMask1, rotMaskAdd);
107+
maskIndex = _mm512_sub_epi8(maskIndex, negativeOnes);
108+
109+
instIndex++;
43110
}
44111

45-
instIndex++;
46-
} while (instIndex < (program.data.size() / 4));
47-
return acc0 + acc1;
48-
}
112+
const __m512i result = _mm512_add_epi64(acc0, acc1);
113+
usz hash = _mm512_reduce_add_epi64(result);
114+
115+
return hash;
116+
}
117+
#endif
118+
119+
// Checksum as hash with rotated data
120+
const void* instbuffer = program.data.data();
121+
u32 instIndex = 0;
122+
usz acc0 = 0;
123+
usz acc1 = 0;
124+
125+
do
126+
{
127+
if (program.instruction_mask[instIndex])
128+
{
129+
const auto inst = v128::loadu(instbuffer, instIndex);
130+
usz tmp0 = std::rotr(inst._u64[0], instIndex * 2);
131+
acc0 += tmp0;
132+
usz tmp1 = std::rotr(inst._u64[1], (instIndex * 2) + 1);
133+
acc1 += tmp1;
134+
}
135+
136+
instIndex++;
137+
} while (instIndex < (program.data.size() / 4));
138+
u64 hash = acc0 + acc1;
139+
140+
return hash;
141+
}
49142

50143
vertex_program_utils::vertex_program_metadata vertex_program_utils::analyse_vertex_program(const u32* data, u32 entry, RSXVertexProgram& dst_prog)
51144
{
@@ -350,7 +443,7 @@ usz vertex_program_storage_hash::operator()(const RSXVertexProgram &program) con
350443
return rpcs3::hash64(ucode_hash, metadata_hash);
351444
}
352445

353-
bool vertex_program_compare::operator()(const RSXVertexProgram &binary1, const RSXVertexProgram &binary2) const
446+
AVX512_ICL_FUNC bool vertex_program_compare::operator()(const RSXVertexProgram &binary1, const RSXVertexProgram &binary2) const
354447
{
355448
if (binary1.output_mask != binary2.output_mask)
356449
return false;
@@ -363,10 +456,88 @@ bool vertex_program_compare::operator()(const RSXVertexProgram &binary1, const R
363456
if (binary1.jump_table != binary2.jump_table)
364457
return false;
365458

459+
#ifdef ARCH_X64
460+
if (utils::has_avx512_icl())
461+
{
462+
// Load all elements of the instruction_mask bitset
463+
const __m512i* instMask512 = reinterpret_cast<const __m512i*>(&binary1.instruction_mask);
464+
const __m128i* instMask128 = reinterpret_cast<const __m128i*>(&binary1.instruction_mask);
465+
466+
const __m512i lowerMask = _mm512_loadu_si512(instMask512);
467+
const __m128i upper128 = _mm_loadu_si128(instMask128 + 4);
468+
const __m512i upperMask = _mm512_zextsi128_si512(upper128);
469+
470+
__m512i maskIndex = _mm512_setzero_si512();
471+
const __m512i negativeOnes = _mm512_set1_epi64(-1);
472+
473+
// Special masks to test against bitset
474+
const __m512i testMask0 = _mm512_set_epi64(
475+
0x0808080808080808,
476+
0x0808080808080808,
477+
0x0404040404040404,
478+
0x0404040404040404,
479+
0x0202020202020202,
480+
0x0202020202020202,
481+
0x0101010101010101,
482+
0x0101010101010101);
483+
484+
const __m512i testMask1 = _mm512_set_epi64(
485+
0x8080808080808080,
486+
0x8080808080808080,
487+
0x4040404040404040,
488+
0x4040404040404040,
489+
0x2020202020202020,
490+
0x2020202020202020,
491+
0x1010101010101010,
492+
0x1010101010101010);
493+
494+
const __m512i* instBuffer1 = reinterpret_cast<const __m512i*>(binary1.data.data());
495+
const __m512i* instBuffer2 = reinterpret_cast<const __m512i*>(binary2.data.data());
496+
497+
// If there is remainder, add an extra (masked) iteration
498+
u32 extraIteration = (binary1.data.size() % 32 != 0) ? 1 : 0;
499+
u32 length = (binary1.data.size() / 32) + extraIteration;
500+
501+
u32 instIndex = 0;
502+
503+
// The instruction mask will prevent us from reading out of bounds, we do not need a seperate masked loop
504+
// for the remainder, or a scalar loop.
505+
while (instIndex < (length))
506+
{
507+
const __m512i masks = _mm512_permutex2var_epi8(lowerMask, maskIndex, upperMask);
508+
509+
const __mmask8 result0 = _mm512_test_epi64_mask(masks, testMask0);
510+
const __mmask8 result1 = _mm512_test_epi64_mask(masks, testMask1);
511+
512+
const __m512i load0 = _mm512_maskz_loadu_epi64(result0, (instBuffer1 + (instIndex * 2)));
513+
const __m512i load1 = _mm512_maskz_loadu_epi64(result0, (instBuffer2 + (instIndex * 2)));
514+
const __m512i load2 = _mm512_maskz_loadu_epi64(result1, (instBuffer1 + (instIndex * 2) + 1));
515+
const __m512i load3 = _mm512_maskz_loadu_epi64(result1, (instBuffer2 + (instIndex * 2)+ 1));
516+
517+
const __mmask8 res0 = _mm512_cmpneq_epi64_mask(load0, load1);
518+
const __mmask8 res1 = _mm512_cmpneq_epi64_mask(load2, load3);
519+
520+
const u8 result = _kortestz_mask8_u8(res0, res1);
521+
522+
//kortestz will set result to 1 if all bits are zero, so invert the check for result
523+
if (!result)
524+
{
525+
return false;
526+
}
527+
528+
maskIndex = _mm512_sub_epi8(maskIndex, negativeOnes);
529+
530+
instIndex++;
531+
}
532+
533+
return true;
534+
}
535+
#endif
536+
366537
const void* instBuffer1 = binary1.data.data();
367538
const void* instBuffer2 = binary2.data.data();
368539
usz instIndex = 0;
369-
for (unsigned i = 0; i < binary1.data.size() / 4; i++)
540+
while (instIndex < (binary1.data.size() / 4))
370541
{
371542
if (binary1.instruction_mask[instIndex])
372543
{

0 commit comments

Comments
 (0)