Skip to content

Commit 2a51efe

Browse files
Add option to upload files to GradioUI (huggingface#138)
* Add option to upload files to GradioUI
1 parent 67ee777 commit 2a51efe

File tree

3 files changed

+68
-2
lines changed

3 files changed

+68
-2
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ wandb
66
# Data
77
data
88
outputs
9+
data/
910

1011
# Apple
1112
.DS_Store

examples/gradio_upload.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from smolagents import (
2+
CodeAgent,
3+
HfApiModel,
4+
GradioUI
5+
)
6+
7+
agent = CodeAgent(
8+
tools=[], model=HfApiModel(), max_steps=4, verbose=True
9+
)
10+
11+
GradioUI(agent, file_upload_folder='./data').launch()

src/smolagents/gradio_ui.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
22
# coding=utf-8
3-
43
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
54
#
65
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,6 +14,10 @@
1514
# See the License for the specific language governing permissions and
1615
# limitations under the License.
1716
import gradio as gr
17+
import shutil
18+
import os
19+
import mimetypes
20+
import re
1821

1922
from .agents import ActionStep, AgentStep, MultiStepAgent
2023
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
@@ -82,8 +85,12 @@ def stream_to_gradio(
8285
class GradioUI:
8386
"""A one-line interface to launch your agent in Gradio"""
8487

85-
def __init__(self, agent: MultiStepAgent):
88+
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None):
8689
self.agent = agent
90+
self.file_upload_folder = file_upload_folder
91+
if self.file_upload_folder is not None:
92+
if not os.path.exists(file_upload_folder):
93+
os.mkdir(file_upload_folder)
8794

8895
def interact_with_agent(self, prompt, messages):
8996
messages.append(gr.ChatMessage(role="user", content=prompt))
@@ -93,6 +100,45 @@ def interact_with_agent(self, prompt, messages):
93100
yield messages
94101
yield messages
95102

103+
def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
104+
"""
105+
Handle file uploads, default allowed types are pdf, docx, and .txt
106+
"""
107+
108+
# Check if file is uploaded
109+
if file is None:
110+
return "No file uploaded"
111+
112+
# Check if file is in allowed filetypes
113+
name = os.path.basename(file.name)
114+
try:
115+
mime_type, _ = mimetypes.guess_type(file.name)
116+
except Exception as e:
117+
return f"Error: {e}"
118+
119+
if mime_type not in allowed_file_types:
120+
return "File type disallowed"
121+
122+
# Sanitize file name
123+
original_name = os.path.basename(file.name)
124+
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
125+
126+
type_to_ext = {}
127+
for ext, t in mimetypes.types_map.items():
128+
if t not in type_to_ext:
129+
type_to_ext[t] = ext
130+
131+
# Ensure the extension correlates to the mime type
132+
sanitized_name = sanitized_name.split(".")[:-1]
133+
sanitized_name.append("" + type_to_ext[mime_type])
134+
sanitized_name = "".join(sanitized_name)
135+
136+
# Save the uploaded file to the specified folder
137+
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
138+
shutil.copy(file.name, file_path)
139+
140+
return f"File uploaded successfully to {self.file_upload_folder}"
141+
96142
def launch(self):
97143
with gr.Blocks() as demo:
98144
stored_message = gr.State([])
@@ -104,6 +150,14 @@ def launch(self):
104150
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
105151
),
106152
)
153+
# If an upload folder is provided, enable the upload feature
154+
if self.file_upload_folder is not None:
155+
upload_file = gr.File(label="Upload a file")
156+
upload_status = gr.Textbox(label="Upload Status", interactive=False)
157+
158+
upload_file.change(
159+
self.upload_file, [upload_file], [upload_status]
160+
)
107161
text_input = gr.Textbox(lines=1, label="Chat Message")
108162
text_input.submit(
109163
lambda s: (s, ""), [text_input], [stored_message, text_input]

0 commit comments

Comments
 (0)