|
| 1 | +import pathlib |
| 2 | +import subprocess |
| 3 | +from os import getenv |
| 4 | + |
| 5 | +import streamlit as st |
| 6 | +from dotenv import load_dotenv |
| 7 | +from openai import AzureOpenAI |
| 8 | + |
| 9 | +load_dotenv() |
| 10 | + |
| 11 | +# Initialize the session state |
| 12 | +if "transcribed_result" not in st.session_state: |
| 13 | + st.session_state["transcribed_result"] = "" |
| 14 | + |
| 15 | +with st.sidebar: |
| 16 | + inference_type = st.selectbox( |
| 17 | + label="INEFERENCE_TYPE", |
| 18 | + options=[ |
| 19 | + "azure", |
| 20 | + "local", |
| 21 | + ], |
| 22 | + key="INEFERENCE_TYPE", |
| 23 | + ) |
| 24 | + azure_ai_speech_api_language = st.selectbox( |
| 25 | + label="AZURE_AI_SPEECH_API_LANGUAGE", |
| 26 | + options=[ |
| 27 | + "en-US", |
| 28 | + "ja-JP", |
| 29 | + ], |
| 30 | + key="AZURE_AI_SPEECH_API_LANGUAGE", |
| 31 | + ) |
| 32 | + if inference_type == "local": |
| 33 | + path_to_model = st.text_input( |
| 34 | + label="PATH_TO_MODEL", |
| 35 | + value="./model", |
| 36 | + key="PATH_TO_MODEL", |
| 37 | + type="default", |
| 38 | + ) |
| 39 | + stt_host = st.text_input( |
| 40 | + label="STT_HOST", |
| 41 | + value="ws://localhost:5000", |
| 42 | + key="STT_HOST", |
| 43 | + type="default", |
| 44 | + ) |
| 45 | + st.warning("yet to be implemented") |
| 46 | + if inference_type == "azure": |
| 47 | + azure_openai_endpoint = st.text_input( |
| 48 | + label="AZURE_OPENAI_ENDPOINT", |
| 49 | + value=getenv("AZURE_OPENAI_ENDPOINT"), |
| 50 | + key="AZURE_OPENAI_ENDPOINT", |
| 51 | + type="default", |
| 52 | + ) |
| 53 | + azure_openai_api_key = st.text_input( |
| 54 | + label="AZURE_OPENAI_API_KEY", |
| 55 | + value=getenv("AZURE_OPENAI_API_KEY"), |
| 56 | + key="AZURE_OPENAI_API_KEY", |
| 57 | + type="password", |
| 58 | + ) |
| 59 | + azure_openai_api_version = st.text_input( |
| 60 | + label="AZURE_OPENAI_API_VERSION", |
| 61 | + value=getenv("AZURE_OPENAI_API_VERSION"), |
| 62 | + key="AZURE_OPENAI_API_VERSION", |
| 63 | + type="default", |
| 64 | + ) |
| 65 | + azure_openai_gpt_model = st.text_input( |
| 66 | + label="AZURE_OPENAI_GPT_MODEL", |
| 67 | + value=getenv("AZURE_OPENAI_GPT_MODEL"), |
| 68 | + key="AZURE_OPENAI_GPT_MODEL", |
| 69 | + type="default", |
| 70 | + ) |
| 71 | + azure_ai_speech_api_subscription_key = st.text_input( |
| 72 | + label="AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY", |
| 73 | + value=getenv("AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY"), |
| 74 | + key="AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY", |
| 75 | + type="password", |
| 76 | + ) |
| 77 | + azure_ai_speech_api_region = st.text_input( |
| 78 | + label="AZURE_AI_SPEECH_API_REGION", |
| 79 | + value=getenv("AZURE_AI_SPEECH_API_REGION"), |
| 80 | + key="AZURE_AI_SPEECH_API_REGION", |
| 81 | + type="default", |
| 82 | + ) |
| 83 | + "[Azure Portal](https://portal.azure.com/)" |
| 84 | + "[Azure OpenAI Studio](https://oai.azure.com/resource/overview)" |
| 85 | + "[View the source code](https://github.com/ks6088ts-labs/workshop-azure-openai/blob/main/apps/14_streamlit_azure_ai_speech/main.py)" |
| 86 | + |
| 87 | + |
| 88 | +def is_configured(): |
| 89 | + if inference_type == "local": |
| 90 | + return path_to_model and stt_host |
| 91 | + if inference_type == "azure": |
| 92 | + return azure_openai_api_key and azure_openai_endpoint and azure_openai_api_version and azure_openai_gpt_model |
| 93 | + |
| 94 | + |
| 95 | +st.title("transcribe text") |
| 96 | + |
| 97 | +if not is_configured(): |
| 98 | + st.warning("Please fill in the required fields at the sidebar.") |
| 99 | + |
| 100 | +st.info("This is a sample to transcribe text.") |
| 101 | + |
| 102 | +# --- |
| 103 | +# 2 column layout |
| 104 | + |
| 105 | +# 1st row |
| 106 | +row1_left, row1_right = st.columns(2) |
| 107 | +with row1_left: |
| 108 | + input = st.text_area( |
| 109 | + "Transcribed text", |
| 110 | + height=400, |
| 111 | + placeholder="Please enter the text to transcribe.", |
| 112 | + key="input", |
| 113 | + value=st.session_state["transcribed_result"], |
| 114 | + ) |
| 115 | + |
| 116 | +with row1_right: |
| 117 | + start_transcribe_button = st.button("start", disabled=not is_configured()) |
| 118 | + stop_transcribe_button = st.button("stop", disabled=not is_configured()) |
| 119 | + transcription_status = st.empty() |
| 120 | + |
| 121 | +# line break horizontal line |
| 122 | +st.markdown("---") |
| 123 | + |
| 124 | +# 2nd row |
| 125 | +row2_left, row2_right = st.columns(2) |
| 126 | + |
| 127 | +with row2_left: |
| 128 | + selected_task = st.selectbox( |
| 129 | + "Task", |
| 130 | + [ |
| 131 | + "Create summaries from the following text", |
| 132 | + "Extract 3 main points from the following text", |
| 133 | + # Add more tasks here |
| 134 | + ], |
| 135 | + key="selected_task", |
| 136 | + index=0, |
| 137 | + ) |
| 138 | + |
| 139 | +with row2_right: |
| 140 | + run_task_button = st.button("run_task", disabled=not is_configured()) |
| 141 | + |
| 142 | +path_to_transcribed_text = ".transcribed.txt" |
| 143 | + |
| 144 | + |
| 145 | +def start_recognition(): |
| 146 | + global process |
| 147 | + if inference_type == "local": |
| 148 | + command = f"python apps/14_streamlit_azure_ai_speech/speech_to_text.py --output {path_to_transcribed_text} --endpoint {stt_host} --language {azure_ai_speech_api_language} --type local --verbose" # noqa |
| 149 | + process = subprocess.Popen(command, shell=True) |
| 150 | + st.warning("Local inference is not yet implemented.") |
| 151 | + return |
| 152 | + if inference_type == "azure": |
| 153 | + command = f"python apps/14_streamlit_azure_ai_speech/speech_to_text.py --output {path_to_transcribed_text} --subscription {azure_ai_speech_api_subscription_key} --region {azure_ai_speech_api_region} --language {azure_ai_speech_api_language} --type azure --verbose" # noqa |
| 154 | + process = subprocess.Popen(command, shell=True) |
| 155 | + |
| 156 | + |
| 157 | +def run_task(selected_task: str, input: str) -> str: |
| 158 | + if inference_type == "local": |
| 159 | + st.warning("Local inference is not yet implemented.") |
| 160 | + return |
| 161 | + if inference_type == "azure": |
| 162 | + client = AzureOpenAI( |
| 163 | + api_key=azure_openai_api_key, |
| 164 | + api_version=azure_openai_api_version, |
| 165 | + azure_endpoint=azure_openai_endpoint, |
| 166 | + ) |
| 167 | + |
| 168 | + response = client.chat.completions.create( |
| 169 | + model=azure_openai_gpt_model, |
| 170 | + messages=[ |
| 171 | + { |
| 172 | + "role": "system", |
| 173 | + "content": f""" |
| 174 | + Task: {selected_task}. |
| 175 | + --- |
| 176 | + {input} |
| 177 | + --- |
| 178 | + """, |
| 179 | + }, |
| 180 | + ], |
| 181 | + ) |
| 182 | + return response.choices[0].message.content |
| 183 | + raise ValueError(f"Inference type is not supported: {inference_type}") |
| 184 | + |
| 185 | + |
| 186 | +def load_transcribed_text(): |
| 187 | + with open(path_to_transcribed_text) as f: |
| 188 | + return f.read() |
| 189 | + |
| 190 | + |
| 191 | +if start_transcribe_button: |
| 192 | + if not st.session_state.get("process"): |
| 193 | + transcription_status.info(f"Transcribing... (language={azure_ai_speech_api_language})") |
| 194 | + start_recognition() |
| 195 | + else: |
| 196 | + transcription_status.warning("Transcription is already running.") |
| 197 | + |
| 198 | +if stop_transcribe_button: |
| 199 | + pathlib.Path(".stop").touch() |
| 200 | + output = load_transcribed_text() |
| 201 | + st.session_state.transcribed_result = output |
| 202 | + st.rerun() |
| 203 | + |
| 204 | +if run_task_button: |
| 205 | + with st.spinner("Running..."): |
| 206 | + output = run_task( |
| 207 | + selected_task=selected_task, |
| 208 | + input=input, |
| 209 | + ) |
| 210 | + st.write(output) |
0 commit comments