diff --git a/lagent/actions/bing_browser.py b/lagent/actions/bing_browser.py index fddaa6b5..db6dbec3 100755 --- a/lagent/actions/bing_browser.py +++ b/lagent/actions/bing_browser.py @@ -373,6 +373,82 @@ def _parse_response(self, response: dict) -> dict: return self._filter_results(raw_results) +class BoChaSearch(BaseSearch): + """ + Wrapper around the BoCha Web Search API. + + To use, you should pass your BoCha API key to the constructor. + + Args: + api_key (str): API KEY to use BoCha web search API. + You can create a API key at https://bochaai.com/. + summary (bool): Indicates whether the content of the website should be summarized LLM. + If True, the summary will be retrieved as part of the result from the web search API. + topk (int): The number of search results returned in response from api search results. + **kwargs: Any other parameters related to the BoCha API. Find more details at + https://bochaai.com/ + """ + + def __init__(self, + api_key: str, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): + self.api_key = api_key + self.summary = True + self.proxy = kwargs.get('proxy') + self.kwargs = kwargs + super().__init__(topk, black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_bocha_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from BoCha Search after retries.') + + def _call_bocha_api(self, query: str) -> dict: + endpoint = 'https://api.bochaai.com/v1/web-search' + params = json.dumps({ + 'query': query, + 'count': self.topk, + 'summary': self.summary, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, + }) + headers = { + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json' + } + response = requests.request( + 'POST', endpoint, headers=headers, data=params) + response.raise_for_status() + return response.json() + + def _parse_response(self, response: dict) -> dict: + raw_results = [(w.get('url', + ''), w.get('snippet', '') + w.get('summary', ''), + w.get('name', + '')) for w in response.get('data', {}).get( + 'webPages', {}).get('value', [])] + + return self._filter_results(raw_results) + + class ContentFetcher: def __init__(self, timeout: int = 5):