Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion neuron-explainer/neuron_explainer/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(

@exponential_backoff(retry_on=is_api_error)
async def make_request(
self, timeout_seconds: Optional[int] = None, **kwargs: Any
self, timeout_seconds: Optional[int] = None, json_mode: Optional[bool] = False, **kwargs: Any
) -> dict[str, Any]:
if self._cache is not None:
key = orjson.dumps(kwargs)
Expand All @@ -130,6 +130,8 @@ async def make_request(
# endpoint. Otherwise, it should be sent to the /completions endpoint.
url = BASE_API_URL + ("/chat/completions" if "messages" in kwargs else "/completions")
kwargs["model"] = self.model_name
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
response = await http_client.post(url, headers=API_HTTP_HEADERS, json=kwargs)
# The response json has useful information but the exception doesn't include it, so print it
# out then reraise.
Expand Down
179 changes: 179 additions & 0 deletions neuron-explainer/neuron_explainer/explanations/few_shot_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class FewShotExampleSet(Enum):
ORIGINAL = "original"
NEWER = "newer"
TEST = "test"
JL_FINE_TUNED = "jl_fine_tuned"

@classmethod
def from_string(cls, string: str) -> FewShotExampleSet:
Expand All @@ -56,6 +57,8 @@ def get_examples(self) -> list[Example]:
return NEWER_EXAMPLES
elif self is FewShotExampleSet.TEST:
return TEST_EXAMPLES
elif self is FewShotExampleSet.JL_FINE_TUNED:
return JL_FINE_TUNED_EXAMPLES
else:
raise ValueError(f"Unhandled example set: {self}")

Expand Down Expand Up @@ -1038,3 +1041,179 @@ def get_single_token_prediction_example(self) -> Example:
token_index_to_score=18,
explanation="instances of the token 'ate' as part of another word",
)


JL_FINE_TUNED_EXAMPLES = [
Example(
activation_records=[
ActivationRecord(
tokens=[
"The",
" cat",
" jumped",
" on",
" my",
" laptop",
".",
],
activations=[
0,
0,
0,
0,
0,
0,
0
],
),
],
first_revealed_activation_indices=[],
explanation="the word \"laptop\" before the word \"cat\"",
),
Example(
activation_records=[
ActivationRecord(
tokens=[
"The",
" cat",
" jumped",
" on",
" my",
" laptop",
".",
],
activations=[
0,
10,
0,
0,
0,
0,
0
],
),
],
first_revealed_activation_indices=[],
explanation="the word \"cat\" before the word \"laptop\"",
),
Example(
activation_records=[
ActivationRecord(
tokens=[
"I",
" am",
" using",
" a",
" keyboard",
".",
],
activations=[
0,
0,
0,
0,
10,
0
],
),
],
first_revealed_activation_indices=[],
explanation="the word before a period",
),
Example(
activation_records=[
ActivationRecord(
tokens=[
"The",
" sun",
" is",
" shining",
".",
" The",
" clouds",
" are",
" gone",
".",
" Great",
" weather",
"!",
],
activations=[
0,
0,
0,
10,
0,
0,
0,
0,
10,
0,
0,
0,
0
],
),
],
first_revealed_activation_indices=[],
explanation="the word before period",
),
]

NEWER_SINGLE_TOKEN_EXAMPLE = Example(
activation_records=[
ActivationRecord(
tokens=[
"B",
"10",
" ",
"111",
" MON",
"DAY",
",",
" F",
"EB",
"RU",
"ARY",
" ",
"11",
",",
" ",
"201",
"9",
" DON",
"ATE",
"fake higher scoring token", # See below.
],
activations=[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0.37,
# This fake activation makes the previous token's activation normalize to 8, which
# might help address overconfidence in "10" activations for the one-token-at-a-time
# scoring prompt. This value and the associated token don't actually appear anywhere
# in the prompt.
0.45,
],
),
],
first_revealed_activation_indices=[],
token_index_to_score=18,
explanation="instances of the token 'ate' as part of another word",
)
Loading