Skip to content

Commit cc2e048

Browse files
committed
fix: correctly save and load model name in UI
1 parent 7fdf95e commit cc2e048

File tree

2 files changed

+38
-25
lines changed

2 files changed

+38
-25
lines changed

src/utils/utils.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,13 @@ def get_llm_model(provider: str, **kwargs):
182182

183183

184184
# Callback to update the model name dropdown based on the selected provider
185-
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
186-
"""
187-
Update the model name dropdown with predefined models for the selected provider.
188-
"""
189-
import gradio as gr
190-
# Use API keys from .env if not provided
191-
if not api_key:
192-
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
193-
if not base_url:
194-
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
195-
196-
# Use predefined models for the selected provider
197-
if llm_provider in model_names:
198-
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
199-
else:
200-
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
185+
def update_model_dropdown(selected_provider, current_model):
186+
"""Update model choices based on provider."""
187+
choices = model_names[selected_provider]
188+
189+
if current_model not in choices and any(current_model in models for models in model_names.values()):
190+
current_model = choices[0]
191+
return gr.update(choices=choices, value=current_model)
201192

202193

203194
class MissingAPIKeyError(Exception):

webui.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import gradio as gr
1616
import inspect
17-
from functools import wraps
17+
from functools import partial, wraps
1818

1919
from browser_use.agent.service import Agent
2020
from playwright.async_api import async_playwright
@@ -49,37 +49,59 @@
4949
webui_config_manager = utils.ConfigManager()
5050

5151

52+
def sync_component_value_to_manager(component_registered_name, new_value):
53+
"""Sync the value of a component to the config manager"""
54+
global webui_config_manager
55+
if webui_config_manager:
56+
component_object = webui_config_manager.components.get(component_registered_name)
57+
if component_object:
58+
current_manager_value = getattr(component_object, "value", None)
59+
if current_manager_value != new_value:
60+
component_object.value = new_value
61+
return None
62+
5263
def scan_and_register_components(blocks):
5364
"""扫描一个 Blocks 对象并注册其中的所有交互式组件,但不包括按钮"""
5465
global webui_config_manager
66+
component_map = {}
5567

5668
def traverse_blocks(block, prefix=""):
69+
nonlocal component_map
5770
registered = 0
58-
5971
# 处理 Blocks 自身的组件
6072
if hasattr(block, "children"):
6173
for i, child in enumerate(block.children):
74+
name = None
75+
is_eligible_for_config = False
6276
if isinstance(child, gr.components.Component):
63-
# 排除按钮 (Button) 组件
64-
if getattr(child, "interactive", False) and not isinstance(child, gr.Button):
77+
# 排除按钮 (Button/File) 组件
78+
if getattr(child, "interactive", False) and not isinstance(child, gr.Button) and not isinstance(child, gr.File):
79+
is_eligible_for_config = True
6580
name = f"{prefix}component_{i}"
6681
if hasattr(child, "label") and child.label:
6782
# 使用标签作为名称的一部分
6883
label = child.label
6984
name = f"{prefix}{label}"
70-
logger.debug(f"Registering component: {name}")
71-
webui_config_manager.register_component(name, child)
72-
registered += 1
7385
elif hasattr(child, "children"):
7486
# 递归处理嵌套的 Blocks
7587
new_prefix = f"{prefix}block_{i}_"
7688
registered += traverse_blocks(child, new_prefix)
7789

90+
if is_eligible_for_config and name:
91+
webui_config_manager.register_component(name, child)
92+
component_map[name] = child
93+
registered += 1
7894
return registered
7995

8096
total = traverse_blocks(blocks)
8197
logger.info(f"Total registered components: {total}")
8298

99+
# Register the components with the config manager
100+
for name, component_obj in component_map.items():
101+
sync_handler = partial(sync_component_value_to_manager, name)
102+
if hasattr(component_obj, 'change'):
103+
component_obj.change(fn=sync_handler, inputs=[component_obj], outputs=None)
104+
83105

84106
def save_current_config():
85107
return webui_config_manager.save_current_config()
@@ -1158,8 +1180,8 @@ def list_recordings(save_recording_path):
11581180

11591181
# Attach the callback to the LLM provider dropdown
11601182
llm_provider.change(
1161-
lambda provider, api_key, base_url: update_model_dropdown(provider, api_key, base_url),
1162-
inputs=[llm_provider, llm_api_key, llm_base_url],
1183+
fn=utils.update_model_dropdown,
1184+
inputs=[llm_provider, llm_model_name],
11631185
outputs=llm_model_name
11641186
)
11651187

0 commit comments

Comments
 (0)