forked from xai-org/x-algorithm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_retrieval.py
More file actions
149 lines (126 loc) · 4.71 KB
/
run_retrieval.py
File metadata and controls
149 lines (126 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from grok import TransformerConfig
from recsys_model import HashConfig
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from runners import (
RecsysRetrievalInferenceRunner,
RetrievalModelRunner,
create_example_batch,
create_example_corpus,
ACTIONS,
)
def main():
# Model configuration - same architecture as Phoenix ranker
emb_size = 128 # Embedding dimension
num_actions = len(ACTIONS) # Number of explicit engagement actions
history_seq_len = 32 # Max history length
candidate_seq_len = 8 # Max candidates per batch (for training)
# Hash configuration
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
# Configure the retrieval model - uses same transformer as Phoenix
retrieval_model_config = PhoenixRetrievalModelConfig(
emb_size=emb_size,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)
# Create inference runner
inference_runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=retrieval_model_config,
bs_per_device=0.125,
),
name="retrieval_local",
)
print("Initializing retrieval model...")
inference_runner.initialize()
print("Model initialized!")
# Create example batch with simulated user and history
print("\n" + "=" * 70)
print("RETRIEVAL SYSTEM DEMO")
print("=" * 70)
batch_size = 2 # Two users for demo
example_batch, example_embeddings = create_example_batch(
batch_size=batch_size,
emb_size=emb_size,
history_len=history_seq_len,
num_candidates=candidate_seq_len,
num_actions=num_actions,
num_user_hashes=hash_config.num_user_hashes,
num_item_hashes=hash_config.num_item_hashes,
num_author_hashes=hash_config.num_author_hashes,
product_surface_vocab_size=16,
)
# Count valid history items
valid_history_count = int((example_batch.history_post_hashes[:, :, 0] != 0).sum()) # type: ignore
print(f"\nUsers have viewed {valid_history_count} posts total in their history")
# Step 1: Create a corpus of candidate posts
print("\n" + "-" * 70)
print("STEP 1: Creating Candidate Corpus")
print("-" * 70)
corpus_size = 1000 # Simulated corpus of 1000 posts
corpus_embeddings, corpus_post_ids = create_example_corpus(
corpus_size=corpus_size,
emb_size=emb_size,
seed=456,
)
print(f"Corpus size: {corpus_size} posts")
print(f"Corpus embeddings shape: {corpus_embeddings.shape}")
# Set corpus for retrieval
inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)
# Step 2: Retrieve top-k candidates for each user
print("\n" + "-" * 70)
print("STEP 2: Retrieving Top-K Candidates")
print("-" * 70)
top_k = 10
retrieval_output = inference_runner.retrieve(
example_batch,
example_embeddings,
top_k=top_k,
)
print(f"\nRetrieved top {top_k} candidates for each of {batch_size} users:")
top_k_indices = np.array(retrieval_output.top_k_indices)
top_k_scores = np.array(retrieval_output.top_k_scores)
for user_idx in range(batch_size):
print(f"\n User {user_idx + 1}:")
print(f" {'Rank':<6} {'Post ID':<12} {'Score':<12}")
print(f" {'-' * 30}")
for rank in range(top_k):
post_id = top_k_indices[user_idx, rank]
score = top_k_scores[user_idx, rank]
bar = "█" * int((score + 1) * 10) + "░" * (20 - int((score + 1) * 10))
print(f" {rank + 1:<6} {post_id:<12} {bar} {score:.4f}")
print("\n" + "=" * 70)
print("Demo complete!")
print("=" * 70)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()