import os import json import re import sys import io import contextlib import warnings from typing import Optional, List, Any, Tuple from PIL import Image import streamlit as st import pandas as pd import base64 from io import BytesIO from together import Together from e2b_code_interpreter import Sandbox warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL) def code_interpret(e2b_code_interpreter: Sandbox, code: str) -> Optional[List[Any]]: with st.spinner('Executing code in E2B sandbox...'): stdout_capture = io.StringIO() stderr_capture = io.StringIO() with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture): with warnings.catch_warnings(): warnings.simplefilter("ignore") exec = e2b_code_interpreter.run_code(code) if stderr_capture.getvalue(): print("[Code Interpreter Warnings/Errors]", file=sys.stderr) print(stderr_capture.getvalue(), file=sys.stderr) if stdout_capture.getvalue(): print("[Code Interpreter Output]", file=sys.stdout) print(stdout_capture.getvalue(), file=sys.stdout) if exec.error: print(f"[Code Interpreter ERROR] {exec.error}", file=sys.stderr) return None return exec.results def match_code_blocks(llm_response: str) -> str: match = pattern.search(llm_response) if match: code = match.group(1) return code return "" def chat_with_llm(e2b_code_interpreter: Sandbox, user_message: str, dataset_path: str) -> Tuple[Optional[List[Any]], str]: # Update system prompt to include dataset path information system_prompt = f"""You're a Python data scientist and data visualization expert. You are given a dataset at path '{dataset_path}' and also the user's query. You need to analyze the dataset and answer the user's query with a response and you run Python code to solve them. IMPORTANT: Always use the dataset path variable '{dataset_path}' in your code when reading the CSV file.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ] with st.spinner('Getting response from Together AI LLM model...'): client = Together(api_key=st.session_state.together_api_key) response = client.chat.completions.create( model=st.session_state.model_name, messages=messages, ) response_message = response.choices[0].message python_code = match_code_blocks(response_message.content) if python_code: code_interpreter_results = code_interpret(e2b_code_interpreter, python_code) return code_interpreter_results, response_message.content else: st.warning(f"Failed to match any Python code in model's response") return None, response_message.content def upload_dataset(code_interpreter: Sandbox, uploaded_file) -> str: dataset_path = f"./{uploaded_file.name}" try: code_interpreter.files.write(dataset_path, uploaded_file) return dataset_path except Exception as error: st.error(f"Error during file upload: {error}") raise error def main(): """Main Streamlit application.""" st.set_page_config(page_title="📊 AI Data Visualization Agent", page_icon="📊", layout="wide") st.title("📊 AI Data Visualization Agent") st.write("Upload your dataset and ask questions about it!") # Initialize session state variables if 'together_api_key' not in st.session_state: st.session_state.together_api_key = '' if 'e2b_api_key' not in st.session_state: st.session_state.e2b_api_key = '' if 'model_name' not in st.session_state: st.session_state.model_name = '' # Sidebar for API keys and model configuration with st.sidebar: st.header("🔑 API Keys and Model Configuration") st.session_state.together_api_key = st.text_input("Together AI API Key", type="password") st.info("💡 Everyone gets a free $1 credit by Together AI - AI Acceleration Cloud platform") st.markdown("[Get Together AI API Key](https://api.together.ai/signin)") st.session_state.e2b_api_key = st.text_input("Enter E2B API Key", type="password") st.markdown("[Get E2B API Key](https://e2b.dev/docs/legacy/getting-started/api-key)") # Add model selection dropdown model_options = { "Meta-Llama 3.1 405B": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "DeepSeek V3": "deepseek-ai/DeepSeek-V3", "Qwen 2.5 7B": "Qwen/Qwen2.5-7B-Instruct-Turbo", "Meta-Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct-Turbo" } st.session_state.model_name = st.selectbox( "Select Model", options=list(model_options.keys()), index=0 # Default to first option ) st.session_state.model_name = model_options[st.session_state.model_name] # Main content layout col1, col2 = st.columns([1, 2]) # Split the main content into two columns with col1: st.header("📂 Upload Dataset") uploaded_file = st.file_uploader("Choose a CSV file", type="csv", key="file_uploader") if uploaded_file is not None: # Display dataset with toggle df = pd.read_csv(uploaded_file) st.write("### Dataset Preview") show_full = st.checkbox("Show full dataset") if show_full: st.dataframe(df) else: st.write("Preview (first 5 rows):") st.dataframe(df.head()) with col2: if uploaded_file is not None: st.header("❓ Ask a Question") query = st.text_area( "What would you like to know about your data?", "Can you compare the average cost for two people between different categories?", height=100 ) if st.button("Analyze", type="primary", key="analyze_button"): if not st.session_state.together_api_key or not st.session_state.e2b_api_key: st.error("Please enter both API keys in the sidebar.") else: with Sandbox(api_key=st.session_state.e2b_api_key) as code_interpreter: # Upload the dataset dataset_path = upload_dataset(code_interpreter, uploaded_file) # Pass dataset_path to chat_with_llm code_results, llm_response = chat_with_llm(code_interpreter, query, dataset_path) # Display LLM's text response st.header("🤖 AI Response") st.write(llm_response) # Display results/visualizations if code_results: st.header("📊 Analysis Results") for result in code_results: if hasattr(result, 'png') and result.png: # Check if PNG data is available # Decode the base64-encoded PNG data png_data = base64.b64decode(result.png) # Convert PNG data to an image and display it image = Image.open(BytesIO(png_data)) st.image(image, caption="Generated Visualization", use_container_width=True) elif hasattr(result, 'figure'): # For matplotlib figures fig = result.figure # Extract the matplotlib figure st.pyplot(fig) # Display using st.pyplot elif hasattr(result, 'show'): # For plotly figures st.plotly_chart(result) elif isinstance(result, (pd.DataFrame, pd.Series)): st.dataframe(result) else: st.write(result) if __name__ == "__main__": main()