diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index 4df1ae9205..ea0677527e 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -17,6 +17,12 @@ get_embedding_matrix, get_embeddings, ) +from pyrit.auxiliary_attacks.gcg.default_implementations import ( + CrossEntropyLoss, + LengthPreservingFilter, + StandardGCGSampling, +) +from pyrit.auxiliary_attacks.gcg.extension_protocols import CandidateFilter, LossFunction, SamplingStrategy logger = logging.getLogger(__name__) @@ -125,6 +131,99 @@ def sample_control( class GCGMultiPromptAttack(MultiPromptAttack): """GCG-specific multi-prompt attack that implements the GCG optimization step.""" + def __init__( + self, + goals: list[str], + targets: list[str], + workers: list[Any], + control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[Any] | None = None, + *, + sampling: SamplingStrategy | None = None, + loss: LossFunction | None = None, + candidate_filter: CandidateFilter | None = None, + ) -> None: + super().__init__( + goals, + targets, + workers, + control_init, + test_prefixes, + logfile, + managers, + test_goals, + test_targets, + test_workers, + ) + self._sampling = sampling + self._loss = loss + self._candidate_filter = candidate_filter + + def _resolve_sampling(self) -> SamplingStrategy: + sampling = getattr(self, "_sampling", None) + if sampling is not None: + return sampling + return StandardGCGSampling() + + def _resolve_loss(self, *, target_weight: float, control_weight: float) -> LossFunction: + loss = getattr(self, "_loss", None) + if loss is not None: + return loss + return CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight) + + def _resolve_candidate_filter(self, *, filter_cand: bool) -> CandidateFilter: + candidate_filter = getattr(self, "_candidate_filter", None) + if candidate_filter is not None: + return candidate_filter + return LengthPreservingFilter(filter=filter_cand) + + def _sample_control_candidates( + self, + *, + worker_index: int, + gradient: torch.Tensor, + batch_size: int, + topk: int, + temp: float, + allow_non_ascii: bool, + ) -> torch.Tensor: + sampler = self._resolve_sampling() + prompt_manager = self.prompts[worker_index] + return sampler.sample_candidates( + gradient=gradient, + control_tokens=prompt_manager.control_toks, + batch_size=batch_size, + top_k=topk, + temperature=temp, + allow_non_ascii=allow_non_ascii, + non_ascii_tokens=prompt_manager.disallowed_toks, + ) + + def _filter_control_candidates( + self, + *, + worker_index: int, + control_cand: torch.Tensor, + filter_cand: bool, + ) -> list[str]: + candidate_filter = self._resolve_candidate_filter(filter_cand=filter_cand) + return candidate_filter.filter_candidates( + candidate_tokens=control_cand, + tokenizer=self.workers[worker_index].tokenizer, + current_control=self.control_str, + ) + + def _get_control_length(self, *, control: str) -> int | None: + try: + return len(self.workers[0].tokenizer(control).input_ids[1:]) + except (AttributeError, TypeError, ValueError): + return None + def step( self, *, @@ -158,6 +257,7 @@ def step( """ main_device = self.models[0].device control_cands = [] + loss_function = self._resolve_loss(target_weight=target_weight, control_weight=control_weight) for j, worker in enumerate(self.workers): worker(self.prompts[j], "grad", worker.model) @@ -171,10 +271,19 @@ def step( grad = torch.zeros_like(new_grad) if grad.shape != new_grad.shape: with torch.no_grad(): - control_cand = self.prompts[j - 1].sample_control(grad, batch_size, topk, temp, allow_non_ascii) + control_cand = self._sample_control_candidates( + worker_index=j - 1, + gradient=grad, + batch_size=batch_size, + topk=topk, + temp=temp, + allow_non_ascii=allow_non_ascii, + ) control_cands.append( - self.get_filtered_cands( - j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str + self._filter_control_candidates( + worker_index=j - 1, + control_cand=control_cand, + filter_cand=filter_cand, ) ) grad = new_grad @@ -182,9 +291,20 @@ def step( grad += new_grad with torch.no_grad(): - control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii) + control_cand = self._sample_control_candidates( + worker_index=j, + gradient=grad, + batch_size=batch_size, + topk=topk, + temp=temp, + allow_non_ascii=allow_non_ascii, + ) control_cands.append( - self.get_filtered_cands(j, control_cand, filter_cand=filter_cand, curr_control=self.control_str) + self._filter_control_candidates( + worker_index=j, + control_cand=control_cand, + filter_cand=filter_cand, + ) ) del grad, control_cand gc.collect() @@ -205,14 +325,14 @@ def step( worker(self.prompts[k][i], "logits", worker.model, cand, return_ids=True) logits, ids = zip(*[worker.results.get() for worker in self.workers]) loss[j * batch_size : (j + 1) * batch_size] += sum( - target_weight * self.prompts[k][i].target_loss(logit, id).mean(dim=-1).to(main_device) + loss_function.compute_loss( + logits=logit, + token_ids=id, + target_slice=self.prompts[k][i]._target_slice, + control_slice=self.prompts[k][i]._control_slice, + ).to(main_device) for k, (logit, id) in enumerate(zip(logits, ids)) ) - if control_weight != 0: - loss[j * batch_size : (j + 1) * batch_size] += sum( - control_weight * self.prompts[k][i].control_loss(logit, id).mean(dim=-1).to(main_device) - for k, (logit, id) in enumerate(zip(logits, ids)) - ) del logits, ids gc.collect() @@ -229,7 +349,9 @@ def step( del control_cands, loss gc.collect() - logger.info(f"Current length: {len(self.workers[0].tokenizer(next_control).input_ids[1:])}") + current_length = self._get_control_length(control=next_control) + if current_length is not None: + logger.info(f"Current length: {current_length}") logger.info(next_control) return next_control, cand_loss.item() / len(self.prompts[0]) / len(self.workers) diff --git a/pyrit/auxiliary_attacks/gcg/config.py b/pyrit/auxiliary_attacks/gcg/config.py index 097a9087af..c2debada6e 100644 --- a/pyrit/auxiliary_attacks/gcg/config.py +++ b/pyrit/auxiliary_attacks/gcg/config.py @@ -25,6 +25,13 @@ if TYPE_CHECKING: from pathlib import Path + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) + _DEFAULT_CONTROL_INIT: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !" @@ -147,6 +154,18 @@ class GCGAlgorithmConfig: random_seed (int): Seed for ``torch``/``numpy``/``random``. Defaults to 42. control_init (str): Initial suffix string the optimization starts from. Defaults to twenty space-separated ``!`` tokens. + sampling (SamplingStrategy | None): Optional strategy object that + samples candidate suffix token sequences from the aggregated + gradient. ``None`` uses the built-in default implementation. + loss (LossFunction | None): Optional loss object used to score each + candidate suffix. ``None`` uses the built-in weighted + cross-entropy default that preserves legacy behavior. + candidate_filter (CandidateFilter | None): Optional candidate-filter + object that decodes/prunes sampled candidate token sequences. + ``None`` uses the built-in length-preserving filter. + suffix_init (SuffixInitializer | None): Optional initializer object + that produces the initial suffix string at attack construction + time. ``None`` uses ``control_init`` verbatim. """ n_steps: int = 500 @@ -161,6 +180,10 @@ class GCGAlgorithmConfig: filter_cand: bool = True random_seed: int = 42 control_init: str = _DEFAULT_CONTROL_INIT + sampling: SamplingStrategy | None = None + loss: LossFunction | None = None + candidate_filter: CandidateFilter | None = None + suffix_init: SuffixInitializer | None = None def __post_init__(self) -> None: if self.n_steps <= 0: @@ -183,6 +206,27 @@ def __post_init__(self) -> None: ) if not self.control_init: raise ValueError("GCGAlgorithmConfig.control_init must be a non-empty string.") + self._validate_extensions() + + def _validate_extensions(self) -> None: + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) + + checks = ( + ("sampling", self.sampling, SamplingStrategy), + ("loss", self.loss, LossFunction), + ("candidate_filter", self.candidate_filter, CandidateFilter), + ("suffix_init", self.suffix_init, SuffixInitializer), + ) + for field_name, value, protocol in checks: + if value is not None and not isinstance(value, protocol): + raise ValueError( + f"GCGAlgorithmConfig.{field_name} must satisfy {protocol.__name__}, got {type(value)!r}." + ) @dataclass diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py index f9f1a3013e..973fb22a2b 100644 --- a/pyrit/auxiliary_attacks/gcg/extension_protocols.py +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -16,12 +16,11 @@ - ``SuffixInitializer`` — how the initial suffix string fed into the optimization loop is constructed. -The module is **typing surface only**. It ships no concrete implementations, -no defaults, and no wiring into ``GCGAlgorithmConfig`` or -``GCGMultiPromptAttack``. The default behaviors that match the current attack -code will land as concrete classes in a follow-up PR; the optional -``GCGAlgorithmConfig`` fields that select between defaults and custom -implementations will land in the PR after that. +The module is **typing surface only**. Concrete defaults live in +``default_implementations.py``, and orchestration wiring lives in +``GCGAlgorithmConfig`` + ``GCGMultiPromptAttack``. Keeping this module purely +protocol definitions preserves a stable extension API that can be imported +without pulling in heavy runtime dependencies. Tensor-typed signatures are kept lazy via ``from __future__ import annotations`` plus a ``TYPE_CHECKING`` import for ``torch`` so that diff --git a/pyrit/auxiliary_attacks/gcg/generator.py b/pyrit/auxiliary_attacks/gcg/generator.py index 4c812594e9..12ef46040c 100644 --- a/pyrit/auxiliary_attacks/gcg/generator.py +++ b/pyrit/auxiliary_attacks/gcg/generator.py @@ -38,6 +38,7 @@ import logging import time from dataclasses import dataclass, field +from functools import partial from typing import Any, overload import numpy as np @@ -212,6 +213,18 @@ def _build_identifier(self) -> ComponentIdentifier: "topk": self._algorithm.topk, "target_weight": self._algorithm.target_weight, "control_weight": self._algorithm.control_weight, + "sampling_impl": ( + type(self._algorithm.sampling).__name__ if self._algorithm.sampling is not None else "default" + ), + "loss_impl": type(self._algorithm.loss).__name__ if self._algorithm.loss is not None else "default", + "candidate_filter_impl": ( + type(self._algorithm.candidate_filter).__name__ + if self._algorithm.candidate_filter is not None + else "default" + ), + "suffix_init_impl": ( + type(self._algorithm.suffix_init).__name__ if self._algorithm.suffix_init is not None else "default" + ), "transfer": self._strategy.transfer, "progressive_goals": self._strategy.progressive_goals, "progressive_models": self._strategy.progressive_models, @@ -257,7 +270,12 @@ async def _perform_async(self, *, context: GCGContext) -> GCGResult: managers = { "AP": attack_lib.GCGAttackPrompt, "PM": attack_lib.GCGPromptManager, - "MPA": attack_lib.GCGMultiPromptAttack, + "MPA": partial( + attack_lib.GCGMultiPromptAttack, + sampling=self._algorithm.sampling, + loss=self._algorithm.loss, + candidate_filter=self._algorithm.candidate_filter, + ), } context.attack = self._create_attack( params=params, @@ -400,6 +418,7 @@ def _create_attack( logfile_path: str, ) -> Any: """Build the right attack object based on the strategy flags.""" + control_init = self._resolve_control_init(workers=workers) if self._strategy.transfer: return ProgressiveMultiPromptAttack( train_goals, @@ -407,7 +426,7 @@ def _create_attack( workers, progressive_models=self._strategy.progressive_models, progressive_goals=self._strategy.progressive_goals, - control_init=self._algorithm.control_init, + control_init=control_init, logfile=logfile_path, managers=managers, test_goals=test_goals, @@ -421,7 +440,7 @@ def _create_attack( train_goals, train_targets, workers, - control_init=self._algorithm.control_init, + control_init=control_init, logfile=logfile_path, managers=managers, test_goals=test_goals, @@ -432,6 +451,18 @@ def _create_attack( mpa_n_steps=self._algorithm.n_steps, ) + def _resolve_control_init(self, *, workers: list[Any]) -> str: + """Resolve the initial suffix string for a run. + + Uses the configured ``suffix_init`` extension when provided; otherwise + falls back to the legacy literal ``control_init`` value. + """ + if self._algorithm.suffix_init is None: + return self._algorithm.control_init + if not workers: + raise ValueError("Cannot resolve suffix_init without at least one worker tokenizer.") + return self._algorithm.suffix_init.make_initial_suffix(tokenizer=workers[0].tokenizer) + @staticmethod def _read_result(*, logfile_path: str, memory_labels: dict[str, str]) -> GCGResult: """Pull final-step values out of the JSON log written during the run.""" diff --git a/tests/unit/auxiliary_attacks/gcg/test_config.py b/tests/unit/auxiliary_attacks/gcg/test_config.py index da0a7f6a9a..922b0ffedd 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_config.py +++ b/tests/unit/auxiliary_attacks/gcg/test_config.py @@ -9,7 +9,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest @@ -28,6 +28,49 @@ _LLAMA_2 = "meta-llama/Llama-2-7b-chat-hf" +class _SamplingStub: + def sample_candidates( + self, + *, + gradient: Any, + control_tokens: Any, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: Any, + ) -> Any: + return control_tokens + + +class _LossStub: + def compute_loss( + self, + *, + logits: Any, + token_ids: Any, + target_slice: slice, + control_slice: slice, + ) -> Any: + return logits + + +class _FilterStub: + def filter_candidates( + self, + *, + candidate_tokens: Any, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return [current_control] + + +class _SuffixInitStub: + def make_initial_suffix(self, *, tokenizer: Any) -> str: + return "stub suffix" + + def _minimal_config() -> GCGConfig: return GCGConfig(models=[GCGModelConfig(name=_LLAMA_2)]) @@ -42,6 +85,10 @@ def test_minimal_config_constructs_with_defaults() -> None: assert config.test_models == [] assert config.algorithm.n_steps == 500 assert config.algorithm.batch_size == 512 + assert config.algorithm.sampling is None + assert config.algorithm.loss is None + assert config.algorithm.candidate_filter is None + assert config.algorithm.suffix_init is None assert config.strategy.transfer is False assert config.output.verbose is True assert config.hf_token is None @@ -100,6 +147,33 @@ def test_algorithm_empty_control_init_raises() -> None: GCGAlgorithmConfig(control_init="") +@pytest.mark.parametrize( + "field_name,value", + [ + ("sampling", object()), + ("loss", object()), + ("candidate_filter", object()), + ("suffix_init", object()), + ], +) +def test_algorithm_extension_type_validation(field_name: str, value: object) -> None: + with pytest.raises(ValueError, match=rf"GCGAlgorithmConfig\.{field_name} must satisfy"): + GCGAlgorithmConfig(**{field_name: value}) + + +def test_algorithm_accepts_protocol_implementations() -> None: + config = GCGAlgorithmConfig( + sampling=_SamplingStub(), + loss=_LossStub(), + candidate_filter=_FilterStub(), + suffix_init=_SuffixInitStub(), + ) + assert config.sampling is not None + assert config.loss is not None + assert config.candidate_filter is not None + assert config.suffix_init is not None + + @pytest.mark.parametrize("field_name", ["n_train_data", "n_test_data"]) def test_data_negative_count_raises(field_name: str) -> None: with pytest.raises(ValueError, match=f"GCGDataConfig.{field_name} must be >= 0"): diff --git a/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py b/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py index c3858bf357..a90563ed85 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py +++ b/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import MagicMock +from typing import Any +from unittest.mock import MagicMock, patch import pytest @@ -25,6 +26,7 @@ "pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack", reason="GCG optional dependencies not installed", ) +GCGMultiPromptAttack = gcg_attack_mod.GCGMultiPromptAttack GCGPromptManager = gcg_attack_mod.GCGPromptManager token_gradients = gcg_attack_mod.token_gradients @@ -501,3 +503,466 @@ def test_raises_when_tokenizer_has_no_chat_template(self) -> None: with patch.object(attack_manager_mod.AutoTokenizer, "from_pretrained", return_value=bare_tokenizer): with pytest.raises(ValueError, match="no chat_template configured"): get_workers(params) + + +class _Queue: + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def get(self) -> Any: + return self._items.pop(0) + + +class _WorkerStub: + def __init__( + self, + *, + gradient: torch.Tensor, + logits: torch.Tensor, + token_ids: torch.Tensor, + tokenizer: MagicMock, + ) -> None: + self.model = MagicMock() + self.model.device = "cpu" + self.tokenizer = tokenizer + self.results = _Queue([gradient, (logits, token_ids)]) + self.calls: list[tuple] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> None: + self.calls.append((args, kwargs)) + + +class _PromptManagerStub: + def __init__( + self, + *, + prompt: AttackPrompt, + control_tokens: torch.Tensor, + disallowed_tokens: torch.Tensor, + control_str: str, + ) -> None: + self._prompts = [prompt] + self._control_tokens = control_tokens + self._disallowed_tokens = disallowed_tokens + self.control_str = control_str + + def __len__(self) -> int: + return len(self._prompts) + + def __getitem__(self, i: int) -> AttackPrompt: + return self._prompts[i] + + @property + def control_toks(self) -> torch.Tensor: + return self._control_tokens + + @property + def disallowed_toks(self) -> torch.Tensor: + return self._disallowed_tokens + + +class _SpySampling: + def __init__(self, *, sampled_tokens: torch.Tensor) -> None: + self.sampled_tokens = sampled_tokens + self.calls: list[dict] = [] + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + self.calls.append( + { + "gradient": gradient.clone(), + "control_tokens": control_tokens.clone(), + "batch_size": batch_size, + "top_k": top_k, + "temperature": temperature, + "allow_non_ascii": allow_non_ascii, + "non_ascii_tokens": non_ascii_tokens.clone(), + } + ) + return self.sampled_tokens.clone() + + +class _SpyLoss: + def __init__(self, *, losses: torch.Tensor) -> None: + self.losses = losses + self.calls: list[dict] = [] + + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + self.calls.append( + { + "logits": logits.clone(), + "token_ids": token_ids.clone(), + "target_slice": target_slice, + "control_slice": control_slice, + } + ) + return self.losses.to(logits.device) + + +class _SpyFilter: + def __init__(self, *, candidates: list[str]) -> None: + self.candidates = list(candidates) + self.calls: list[dict] = [] + + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: MagicMock, + current_control: str, + ) -> list[str]: + self.calls.append( + { + "candidate_tokens": candidate_tokens.clone(), + "tokenizer": tokenizer, + "current_control": current_control, + } + ) + return list(self.candidates) + + +class TestGCGMultiPromptAttackStepWiring: + @staticmethod + def _make_tokenizer() -> MagicMock: + tokenizer = MagicMock() + tokenizer.vocab_size = 100 + + def decode_fn(ids, **_kwargs): + values = ids.tolist() if hasattr(ids, "tolist") else list(ids) + return " ".join(str(int(v)) for v in values) + + def call_fn(text, **_kwargs): + output = MagicMock() + if text == "!": + output.input_ids = [0] + else: + output.input_ids = [int(piece) for piece in text.split()] if text else [] + return output + + tokenizer.decode.side_effect = decode_fn + tokenizer.side_effect = call_fn + return tokenizer + + @staticmethod + def _make_prompt(*, target_slice: slice, control_slice: slice) -> AttackPrompt: + prompt = object.__new__(AttackPrompt) + prompt._target_slice = target_slice + prompt._control_slice = control_slice + return prompt + + @staticmethod + def _make_attack( + *, + worker: _WorkerStub, + prompt_manager: _PromptManagerStub, + sampling: object | None = None, + loss: object | None = None, + candidate_filter: object | None = None, + ) -> GCGMultiPromptAttack: + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker] + attack.models = [worker.model] + attack.prompts = [prompt_manager] + attack._sampling = sampling + attack._loss = loss + attack._candidate_filter = candidate_filter + return attack + + def test_step_default_path_matches_legacy_behavior(self) -> None: + gradient = torch.tensor( + [ + [0.3, -0.4, 0.8, -0.2, 0.1, 0.5], + [-0.3, 0.2, -0.8, 0.4, 0.1, 0.7], + [0.2, 0.6, -0.1, -0.5, 0.4, -0.2], + ], + dtype=torch.float32, + ) + logits = torch.randn(1, 8, 10) + token_ids = torch.randint(0, 10, (1, 8)) + control_tokens = torch.tensor([1, 2, 3], dtype=torch.long) + disallowed_tokens = torch.tensor([], dtype=torch.long) + target_slice = slice(4, 6) + control_slice = slice(1, 4) + current_control = "99 99 99" + tokenizer = self._make_tokenizer() + + worker = _WorkerStub(gradient=gradient.clone(), logits=logits, token_ids=token_ids, tokenizer=tokenizer) + prompt = self._make_prompt(target_slice=target_slice, control_slice=control_slice) + prompt_manager = _PromptManagerStub( + prompt=prompt, + control_tokens=control_tokens, + disallowed_tokens=disallowed_tokens, + control_str=current_control, + ) + attack = self._make_attack(worker=worker, prompt_manager=prompt_manager) + + target_weight = 1.3 + control_weight = 0.2 + torch.manual_seed(2026) + actual_control, actual_loss = attack.step( + batch_size=1, + topk=3, + temp=1.0, + allow_non_ascii=True, + target_weight=target_weight, + control_weight=control_weight, + verbose=True, + filter_cand=True, + ) + + legacy_prompt_manager = object.__new__(GCGPromptManager) + legacy_prompt_for_sampling = MagicMock() + legacy_prompt_for_sampling.control_toks = control_tokens.clone() + legacy_prompt_manager._prompts = [legacy_prompt_for_sampling] + legacy_prompt_manager._nonascii_toks = disallowed_tokens + + legacy_attack = object.__new__(MultiPromptAttack) + legacy_worker = MagicMock() + legacy_worker.tokenizer = tokenizer + legacy_attack.workers = [legacy_worker] + + legacy_prompt_for_loss = self._make_prompt(target_slice=target_slice, control_slice=control_slice) + normalized_gradient = gradient / gradient.norm(dim=-1, keepdim=True) + torch.manual_seed(2026) + legacy_control_cand = legacy_prompt_manager.sample_control( + normalized_gradient.clone(), + 1, + topk=3, + temp=1.0, + allow_non_ascii=True, + ) + legacy_controls = legacy_attack.get_filtered_cands( + 0, + legacy_control_cand, + filter_cand=True, + curr_control=current_control, + ) + legacy_loss = target_weight * legacy_prompt_for_loss.target_loss(logits, token_ids).mean( + dim=-1 + ) + control_weight * legacy_prompt_for_loss.control_loss(logits, token_ids).mean(dim=-1) + + assert actual_control == legacy_controls[0] + assert actual_loss == pytest.approx(legacy_loss[0].item()) + + def test_step_uses_custom_protocol_implementations_when_supplied(self) -> None: + gradient = torch.randn(3, 6) + logits = torch.randn(2, 8, 10) + token_ids = torch.randint(0, 10, (2, 8)) + control_tokens = torch.tensor([1, 2, 3], dtype=torch.long) + disallowed_tokens = torch.tensor([5], dtype=torch.long) + tokenizer = self._make_tokenizer() + + worker = _WorkerStub(gradient=gradient.clone(), logits=logits, token_ids=token_ids, tokenizer=tokenizer) + prompt = self._make_prompt(target_slice=slice(4, 6), control_slice=slice(1, 4)) + prompt_manager = _PromptManagerStub( + prompt=prompt, + control_tokens=control_tokens, + disallowed_tokens=disallowed_tokens, + control_str="current control", + ) + + sampled_tokens = torch.tensor([[8, 8, 8], [9, 9, 9]], dtype=torch.long) + sampling = _SpySampling(sampled_tokens=sampled_tokens) + candidate_filter = _SpyFilter(candidates=["candidate-A", "candidate-B"]) + custom_losses = torch.tensor([3.0, 0.5], dtype=torch.float32) + loss = _SpyLoss(losses=custom_losses) + attack = self._make_attack( + worker=worker, + prompt_manager=prompt_manager, + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ) + + selected_control, normalized_loss = attack.step( + batch_size=2, + topk=4, + temp=0.8, + allow_non_ascii=False, + target_weight=0.0, + control_weight=1.0, + verbose=True, + filter_cand=True, + ) + + assert selected_control == "candidate-B" + assert normalized_loss == pytest.approx(0.5) + assert len(sampling.calls) == 1 + assert len(candidate_filter.calls) == 1 + assert len(loss.calls) == 1 + assert sampling.calls[0]["batch_size"] == 2 + assert sampling.calls[0]["top_k"] == 4 + assert sampling.calls[0]["allow_non_ascii"] is False + assert candidate_filter.calls[0]["current_control"] == "current control" + + def test_gcg_multi_prompt_attack_init_with_custom_protocols(self) -> None: + """Test GCGMultiPromptAttack.__init__ stores custom sampling/loss/filter.""" + sampling = _SpySampling(sampled_tokens=torch.tensor([[1, 2, 3]])) + loss = _SpyLoss(losses=torch.tensor([1.0])) + candidate_filter = _SpyFilter(candidates=["filtered"]) + workers = [MagicMock()] + + with patch.object(MultiPromptAttack, "__init__", return_value=None) as mock_base_init: + attack = GCGMultiPromptAttack( + goals=["goal"], + targets=["target"], + workers=workers, + control_init="seed control", + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ) + + assert mock_base_init.call_count == 1 + assert mock_base_init.call_args.args[:4] == (["goal"], ["target"], workers, "seed control") + + assert attack._sampling is sampling + assert attack._loss is loss + assert attack._candidate_filter is candidate_filter + + def test_step_aggregates_workers_when_grad_shapes_mismatch(self) -> None: + """Test step handles a worker gradient shape mismatch by sampling per group.""" + tokenizer = self._make_tokenizer() + prompt = self._make_prompt(target_slice=slice(0, 1), control_slice=slice(0, 1)) + prompt_manager1 = _PromptManagerStub( + prompt=prompt, + control_tokens=torch.tensor([1], dtype=torch.long), + disallowed_tokens=torch.tensor([], dtype=torch.long), + control_str="seed", + ) + prompt_manager2 = _PromptManagerStub( + prompt=prompt, + control_tokens=torch.tensor([1], dtype=torch.long), + disallowed_tokens=torch.tensor([], dtype=torch.long), + control_str="seed", + ) + + grad1 = torch.tensor([[0.1, 0.2, 0.3]], dtype=torch.float32) + grad2 = torch.tensor([[0.4, 0.5, 0.6, 0.7]], dtype=torch.float32) + logits = torch.randn(1, 8, 10) + token_ids = torch.randint(0, 10, (1, 8)) + worker1 = _WorkerStub(gradient=grad1, logits=logits, token_ids=token_ids, tokenizer=tokenizer) + worker2 = _WorkerStub(gradient=grad2, logits=logits, token_ids=token_ids, tokenizer=tokenizer) + worker1.results = _Queue([grad1, (logits, token_ids), (logits, token_ids)]) + worker2.results = _Queue([grad2, (logits, token_ids), (logits, token_ids)]) + + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker1, worker2] + attack.models = [worker1.model] + attack.prompts = [prompt_manager1, prompt_manager2] + attack.control_str = "seed" + + class _ConstantLoss: + @staticmethod + def compute_loss( + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + return torch.tensor([0.5], dtype=torch.float32) + + with ( + patch.object( + attack, + "_sample_control_candidates", + return_value=torch.tensor([[1, 2, 3]], dtype=torch.long), + ) as mock_sample, + patch.object(attack, "_filter_control_candidates", return_value=["candidate"]), + patch.object(attack, "_resolve_loss", return_value=_ConstantLoss()), + patch.object(attack, "_get_control_length", return_value=None), + ): + control, normalized_loss = attack.step( + batch_size=1, + topk=2, + temp=1.0, + allow_non_ascii=True, + target_weight=1.0, + control_weight=0.1, + verbose=True, + filter_cand=True, + ) + + assert control == "candidate" + assert normalized_loss == pytest.approx(0.5) + assert mock_sample.call_count == 2 + assert mock_sample.call_args_list[0].kwargs["worker_index"] == 0 + assert mock_sample.call_args_list[1].kwargs["worker_index"] == 1 + + def test_resolve_methods_return_defaults_when_none(self) -> None: + """Test _resolve_* methods return defaults when custom protocols are None.""" + worker = _WorkerStub( + gradient=torch.tensor([[0.1]]), + logits=torch.randn(1, 8, 10), + token_ids=torch.randint(0, 10, (1, 8)), + tokenizer=self._make_tokenizer(), + ) + prompt_manager = _PromptManagerStub( + prompt=self._make_prompt(target_slice=slice(0, 1), control_slice=slice(0, 1)), + control_tokens=torch.tensor([1]), + disallowed_tokens=torch.tensor([]), + control_str="test", + ) + + attack = self._make_attack(worker=worker, prompt_manager=prompt_manager) + + # Test _resolve_sampling returns default + sampler = attack._resolve_sampling() + assert sampler is not None + + # Test _resolve_loss returns default + loss_func = attack._resolve_loss(target_weight=1.0, control_weight=0.1) + assert loss_func is not None + + # Test _resolve_candidate_filter returns default + filter_func = attack._resolve_candidate_filter(filter_cand=True) + assert filter_func is not None + + def test_get_control_length_success(self) -> None: + """Test _get_control_length returns token count after dropping the first token.""" + tokenizer = self._make_tokenizer() + worker = _WorkerStub( + gradient=torch.tensor([[0.1]]), + logits=torch.randn(1, 8, 10), + token_ids=torch.randint(0, 10, (1, 8)), + tokenizer=tokenizer, + ) + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker] + + length = attack._get_control_length(control="1 2 3") + assert length == 2 + + def test_get_control_length_handles_error(self) -> None: + """Test _get_control_length returns None on tokenizer error.""" + tokenizer = MagicMock() + tokenizer.side_effect = ValueError("Tokenizer error") + + worker = _WorkerStub( + gradient=torch.tensor([[0.1]]), + logits=torch.randn(1, 8, 10), + token_ids=torch.randint(0, 10, (1, 8)), + tokenizer=tokenizer, + ) + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker] + + length = attack._get_control_length(control="test") + assert length is None diff --git a/tests/unit/auxiliary_attacks/gcg/test_generator.py b/tests/unit/auxiliary_attacks/gcg/test_generator.py index f410aa5079..956dcc4953 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_generator.py +++ b/tests/unit/auxiliary_attacks/gcg/test_generator.py @@ -6,8 +6,9 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch +from functools import partial +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -238,6 +239,137 @@ def test_augmentation_modifies_at_least_some_targets(self) -> None: assert num_changed > 0 +class TestExtensionWiring: + def test_create_attack_uses_suffix_initializer_when_configured(self) -> None: + class _SuffixInitStub: + def __init__(self) -> None: + self.calls: list[object] = [] + + def make_initial_suffix(self, *, tokenizer: object) -> str: + self.calls.append(tokenizer) + return "initialized suffix" + + suffix_init = _SuffixInitStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig(suffix_init=suffix_init), + ) + worker = MagicMock() + worker.tokenizer = MagicMock() + + with patch.object(generator_mod, "IndividualPromptAttack") as mock_individual: + gen._create_attack( + params=MagicMock(), + managers={"MPA": MagicMock()}, + train_goals=["g"], + train_targets=["t"], + test_goals=[], + test_targets=[], + workers=[worker], + test_workers=[], + logfile_path="out.json", + ) + + assert suffix_init.calls == [worker.tokenizer] + assert mock_individual.call_args.kwargs["control_init"] == "initialized suffix" + + def test_resolve_control_init_returns_default_when_suffix_init_not_configured(self) -> None: + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig(control_init="seed control"), + ) + + assert gen._resolve_control_init(workers=[]) == "seed control" + + def test_resolve_control_init_raises_when_suffix_init_requires_workers(self) -> None: + """Test _resolve_control_init raises ValueError when suffix_init configured but no workers.""" + + class _SuffixInitStub: + def make_initial_suffix(self, *, tokenizer: object) -> str: + return "initialized suffix" + + suffix_init = _SuffixInitStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig(suffix_init=suffix_init), + ) + + with pytest.raises(ValueError, match="Cannot resolve suffix_init without at least one worker"): + gen._resolve_control_init(workers=[]) + + async def test_perform_async_binds_algorithm_extensions_into_mpa_factory(self, tmp_path: Path) -> None: + class _SamplingStub: + def sample_candidates( + self, + *, + gradient: Any, + control_tokens: Any, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: Any, + ) -> Any: + return control_tokens + + class _LossStub: + def compute_loss( + self, + *, + logits: Any, + token_ids: Any, + target_slice: slice, + control_slice: slice, + ) -> Any: + return logits + + class _FilterStub: + def filter_candidates( + self, + *, + candidate_tokens: Any, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return [current_control] + + sampling = _SamplingStub() + loss = _LossStub() + candidate_filter = _FilterStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig( + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ), + output=GCGOutputConfig(result_prefix=str(tmp_path / "gcg")), + ) + context = GCGContext( + goals=["g"], + targets=["t"], + workers=[MagicMock()], + test_workers=[], + ) + fake_attack = MagicMock() + + with ( + patch.object(gen, "_create_attack", return_value=fake_attack) as mock_create_attack, + patch.object(gen, "_build_logfile_path", return_value=str(tmp_path / "result.json")), + patch.object(gen, "_read_result", return_value=GCGResult(final_suffix="x")), + patch("pyrit.auxiliary_attacks.gcg.generator.asyncio.to_thread", new=AsyncMock(return_value=None)), + ): + await gen._perform_async(context=context) + + managers = mock_create_attack.call_args.kwargs["managers"] + mpa_factory = managers["MPA"] + assert isinstance(mpa_factory, partial) + assert mpa_factory.func is generator_mod.attack_lib.GCGMultiPromptAttack + assert mpa_factory.keywords["sampling"] is sampling + assert mpa_factory.keywords["loss"] is loss + assert mpa_factory.keywords["candidate_filter"] is candidate_filter + + class TestReadResult: def test_reads_final_suffix_and_loss(self, tmp_path: Path) -> None: log_path = tmp_path / "result.json"