Skip to content

Commit

Permalink
Merge pull request #9 from AIVIETNAMResearch/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
BachNgoH authored Jun 19, 2024
2 parents a5c7a21 + fd34302 commit b4c3783
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 33 deletions.
44 changes: 28 additions & 16 deletions api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from llama_index.core.callbacks import CallbackManager
# from src.agents.assistant_agent import AssistantAgent
from llama_index.agent.openai import OpenAIAgent
from src.agents.gemini_agent import GeminiForFunctionCalling
from llama_index.core.agent import ReActAgent
# from src.agents.gemini_agent import GeminiForFunctionCalling
from llama_index.core import Settings
from src.tools.paper_search_tool import load_paper_search_tool, load_daily_paper_tool, load_get_time_tool
from src.tools.web_search_tool import load_web_search_tool
from src.tools.summarize_tool import load_summarize_tool
from src.tasks.question_recommend_task import QuestionRecommender
from src.constants import SYSTEM_PROMPT
from starlette.responses import StreamingResponse, Response
from llama_index.core.base.llms.types import ChatMessage
Expand All @@ -27,7 +29,9 @@
SERVICE,
TEMPERATURE,
MODEL_ID,
STREAM
STREAM,
ENABLE_QUESTION_RECOMMENDER,
cfg
)
load_dotenv(override=True)

Expand All @@ -40,6 +44,11 @@ def __init__(self, callback_manager: Optional[CallbackManager] = None):
self.callback_manager = callback_manager
self.query_engine = self.create_query_engine()

if ENABLE_QUESTION_RECOMMENDER:
qr_llm = self.load_model(cfg.MODEL.QR_SERVICE, cfg.MODEL.QR_MODEL_ID)
self.question_recommender = QuestionRecommender.from_defaults(llm=qr_llm)


def load_tools(self):
paper_search_tool = load_paper_search_tool()
paper_summarize_tool = load_summarize_tool()
Expand All @@ -64,11 +73,13 @@ def create_query_engine(self):
Settings.llm = llm
self.tools = self.load_tools()

if SERVICE == "gemini":
query_engine = GeminiForFunctionCalling(
if SERVICE != "openai":
query_engine = ReActAgent.from_tools(
tools=self.tools,
api_key=os.getenv("GOOGLE_API_KEY"),
temperature=TEMPERATURE
verbose=True,
llm=llm,
system_prompt = SYSTEM_PROMPT,
callback_manager=self.callback_manager
)
else:
query_engine = OpenAIAgent.from_tools(
Expand Down Expand Up @@ -165,13 +176,14 @@ async def aon_message(self, message: cl.Message):

await msg.send()

history.append({"role": "assistant", "content": res.response})

next_questions = handle_next_question_generation(tools=self.tools, query_str=message.content, llm_response=res.response)
handle_generate_actions(next_questions)

actions = [
cl.Action(name=question, value=question, description=question) for question in next_questions
]
msg.actions = actions
await msg.update()
if ENABLE_QUESTION_RECOMMENDER:
history.append({"role": "assistant", "content": res.response})

next_questions = handle_next_question_generation(tools=self.tools, query_str=message.content, llm_response=res.response)
handle_generate_actions(next_questions)

actions = [
cl.Action(name=question, value=question, description=question) for question in next_questions
]
msg.actions = actions
await msg.update()
6 changes: 5 additions & 1 deletion config/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ MODEL:


VECTOR_STORE: "chroma" # currently support [qdrant, chroma]
PAPER_COLLECTION_NAME: "gemma_assistant_arxiv_papers"
PAPER_COLLECTION_NAME: "gemma_assistant_arxiv_papers"

ENABLE_QUESTION_RECOMMENDER: False
QR_SERVICE: # [ ollama, openai, groq, gemini ]
QR_MODEL_ID:
Binary file added public/web_icon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion src/callbacks/chainlit_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
logger.setLevel(logging.WARNING)

async def run_step(payload):
print(payload)
function_response = ast.literal_eval(payload["function_call_response"])
source_nodes = []

Expand Down
2 changes: 2 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
EMBEDDING_SERVICE = cfg.MODEL.EMBEDDING_SERVICE
EMBEDDING_MODEL_NAME = cfg.MODEL.EMBEDDING_MODEL_NAME

ENABLE_QUESTION_RECOMMENDER = cfg.MODEL.ENABLE_QUESTION_RECOMMENDER

DEFAULT_SYSTEM_PROMPT = """
Bạn là chatbot được phát triển bởi team GenAIO thuộc AIVIETNAM.
Bạn được đưa một nội dung từ một văn bản và công việc của bạn là trả lời một câu hỏi của user về nội dung đã được cung cấp
Expand Down
6 changes: 4 additions & 2 deletions src/utils/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def setup_history(thread: ThreadDict):
return history


def handle_next_question_generation(tools: list, query_str: str, llm_response: str):
def handle_next_question_generation(
tools: list,
query_str: str,
llm_response: str, question_recommender: QuestionRecommender = QuestionRecommender.from_defaults()):

question_recommender = QuestionRecommender.from_defaults()
recommended_questions = question_recommender.generate(
tools=[tool.metadata for tool in tools],
query_str=query_str,
Expand Down
33 changes: 20 additions & 13 deletions src/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import requests
from bs4 import BeautifulSoup
from urllib.parse import urlparse, urljoin

DEFAULT_FAVICON_URL = f"{os.getcwd()}/public/web_icon.png"

def get_website_info(url):
parsed_url = urlparse(url)
domain = parsed_url.netloc
Expand All @@ -10,18 +13,22 @@ def get_website_info(url):
website_name = domain.split('.')[-2] # e.g., "youtube"

# Get the favicon URL
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
try:
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')

# Look for the favicon in the <link> tags
icon_link = soup.find('link', rel=lambda x: x and 'icon' in x.lower())
if icon_link:
favicon_url = icon_link.get('href')
parsed_favicon_url = urlparse(favicon_url)
if not parsed_favicon_url.netloc: # relative URL
favicon_url = urljoin(url, favicon_url)
else:
# Default favicon location
favicon_url = urljoin(url, '/favicon.ico')
# Look for the favicon in the <link> tags
icon_link = soup.find('link', rel=lambda x: x and 'icon' in x.lower())
if icon_link:
favicon_url = icon_link.get('href')
parsed_favicon_url = urlparse(favicon_url)
if not parsed_favicon_url.netloc: # relative URL
favicon_url = urljoin(url, favicon_url)
else:
# Default favicon location
favicon_url = urljoin(url, '/favicon.ico')

return website_name, favicon_url
return website_name, favicon_url
except Exception as _:
print(DEFAULT_FAVICON_URL)
return website_name, DEFAULT_FAVICON_URL

0 comments on commit b4c3783

Please sign in to comment.