diff --git a/neuron-explainer/neuron_explainer/api_client.py b/neuron-explainer/neuron_explainer/api_client.py index 46b5e96..e63ca32 100644 --- a/neuron-explainer/neuron_explainer/api_client.py +++ b/neuron-explainer/neuron_explainer/api_client.py @@ -114,7 +114,7 @@ def __init__( @exponential_backoff(retry_on=is_api_error) async def make_request( - self, timeout_seconds: Optional[int] = None, **kwargs: Any + self, timeout_seconds: Optional[int] = None, json_mode: Optional[bool] = False, **kwargs: Any ) -> dict[str, Any]: if self._cache is not None: key = orjson.dumps(kwargs) @@ -130,6 +130,8 @@ async def make_request( # endpoint. Otherwise, it should be sent to the /completions endpoint. url = BASE_API_URL + ("/chat/completions" if "messages" in kwargs else "/completions") kwargs["model"] = self.model_name + if json_mode: + kwargs["response_format"] = {"type": "json_object"} response = await http_client.post(url, headers=API_HTTP_HEADERS, json=kwargs) # The response json has useful information but the exception doesn't include it, so print it # out then reraise. diff --git a/neuron-explainer/neuron_explainer/explanations/few_shot_examples.py b/neuron-explainer/neuron_explainer/explanations/few_shot_examples.py index 1fb933b..0ae040a 100644 --- a/neuron-explainer/neuron_explainer/explanations/few_shot_examples.py +++ b/neuron-explainer/neuron_explainer/explanations/few_shot_examples.py @@ -40,6 +40,7 @@ class FewShotExampleSet(Enum): ORIGINAL = "original" NEWER = "newer" TEST = "test" + JL_FINE_TUNED = "jl_fine_tuned" @classmethod def from_string(cls, string: str) -> FewShotExampleSet: @@ -56,6 +57,8 @@ def get_examples(self) -> list[Example]: return NEWER_EXAMPLES elif self is FewShotExampleSet.TEST: return TEST_EXAMPLES + elif self is FewShotExampleSet.JL_FINE_TUNED: + return JL_FINE_TUNED_EXAMPLES else: raise ValueError(f"Unhandled example set: {self}") @@ -1038,3 +1041,179 @@ def get_single_token_prediction_example(self) -> Example: token_index_to_score=18, explanation="instances of the token 'ate' as part of another word", ) + + +JL_FINE_TUNED_EXAMPLES = [ + Example( + activation_records=[ + ActivationRecord( + tokens=[ + "The", + " cat", + " jumped", + " on", + " my", + " laptop", + ".", + ], + activations=[ + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + ), + ], + first_revealed_activation_indices=[], + explanation="the word \"laptop\" before the word \"cat\"", + ), + Example( + activation_records=[ + ActivationRecord( + tokens=[ + "The", + " cat", + " jumped", + " on", + " my", + " laptop", + ".", + ], + activations=[ + 0, + 10, + 0, + 0, + 0, + 0, + 0 + ], + ), + ], + first_revealed_activation_indices=[], + explanation="the word \"cat\" before the word \"laptop\"", + ), + Example( + activation_records=[ + ActivationRecord( + tokens=[ + "I", + " am", + " using", + " a", + " keyboard", + ".", + ], + activations=[ + 0, + 0, + 0, + 0, + 10, + 0 + ], + ), + ], + first_revealed_activation_indices=[], + explanation="the word before a period", + ), + Example( + activation_records=[ + ActivationRecord( + tokens=[ + "The", + " sun", + " is", + " shining", + ".", + " The", + " clouds", + " are", + " gone", + ".", + " Great", + " weather", + "!", + ], + activations=[ + 0, + 0, + 0, + 10, + 0, + 0, + 0, + 0, + 10, + 0, + 0, + 0, + 0 + ], + ), + ], + first_revealed_activation_indices=[], + explanation="the word before period", + ), +] + +NEWER_SINGLE_TOKEN_EXAMPLE = Example( + activation_records=[ + ActivationRecord( + tokens=[ + "B", + "10", + " ", + "111", + " MON", + "DAY", + ",", + " F", + "EB", + "RU", + "ARY", + " ", + "11", + ",", + " ", + "201", + "9", + " DON", + "ATE", + "fake higher scoring token", # See below. + ], + activations=[ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.37, + # This fake activation makes the previous token's activation normalize to 8, which + # might help address overconfidence in "10" activations for the one-token-at-a-time + # scoring prompt. This value and the associated token don't actually appear anywhere + # in the prompt. + 0.45, + ], + ), + ], + first_revealed_activation_indices=[], + token_index_to_score=18, + explanation="instances of the token 'ate' as part of another word", +) diff --git a/neuron-explainer/neuron_explainer/explanations/simulator.py b/neuron-explainer/neuron_explainer/explanations/simulator.py index 4111ead..d40cea5 100644 --- a/neuron-explainer/neuron_explainer/explanations/simulator.py +++ b/neuron-explainer/neuron_explainer/explanations/simulator.py @@ -4,6 +4,7 @@ import asyncio import logging +import json from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum @@ -36,6 +37,9 @@ VALID_ACTIVATION_TOKENS_ORDERED = list(str(i) for i in range(MAX_NORMALIZED_ACTIVATION + 1)) VALID_ACTIVATION_TOKENS = set(VALID_ACTIVATION_TOKENS_ORDERED) +# Edge Case #3: The chat-based simulator is confused by end token. Replace it with a "not end token" +END_OF_TEXT_TOKEN = "<|endoftext|>" +END_OF_TEXT_TOKEN_REPLACEMENT = "<|not_endoftext|>" class SimulationType(str, Enum): """How to simulate neuron activations. Values correspond to subclasses of NeuronSimulator.""" @@ -590,6 +594,9 @@ def _format_record_for_logprob_free_simulation( activation_record.activations, max_activation=max_activation ) for i, token in enumerate(activation_record.tokens): + # Edge Case #3: End tokens confuse the chat-based simulator. Replace end token with "not end token". + if token.strip() == END_OF_TEXT_TOKEN: + token = END_OF_TEXT_TOKEN_REPLACEMENT # We use a weird unicode character here to make it easier to parse the response (can split on "༗\n"). if include_activations: response += f"{token}\t{normalized_activations[i]}༗\n" @@ -597,44 +604,167 @@ def _format_record_for_logprob_free_simulation( response += f"{token}\t༗\n" return response +def _format_record_for_logprob_free_simulation_json( + explanation: str, + activation_record: ActivationRecord, + include_activations: bool = False, +) -> str: + if include_activations: + assert len(activation_record.tokens) == len( + activation_record.activations + ), f"{len(activation_record.tokens)=}, {len(activation_record.activations)=}" + return json.dumps({ + "to_find": explanation, + "document": "".join(activation_record.tokens), + "activations": [ + { + "token": token, + "activation": activation_record.activations[i] if include_activations else None + } for i, token in enumerate(activation_record.tokens) + ] + }) + +def _parse_no_logprobs_completion_json( + completion: str, + tokens: Sequence[str], +) -> Sequence[float]: + """ + Parse a completion into a list of simulated activations. If the model did not faithfully + reproduce the token sequence, return a list of 0s. If the model's activation for a token + is not a number between 0 and 10 (inclusive), substitute 0. + + Args: + completion: completion from the API + tokens: list of tokens as strings in the sequence where the neuron is being simulated + """ + + logger.debug("for tokens:\n%s", tokens) + logger.debug("received completion:\n%s", completion) + + zero_prediction = [0] * len(tokens) + + try: + completion = json.loads(completion) + if "activations" not in completion: + logger.error("The key 'activations' is not in the completion:\n%s\nExpected Tokens:\n%s", json.dumps(completion), tokens) + return zero_prediction + activations = completion["activations"] + if len(activations) != len(tokens): + logger.error("Tokens and activations length did not match:\n%s\nExpected Tokens:\n%s", json.dumps(completion), tokens) + return zero_prediction + predicted_activations = [] + # check that there is a token and activation value + # no need to double check the token matches exactly + for i, activation in enumerate(activations): + if "token" not in activation: + logger.error("The key 'token' is not in activation:\n%s\nCompletion:%s\nExpected Tokens:\n%s", activation, json.dumps(completion), tokens) + predicted_activations.append(0) + continue + if "activation" not in activation: + logger.error("The key 'activation' is not in activation:\n%s\nCompletion:%s\nExpected Tokens:\n%s", activation, json.dumps(completion), tokens) + predicted_activations.append(0) + continue + # Ensure activation value is between 0-10 inclusive + try: + predicted_activation_float = float(activation["activation"]) + if predicted_activation_float < 0 or predicted_activation_float > MAX_NORMALIZED_ACTIVATION: + logger.error("activation value out of range: %s\nCompletion:%s\nExpected Tokens:\n%s", predicted_activation_float, json.dumps(completion), tokens) + predicted_activations.append(0) + else: + predicted_activations.append(predicted_activation_float) + except ValueError: + logger.error("activation value invalid: %s\nCompletion:%s\nExpected Tokens:\n%s", activation["activation"], json.dumps(completion), tokens) + predicted_activations.append(0) + except TypeError: + logger.error("activation value incorrect type: %s\nCompletion:%s\nExpected Tokens:\n%s", activation["activation"], json.dumps(completion), tokens) + predicted_activations.append(0) + logger.debug("predicted activations: %s", predicted_activations) + return predicted_activations + + except json.JSONDecodeError: + logger.error("Failed to parse completion JSON:\n%s\nExpected Tokens:\n%s", completion, tokens) + return zero_prediction def _parse_no_logprobs_completion( completion: str, tokens: Sequence[str], -) -> Sequence[int]: +) -> Sequence[float]: """ Parse a completion into a list of simulated activations. If the model did not faithfully reproduce the token sequence, return a list of 0s. If the model's activation for a token - is not an integer betwee 0 and 10, substitute 0. + is not a number between 0 and 10 (inclusive), substitute 0. Args: completion: completion from the API tokens: list of tokens as strings in the sequence where the neuron is being simulated """ + + logger.debug("for tokens:\n%s", tokens) + logger.debug("received completion:\n%s", completion) + zero_prediction = [0] * len(tokens) - token_lines = completion.strip("\n").split("༗\n") + # FIX: Strip the last ༗\n, otherwise all last activations are invalid + token_lines = completion.strip("\n").strip("༗\n").split("༗\n") + # Edge Case #2: Sometimes GPT doesn't use the special character when it answers, it only uses the \n" + # The fix is to try splitting by \n if we detect that the response isn't the right format + # TODO: If there are also line breaks in the text, this will probably break + if (len(token_lines)) == 1: + token_lines = completion.strip("\n").strip("༗\n").split("\n") + logger.debug("parsed completion into token_lines as:\n%s", token_lines) + start_line_index = None for i, token_line in enumerate(token_lines): - if token_line.startswith(f"{tokens[0]}\t"): + if (token_line.startswith(f"{tokens[0]}\t") + # Edge Case #1: GPT often omits the space before the first token. + # Allow the returned token line to be either " token" or "token". + or f" {token_line}".startswith(f"{tokens[0]}\t") + # Edge Case #3: Allow our "not end token" replacement + or (token_line.startswith(END_OF_TEXT_TOKEN_REPLACEMENT) and tokens[0].strip() == END_OF_TEXT_TOKEN) + ): + logger.debug("start_line_index is: %s", start_line_index) + logger.debug("matched token %s with token_line %s", tokens[0], token_line) start_line_index = i break # If we didn't find the first token, or if the number of lines in the completion doesn't match # the number of tokens, return a list of 0s. if start_line_index is None or len(token_lines) - start_line_index != len(tokens): + logger.debug("didn't find first token or number of lines didn't match, returning all zeroes") return zero_prediction + predicted_activations = [] for i, token_line in enumerate(token_lines[start_line_index:]): - if not token_line.startswith(f"{tokens[i]}\t"): + if (not token_line.startswith(f"{tokens[i]}\t") + # Edge Case #1: GPT often omits the space before the token. + # Allow the returned token line to be either " token" or "token". + and not f" {token_line}".startswith(f"{tokens[i]}\t") + # Edge Case #3: Allow our "not end token" replacement + and not token_line.startswith(END_OF_TEXT_TOKEN_REPLACEMENT) + ): + logger.debug("failed to match token %s with token_line %s, returning all zeroes", tokens[i], token_line) return zero_prediction - predicted_activation = token_line.split("\t")[1] - if predicted_activation not in VALID_ACTIVATION_TOKENS: + predicted_activation_split = token_line.split("\t") + # Ensure token line has correct size after splitting. If not then assume it's a zero. + if len(predicted_activation_split) != 2: + logger.debug("tokenline split invalid size: %s", token_line) predicted_activations.append(0) - else: - predicted_activations.append(int(predicted_activation)) + continue + predicted_activation = predicted_activation_split[1] + # Sometimes GPT the activation value is not a float (GPT likes to append an extra ༗). + # In all cases if the activation is not numerically parseable, set it to 0 + try: + predicted_activation_float = float(predicted_activation) + if predicted_activation_float < 0 or predicted_activation_float > MAX_NORMALIZED_ACTIVATION: + logger.debug("activation value out of range: %s", predicted_activation_float) + predicted_activations.append(0) + else: + predicted_activations.append(predicted_activation_float) + except ValueError: + logger.debug("activation value not numeric: %s", predicted_activation) + predicted_activations.append(0) + logger.debug("predicted activations: %s", predicted_activations) return predicted_activations - class LogprobFreeExplanationTokenSimulator(NeuronSimulator): """ Simulate neuron behavior based on an explanation. @@ -695,6 +825,7 @@ def __init__( model_name: str, explanation: str, max_concurrent: Optional[int] = 10, + json_mode: Optional[bool] = True, few_shot_example_set: FewShotExampleSet = FewShotExampleSet.NEWER, prompt_format: PromptFormat = PromptFormat.HARMONY_V4, cache: bool = False, @@ -705,6 +836,7 @@ def __init__( self.api_client = ApiClient( model_name=model_name, max_concurrent=max_concurrent, cache=cache ) + self.json_mode = json_mode self.explanation = explanation self.few_shot_example_set = few_shot_example_set self.prompt_format = prompt_format @@ -713,24 +845,30 @@ async def simulate( self, tokens: Sequence[str], ) -> SequenceSimulation: - prompt = self._make_simulation_prompt( - tokens, - self.explanation, - ) - response = await self.api_client.make_request( - prompt=prompt, echo=False, max_tokens=1000 - ) - assert len(response["choices"]) == 1 - - choice = response["choices"][0] - if self.prompt_format == PromptFormat.HARMONY_V4: + if self.json_mode: + prompt = self._make_simulation_prompt_json( + tokens, + self.explanation, + ) + response = await self.api_client.make_request( + messages=prompt, max_tokens=2000, temperature=0, json_mode=True + ) + assert len(response["choices"]) == 1 + choice = response["choices"][0] completion = choice["message"]["content"] - elif self.prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]: - completion = choice["text"] + predicted_activations = _parse_no_logprobs_completion_json(completion, tokens) else: - raise ValueError(f"Unhandled prompt format {self.prompt_format}") - - predicted_activations = _parse_no_logprobs_completion(completion, tokens) + prompt = self._make_simulation_prompt( + tokens, + self.explanation, + ) + response = await self.api_client.make_request( + messages=prompt, max_tokens=1000, temperature=0 + ) + assert len(response["choices"]) == 1 + choice = response["choices"][0] + completion = choice["message"]["content"] + predicted_activations = _parse_no_logprobs_completion(completion, tokens) result = SequenceSimulation( activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, @@ -743,6 +881,77 @@ async def simulate( logger.debug("result in score_explanation_by_activations is %s", result) return result + def _make_simulation_prompt_json( + self, + tokens: Sequence[str], + explanation: str, + ) -> Union[str, list[HarmonyMessage]]: + """Make a few-shot prompt for predicting the neuron's activations on a sequence.""" + """NOTE: The JSON version does not give GPT multiple sequence examples per neuron.""" + assert explanation != "" + prompt_builder = PromptBuilder() + prompt_builder.add_message( + Role.SYSTEM, + """We're studying neurons in a neural network. Each neuron looks for certain things in a short document. Your task is to read the explanation of what the neuron does, and predict the neuron's activations for each token in the document. + +For each document, you will see the full text of the document, then the tokens in the document with the activation left blank. You will print, in valid json, the exact same tokens verbatim, but with the activation values filled in according to the explanation. Pay special attention to the explanation's description of the context and order of tokens or words. + +Fill out the activation values from 0 to 10. Please think carefully."; +""", + ) + + few_shot_examples = self.few_shot_example_set.get_examples() + for example in few_shot_examples: + """ + { + "to_find": "hello", + "document": "The", + "activations": [ + { + "token": "The", + "activation": null + } + ] + } + """ + prompt_builder.add_message( + Role.USER, + _format_record_for_logprob_free_simulation_json(explanation=example.explanation, activation_record=example.activation_records[0], include_activations=False) + ) + """ + { + "to_find": "hello", + "document": "The", + "activations": [ + { + "token": "The", + "activation": 10 + } + ] + } + """ + prompt_builder.add_message( + Role.ASSISTANT, + _format_record_for_logprob_free_simulation_json(explanation=example.explanation, activation_record=example.activation_records[0], include_activations=True) + ) + """ + { + "to_find": "hello", + "document": "The", + "activations": [ + { + "token": "The", + "activation": null + } + ] + } + """ + prompt_builder.add_message( + Role.USER, + _format_record_for_logprob_free_simulation_json(explanation=explanation, activation_record=ActivationRecord(tokens=tokens, activations=[]), include_activations=False) + ) + return prompt_builder.build(self.prompt_format, allow_extra_system_messages=True) + def _make_simulation_prompt( self, tokens: Sequence[str], @@ -753,7 +962,7 @@ def _make_simulation_prompt( prompt_builder = PromptBuilder(allow_extra_system_messages=True) prompt_builder.add_message( Role.SYSTEM, - """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. + """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. The activation format is tokenactivation, and activations range from 0 to 10. Most activations will be 0. For each sequence, you will see the tokens in the sequence where the activations are left blank. You will print the exact same tokens verbatim, but with the activations filled in according to the explanation.