Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 224 additions & 98 deletions neuron-explainer/demos/generate_and_score_explanation.ipynb
Original file line number Diff line number Diff line change
@@ -1,101 +1,227 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"\n",
"os.environ[\"NEURON_EXPLAINER_API_KEY\"] = \"EMPTY\"\n",
"os.environ[\"NEURON_EXPLAINER_API_BASE\"] = \"http://localhost:8000/v1\" # Paste here your vLLM API base URL, with /v1 at the end"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "uj0EX7BdBeo6"
},
"outputs": [],
"source": [
"from neuron_explainer.activations.activation_records import calculate_max_activation\n",
"from neuron_explainer.activations.activations import ActivationRecordSliceParams, load_neuron\n",
"from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator\n",
"from neuron_explainer.explanations.explainer import TokenActivationPairExplainer\n",
"from neuron_explainer.explanations.prompt_builder import PromptFormat\n",
"from neuron_explainer.explanations.scoring import simulate_and_score\n",
"from neuron_explainer.explanations.simulator import ExplanationNeuronSimulator, ExplanationTokenByTokenSimulator"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "s3cxrE6MBhdn"
},
"outputs": [],
"source": [
"EXPLAINER_MODEL_NAME = \"Qwen/Qwen3-Coder-30B-A3B-Instruct\"\n",
"SIMULATOR_MODEL_NAME = \"Qwen/Qwen3-Coder-30B-A3B-Instruct\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "OGBG65ecDEW4"
},
"outputs": [],
"source": [
"from neuron_explainer.api_client import ApiClient"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "syncikDxDIib"
},
"outputs": [],
"source": [
"client = ApiClient(model_name=\"Qwen/Qwen3-Coder-30B-A3B-Instruct\", max_concurrent=1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4JZBbRnXCldP",
"outputId": "554dc778-e35e-4eaf-c83d-511cd9d043e9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Response: {'role': 'assistant', 'content': 'Nothing much', 'refusal': None, 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': [], 'reasoning_content': None}\n"
]
}
],
"source": [
"test_response = await client.make_request(messages=[{\"role\": \"user\", \"content\": \"What's up?\"}], max_tokens=2)\n",
"print(\"Response:\", test_response[\"choices\"][0][\"message\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "cCYOJ4kZCjvc"
},
"outputs": [],
"source": [
"# Load a neuron record.\n",
"neuron_record = load_neuron(9, 6236)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "aSnM2HneGAkY"
},
"outputs": [],
"source": [
"# Grab the activation records we'll need.\n",
"slice_params = ActivationRecordSliceParams(n_examples_per_split=5)\n",
"train_activation_records = neuron_record.train_activation_records(\n",
" activation_record_slice_params=slice_params\n",
")\n",
"valid_activation_records = neuron_record.valid_activation_records(\n",
" activation_record_slice_params=slice_params\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "7sX5EvHvGSE9"
},
"outputs": [],
"source": [
"# Generate an explanation for the neuron.\n",
"explainer = TokenActivationPairExplainer(\n",
" model_name=EXPLAINER_MODEL_NAME,\n",
" prompt_format=PromptFormat.HARMONY_V4,\n",
" max_concurrent=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2dSp0pPzGZdS",
"outputId": "32333e37-d923-41aa-9d07-8b228cc903e8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"explanation=' phrases indicating repetition or multiple occurrences.'\n"
]
}
],
"source": [
"explanations = await explainer.generate_explanations(\n",
" all_activation_records=train_activation_records,\n",
" max_activation=calculate_max_activation(train_activation_records),\n",
" num_samples=1,\n",
")\n",
"assert len(explanations) == 1\n",
"explanation = explanations[0]\n",
"print(f\"{explanation=}\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 339
},
"id": "ANFzVlEKGhv6",
"outputId": "1e689aef-667d-4cdc-df25-ac6b046ca6f1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"score=0.08\n"
]
}
],
"source": [
"# Simulate and score the explanation.\n",
"simulator = UncalibratedNeuronSimulator(\n",
" ExplanationNeuronSimulator(\n",
" SIMULATOR_MODEL_NAME,\n",
" explanation,\n",
" max_concurrent=1,\n",
" prompt_format=PromptFormat.INSTRUCTION_FOLLOWING,\n",
" )\n",
")\n",
"scored_simulation = await simulate_and_score(simulator, valid_activation_records)\n",
"print(f\"score={scored_simulation.get_preferred_score():.2f}\")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "neuron-explainer",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"put-key-here\"\n",
"\n",
"from neuron_explainer.activations.activation_records import calculate_max_activation\n",
"from neuron_explainer.activations.activations import ActivationRecordSliceParams, load_neuron\n",
"from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator\n",
"from neuron_explainer.explanations.explainer import TokenActivationPairExplainer\n",
"from neuron_explainer.explanations.prompt_builder import PromptFormat\n",
"from neuron_explainer.explanations.scoring import simulate_and_score\n",
"from neuron_explainer.explanations.simulator import ExplanationNeuronSimulator\n",
"\n",
"EXPLAINER_MODEL_NAME = \"gpt-4\"\n",
"SIMULATOR_MODEL_NAME = \"text-davinci-003\"\n",
"\n",
"\n",
"# test_response = await client.make_request(prompt=\"test 123<|endofprompt|>\", max_tokens=2)\n",
"# print(\"Response:\", test_response[\"choices\"][0][\"text\"])\n",
"\n",
"# Load a neuron record.\n",
"neuron_record = load_neuron(9, 6236)\n",
"\n",
"# Grab the activation records we'll need.\n",
"slice_params = ActivationRecordSliceParams(n_examples_per_split=5)\n",
"train_activation_records = neuron_record.train_activation_records(\n",
" activation_record_slice_params=slice_params\n",
")\n",
"valid_activation_records = neuron_record.valid_activation_records(\n",
" activation_record_slice_params=slice_params\n",
")\n",
"\n",
"# Generate an explanation for the neuron.\n",
"explainer = TokenActivationPairExplainer(\n",
" model_name=EXPLAINER_MODEL_NAME,\n",
" prompt_format=PromptFormat.HARMONY_V4,\n",
" max_concurrent=1,\n",
")\n",
"explanations = await explainer.generate_explanations(\n",
" all_activation_records=train_activation_records,\n",
" max_activation=calculate_max_activation(train_activation_records),\n",
" num_samples=1,\n",
")\n",
"assert len(explanations) == 1\n",
"explanation = explanations[0]\n",
"print(f\"{explanation=}\")\n",
"\n",
"# Simulate and score the explanation.\n",
"simulator = UncalibratedNeuronSimulator(\n",
" ExplanationNeuronSimulator(\n",
" SIMULATOR_MODEL_NAME,\n",
" explanation,\n",
" max_concurrent=1,\n",
" prompt_format=PromptFormat.INSTRUCTION_FOLLOWING,\n",
" )\n",
")\n",
"scored_simulation = await simulate_and_score(simulator, valid_activation_records)\n",
"print(f\"score={scored_simulation.get_preferred_score():.2f}\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat": 4,
"nbformat_minor": 0
}
13 changes: 6 additions & 7 deletions neuron-explainer/neuron_explainer/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,12 @@ async def f_retry(*args: Any, **kwargs: Any) -> None:
return decorate


