Skip to content
146 changes: 134 additions & 12 deletions pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
get_embedding_matrix,
get_embeddings,
)
from pyrit.auxiliary_attacks.gcg.default_implementations import (
CrossEntropyLoss,
LengthPreservingFilter,
StandardGCGSampling,
)
from pyrit.auxiliary_attacks.gcg.extension_protocols import CandidateFilter, LossFunction, SamplingStrategy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,6 +131,99 @@ def sample_control(
class GCGMultiPromptAttack(MultiPromptAttack):
"""GCG-specific multi-prompt attack that implements the GCG optimization step."""

def __init__(
self,
goals: list[str],
targets: list[str],
workers: list[Any],
control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
Comment thread
romanlutz marked this conversation as resolved.
test_prefixes: list[str] | None = None,
logfile: str | None = None,
managers: dict[str, Any] | None = None,
test_goals: list[str] | None = None,
test_targets: list[str] | None = None,
test_workers: list[Any] | None = None,
*,
sampling: SamplingStrategy | None = None,
loss: LossFunction | None = None,
candidate_filter: CandidateFilter | None = None,
) -> None:
super().__init__(
goals,
targets,
workers,
control_init,
test_prefixes,
logfile,
managers,
test_goals,
test_targets,
test_workers,
)
self._sampling = sampling
self._loss = loss
self._candidate_filter = candidate_filter

def _resolve_sampling(self) -> SamplingStrategy:
sampling = getattr(self, "_sampling", None)
if sampling is not None:
return sampling
return StandardGCGSampling()

def _resolve_loss(self, *, target_weight: float, control_weight: float) -> LossFunction:
loss = getattr(self, "_loss", None)
if loss is not None:
return loss
return CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight)

def _resolve_candidate_filter(self, *, filter_cand: bool) -> CandidateFilter:
candidate_filter = getattr(self, "_candidate_filter", None)
if candidate_filter is not None:
return candidate_filter
return LengthPreservingFilter(filter=filter_cand)

def _sample_control_candidates(
self,
*,
worker_index: int,
gradient: torch.Tensor,
batch_size: int,
topk: int,
temp: float,
allow_non_ascii: bool,
) -> torch.Tensor:
sampler = self._resolve_sampling()
prompt_manager = self.prompts[worker_index]
return sampler.sample_candidates(
gradient=gradient,
control_tokens=prompt_manager.control_toks,
batch_size=batch_size,
top_k=topk,
temperature=temp,
allow_non_ascii=allow_non_ascii,
non_ascii_tokens=prompt_manager.disallowed_toks,
)

def _filter_control_candidates(
self,
*,
worker_index: int,
control_cand: torch.Tensor,
filter_cand: bool,
) -> list[str]:
candidate_filter = self._resolve_candidate_filter(filter_cand=filter_cand)
return candidate_filter.filter_candidates(
candidate_tokens=control_cand,
tokenizer=self.workers[worker_index].tokenizer,
current_control=self.control_str,
)

def _get_control_length(self, *, control: str) -> int | None:
try:
return len(self.workers[0].tokenizer(control).input_ids[1:])
except (AttributeError, TypeError, ValueError):
return None

