Skip to content

Commit 6fe4532

Browse files
committed
fix: correctly save and load model name in UI
1 parent f4f36b4 commit 6fe4532

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

src/utils/utils.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def get_llm_model(provider: str, **kwargs):
182182

183183

184184
# Callback to update the model name dropdown based on the selected provider
185-
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
185+
def update_model_dropdown(llm_provider, current_model_value, api_key=None, base_url=None):
186186
"""
187187
Update the model name dropdown with predefined models for the selected provider.
188188
"""
@@ -194,10 +194,13 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None):
194194
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
195195

196196
# Use predefined models for the selected provider
197-
if llm_provider in model_names:
198-
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
199-
else:
200-
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
197+
choices = model_names[llm_provider]
198+
new_value = current_model_value
199+
200+
if not choices or current_model_value not in choices:
201+
new_value = choices[0]
202+
203+
return gr.update(choices=choices, value=new_value)
201204

202205

203206
class MissingAPIKeyError(Exception):
@@ -289,13 +292,11 @@ def register_component(self, name: str, component):
289292
self.component_order.append(name)
290293
return component
291294

292-
def save_current_config(self):
293-
"""Save the current configuration of all registered components."""
295+
def save_current_config(self, *component_values):
296+
"""Save the current configuration passed directly from Gradio's inputs."""
294297
current_config = {}
295-
for name in self.component_order:
296-
component = self.components[name]
297-
# Get the current value from the component
298-
current_config[name] = getattr(component, "value", None)
298+
for i, name in enumerate(self.component_order):
299+
current_config[name] = component_values[i]
299300

300301
return save_config_to_file(current_config)
301302

webui.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1150,16 +1150,11 @@ def list_recordings(save_recording_path):
11501150
lines=2,
11511151
interactive=False
11521152
)
1153-
save_config_button.click(
1154-
fn=save_current_config,
1155-
inputs=[], # 不需要输入参数
1156-
outputs=[config_status]
1157-
)
11581153

11591154
# Attach the callback to the LLM provider dropdown
11601155
llm_provider.change(
1161-
lambda provider, api_key, base_url: update_model_dropdown(provider, api_key, base_url),
1162-
inputs=[llm_provider, llm_api_key, llm_base_url],
1156+
lambda provider, model_name, api_key, base_url: update_model_dropdown(provider, model_name, api_key, base_url),
1157+
inputs=[llm_provider, llm_model_name, llm_api_key, llm_base_url],
11631158
outputs=llm_model_name
11641159
)
11651160

@@ -1177,6 +1172,12 @@ def list_recordings(save_recording_path):
11771172
global webui_config_manager
11781173
all_components = webui_config_manager.get_all_components()
11791174

1175+
save_config_button.click(
1176+
fn=webui_config_manager.save_current_config,
1177+
inputs=all_components,
1178+
outputs=[config_status]
1179+
)
1180+
11801181
load_config_button.click(
11811182
fn=update_ui_from_config,
11821183
inputs=[config_file_input],

0 commit comments

Comments
 (0)