API_KEY = os.getenv("OPENAI_API_KEY")
assert API_KEY, "Please set the OPENAI_API_KEY environment variable"
API_HTTP_HEADERS = {
"Content-Type": "application/json",
"Authorization": "Bearer " + API_KEY,
}
BASE_API_URL = "https://api.openai.com/v1"
API_KEY = os.getenv("NEURON_EXPLAINER_API_KEY", os.getenv("OPENAI_API_KEY", ""))
BASE_API_URL = os.getenv("NEURON_EXPLAINER_API_BASE", "https://api.openai.com/v1")

API_HTTP_HEADERS = {"Content-Type": "application/json"}
if API_KEY:
API_HTTP_HEADERS["Authorization"] = f"Bearer {API_KEY}"


class ApiClient:
Expand Down
9 changes: 0 additions & 9 deletions neuron-explainer/neuron_explainer/explanations/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def from_int(cls, i: int) -> ContextSize:
raise ValueError(f"{i} is not a valid ContextSize")


HARMONY_V4_MODELS = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]


class NeuronExplainer(ABC):
"""
Abstract base class for Explainer classes that generate explanations from subclass-specific
Expand All @@ -82,12 +79,6 @@ def __init__(
max_concurrent: Optional[int] = 10,
cache: bool = False,
):
if prompt_format == PromptFormat.HARMONY_V4:
assert model_name in HARMONY_V4_MODELS
elif prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]:
assert model_name not in HARMONY_V4_MODELS
else:
raise ValueError(f"Unhandled prompt format {prompt_format}")

self.model_name = model_name
self.prompt_format = prompt_format
Expand Down