def step(
self,
*,
Expand Down Expand Up @@ -158,6 +257,7 @@ def step(
"""
main_device = self.models[0].device
control_cands = []
loss_function = self._resolve_loss(target_weight=target_weight, control_weight=control_weight)

for j, worker in enumerate(self.workers):
worker(self.prompts[j], "grad", worker.model)
Expand All @@ -171,20 +271,40 @@ def step(
grad = torch.zeros_like(new_grad)
if grad.shape != new_grad.shape:
with torch.no_grad():
control_cand = self.prompts[j - 1].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
control_cand = self._sample_control_candidates(
worker_index=j - 1,
gradient=grad,
batch_size=batch_size,
topk=topk,
temp=temp,
allow_non_ascii=allow_non_ascii,
)
control_cands.append(
self.get_filtered_cands(
j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str
self._filter_control_candidates(
worker_index=j - 1,
control_cand=control_cand,
filter_cand=filter_cand,
)
)
grad = new_grad
else:
grad += new_grad

with torch.no_grad():
control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
control_cand = self._sample_control_candidates(
worker_index=j,
gradient=grad,
batch_size=batch_size,
topk=topk,
temp=temp,
allow_non_ascii=allow_non_ascii,
)
control_cands.append(
self.get_filtered_cands(j, control_cand, filter_cand=filter_cand, curr_control=self.control_str)
self._filter_control_candidates(
worker_index=j,
control_cand=control_cand,
filter_cand=filter_cand,
)
)
del grad, control_cand
gc.collect()
Expand All @@ -205,14 +325,14 @@ def step(
worker(self.prompts[k][i], "logits", worker.model, cand, return_ids=True)
logits, ids = zip(*[worker.results.get() for worker in self.workers])
loss[j * batch_size : (j + 1) * batch_size] += sum(
target_weight * self.prompts[k][i].target_loss(logit, id).mean(dim=-1).to(main_device)
loss_function.compute_loss(
logits=logit,
token_ids=id,
target_slice=self.prompts[k][i]._target_slice,
control_slice=self.prompts[k][i]._control_slice,
).to(main_device)
for k, (logit, id) in enumerate(zip(logits, ids))
)
if control_weight != 0:
loss[j * batch_size : (j + 1) * batch_size] += sum(
control_weight * self.prompts[k][i].control_loss(logit, id).mean(dim=-1).to(main_device)
for k, (logit, id) in enumerate(zip(logits, ids))
)
del logits, ids
gc.collect()

Expand All @@ -229,7 +349,9 @@ def step(
del control_cands, loss
gc.collect()

logger.info(f"Current length: {len(self.workers[0].tokenizer(next_control).input_ids[1:])}")
current_length = self._get_control_length(control=next_control)
if current_length is not None:
logger.info(f"Current length: {current_length}")
logger.info(next_control)

return next_control, cand_loss.item() / len(self.prompts[0]) / len(self.workers)
44 changes: 44 additions & 0 deletions pyrit/auxiliary_attacks/gcg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
if TYPE_CHECKING:
from pathlib import Path

from pyrit.auxiliary_attacks.gcg.extension_protocols import (
CandidateFilter,
LossFunction,
SamplingStrategy,
SuffixInitializer,
)

_DEFAULT_CONTROL_INIT: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"


Expand Down Expand Up @@ -147,6 +154,18 @@ class GCGAlgorithmConfig:
random_seed (int): Seed for ``torch``/``numpy``/``random``. Defaults to 42.
control_init (str): Initial suffix string the optimization starts from.
Defaults to twenty space-separated ``!`` tokens.
sampling (SamplingStrategy | None): Optional strategy object that
samples candidate suffix token sequences from the aggregated
gradient. ``None`` uses the built-in default implementation.
loss (LossFunction | None): Optional loss object used to score each
candidate suffix. ``None`` uses the built-in weighted
cross-entropy default that preserves legacy behavior.
candidate_filter (CandidateFilter | None): Optional candidate-filter
object that decodes/prunes sampled candidate token sequences.
``None`` uses the built-in length-preserving filter.
suffix_init (SuffixInitializer | None): Optional initializer object
that produces the initial suffix string at attack construction
time. ``None`` uses ``control_init`` verbatim.
"""

n_steps: int = 500
Expand All @@ -161,6 +180,10 @@ class GCGAlgorithmConfig:
filter_cand: bool = True
random_seed: int = 42
control_init: str = _DEFAULT_CONTROL_INIT
sampling: SamplingStrategy | None = None
loss: LossFunction | None = None
candidate_filter: CandidateFilter | None = None
suffix_init: SuffixInitializer | None = None

def __post_init__(self) -> None:
if self.n_steps <= 0:
Expand All @@ -183,6 +206,27 @@ def __post_init__(self) -> None:
)
if not self.control_init:
raise ValueError("GCGAlgorithmConfig.control_init must be a non-empty string.")
self._validate_extensions()

def _validate_extensions(self) -> None:
from pyrit.auxiliary_attacks.gcg.extension_protocols import (
CandidateFilter,
LossFunction,
SamplingStrategy,
SuffixInitializer,
)

checks = (
("sampling", self.sampling, SamplingStrategy),
("loss", self.loss, LossFunction),
("candidate_filter", self.candidate_filter, CandidateFilter),
("suffix_init", self.suffix_init, SuffixInitializer),
)
for field_name, value, protocol in checks:
if value is not None and not isinstance(value, protocol):
raise ValueError(
f"GCGAlgorithmConfig.{field_name} must satisfy {protocol.__name__}, got {type(value)!r}."
)


@dataclass
Expand Down
11 changes: 5 additions & 6 deletions pyrit/auxiliary_attacks/gcg/extension_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
- ``SuffixInitializer`` — how the initial suffix string fed into the
optimization loop is constructed.

The module is **typing surface only**. It ships no concrete implementations,
no defaults, and no wiring into ``GCGAlgorithmConfig`` or
``GCGMultiPromptAttack``. The default behaviors that match the current attack
code will land as concrete classes in a follow-up PR; the optional
``GCGAlgorithmConfig`` fields that select between defaults and custom
implementations will land in the PR after that.
The module is **typing surface only**. Concrete defaults live in
``default_implementations.py``, and orchestration wiring lives in
``GCGAlgorithmConfig`` + ``GCGMultiPromptAttack``. Keeping this module purely
protocol definitions preserves a stable extension API that can be imported
without pulling in heavy runtime dependencies.

Tensor-typed signatures are kept lazy via ``from __future__ import
annotations`` plus a ``TYPE_CHECKING`` import for ``torch`` so that
Expand Down
37 changes: 34 additions & 3 deletions pyrit/auxiliary_attacks/gcg/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import logging
import time
from dataclasses import dataclass, field
from functools import partial
from typing import Any, overload

import numpy as np
Expand Down Expand Up @@ -212,6 +213,18 @@ def _build_identifier(self) -> ComponentIdentifier:
"topk": self._algorithm.topk,
"target_weight": self._algorithm.target_weight,
"control_weight": self._algorithm.control_weight,
"sampling_impl": (
type(self._algorithm.sampling).__name__ if self._algorithm.sampling is not None else "default"
),
"loss_impl": type(self._algorithm.loss).__name__ if self._algorithm.loss is not None else "default",
"candidate_filter_impl": (
type(self._algorithm.candidate_filter).__name__
if self._algorithm.candidate_filter is not None
else "default"
),
"suffix_init_impl": (
type(self._algorithm.suffix_init).__name__ if self._algorithm.suffix_init is not None else "default"
),
"transfer": self._strategy.transfer,
"progressive_goals": self._strategy.progressive_goals,
"progressive_models": self._strategy.progressive_models,
Expand Down Expand Up @@ -257,7 +270,12 @@ async def _perform_async(self, *, context: GCGContext) -> GCGResult:
managers = {
"AP": attack_lib.GCGAttackPrompt,
"PM": attack_lib.GCGPromptManager,
"MPA": attack_lib.GCGMultiPromptAttack,
"MPA": partial(
attack_lib.GCGMultiPromptAttack,
sampling=self._algorithm.sampling,
loss=self._algorithm.loss,
candidate_filter=self._algorithm.candidate_filter,
),
}
context.attack = self._create_attack(
params=params,
Expand Down Expand Up @@ -400,14 +418,15 @@ def _create_attack(
logfile_path: str,
) -> Any:
"""Build the right attack object based on the strategy flags."""
control_init = self._resolve_control_init(workers=workers)
if self._strategy.transfer:
return ProgressiveMultiPromptAttack(
train_goals,
train_targets,
workers,
progressive_models=self._strategy.progressive_models,
progressive_goals=self._strategy.progressive_goals,
control_init=self._algorithm.control_init,
control_init=control_init,
logfile=logfile_path,
managers=managers,
test_goals=test_goals,
Expand All @@ -421,7 +440,7 @@ def _create_attack(
train_goals,
train_targets,
workers,
control_init=self._algorithm.control_init,
control_init=control_init,
logfile=logfile_path,
managers=managers,
test_goals=test_goals,
Expand All @@ -432,6 +451,18 @@ def _create_attack(
mpa_n_steps=self._algorithm.n_steps,
)

def _resolve_control_init(self, *, workers: list[Any]) -> str:
"""Resolve the initial suffix string for a run.

Uses the configured ``suffix_init`` extension when provided; otherwise
falls back to the legacy literal ``control_init`` value.
"""
if self._algorithm.suffix_init is None:
return self._algorithm.control_init
if not workers:
raise ValueError("Cannot resolve suffix_init without at least one worker tokenizer.")
return self._algorithm.suffix_init.make_initial_suffix(tokenizer=workers[0].tokenizer)

@staticmethod
def _read_result(*, logfile_path: str, memory_labels: dict[str, str]) -> GCGResult:
"""Pull final-step values out of the JSON log written during the run."""
Expand Down
Loading
Loading