|
30 | 30 | },
|
31 | 31 | "outputs": [],
|
32 | 32 | "source": [
|
33 |
| - "pip install --quiet --upgrade pymongo cohere" |
| 33 | + "pip install --quiet --upgrade voyageai pymongo" |
34 | 34 | ]
|
35 | 35 | },
|
36 | 36 | {
|
|
40 | 40 | "outputs": [],
|
41 | 41 | "source": [
|
42 | 42 | "import os\n",
|
43 |
| - "import pymongo\n", |
44 |
| - "import cohere\n", |
| 43 | + "import voyageai\n", |
45 | 44 | "from bson.binary import Binary, BinaryVectorDtype\n",
|
46 | 45 | "\n",
|
47 |
| - "# Specify your Cohere API key\n", |
48 |
| - "os.environ[\"COHERE_API_KEY\"] = \"<COHERE-API-KEY>\"\n", |
49 |
| - "cohere_client = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n", |
| 46 | + "# Initialize the VoyageAI Client\n", |
| 47 | + "os.environ[\"VOYAGE_API_KEY\"] = \"<VOYAGEAI-API-KEY>\"\n", |
| 48 | + "vo = voyageai.Client()\n", |
50 | 49 | "\n",
|
51 |
| - "# Define function to generate embeddings using the embed-english-v3.0 model\n", |
52 |
| - "def get_embedding(text):\n", |
53 |
| - " response = cohere_client.embed(\n", |
54 |
| - " texts=[text],\n", |
55 |
| - " model='embed-english-v3.0',\n", |
56 |
| - " input_type='search_document',\n", |
57 |
| - " embedding_types=[\"float\"] # Can also be \"int8\" or \"ubinary\" (int1)\n", |
58 |
| - " )\n", |
59 |
| - " embedding = response.embeddings.float[0]\n", |
| 50 | + "# Define a function to generate embeddings for all strings in `texts`\n", |
| 51 | + "def generate_embeddings(texts, model: str, dtype: str, output_dimension: int):\n", |
| 52 | + " embeddings = []\n", |
| 53 | + " for text in texts: # Process eachstring in the data list\n", |
| 54 | + " embedding = vo.embed(\n", |
| 55 | + " texts=[text], # Pass each string as a list with a single item\n", |
| 56 | + " model=model,\n", |
| 57 | + " output_dtype=dtype,\n", |
| 58 | + " output_dimension=output_dimension,\n", |
| 59 | + " ).embeddings[0]\n", |
| 60 | + " embeddings.append(embedding) # Collect the embedding for the current text\n", |
| 61 | + " return embeddings\n", |
60 | 62 | "\n",
|
61 |
| - " # If you specified a different data type, uncomment one of following lines and delete the preceding line\n", |
62 |
| - " # embedding = response.embeddings.int8[0]\n", |
63 |
| - " # embedding = response.embeddings.ubinary[0] # refers to int1 data type\n", |
64 |
| - "\n", |
65 |
| - " return embedding\n", |
66 |
| - "\n", |
67 |
| - "# Define function to convert embeddings to BSON-compatible format\n", |
| 63 | + "# Convert embeddings to BSON vectors\n", |
68 | 64 | "def generate_bson_vector(vector, vector_dtype):\n",
|
69 |
| - " return Binary.from_vector(vector, vector_dtype)" |
| 65 | + " return Binary.from_vector(vector, vector_dtype)" |
70 | 66 | ]
|
71 | 67 | },
|
72 | 68 | {
|
|
75 | 71 | "metadata": {},
|
76 | 72 | "outputs": [],
|
77 | 73 | "source": [
|
| 74 | + "import pymongo \n", |
| 75 | + "\n", |
78 | 76 | "# Connect to your Atlas cluster\n",
|
79 |
| - "mongo_client = pymongo.MongoClient(\"<ATLAS-CONNECTION-STRING>\")\n", |
| 77 | + "mongo_client = pymongo.MongoClient(\"<CONNECTION-STRING>\")\n", |
80 | 78 | "db = mongo_client[\"sample_airbnb\"]\n",
|
81 | 79 | "collection = db[\"listingsAndReviews\"]\n",
|
82 | 80 | "\n",
|
|
96 | 94 | "metadata": {},
|
97 | 95 | "outputs": [],
|
98 | 96 | "source": [
|
99 |
| - "for doc in documents:\n", |
100 |
| - " # Generate embeddings based on the summary\n", |
101 |
| - " summary = doc[\"summary\"]\n", |
102 |
| - " embedding = get_embedding(summary) # Get float32 embedding\n", |
103 |
| - "\n", |
104 |
| - " # Convert float32 embeddings into BSON format\n", |
105 |
| - " bson_vector = generate_bson_vector(embedding, BinaryVectorDtype.FLOAT32)\n", |
106 |
| - "\n", |
107 |
| - " # If you specified a different data type, uncomment one of following lines and delete the preceding line\n", |
108 |
| - " # bson_vector = generate_bson_vector(embedding, BinaryVectorDtype.INT8)\n", |
109 |
| - " # bson_vector = generate_bson_vector(embedding, BinaryVectorDtype.PACKED_BIT) # refers to int1 data type\n", |
110 |
| - "\n", |
111 |
| - " # Update the document with the BSON embedding\n", |
112 |
| - " collection.update_one(\n", |
113 |
| - " {\"_id\": doc[\"_id\"]},\n", |
114 |
| - " {\"$set\": {\"embedding\": bson_vector}}\n", |
115 |
| - " )\n", |
116 |
| - " updated_doc_count += 1\n", |
117 |
| - "\n", |
118 |
| - "print(f\"Updated {updated_doc_count} documents with BSON embeddings.\")" |
| 97 | + "model_name = \"voyage-3-large\"\n", |
| 98 | + "output_dimension = 1024\n", |
| 99 | + "float32_field = \"float32-embedding\"\n", |
| 100 | + "int8_field = \"int8-embedding\"\n", |
| 101 | + "int1_field = \"int1-embedding\"\n", |
| 102 | + "\n", |
| 103 | + "# Process and update each document\n", |
| 104 | + "updated_doc_count = 0 \n", |
| 105 | + "for document in documents: \n", |
| 106 | + " summary = document.get(\"summary\") \n", |
| 107 | + " if not summary: \n", |
| 108 | + " continue \n", |
| 109 | + " \n", |
| 110 | + " # Generate embeddings for the summary field \n", |
| 111 | + " float_embeddings = generate_embeddings([summary], model=model_name, dtype=\"float\", output_dimension=output_dimension) \n", |
| 112 | + " int8_embeddings = generate_embeddings([summary], model=model_name, dtype=\"int8\", output_dimension=output_dimension) \n", |
| 113 | + " ubinary_embeddings = generate_embeddings([summary], model=model_name, dtype=\"ubinary\", output_dimension=output_dimension) \n", |
| 114 | + " \n", |
| 115 | + " # Convert embeddings to BSON-compatible format \n", |
| 116 | + " bson_float = generate_bson_vector(float_embeddings[0], BinaryVectorDtype.FLOAT32) \n", |
| 117 | + " bson_int8 = generate_bson_vector(int8_embeddings[0], BinaryVectorDtype.INT8) \n", |
| 118 | + " bson_ubinary = generate_bson_vector(ubinary_embeddings[0], BinaryVectorDtype.PACKED_BIT) \n", |
| 119 | + " \n", |
| 120 | + " # Prepare the updated document \n", |
| 121 | + " updated_fields = { \n", |
| 122 | + " float32_field: bson_float, \n", |
| 123 | + " int8_field: bson_int8, \n", |
| 124 | + " int1_field: bson_ubinary,\n", |
| 125 | + " } \n", |
| 126 | + " \n", |
| 127 | + " # Update the document in MongoDB \n", |
| 128 | + " result = collection.update_one({\"_id\": document[\"_id\"]}, {\"$set\": updated_fields}) \n", |
| 129 | + " if result.modified_count > 0: \n", |
| 130 | + " updated_doc_count += 1 \n", |
| 131 | + " \n", |
| 132 | + "# Print the results \n", |
| 133 | + "print(f\"Number of documents updated: {updated_doc_count}\") " |
119 | 134 | ]
|
120 | 135 | },
|
121 | 136 | {
|
|
128 | 143 | "import time\n",
|
129 | 144 | "\n",
|
130 | 145 | "# Define and create the vector search index\n",
|
131 |
| - "index_name = \"<INDEX-NAME>\"\n", |
| 146 | + "index_name = \"vector_index\"\n", |
132 | 147 | "search_index_model = SearchIndexModel(\n",
|
133 | 148 | " definition={\n",
|
134 | 149 | " \"fields\": [\n",
|
135 | 150 | " {\n",
|
136 | 151 | " \"type\": \"vector\",\n",
|
137 |
| - " \"path\": \"embedding\",\n", |
| 152 | + " \"path\": float32_field,\n", |
| 153 | + " \"similarity\": \"dotProduct\",\n", |
| 154 | + " \"numDimensions\": 1024\n", |
| 155 | + " },\n", |
| 156 | + " {\n", |
| 157 | + " \"type\": \"vector\",\n", |
| 158 | + " \"path\": int8_field,\n", |
| 159 | + " \"similarity\": \"dotProduct\",\n", |
| 160 | + " \"numDimensions\": 1024\n", |
| 161 | + " },\n", |
| 162 | + " {\n", |
| 163 | + " \"type\": \"vector\",\n", |
| 164 | + " \"path\": int1_field,\n", |
138 | 165 | " \"similarity\": \"euclidean\",\n",
|
139 | 166 | " \"numDimensions\": 1024\n",
|
140 | 167 | " }\n",
|
|
165 | 192 | "metadata": {},
|
166 | 193 | "outputs": [],
|
167 | 194 | "source": [
|
168 |
| - "# Define function to run a vector search query\n", |
| 195 | + "import voyageai\n", |
| 196 | + "from bson.binary import Binary, BinaryVectorDtype\n", |
| 197 | + "\n", |
| 198 | + "# Define a function to run a vector search query\n", |
169 | 199 | "def run_vector_search(query_text, collection, path):\n",
|
170 |
| - " query_embedding = get_embedding(\"query_text\")\n", |
171 |
| - " bson_query_vector = generate_bson_vector(query_embedding, BinaryVectorDtype.FLOAT32)\n", |
172 |
| - "\n", |
173 |
| - " # If you specified a different data type, uncomment one of following lines and delete the preceding line\n", |
174 |
| - " # bson_query_vector = generate_bson_vector(query_embedding, BinaryVectorDtype.INT8)\n", |
175 |
| - " # bson_query_vector = generate_bson_vector(query_embedding, BinaryVectorDtype.PACKED_BIT) # refers to int1 data type\n", |
176 |
| - "\n", |
177 |
| - " pipeline = [\n", |
178 |
| - " {\n", |
179 |
| - " '$vectorSearch': {\n", |
180 |
| - " 'index': index_name,\n", |
181 |
| - " 'path': path,\n", |
182 |
| - " 'queryVector': bson_query_vector,\n", |
183 |
| - " 'numCandidates': 20,\n", |
184 |
| - " 'limit': 5\n", |
185 |
| - " }\n", |
186 |
| - " },\n", |
187 |
| - " {\n", |
188 |
| - " '$project': {\n", |
189 |
| - " '_id': 0,\n", |
190 |
| - " 'name': 1,\n", |
191 |
| - " 'summary': 1,\n", |
192 |
| - " 'score': { '$meta': 'vectorSearchScore' }\n", |
| 200 | + " # Map path to output dtype and BSON vector type\n", |
| 201 | + " path_to_dtype = {\n", |
| 202 | + " float32_field: (\"float\", BinaryVectorDtype.FLOAT32),\n", |
| 203 | + " int8_field: (\"int8\", BinaryVectorDtype.INT8),\n", |
| 204 | + " int1_field: (\"ubinary\", BinaryVectorDtype.PACKED_BIT),\n", |
| 205 | + " }\n", |
| 206 | + "\n", |
| 207 | + " if path not in path_to_dtype:\n", |
| 208 | + " raise ValueError(\"Invalid path. Must be one of float32_field, int8_field, int1_field.\")\n", |
| 209 | + "\n", |
| 210 | + " # Get Voyage AI output dtype and BSON vector type based on the path\n", |
| 211 | + " output_dtype, bson_dtype = path_to_dtype[path]\n", |
| 212 | + "\n", |
| 213 | + " # Generate query embeddings using Voyage AI\n", |
| 214 | + " query_vector = vo.embed(\n", |
| 215 | + " texts=[query_text],\n", |
| 216 | + " model=\"voyage-3-large\",\n", |
| 217 | + " input_type=\"query\",\n", |
| 218 | + " output_dtype=output_dtype\n", |
| 219 | + " ).embeddings[0]\n", |
| 220 | + "\n", |
| 221 | + " # Convert the query vector to BSON format\n", |
| 222 | + " bson_query_vector = Binary.from_vector(query_vector, bson_dtype)\n", |
| 223 | + "\n", |
| 224 | + " # Define the aggregation pipeline for vector search\n", |
| 225 | + " pipeline = [\n", |
| 226 | + " {\n", |
| 227 | + " \"$vectorSearch\": {\n", |
| 228 | + " \"index\": index_name, # Replace with your index name\n", |
| 229 | + " \"path\": path, # Path to the embedding field\n", |
| 230 | + " \"queryVector\": bson_query_vector, # BSON-encoded query vector\n", |
| 231 | + " \"numCandidates\": 20,\n", |
| 232 | + " \"limit\": 5\n", |
| 233 | + " }\n", |
| 234 | + " },\n", |
| 235 | + " {\n", |
| 236 | + " \"$project\": {\n", |
| 237 | + " \"_id\": 0,\n", |
| 238 | + " \"summary\": 1,\n", |
| 239 | + " \"score\": { \"$meta\": \"vectorSearchScore\" } # Include the similarity score\n", |
| 240 | + " }\n", |
193 | 241 | " }\n",
|
194 |
| - " }\n", |
195 |
| - " ]\n", |
| 242 | + " ]\n", |
196 | 243 | "\n",
|
197 |
| - " return collection.aggregate(pipeline)" |
| 244 | + " # Run the aggregation pipeline and return results\n", |
| 245 | + " return collection.aggregate(pipeline)" |
198 | 246 | ]
|
199 | 247 | },
|
200 | 248 | {
|
|
205 | 253 | "source": [
|
206 | 254 | "from pprint import pprint\n",
|
207 | 255 | "\n",
|
208 |
| - "# Run a vector search query\n", |
| 256 | + "# Define a list of embedding fields to query\n", |
| 257 | + "embedding_fields = [float32_field, int8_field, int1_field] \n", |
| 258 | + "results = {}\n", |
| 259 | + "\n", |
| 260 | + "# Run vector search queries for each embedding type\n", |
209 | 261 | "query_text = \"ocean view\"\n",
|
210 |
| - "query_results = run_vector_search(query_text, collection, \"embedding\")\n", |
| 262 | + "for field in embedding_fields:\n", |
| 263 | + " results[field] = list(run_vector_search(query_text, collection, field)) \n", |
211 | 264 | "\n",
|
212 |
| - "print(\"query results:\")\n", |
213 |
| - "pprint(list(query_results))" |
| 265 | + "# Print the results\n", |
| 266 | + "for field, field_results in results.items():\n", |
| 267 | + " print(f\"Results from {field}\")\n", |
| 268 | + " pprint(field_results)" |
214 | 269 | ]
|
215 | 270 | }
|
216 | 271 | ],
|
|
0 commit comments