Spaces:
Running on CPU Upgrade

File size: 1,888 Bytes
5ea344f
7360ef0
 
 
f013f59
 
 
 
 
 
 
 
 
 
 
 
7360ef0
f013f59
 
7360ef0
f013f59
 
7360ef0
5ef8568
f013f59
7360ef0
 
f013f59
 
7360ef0
f013f59
 
7360ef0
f013f59
 
 
 
 
7360ef0
59c5924
f013f59
7360ef0
f013f59
7360ef0
f013f59
7360ef0
 
 
f013f59
 
 
7360ef0
f013f59
 
 
 
 
7360ef0
f013f59
59c5924
f013f59
5d7ac94
f013f59
 
 
02ae823
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from gradio_client import Client

MODELS = {"OLMo-2-1124-13B-Instruct": "akhaliq/olmo-anychat", "Llama-3.1-Tulu-3-8B": "akhaliq/allen-test"}


def create_chat_fn(client):
    def chat(message, history):
        response = client.predict(
            message=message,
            system_prompt="You are a helpful AI assistant.",
            temperature=0.7,
            max_new_tokens=1024,
            top_k=40,
            repetition_penalty=1.1,
            top_p=0.95,
            api_name="/chat",
        )
        return response

    return chat


def set_client_for_session(model_name, request: gr.Request):
    headers = {}
    if request and hasattr(request, "request") and hasattr(request.request, "headers"):
        x_ip_token = request.request.headers.get("x-ip-token")
        if x_ip_token:
            headers["X-IP-Token"] = x_ip_token

    return Client(MODELS[model_name], headers=headers)


def safe_chat_fn(message, history, client):
    if client is None:
        return "Error: Client not initialized. Please refresh the page."
    return create_chat_fn(client)(message, history)


with gr.Blocks() as demo:
    client = gr.State()

    model_dropdown = gr.Dropdown(
        choices=list(MODELS.keys()), value="OLMo-2-1124-13B-Instruct", label="Select Model", interactive=True
    )

    chat_interface = gr.ChatInterface(fn=safe_chat_fn, additional_inputs=[client])

    # Update client when model changes
    def update_model(model_name, request):
        return set_client_for_session(model_name, request)

    model_dropdown.change(
        fn=update_model,
        inputs=[model_dropdown],
        outputs=[client],
    )

    # Initialize client on page load
    demo.load(
        fn=set_client_for_session,
        inputs=gr.State("OLMo-2-1124-13B-Instruct"),
        outputs=client,
    )

if __name__ == "__main__":
    demo.launch()