|
1 | 1 | from datetime import datetime, timedelta
|
2 |
| -from typing import Optional, Sequence |
| 2 | +from typing import Generator, Optional, Sequence, Union, NewType, NamedTuple |
3 | 3 | import re
|
4 | 4 |
|
| 5 | +import jsonpath_ng |
5 | 6 | from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
6 | 7 | from sqlalchemy.orm import Session, joinedload
|
7 | 8 |
|
@@ -163,114 +164,177 @@ def get_signal_instances_with_entity(
|
163 | 164 | return signal_instances
|
164 | 165 |
|
165 | 166 |
|
| 167 | +EntityTypePair = NewType( |
| 168 | + "EntityTypePair", |
| 169 | + NamedTuple( |
| 170 | + "EntityTypePairTuple", |
| 171 | + [ |
| 172 | + ("entity_type", EntityType), |
| 173 | + ("regex", Union[re.Pattern[str], None]), |
| 174 | + ("json_path", Union[jsonpath_ng.JSONPath, None]), |
| 175 | + ], |
| 176 | + ), |
| 177 | +) |
| 178 | + |
| 179 | + |
166 | 180 | def find_entities(
|
167 | 181 | db_session: Session, signal_instance: SignalInstance, entity_types: Sequence[EntityType]
|
168 | 182 | ) -> list[Entity]:
|
169 |
| - """Find entities of the given types in the raw data of a signal instance. |
| 183 | + """ |
| 184 | + Find entities in a SignalInstance based on a list of EntityTypes. |
170 | 185 |
|
171 | 186 | Args:
|
172 |
| - db_session (Session): SQLAlchemy database session. |
173 |
| - signal_instance (SignalInstance): SignalInstance to search for entities in. |
174 |
| - entity_types (list[EntityType]): List of EntityType objects to search for. |
| 187 | + db_session (Session): The database session to use for entity creation. |
| 188 | + signal_instance (SignalInstance): The SignalInstance to extract entities from. |
| 189 | + entity_types (Sequence[EntityType]): A list of EntityTypes to search for in the SignalInstance. |
175 | 190 |
|
176 | 191 | Returns:
|
177 |
| - list[Entity]: List of Entity objects found. |
178 |
| -
|
179 |
| - Example: |
180 |
| - >>> signal_instance = SignalInstance( |
181 |
| - ... raw={ |
182 |
| - ... "name": "John Smith", |
183 |
| - |
184 |
| - ... "phone": "555-555-1212", |
185 |
| - ... "address": { |
186 |
| - ... "street": "123 Main St", |
187 |
| - ... "city": "Anytown", |
188 |
| - ... "state": "CA", |
189 |
| - ... "zip": "12345" |
190 |
| - ... }, |
191 |
| - ... "notes": "Customer is interested in buying a product." |
192 |
| - ... } |
193 |
| - ... ) |
194 |
| - >>> entity_types = [ |
195 |
| - ... EntityType(name="Name", field="name", regular_expression=r"\b[A-Z][a-z]+ [A-Z][a-z]+\b"), |
196 |
| - ... EntityType(name="Phone", field=None, regular_expression=r"\b\\d{3}[-.]?\\d{3}[-.]?\\d{4}\b"), |
197 |
| - ... EntityType(name="Street", field="address.street"), |
198 |
| - ... ] |
199 |
| - >>> entities = find_entities(db_session, signal_instance, entity_types) |
200 |
| -
|
201 |
| - Notes: |
202 |
| - This function uses depth-first search to traverse the raw data of the signal instance. It searches for |
203 |
| - the regular expressions specified in the EntityType objects in the values of the dictionary, list, and |
204 |
| - string objects encountered during the traversal. The search can be limited to a specific key in the |
205 |
| - dictionary objects by specifying a value for the 'field' attribute of the EntityType object. |
| 192 | + list[Entity]: A list of entities found in the SignalInstance. |
206 | 193 | """
|
207 | 194 |
|
208 |
| - def _search(key, val, entity_type_pairs): |
209 |
| - # Create a list to hold any entities that are found in this value |
210 |
| - entities = [] |
211 |
| - |
212 |
| - # If this value has been searched before, return the cached entities |
213 |
| - if id(val) in cache: |
214 |
| - return cache[id(val)] |
215 |
| - |
| 195 | + def _find_entites_by_regex( |
| 196 | + val: Union[dict, str, list], |
| 197 | + signal_instance: SignalInstance, |
| 198 | + entity_type_pairs: list[EntityTypePair], |
| 199 | + ) -> Generator[EntityCreate, None, None]: |
| 200 | + """ |
| 201 | + Find entities in a value using regular expressions. |
| 202 | +
|
| 203 | + Args: |
| 204 | + val: The value to search for entities in. |
| 205 | + signal_instance (SignalInstance): The SignalInstance being processed. |
| 206 | + entity_type_pairs (list): A list of (entity_type, entity_regex, field) tuples to search for. |
| 207 | +
|
| 208 | + Yields: |
| 209 | + EntityCreate: An entity found in the value. |
| 210 | +
|
| 211 | + Examples: |
| 212 | + >>> entity_type_pairs = [ |
| 213 | + ... ( |
| 214 | + ... EntityType("PERSON", r"([A-Z][a-z]+)+"), |
| 215 | + ... re.compile(r"([A-Z][a-z]+)+"), |
| 216 | + ... None |
| 217 | + ... ), |
| 218 | + ... ( |
| 219 | + ... EntityType("DATE", r"(\d{4}(-\d{2}){2}|\d{4}\/\d{2}\/\d{2})"), # noqa |
| 220 | + ... re.compile(r"(\d{4}(-\d{2}){2}|\\d{4}\/\d{2}\/\d{2})"), # noqa |
| 221 | + ... None |
| 222 | + ... ) |
| 223 | + ... ] |
| 224 | +
|
| 225 | + >>> signal_instance = SignalInstance(raw={"text": "John Doe was born on 1987-05-12."}) |
| 226 | +
|
| 227 | + >>> entities = list(_find_entites_by_regex(signal_instance.raw, signal_instance, entity_type_pairs)) |
| 228 | +
|
| 229 | + >>> entities[0].value |
| 230 | + 'John Doe' |
| 231 | + >>> entities[0].entity_type.name |
| 232 | + 'PERSON' |
| 233 | + >>> entities[1].value |
| 234 | + '1987-05-12' |
| 235 | + >>> entities[1].entity_type.name |
| 236 | + 'DATE' |
| 237 | + """ |
216 | 238 | # If the value is a dictionary, search its key-value pairs recursively
|
217 | 239 | if isinstance(val, dict):
|
218 |
| - for subkey, subval in val.items(): |
219 |
| - entities.extend(_search(subkey, subval, entity_type_pairs)) |
| 240 | + for _, subval in val.items(): |
| 241 | + yield from _find_entites_by_regex( |
| 242 | + subval, |
| 243 | + signal_instance, |
| 244 | + entity_type_pairs, |
| 245 | + ) |
220 | 246 |
|
221 | 247 | # If the value is a list, search its items recursively
|
222 | 248 | elif isinstance(val, list):
|
223 | 249 | for item in val:
|
224 |
| - entities.extend(_search(None, item, entity_type_pairs)) |
| 250 | + yield from _find_entites_by_regex( |
| 251 | + item, |
| 252 | + signal_instance, |
| 253 | + entity_type_pairs, |
| 254 | + ) |
225 | 255 |
|
226 | 256 | # If the value is a string, search it for entity matches
|
227 | 257 | elif isinstance(val, str):
|
228 |
| - for entity_type, entity_regex, field in entity_type_pairs: |
229 |
| - # If a field was specified for this entity type, only search that field |
230 |
| - if not field or key == field: |
231 |
| - if entity_regex is None: |
232 |
| - # If no regular expression was specified, return the value of the field/key |
233 |
| - entity = EntityCreate( |
234 |
| - value=val, |
235 |
| - entity_type=entity_type, |
236 |
| - project=signal_instance.project, |
237 |
| - ) |
238 |
| - entities.append(entity) |
239 |
| - else: |
240 |
| - # Search the string for matches to the entity type's regular expression |
241 |
| - if match := entity_regex.search(val): |
242 |
| - entity = EntityCreate( |
243 |
| - value=match.group(0), |
244 |
| - entity_type=entity_type, |
245 |
| - project=signal_instance.project, |
246 |
| - ) |
247 |
| - entities.append(entity) |
248 |
| - |
249 |
| - # Cache the entities found for this value |
250 |
| - cache[id(val)] = entities |
251 |
| - return entities |
| 258 | + for entity_type, entity_regex, _ in entity_type_pairs: |
| 259 | + # Search the string for matches to the entity type's regular expression |
| 260 | + if match := entity_regex.search(val): |
| 261 | + yield EntityCreate( |
| 262 | + value=match.group(0), |
| 263 | + entity_type=entity_type, |
| 264 | + project=signal_instance.project, |
| 265 | + ) |
| 266 | + |
| 267 | + def _find_entities_by_regex_and_jsonpath_expression( |
| 268 | + signal_instance: SignalInstance, |
| 269 | + entity_type_pairs: list[EntityTypePair], |
| 270 | + ) -> Generator[EntityCreate, None, None]: |
| 271 | + """ |
| 272 | + Yield entities found in a SignalInstance by searching its fields using regular expressions and JSONPath expressions. |
| 273 | +
|
| 274 | + Args: |
| 275 | + signal_instance: The SignalInstance to extract entities from. |
| 276 | + entity_type_pairs: A list of (entity_type, entity_regex, field) tuples to search for. |
| 277 | +
|
| 278 | + Yields: |
| 279 | + EntityCreate: An entity found in the SignalInstance. |
| 280 | + """ |
| 281 | + for entity_type, entity_regex, field in entity_type_pairs: |
| 282 | + if field: |
| 283 | + try: |
| 284 | + matches = field.find(signal_instance.raw) |
| 285 | + for match in matches: |
| 286 | + if isinstance(match.value, str): |
| 287 | + if entity_regex is None: |
| 288 | + yield EntityCreate( |
| 289 | + value=match.value, |
| 290 | + entity_type=entity_type, |
| 291 | + project=signal_instance.project, |
| 292 | + ) |
| 293 | + else: |
| 294 | + if match := entity_regex.search(match.value): |
| 295 | + yield EntityCreate( |
| 296 | + value=match.group(0), |
| 297 | + entity_type=entity_type, |
| 298 | + project=signal_instance.project, |
| 299 | + ) |
| 300 | + except jsonpath_ng.PathNotFound: |
| 301 | + # field not found in signal_instance.raw |
| 302 | + pass |
252 | 303 |
|
253 | 304 | # Create a list of (entity type, regular expression, field) tuples
|
254 | 305 | entity_type_pairs = [
|
255 |
| - (type, re.compile(type.regular_expression) if type.regular_expression else None, type.field) |
| 306 | + ( |
| 307 | + type, |
| 308 | + re.compile(type.regular_expression) if type.regular_expression else None, |
| 309 | + jsonpath_ng.parse(type.field) if type.field else None, |
| 310 | + ) |
256 | 311 | for type in entity_types
|
257 | 312 | if isinstance(type.regular_expression, str) or type.field is not None
|
258 | 313 | ]
|
259 | 314 |
|
260 |
| - # Initialize a cache of previously searched values |
261 |
| - cache = {} |
| 315 | + # Filter the entity type pairs based on the field |
| 316 | + filtered_entity_type_pairs = [ |
| 317 | + (entity_type, entity_regex, field) |
| 318 | + for entity_type, entity_regex, field in entity_type_pairs |
| 319 | + if not field |
| 320 | + ] |
262 | 321 |
|
263 |
| - # Traverse the signal data using depth-first search |
| 322 | + # Use the recursive search function to find entities in the raw data |
264 | 323 | entities = [
|
265 | 324 | entity
|
266 |
| - for key, val in signal_instance.raw.items() |
267 |
| - for entity in _search(key, val, entity_type_pairs) |
| 325 | + for _, val in signal_instance.raw.items() |
| 326 | + for entity in _find_entites_by_regex(val, signal_instance, filtered_entity_type_pairs) |
268 | 327 | ]
|
269 | 328 |
|
270 |
| - # Create the entities in the database and add them to the signal instance |
| 329 | + entities.extend( |
| 330 | + _find_entities_by_regex_and_jsonpath_expression(signal_instance, entity_type_pairs) |
| 331 | + ) |
| 332 | + |
| 333 | + # Filter out duplicate entities |
| 334 | + entities = list(set(entities)) |
| 335 | + |
271 | 336 | entities_out = [
|
272 | 337 | get_by_value_or_create(db_session=db_session, entity_in=entity_in) for entity_in in entities
|
273 | 338 | ]
|
274 | 339 |
|
275 |
| - # Return the list of entities found |
276 | 340 | return entities_out
|
0 commit comments