1
1
#!/usr/bin/env python
2
2
# coding=utf-8
3
-
4
3
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
4
#
6
5
# Licensed under the Apache License, Version 2.0 (the "License");
15
14
# See the License for the specific language governing permissions and
16
15
# limitations under the License.
17
16
import gradio as gr
17
+ import shutil
18
+ import os
19
+ import mimetypes
20
+ import re
18
21
19
22
from .agents import ActionStep , AgentStep , MultiStepAgent
20
23
from .types import AgentAudio , AgentImage , AgentText , handle_agent_output_types
@@ -82,8 +85,12 @@ def stream_to_gradio(
82
85
class GradioUI :
83
86
"""A one-line interface to launch your agent in Gradio"""
84
87
85
- def __init__ (self , agent : MultiStepAgent ):
88
+ def __init__ (self , agent : MultiStepAgent , file_upload_folder : str | None = None ):
86
89
self .agent = agent
90
+ self .file_upload_folder = file_upload_folder
91
+ if self .file_upload_folder is not None :
92
+ if not os .path .exists (file_upload_folder ):
93
+ os .mkdir (file_upload_folder )
87
94
88
95
def interact_with_agent (self , prompt , messages ):
89
96
messages .append (gr .ChatMessage (role = "user" , content = prompt ))
@@ -93,6 +100,45 @@ def interact_with_agent(self, prompt, messages):
93
100
yield messages
94
101
yield messages
95
102
103
+ def upload_file (self , file , allowed_file_types = ["application/pdf" , "application/vnd.openxmlformats-officedocument.wordprocessingml.document" , "text/plain" ]):
104
+ """
105
+ Handle file uploads, default allowed types are pdf, docx, and .txt
106
+ """
107
+
108
+ # Check if file is uploaded
109
+ if file is None :
110
+ return "No file uploaded"
111
+
112
+ # Check if file is in allowed filetypes
113
+ name = os .path .basename (file .name )
114
+ try :
115
+ mime_type , _ = mimetypes .guess_type (file .name )
116
+ except Exception as e :
117
+ return f"Error: { e } "
118
+
119
+ if mime_type not in allowed_file_types :
120
+ return "File type disallowed"
121
+
122
+ # Sanitize file name
123
+ original_name = os .path .basename (file .name )
124
+ sanitized_name = re .sub (r'[^\w\-.]' , '_' , original_name ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
125
+
126
+ type_to_ext = {}
127
+ for ext , t in mimetypes .types_map .items ():
128
+ if t not in type_to_ext :
129
+ type_to_ext [t ] = ext
130
+
131
+ # Ensure the extension correlates to the mime type
132
+ sanitized_name = sanitized_name .split ("." )[:- 1 ]
133
+ sanitized_name .append ("" + type_to_ext [mime_type ])
134
+ sanitized_name = "" .join (sanitized_name )
135
+
136
+ # Save the uploaded file to the specified folder
137
+ file_path = os .path .join (self .file_upload_folder , os .path .basename (sanitized_name ))
138
+ shutil .copy (file .name , file_path )
139
+
140
+ return f"File uploaded successfully to { self .file_upload_folder } "
141
+
96
142
def launch (self ):
97
143
with gr .Blocks () as demo :
98
144
stored_message = gr .State ([])
@@ -104,6 +150,14 @@ def launch(self):
104
150
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png" ,
105
151
),
106
152
)
153
+ # If an upload folder is provided, enable the upload feature
154
+ if self .file_upload_folder is not None :
155
+ upload_file = gr .File (label = "Upload a file" )
156
+ upload_status = gr .Textbox (label = "Upload Status" , interactive = False )
157
+
158
+ upload_file .change (
159
+ self .upload_file , [upload_file ], [upload_status ]
160
+ )
107
161
text_input = gr .Textbox (lines = 1 , label = "Chat Message" )
108
162
text_input .submit (
109
163
lambda s : (s , "" ), [text_input ], [stored_message , text_input ]
0 commit comments