diff --git a/Assets/Editor/AICommandWindow.cs b/Assets/Editor/AICommandWindow.cs index 9c22739..c3ab160 100644 --- a/Assets/Editor/AICommandWindow.cs +++ b/Assets/Editor/AICommandWindow.cs @@ -35,7 +35,7 @@ static string WrapPrompt(string input) void RunGenerator() { - var code = OpenAIUtil.InvokeChat(WrapPrompt(_prompt)); + var code = OpenAIUtil.InvokeChat(WrapPrompt(_prompt), _modelType); Debug.Log("AI command script:" + code); CreateScriptAsset(code); } @@ -45,6 +45,8 @@ void RunGenerator() #region Editor GUI string _prompt = "Create 100 cubes at random points."; + + ModelType _modelType = ModelType.gpt_3_5_turbo; const string ApiKeyErrorText = "API Key hasn't been set. Please check the project settings " + @@ -61,6 +63,11 @@ void OnGUI() if (IsApiKeyOk) { _prompt = EditorGUILayout.TextArea(_prompt, GUILayout.ExpandHeight(true)); + _modelType = (ModelType)EditorGUILayout.EnumPopup("Model To Use", _modelType); + + if (_modelType == ModelType.gpt_4) + EditorGUILayout.HelpBox("Ensure you have approved API access to GPT-4, or the command will fail!", MessageType.Info); + if (GUILayout.Button("Run")) RunGenerator(); } else diff --git a/Assets/Editor/OpenAIUtil.cs b/Assets/Editor/OpenAIUtil.cs index 95513c1..ef2e4d3 100644 --- a/Assets/Editor/OpenAIUtil.cs +++ b/Assets/Editor/OpenAIUtil.cs @@ -2,55 +2,78 @@ using UnityEditor; using UnityEngine.Networking; -namespace AICommand { - -static class OpenAIUtil +namespace AICommand { - static string CreateChatRequestBody(string prompt) + public enum ModelType { - var msg = new OpenAI.RequestMessage(); - msg.role = "user"; - msg.content = prompt; - - var req = new OpenAI.Request(); - req.model = "gpt-3.5-turbo"; - req.messages = new [] { msg }; + gpt_4, + gpt_3_5_turbo + }; - return JsonUtility.ToJson(req); + //Dictionary for the ModelType to the string used by the API + static class ModelTypeDict + { + public static string GetModelTypeString(ModelType modelType) + { + switch (modelType) + { + case ModelType.gpt_4: + return "gpt-4"; + default: + return "gpt-3.5-turbo"; + } + } } - public static string InvokeChat(string prompt) + static class OpenAIUtil { - var settings = AICommandSettings.instance; - // POST - using var post = UnityWebRequest.Post - (OpenAI.Api.Url, CreateChatRequestBody(prompt), "application/json"); - - // Request timeout setting - post.timeout = settings.timeout; + static string CreateChatRequestBody(string prompt, ModelType modelType) + { + var msg = new OpenAI.RequestMessage(); + msg.role = "user"; + msg.content = prompt; - // API key authorization - post.SetRequestHeader("Authorization", "Bearer " + settings.apiKey); + var req = new OpenAI.Request(); + req.model = ModelTypeDict.GetModelTypeString(modelType); + req.messages = new[] { msg }; - // Request start - var req = post.SendWebRequest(); + return JsonUtility.ToJson(req); + } - // Progress bar (Totally fake! Don't try this at home.) - for (var progress = 0.0f; !req.isDone; progress += 0.01f) + public static string InvokeChat(string prompt, ModelType _modelType) { - EditorUtility.DisplayProgressBar - ("AI Command", "Generating...", progress); - System.Threading.Thread.Sleep(100); - progress += 0.01f; - } - EditorUtility.ClearProgressBar(); + var settings = AICommandSettings.instance; - // Response extraction - var json = post.downloadHandler.text; - var data = JsonUtility.FromJson(json); - return data.choices[0].message.content; + // POST + using var post = UnityWebRequest.Post + (OpenAI.Api.Url, CreateChatRequestBody(prompt, _modelType), "application/json"); + + // Request timeout setting + post.timeout = settings.timeout; + + // API key authorization + post.SetRequestHeader("Authorization", "Bearer " + settings.apiKey); + + // Request start + var req = post.SendWebRequest(); + + // Progress bar (Totally fake! Don't try this at home.) + for (var progress = 0.0f; !req.isDone; progress += 0.01f) + { + EditorUtility.DisplayProgressBar + ("AI Command", "Generating...", progress); + System.Threading.Thread.Sleep(100); + progress += 0.01f; + } + + EditorUtility.ClearProgressBar(); + + // Response extraction + var json = post.downloadHandler.text; + var data = JsonUtility.FromJson(json); + return data.choices[0].message.content; + } } } - -} // namespace AICommand +// namespace AICommand