From beab4049ba1d2ea47ac35c543d959245e78d2a5c Mon Sep 17 00:00:00 2001 From: giulis13 Date: Mon, 25 May 2026 23:07:35 +0200 Subject: [PATCH 1/4] Add llama.cpp LLM backend for LLM player with JSON action enforcing --- python/llmplayer.py | 22 ++-- python/rlc/llm_runner.py | 264 +++++++++++++++++++++++++++++++++++++-- python/rlc/utils.py | 19 +++ run-requirements.txt | 1 + stdlib/regex.rl | 176 ++++++++++++++++++++++++++ 5 files changed, 465 insertions(+), 17 deletions(-) create mode 100644 python/rlc/utils.py create mode 100644 stdlib/regex.rl diff --git a/python/llmplayer.py b/python/llmplayer.py index b91fd222..7383e57c 100644 --- a/python/llmplayer.py +++ b/python/llmplayer.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from os import devnull +from pathlib import Path + from command_line import ( + get_included_conents_from_args, load_program_from_args, make_rlc_argparse, - get_included_conents_from_args, ) -from rlc import State, Program, make_llm, run_game -from sys import stdout -from os import devnull +from rlc import make_llm, run_game def main(): @@ -35,9 +36,7 @@ def main(): parser.add_argument("--trace-output", type=str, default="-", nargs="?") parser.add_argument( "--gemini-stateless", - type=bool, - default=True, - nargs="?", + action="store_true", help="Use gemini but send only the current state, and do not keep track of past knowledge", ) parser.add_argument( @@ -48,23 +47,26 @@ def main(): parser.add_argument( "--ollama-local", action="store_true", help="Use ollama locally" ) + parser.add_argument("--llamacpp", action="store_true", help="Use llama.cpp") args = parser.parse_args() output = open(args.output, "w+") if args.output != "-" else open(devnull, "w") trace_output = open(args.trace_output, "w+") if args.trace_output != "-" else open(devnull, "w") rules = get_included_conents_from_args(args) - with load_program_from_args(args, optimize=True) as program: - llm = make_llm(args, program) + game_name = Path(args.source_file).stem + with load_program_from_args(args, optimize=True, extra_source_files=["stdlib/regex.rl"]) as program: + llm = make_llm(args, program, game_name) for action, thought in run_game( llm=llm, + game_name=game_name, program=program, rules=rules, output=output, trace_output=trace_output, ): print(thought) - print(program.to_string(action)) + print(program.module.to_string(action)) if __name__ == "__main__": diff --git a/python/rlc/llm_runner.py b/python/rlc/llm_runner.py index d3884092..11d9a315 100644 --- a/python/rlc/llm_runner.py +++ b/python/rlc/llm_runner.py @@ -1,6 +1,18 @@ -from rlc import State, Program -from sys import stdout +import json +import os +import re +import sys +import tempfile +from collections import defaultdict from random import choice +from sys import stdout +from textwrap import dedent + +import openai +import xgrammar as xgr + +from rlc import Program, State +from rlc.utils import rl_string_to_python, rl_vector_of_strings_to_python class Ollama: @@ -10,7 +22,7 @@ def __init__(self, program: Program): self.contexts = [None for x in range(program.module.get_num_players())] self.generate = generate - def chat(self, message: str, player_id: int) -> str: + def chat(self, message: str, player_id: int, *args, **kwargs) -> str: answer = self.generate( model="deepseek-r1:14b", prompt=message, context=self.contexts[player_id] ) @@ -34,7 +46,7 @@ def __init__(self, program: Program, model="gemini-2.0-flash"): for x in range(program.module.get_num_players()) ] - def chat(self, message: str, player_id: int) -> str: + def chat(self, message: str, player_id: int, *args, **kwargs) -> str: response = self.chats[player_id].send_message( f"You are player {player_id}. Notice the game code may imply that your id is mapped onto other numbers in the game state. " + message @@ -56,7 +68,7 @@ def __init__(self, program: Program, model="gemini-2.5-flash-preview-04-17", fir ] self.first_messages = [first_message for x in range(program.module.get_num_players())] - def chat(self, message: str, player_id: int) -> str: + def chat(self, message: str, player_id: int, *args, **kwargs) -> str: from google import genai if self.first_messages[player_id] == None: self.first_messages[player_id] = message @@ -74,6 +86,81 @@ def chat(self, message: str, player_id: int) -> str: return response.text +class LlamaCpp: + def __init__(self, program: Program, game_name: str, model_name: str = "model"): + self.client = openai.Client(api_key="...", base_url="http://localhost:8000/v1") + self.model_name = model_name + self.program = program + self.chats = [ + [{ + "role": "system", + "content": (f"You are an agent that plays in a reinforcement learning enviroment, you're player id is {x}, the game is {game_name}. Your goal is to win the game, by selecting the best action (among the legal ones) at each turn.\n" + "INSTRUCTIONS:\n- First, reason about your strategy (max 300 tokens).\n- Once you decided, write the chosen action.\nOutput only the action you choose, nothing else, no explanations, no comments, just the action.\n") + }] + for x in range(program.module.get_num_players()) + ] + + def chat(self, message: str, player_id: int, state: State) -> str: + chat_messages = self.chats[player_id] + self.chats[player_id].append({"role": "user", "content": message}) + + max_turns_in_history = 10 + previous_turns_in_history = 0 + for i in range(1, len(chat_messages)-1, 2): + if chat_messages[i]["role"] == "system": + break # found sys error of current turn, stop counting + previous_turns_in_history += 1 + if previous_turns_in_history > max_turns_in_history: + # keep system message and last max_turns_in_history interactions + n_removed_interactions = previous_turns_in_history - max_turns_in_history + chat_messages = [chat_messages[0]] + \ + [{"role": "system", "content": f"Previous {n_removed_interactions} interactions removed to keep the context within the limit."}] +\ + chat_messages[n_removed_interactions * 2 + 1:] + + + # do a first call just to let the model reason about the current state and decide what action to take + # no action is expected to be output, as we can't constrain the output yet + chat_response_reasoning = self.client.chat.completions.create( + model=self.model_name, + messages=chat_messages, + max_tokens=768 + 23, # reasononing budget tokens + reasoning budget message tokens + ) + if hasattr(chat_response_reasoning.choices[0].message, "reasoning_content"): + reasoning_str = chat_response_reasoning.choices[0].message.reasoning_content.strip() + start_reasoning_str = "<|channel>thought\n" + end_reasoning_str = "" + + # prepare grammar from regex to contrain the action output of the model + allowed_actions_regex = to_regex(self.program.module, state.legal_actions_indicies) + g = xgr.Grammar.from_regex(allowed_actions_regex) + gbnf = str(g) + extra_body = { + "grammar": gbnf, + "chat_template_kwargs": { + "enable_thinking": False, + } + } + self.chats[player_id].append({"role": "assistant", "content": start_reasoning_str + reasoning_str + end_reasoning_str}) + chat_response = self.client.chat.completions.create( + model=self.model_name, + messages=chat_messages, + max_tokens=1500, + extra_body=extra_body, + ) + answer = chat_response.choices[ + 0 + ].message.content.strip() + # remove the reasoning from the chat history + if hasattr(chat_response_reasoning.choices[0].message, "reasoning_content"): + self.chats[player_id].pop() + else: + # sometimes the model forgets to reason and just outputs the action + answer = chat_response_reasoning.choices[0].message.content.strip() + + self.chats[player_id].append({"role": "assistant", "content": answer}) + + return answer + def extract_index(string: str): position = string.rfind("action:") if position == -1: @@ -115,17 +202,22 @@ def solve_randomness(program: Program, state: State, trace_output): yield (action, "") -def make_llm(args, program): +def make_llm(args, program, game_name: str): if args.ollama_local: return Ollama(program) if args.gemini_statefull: return Gemini(program) if args.gemini_stateless: return GeminiStateless(program) + if args.llamacpp: + return LlamaCpp(program, game_name) return None -def run_game(llm, program: Program, rules: str, output=stdout, trace_output=stdout): +def run_game(llm, game_name: str, program: Program, rules: str, output=stdout, trace_output=stdout): + if isinstance(llm, LlamaCpp): + yield from run_game_with_llamacpp(llm, game_name, program, rules, output, trace_output) + return prompt_message = "The following is the current state, follwed by the actions you can take. Terminate your message with the number of the action you want to take, with the following sintax ACTION: INDEX. Explain your decisions." num_players = program.module.get_num_players() state = program.start() @@ -172,3 +264,161 @@ def run_game(llm, program: Program, rules: str, output=stdout, trace_output=stdo "FINAL SCORE: " + str([program.module.score(state.state, x) for x in range(num_players)]) ) + +def run_game_with_llamacpp(llm, game_name: str, program: Program, rules: str, output=stdout, trace_output=stdout): + prompt_message = "The following is the current state, followed by the actions you can take.\n" + num_players = program.module.get_num_players() + state = program.start() + for x in solve_randomness(program, state, trace_output): + yield x + + output.write(rules + "\n") + # for x in range(num_players): + # message = ( + # f"Here are the rules of the game: read them carefully. You will be prompted to play a game as player {x} against the opponent.\n```" + # + rules + "\n```\n\nSummarize the rules (be concise) and formulate a strategy to play the game. Keep it short." + # ) + # answer_to_rules = llm.chat(message=message, player_id=x) + # output.write(answer_to_rules) + + action_format_str = dedent(""" + ```json + { + "action_name": "$ACTION_NAME", + "parameters": { + "$PARAM_NAME1_FLOAT": $VALUE1, + "$PARAM_NAME2_STR": "$VALUE2", + "$PARAM_NAME3_BOOL": $VALUE3 + } + } + ```""") + + output.write(f"starting game {game_name}\n") + turn = 0 + + while not state.is_done(): + current_player = program.module.get_current_player(state.state) + output.write(f"---------- TURN {turn}, PLAYER {current_player} ----------\n") + output.write(capture_stdout(state.pretty_print) + "\n") + + output.write("CURRENT_PLAYER " + str(current_player) + "\n") + state_str = capture_stdout(state.pretty_print) + message = prompt_message + "\nCURRENT STATE:\n" + state_str + "\n" + message += "\nLEGAL ACTIONS:\n" + message += json.dumps(list(map(json.loads, rl_vector_of_strings_to_python(program.module.describe_actions()))), indent=4) + "\n" + + message += f"\nSelect your action by answering with one of above actions, using the format:\n{action_format_str}" + + output.write(message + "\n") + output.flush() + + answer = llm.chat(message=message, player_id=current_player, state=state) + action_index = get_action_index_from_llamacpp_answer(answer, state) + output.write(answer + "\n") + output.flush() + n_attempts = 1 + max_attempts = 20 + while action_index == -1 or not state.can_apply(state.actions[action_index]): + n_attempts += 1 + if n_attempts > max_attempts: + raise Exception(f"LLM failed to provide a valid action after {max_attempts} attempts, aborting the game.") + error_msg = "Failed to apply action, " + if action_index == -1: + error_msg += f"unable to parse answer. Please answer with format:\n{action_format_str}" + else: + error_msg += "the action you selected is not legal in the current state." + output.write(error_msg + "\n") + llm.chats[current_player].append({"role": "system", "content": error_msg}) + answer = llm.chat(message=message, player_id=current_player, state=state) + output.write(answer + "\n") + output.flush() + action_index = get_action_index_from_llamacpp_answer(answer, state) + + action = state.actions[action_index] + trace_output.write(str(action) + "\n") + trace_output.flush() + + output.write(f"player {current_player} chose action {action_index}: {str(action).strip()} ({n_attempts} attempts)\n") + output.flush() + + state.step(action) + yield (action, answer) + for x in solve_randomness(program, state, trace_output): + yield x + + if n_attempts > 1: + # clear wrong attempts from the chat history + for _ in range(3 * (n_attempts - 1)): # for each failed attempt, remove system, user and assistant messages + llm.chats[current_player].pop(-3) # leave last two messages there + turn += 1 + + output.write(f"game {game_name} ended\n") + output.write( + "FINAL SCORE: " + + str([program.module.score(state.state, x) for x in range(num_players)]) + ) + output.flush() + + +def capture_stdout(callable, *args, **kwargs) -> str: + original_stdout_fd = os.dup(sys.stdout.fileno()) + with tempfile.TemporaryFile(mode="w+b") as tmp: + try: + # Redirect stdout to the temporary file + os.dup2(tmp.fileno(), sys.stdout.fileno()) + + callable(*args, **kwargs) + sys.stdout.flush() + + # Seek to the beginning and read + tmp.seek(0) + output_str = tmp.read().decode() + return output_str + finally: + # Always restore original stdout, even if error happens + os.dup2(original_stdout_fd, sys.stdout.fileno()) + + +def to_regex(program_module, legal_actions_indicies): + if isinstance(legal_actions_indicies, list): + # actions = [rl_string_to_python(program_module.to_regex(i)) for i in legal_actions_indicies] + # regex = "(" + "|".join(set(actions)) + ")" + actions_json_list = rl_vector_of_strings_to_python(program_module.describe_actions()) + actions_regex_list = [] + for action_json in actions_json_list: + action = json.loads(action_json) + params = action['parameters_description'] + new_line = ",\n " + action_regex = dedent(f""" + ```json + {{ + "action_name": "{action['action_name']}", + "parameters": {{ + {new_line.join(f'"{param["name"]}": {param["regex"]}' for param in params)} + }} + }} + ``` + """).lstrip().replace("{", r"\{").replace("}", r"\}") + actions_regex_list.append(action_regex) + regex = "(" + "|".join(set(actions_regex_list)) + ")" + return regex + raise ValueError(f"unsupported type for legal_actions_indicies: {type(legal_actions_indicies)}") + +def get_action_index_from_llamacpp_answer(answer: str, state) -> int: + answer_json = json.loads(answer.split("```json")[-1].split("```")[0]) + chosen_action_name = answer_json["action_name"] + chosen_action_params = answer_json.get("parameters", {}) + try: + # find the index of what action was chosen by the LLM + action_index = -1 + for i, action_i in enumerate(state.actions): + action_i_name = str(action_i).split("{")[0].strip() + action_i_params_str = (re.search("({.*})", str(action_i)).group(0)) + action_i_params_json_str = re.sub(r"([\w_]+):", r'"\1":', action_i_params_str) # add double quotes + action_i_params = json.loads(action_i_params_json_str) + if action_i_name == chosen_action_name and action_i_params == chosen_action_params: + action_index = i + break + except ValueError: + action_index = -1 + return action_index \ No newline at end of file diff --git a/python/rlc/utils.py b/python/rlc/utils.py new file mode 100644 index 00000000..d497979b --- /dev/null +++ b/python/rlc/utils.py @@ -0,0 +1,19 @@ +def rl_string_to_python(rl_string) -> str: + """ + Convert an RL string representation to a Python string. + """ + python_string = "".join( + [chr(rl_string.get(i).contents.value) for i in range(rl_string._data._size)] + )[:-1] + return python_string + +def rl_vector_of_strings_to_python(vec) -> list[str]: + """ + Convert an RL vector of strings to a list of Python strings. + """ + result: list[str] = [] + for i in range(vec.size()): + s_ptr = vec.get(i) + s = s_ptr.contents + result.append(rl_string_to_python(s)) + return result \ No newline at end of file diff --git a/run-requirements.txt b/run-requirements.txt index 03529aa4..2029e9c4 100644 --- a/run-requirements.txt +++ b/run-requirements.txt @@ -4,3 +4,4 @@ gym3 tensorboard matplotlib standard-imghdr; python_version >= '3.13' +openai~=2.2 diff --git a/stdlib/regex.rl b/stdlib/regex.rl new file mode 100644 index 00000000..2c1c3c72 --- /dev/null +++ b/stdlib/regex.rl @@ -0,0 +1,176 @@ +import string +import bounded_arg + +import collections.vector +import enum_utils # for s(name) if not already in scope + +import serialization.print +import range + + +cls ParamInfo: + String type_name + String regex + + +fun describe_actions() -> Vector: + let any_action : AnyGameAction + return describe_actions_schema(any_action) + + +fun describe_actions_schema(AllActionsVariant variant) -> Vector: + let out : Vector + + # Case 1: we have a union like AnyGameAction + if variant is Alternative: + # print("variant is Alternative"s) + for alt_name, alt_field of variant: + # alt_name is a StringLiteral with the lowercase action name + # Get the actual type name using get_type_name() method + let action_name : String + if alt_field is CustomGetTypeName: + action_name.append(alt_field.get_type_name()) + else: + action_name.append(alt_name) + + let line = describe_single_action(alt_field, action_name) + out.append(line) + return out + + # Case 2: not a union, just a single action type – treat it as one action + let dummy : AllActionsVariant + let line = describe_single_action(dummy, "Action"s) + out.append(line) + return out + +# produces a JSON string describing the action schema, e.g.: +# { +# "action_name": "$ACTION_NAME", +# "parameters_description": [ +# { +# "name": "$parameter_name1", +# "type": "$parameter_type1", +# "regex": "$VALUE1" +# }, +# { +# "name": "$parameter_name2", +# "type": "$parameter_type2", +# "regex": "$VALUE2" +# } +# ] +# } +fun describe_single_action(ActionType action, String action_name) -> String: + let result : String + result.append("{\n") + result.append(" \"action_name\": \""s + action_name + "\",\n"s) + result.append(" \"parameters_description\": [\n"s) + + let first = true + for field_name, field of action: + if !first: + result.append(",\n") + first = false + + let info = describe_param(field) + result.append(" {\n"s) + result.append(" \"name\": \""s + s(field_name) + "\",\n"s) # convert field_name to String + result.append(" \"type\": \""s + info.type_name + "\",\n"s) # get the type of the field + result.append(" \"regex\": \""s + info.regex + "\"\n"s) + result.append(" }"s) + + result.append("\n ]\n"s) + result.append("}") + return result + +fun ensure_describe_param_implementations_are_instantiated() -> Int: + # Int + let x_int : Int + let _ = describe_param(x_int) + + # BInt TODO find a better way to do this + let x : BInt<0, 3> + let _ = describe_param(x) + let x : BInt<0, 7> + let _ = describe_param(x) + let x : BInt<0, 9> + let _ = describe_param(x) + let x : BInt<0, 10> + let _ = describe_param(x) + let x : BInt<0, 14> + let _ = describe_param(x) + let x : BInt<0, 52> + let _ = describe_param(x) + let x : BInt<1, 10> + let _ = describe_param(x) + + # ... + return 0 + + +fun describe_param(Int x) -> ParamInfo: + let info : ParamInfo + info.type_name = ""s + info.type_name.append("Int"s) + info.regex = ""s + info.regex.append("\\d"s) + return info + + +# For BInt, valid values are from min to max-1 +fun describe_param(BInt x) -> ParamInfo: + let info : ParamInfo + info.type_name = ""s + info.type_name.append("BInt<"s) + info.type_name.append(to_string(min)) + info.type_name.append(","s) + info.type_name.append(to_string(max)) + info.type_name.append(">"s) + + info.regex = ""s + info.regex.append("["s) + info.regex.append(to_string(min)) + info.regex.append("-"s) + info.regex.append(to_string(max - 1)) + info.regex.append("]"s) + return info + + + +# Fallback: anything we don't know how to serialize → ".*" +fun describe_param(T x) -> ParamInfo: + let info : ParamInfo + info.type_name = ""s + info.type_name.append("unknown"s) + info.regex = ""s + info.regex.append(".*"s) + return info + + +fun to_regex(Bool obj) -> String: + return "(true|false)"s + +fun to_regex(Float obj) -> String: + return "\d+\.\d+"s + +fun to_regex(Int obj) -> String: + return "\\d+"s + +fun to_regex(T obj) -> String: + return "unknown"s + +fun ensure_to_regex_is_instantiated() -> Int: + let _ = to_regex(true) + let _ = to_regex(123) + let _ = to_regex(10.123) + let _ = to_regex(""s) + let _ = to_regex([1, 2, 3]) + let x : Int[10] + let _ = to_regex(x) + let y : Float[10] + let _ = to_regex(y) + return 0 + +# fun main() -> Int: +# let r = test_describe_bint() +# print("result = "s + to_string(r)) +# return 0 \ No newline at end of file From 7da859fe042fadcd22a0169d4336ad1869848a06 Mon Sep 17 00:00:00 2001 From: giulis13 Date: Sun, 31 May 2026 23:55:51 +0200 Subject: [PATCH 2/4] Clean up code and add some comments --- python/rlc/llm_runner.py | 62 +++++++++++++++++++++++----------------- stdlib/regex.rl | 49 +++++++++---------------------- 2 files changed, 49 insertions(+), 62 deletions(-) diff --git a/python/rlc/llm_runner.py b/python/rlc/llm_runner.py index 11d9a315..4445be22 100644 --- a/python/rlc/llm_runner.py +++ b/python/rlc/llm_runner.py @@ -91,19 +91,28 @@ def __init__(self, program: Program, game_name: str, model_name: str = "model"): self.client = openai.Client(api_key="...", base_url="http://localhost:8000/v1") self.model_name = model_name self.program = program + # init chat histories for each player self.chats = [ [{ "role": "system", "content": (f"You are an agent that plays in a reinforcement learning enviroment, you're player id is {x}, the game is {game_name}. Your goal is to win the game, by selecting the best action (among the legal ones) at each turn.\n" "INSTRUCTIONS:\n- First, reason about your strategy (max 300 tokens).\n- Once you decided, write the chosen action.\nOutput only the action you choose, nothing else, no explanations, no comments, just the action.\n") - }] + }] # this system prompt is generic because it's supposed to work for any game, the specific info it contains are player_id and game_name for x in range(program.module.get_num_players()) ] def chat(self, message: str, player_id: int, state: State) -> str: + """Interacts with the LLM to get the action to play given the current state. + + The interaction is done in two steps: + 1. We first send the message to the model and let it reason about the current state + 2. We then send the message again, this time with a grammar that constrains the output of the model to be one of the possible actions + """ chat_messages = self.chats[player_id] self.chats[player_id].append({"role": "user", "content": message}) + ### CLEAN HISTORY FROM OLD MESSAGES ### + # to avoid hitting the context window limit, we keep in the history only the last max_turns_in_history interactions max_turns_in_history = 10 previous_turns_in_history = 0 for i in range(1, len(chat_messages)-1, 2): @@ -118,6 +127,7 @@ def chat(self, message: str, player_id: int, state: State) -> str: chat_messages[n_removed_interactions * 2 + 1:] + ### 1. LET THE MODEL REASON ### # do a first call just to let the model reason about the current state and decide what action to take # no action is expected to be output, as we can't constrain the output yet chat_response_reasoning = self.client.chat.completions.create( @@ -126,12 +136,13 @@ def chat(self, message: str, player_id: int, state: State) -> str: max_tokens=768 + 23, # reasononing budget tokens + reasoning budget message tokens ) if hasattr(chat_response_reasoning.choices[0].message, "reasoning_content"): + ### 2. CONSTRAIN THE MODEL OUTPUT WITH A GRAMMAR ### reasoning_str = chat_response_reasoning.choices[0].message.reasoning_content.strip() start_reasoning_str = "<|channel>thought\n" - end_reasoning_str = "" + end_reasoning_str = "\n" # prepare grammar from regex to contrain the action output of the model - allowed_actions_regex = to_regex(self.program.module, state.legal_actions_indicies) + allowed_actions_regex = create_regex_for_constrained_generation(self.program.module) g = xgr.Grammar.from_regex(allowed_actions_regex) gbnf = str(g) extra_body = { @@ -379,30 +390,29 @@ def capture_stdout(callable, *args, **kwargs) -> str: os.dup2(original_stdout_fd, sys.stdout.fileno()) -def to_regex(program_module, legal_actions_indicies): - if isinstance(legal_actions_indicies, list): - # actions = [rl_string_to_python(program_module.to_regex(i)) for i in legal_actions_indicies] - # regex = "(" + "|".join(set(actions)) + ")" - actions_json_list = rl_vector_of_strings_to_python(program_module.describe_actions()) - actions_regex_list = [] - for action_json in actions_json_list: - action = json.loads(action_json) - params = action['parameters_description'] - new_line = ",\n " - action_regex = dedent(f""" - ```json - {{ - "action_name": "{action['action_name']}", - "parameters": {{ - {new_line.join(f'"{param["name"]}": {param["regex"]}' for param in params)} - }} +def create_regex_for_constrained_generation(program_module): + """Creates a regex that matches the possible actions for the game, to be used for constrained generation with LlamaCpp. + + The idea is to generete a regex for the single action starting from the description returned from rlc, than combine the regexes of all the actions into one, using the OR operator.""" + actions_json_list = rl_vector_of_strings_to_python(program_module.describe_actions()) + actions_regex_list = [] + for action_json in actions_json_list: + action = json.loads(action_json) + params = action['parameters_description'] + new_line = ",\n " + action_regex = dedent(f""" + ```json + {{ + "action_name": "{action['action_name']}", + "parameters": {{ + {new_line.join(f'"{param["name"]}": {param["regex"]}' for param in params)} }} - ``` - """).lstrip().replace("{", r"\{").replace("}", r"\}") - actions_regex_list.append(action_regex) - regex = "(" + "|".join(set(actions_regex_list)) + ")" - return regex - raise ValueError(f"unsupported type for legal_actions_indicies: {type(legal_actions_indicies)}") + }} + ``` + """).lstrip().replace("{", r"\{").replace("}", r"\}") + actions_regex_list.append(action_regex) + regex = "(" + "|".join(set(actions_regex_list)) + ")" + return regex def get_action_index_from_llamacpp_answer(answer: str, state) -> int: answer_json = json.loads(answer.split("```json")[-1].split("```")[0]) diff --git a/stdlib/regex.rl b/stdlib/regex.rl index 2c1c3c72..37a204bf 100644 --- a/stdlib/regex.rl +++ b/stdlib/regex.rl @@ -15,15 +15,16 @@ cls ParamInfo: fun describe_actions() -> Vector: let any_action : AnyGameAction - return describe_actions_schema(any_action) + # AnyGameAction gets all possible actions for the game + return create_vector_of_actions_descriptions(any_action) -fun describe_actions_schema(AllActionsVariant variant) -> Vector: +# produces a vector whose elements are JSON strings describing the action schema (see also describe_single_action) +fun create_vector_of_actions_descriptions(AllActionsVariant variant) -> Vector: let out : Vector # Case 1: we have a union like AnyGameAction if variant is Alternative: - # print("variant is Alternative"s) for alt_name, alt_field of variant: # alt_name is a StringLiteral with the lowercase action name # Get the actual type name using get_type_name() method @@ -82,12 +83,14 @@ fun describe_single_action(ActionType action, String action_name) -> result.append("}") return result +# this function is needed to avoid the fallback to the generic implementation +# unfortunately, for BInt right now it's required to explicitly call describe_param with the exact numbers fun ensure_describe_param_implementations_are_instantiated() -> Int: # Int let x_int : Int let _ = describe_param(x_int) - # BInt TODO find a better way to do this + # BInt TODO find a better way to do this, instead of enumerating combinations let x : BInt<0, 3> let _ = describe_param(x) let x : BInt<0, 7> @@ -98,15 +101,16 @@ fun ensure_describe_param_implementations_are_instantiated() -> Int: let _ = describe_param(x) let x : BInt<0, 14> let _ = describe_param(x) - let x : BInt<0, 52> + let x : BInt<0, 52> # range 0-51 for blackjack let _ = describe_param(x) - let x : BInt<1, 10> + let x : BInt<1, 10> # range 0-9 for sudoku let _ = describe_param(x) # ... return 0 +# For an integer, return the regex "\d" and the name "Int" fun describe_param(Int x) -> ParamInfo: let info : ParamInfo info.type_name = ""s @@ -132,11 +136,13 @@ fun describe_param(BInt x) -> ParamInfo: info.regex.append("-"s) info.regex.append(to_string(max - 1)) info.regex.append("]"s) + + # TODO handle numbers with more than 1 digit return info -# Fallback: anything we don't know how to serialize → ".*" +# Fallback: anything we don't know how to serialize -> ".*" fun describe_param(T x) -> ParamInfo: let info : ParamInfo info.type_name = ""s @@ -145,32 +151,3 @@ fun describe_param(T x) -> ParamInfo: info.regex.append(".*"s) return info - -fun to_regex(Bool obj) -> String: - return "(true|false)"s - -fun to_regex(Float obj) -> String: - return "\d+\.\d+"s - -fun to_regex(Int obj) -> String: - return "\\d+"s - -fun to_regex(T obj) -> String: - return "unknown"s - -fun ensure_to_regex_is_instantiated() -> Int: - let _ = to_regex(true) - let _ = to_regex(123) - let _ = to_regex(10.123) - let _ = to_regex(""s) - let _ = to_regex([1, 2, 3]) - let x : Int[10] - let _ = to_regex(x) - let y : Float[10] - let _ = to_regex(y) - return 0 - -# fun main() -> Int: -# let r = test_describe_bint() -# print("result = "s + to_string(r)) -# return 0 \ No newline at end of file From 2f4dd6fb2fa814139553c5b054d40a06e1dc1b59 Mon Sep 17 00:00:00 2001 From: giulis13 Date: Mon, 1 Jun 2026 00:00:02 +0200 Subject: [PATCH 3/4] add xgrammar dependency --- run-requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/run-requirements.txt b/run-requirements.txt index 2029e9c4..6710a9e6 100644 --- a/run-requirements.txt +++ b/run-requirements.txt @@ -5,3 +5,4 @@ tensorboard matplotlib standard-imghdr; python_version >= '3.13' openai~=2.2 +xgrammar==0.1.31 From 42576a4151c7517b0113d0ed4494075f6f01ce69 Mon Sep 17 00:00:00 2001 From: giulis13 Date: Sun, 7 Jun 2026 10:41:02 +0200 Subject: [PATCH 4/4] Add variables to control reasoning and regex --- python/llmplayer.py | 8 +- python/rlc/llm_runner.py | 230 +++++++++++++++++++++------------------ 2 files changed, 129 insertions(+), 109 deletions(-) diff --git a/python/llmplayer.py b/python/llmplayer.py index 7383e57c..7fe75516 100644 --- a/python/llmplayer.py +++ b/python/llmplayer.py @@ -48,6 +48,8 @@ def main(): "--ollama-local", action="store_true", help="Use ollama locally" ) parser.add_argument("--llamacpp", action="store_true", help="Use llama.cpp") + parser.add_argument("--no-reasoning", action="store_true", help="Do not ask the model to reason, just output the action") + parser.add_argument("--no-regex", action="store_true", help="Do not use regex to constrain the model output") args = parser.parse_args() @@ -56,7 +58,7 @@ def main(): rules = get_included_conents_from_args(args) game_name = Path(args.source_file).stem with load_program_from_args(args, optimize=True, extra_source_files=["stdlib/regex.rl"]) as program: - llm = make_llm(args, program, game_name) + llm = make_llm(args, program, game_name, should_reason=not args.no_reasoning, should_use_regex=not args.no_regex) for action, thought in run_game( llm=llm, game_name=game_name, @@ -65,8 +67,8 @@ def main(): output=output, trace_output=trace_output, ): - print(thought) - print(program.module.to_string(action)) + print(f"thought: {thought}") + print(f"action: {action}") if __name__ == "__main__": diff --git a/python/rlc/llm_runner.py b/python/rlc/llm_runner.py index 4445be22..f0ba39e2 100644 --- a/python/rlc/llm_runner.py +++ b/python/rlc/llm_runner.py @@ -87,21 +87,23 @@ def chat(self, message: str, player_id: int, *args, **kwargs) -> str: class LlamaCpp: - def __init__(self, program: Program, game_name: str, model_name: str = "model"): + def __init__(self, program: Program, game_name: str, model_name: str = "model", should_reason: bool = True, should_use_regex: bool = True): self.client = openai.Client(api_key="...", base_url="http://localhost:8000/v1") self.model_name = model_name self.program = program + self.should_reason = should_reason + self.should_use_regex = should_use_regex # init chat histories for each player self.chats = [ [{ "role": "system", "content": (f"You are an agent that plays in a reinforcement learning enviroment, you're player id is {x}, the game is {game_name}. Your goal is to win the game, by selecting the best action (among the legal ones) at each turn.\n" - "INSTRUCTIONS:\n- First, reason about your strategy (max 300 tokens).\n- Once you decided, write the chosen action.\nOutput only the action you choose, nothing else, no explanations, no comments, just the action.\n") + "INSTRUCTIONS:\n- First, reason about your strategy (keep reasoning extremely concise, limiting yourself to no more than 2–3 key logical steps).\n- Once you decided, write the chosen action.\nOutput only the action you choose, nothing else, no explanations, no comments, just the action.\n") }] # this system prompt is generic because it's supposed to work for any game, the specific info it contains are player_id and game_name for x in range(program.module.get_num_players()) ] - def chat(self, message: str, player_id: int, state: State) -> str: + def chat(self, message: str, player_id: int, state: State) -> tuple[str, str]: """Interacts with the LLM to get the action to play given the current state. The interaction is done in two steps: @@ -130,47 +132,48 @@ def chat(self, message: str, player_id: int, state: State) -> str: ### 1. LET THE MODEL REASON ### # do a first call just to let the model reason about the current state and decide what action to take # no action is expected to be output, as we can't constrain the output yet - chat_response_reasoning = self.client.chat.completions.create( - model=self.model_name, - messages=chat_messages, - max_tokens=768 + 23, # reasononing budget tokens + reasoning budget message tokens - ) - if hasattr(chat_response_reasoning.choices[0].message, "reasoning_content"): - ### 2. CONSTRAIN THE MODEL OUTPUT WITH A GRAMMAR ### - reasoning_str = chat_response_reasoning.choices[0].message.reasoning_content.strip() - start_reasoning_str = "<|channel>thought\n" - end_reasoning_str = "\n" - - # prepare grammar from regex to contrain the action output of the model - allowed_actions_regex = create_regex_for_constrained_generation(self.program.module) - g = xgr.Grammar.from_regex(allowed_actions_regex) - gbnf = str(g) - extra_body = { - "grammar": gbnf, - "chat_template_kwargs": { - "enable_thinking": False, - } - } - self.chats[player_id].append({"role": "assistant", "content": start_reasoning_str + reasoning_str + end_reasoning_str}) - chat_response = self.client.chat.completions.create( + start_reasoning_str = "<|channel>thought\n" + end_reasoning_str = "\n" + reasoning_str = "" + if self.should_reason: + chat_response_reasoning = self.client.chat.completions.create( model=self.model_name, messages=chat_messages, - max_tokens=1500, - extra_body=extra_body, + max_tokens=768 + 23, # reasononing budget tokens + reasoning budget message tokens ) - answer = chat_response.choices[ - 0 - ].message.content.strip() - # remove the reasoning from the chat history if hasattr(chat_response_reasoning.choices[0].message, "reasoning_content"): - self.chats[player_id].pop() - else: - # sometimes the model forgets to reason and just outputs the action - answer = chat_response_reasoning.choices[0].message.content.strip() + reasoning_str = chat_response_reasoning.choices[0].message.reasoning_content.strip() + + ### 2. CONSTRAIN THE MODEL OUTPUT WITH A GRAMMAR ### + # prepare grammar from regex to contrain the action output of the model + extra_body = { + "chat_template_kwargs": { + "enable_thinking": False, + } + } + if self.should_use_regex: + allowed_actions_regex = create_regex_for_constrained_generation(self.program.module) + g = xgr.Grammar.from_regex(allowed_actions_regex) + gbnf = str(g) + extra_body["grammar"] = gbnf + + self.chats[player_id].append({"role": "assistant", "content": start_reasoning_str + reasoning_str + end_reasoning_str}) + chat_response = self.client.chat.completions.create( + model=self.model_name, + messages=chat_messages, + max_tokens=1500, + extra_body=extra_body, + ) + answer = chat_response.choices[ + 0 + ].message.content.strip() + # remove the reasoning from the chat history + if self.should_reason and hasattr(chat_response_reasoning.choices[0].message, "reasoning_content"): + self.chats[player_id].pop() self.chats[player_id].append({"role": "assistant", "content": answer}) - return answer + return reasoning_str, answer def extract_index(string: str): position = string.rfind("action:") @@ -213,7 +216,7 @@ def solve_randomness(program: Program, state: State, trace_output): yield (action, "") -def make_llm(args, program, game_name: str): +def make_llm(args, program, game_name: str, should_reason=True, should_use_regex=True): if args.ollama_local: return Ollama(program) if args.gemini_statefull: @@ -221,7 +224,7 @@ def make_llm(args, program, game_name: str): if args.gemini_stateless: return GeminiStateless(program) if args.llamacpp: - return LlamaCpp(program, game_name) + return LlamaCpp(program, game_name, should_reason=should_reason, should_use_regex=should_use_regex) return None @@ -284,13 +287,13 @@ def run_game_with_llamacpp(llm, game_name: str, program: Program, rules: str, ou yield x output.write(rules + "\n") - # for x in range(num_players): - # message = ( - # f"Here are the rules of the game: read them carefully. You will be prompted to play a game as player {x} against the opponent.\n```" - # + rules + "\n```\n\nSummarize the rules (be concise) and formulate a strategy to play the game. Keep it short." - # ) - # answer_to_rules = llm.chat(message=message, player_id=x) - # output.write(answer_to_rules) + for x in range(num_players): + message = ( + f"Here are the rules of the game: read them carefully. You will be prompted to play a game as player {x} against the opponent.\n```" + + rules + "\n```\n\nSummarize the rules (be concise) and formulate a strategy to play the game. Keep it short." + ) + reasoning, answer_to_rules = llm.chat(message=message, player_id=x, state=state) + output.write(answer_to_rules) action_format_str = dedent(""" ```json @@ -307,69 +310,82 @@ def run_game_with_llamacpp(llm, game_name: str, program: Program, rules: str, ou output.write(f"starting game {game_name}\n") turn = 0 - while not state.is_done(): - current_player = program.module.get_current_player(state.state) - output.write(f"---------- TURN {turn}, PLAYER {current_player} ----------\n") - output.write(capture_stdout(state.pretty_print) + "\n") - - output.write("CURRENT_PLAYER " + str(current_player) + "\n") - state_str = capture_stdout(state.pretty_print) - message = prompt_message + "\nCURRENT STATE:\n" + state_str + "\n" - message += "\nLEGAL ACTIONS:\n" - message += json.dumps(list(map(json.loads, rl_vector_of_strings_to_python(program.module.describe_actions()))), indent=4) + "\n" - - message += f"\nSelect your action by answering with one of above actions, using the format:\n{action_format_str}" - - output.write(message + "\n") - output.flush() + num_actions_with_invalid_syntax = 0 + num_illegal_actions = 0 + state_str = "" + try: + while not state.is_done(): + current_player = program.module.get_current_player(state.state) + output.write(f"---------- TURN {turn}, PLAYER {current_player} ----------\n") + output.write(capture_stdout(state.pretty_print) + "\n") + + output.write("CURRENT_PLAYER " + str(current_player) + "\n") + state_str = capture_stdout(state.pretty_print) + message = prompt_message + "\nCURRENT STATE:\n" + state_str + "\n" + message += "\nLEGAL ACTIONS:\n" + message += json.dumps(list(map(json.loads, rl_vector_of_strings_to_python(program.module.describe_actions()))), indent=4) + "\n" + + output.write(message + "\n") + output.flush() - answer = llm.chat(message=message, player_id=current_player, state=state) - action_index = get_action_index_from_llamacpp_answer(answer, state) - output.write(answer + "\n") - output.flush() - n_attempts = 1 - max_attempts = 20 - while action_index == -1 or not state.can_apply(state.actions[action_index]): - n_attempts += 1 - if n_attempts > max_attempts: - raise Exception(f"LLM failed to provide a valid action after {max_attempts} attempts, aborting the game.") - error_msg = "Failed to apply action, " - if action_index == -1: - error_msg += f"unable to parse answer. Please answer with format:\n{action_format_str}" - else: - error_msg += "the action you selected is not legal in the current state." - output.write(error_msg + "\n") - llm.chats[current_player].append({"role": "system", "content": error_msg}) - answer = llm.chat(message=message, player_id=current_player, state=state) + reasoning, answer = llm.chat(message=message, player_id=current_player, state=state) + action_index = get_action_index_from_llamacpp_answer(answer, state) output.write(answer + "\n") output.flush() - action_index = get_action_index_from_llamacpp_answer(answer, state) - - action = state.actions[action_index] - trace_output.write(str(action) + "\n") - trace_output.flush() + n_attempts = 1 + max_attempts = 20 + while action_index == -1 or not state.can_apply(state.actions[action_index]): + if action_index == -1: + num_actions_with_invalid_syntax += 1 + else: + num_illegal_actions += 1 + n_attempts += 1 + if n_attempts > max_attempts: + raise RuntimeError(f"LLM failed to provide a valid action after {max_attempts} attempts, aborting the game.") + error_msg = "Failed to apply action, " + if action_index == -1: + error_msg += f"unable to parse answer. Please answer with format:\n{action_format_str}" + else: + error_msg += "the action you selected is not legal in the current state." + output.write(error_msg + "\n") + llm.chats[current_player].append({"role": "system", "content": error_msg}) + reasoning, answer = llm.chat(message=message, player_id=current_player, state=state) + output.write(answer + "\n") + output.flush() + action_index = get_action_index_from_llamacpp_answer(answer, state) + + action = state.actions[action_index] + trace_output.write(str(action) + "\n") + trace_output.flush() + + output.write(f"player {current_player} chose action {action_index}: {str(action).strip()} ({n_attempts} attempts)\n") + output.flush() - output.write(f"player {current_player} chose action {action_index}: {str(action).strip()} ({n_attempts} attempts)\n") + state.step(action) + yield (action, reasoning) + for x in solve_randomness(program, state, trace_output): + yield x + + if n_attempts > 1: + # clear wrong attempts from the chat history + for _ in range(3 * (n_attempts - 1)): # for each failed attempt, remove system, user and assistant messages + llm.chats[current_player].pop(-3) # leave last two messages there + turn += 1 + except RuntimeError as e: + output.write(str(e) + "\n") + finally: + print(f"final state:\n{state_str}") + output.write(f"game {game_name} ended\n") + output.write(f"number of actions with invalid syntax: {num_actions_with_invalid_syntax}\n") + output.write(f"number of illegal actions: {num_illegal_actions}\n") + output.write(f"total number of wrong actions: {num_actions_with_invalid_syntax + num_illegal_actions}\n") + output.write(f"total number of turns: {turn}\n") + output.write( + "FINAL SCORE: " + + str([program.module.score(state.state, x) for x in range(num_players)]) + ) output.flush() - state.step(action) - yield (action, answer) - for x in solve_randomness(program, state, trace_output): - yield x - - if n_attempts > 1: - # clear wrong attempts from the chat history - for _ in range(3 * (n_attempts - 1)): # for each failed attempt, remove system, user and assistant messages - llm.chats[current_player].pop(-3) # leave last two messages there - turn += 1 - - output.write(f"game {game_name} ended\n") - output.write( - "FINAL SCORE: " - + str([program.module.score(state.state, x) for x in range(num_players)]) - ) - output.flush() - def capture_stdout(callable, *args, **kwargs) -> str: original_stdout_fd = os.dup(sys.stdout.fileno()) @@ -415,10 +431,10 @@ def create_regex_for_constrained_generation(program_module): return regex def get_action_index_from_llamacpp_answer(answer: str, state) -> int: - answer_json = json.loads(answer.split("```json")[-1].split("```")[0]) - chosen_action_name = answer_json["action_name"] - chosen_action_params = answer_json.get("parameters", {}) try: + answer_json = json.loads(answer.split("```json")[-1].split("```")[0]) + chosen_action_name = answer_json["action_name"] + chosen_action_params = answer_json.get("parameters", {}) # find the index of what action was chosen by the LLM action_index = -1 for i, action_i in enumerate(state.actions): @@ -429,6 +445,8 @@ def get_action_index_from_llamacpp_answer(answer: str, state) -> int: if action_i_name == chosen_action_name and action_i_params == chosen_action_params: action_index = i break + except json.JSONDecodeError: + action_index = -1 except ValueError: action_index = -1 return action_index \ No newline at end of file