Skip to content

Processor saving does not work when multiple tokenizers #41816

@AmitMY

Description

@AmitMY

System Info

  • transformers version: 4.57.1
  • Platform: macOS-26.0.1-arm64-arm-64bit
  • Python version: 3.12.2
  • Huggingface_hub version: 0.34.3
  • Safetensors version: 0.5.3
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0 (NA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

@Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Currently, processors are saved with fixed names:

FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
IMAGE_PROCESSOR_NAME = "preprocessor_config.json"
VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json"
AUDIO_TOKENIZER_NAME = "audio_tokenizer_config.json"
PROCESSOR_NAME = "processor_config.json"
GENERATION_CONFIG_NAME = "generation_config.json"
MODEL_CARD_NAME = "modelcard.json"

That means, that if you create a processor that uses two of the same kind of subprocessors (for example, a byte level tokenizer and a BPE tokenizer, or two image processors, etc), they override eachother, because they use the same file name.

import tempfile

from transformers import ProcessorMixin, AutoTokenizer, PreTrainedTokenizer


class OtherProcessor(ProcessorMixin):
    name = "other-processor"

    attributes = [
        "tokenizer1",
        "tokenizer2",
    ]
    tokenizer1_class = "AutoTokenizer"
    tokenizer2_class = "AutoTokenizer"

    def __init__(self,
                 tokenizer1: PreTrainedTokenizer,
                 tokenizer2: PreTrainedTokenizer
                 ):
        super().__init__(tokenizer1=tokenizer1,
                         tokenizer2=tokenizer2)


tokenizer1 = AutoTokenizer.from_pretrained("google/gemma-3-270m")
tokenizer2 = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B")

processor = OtherProcessor(tokenizer1=tokenizer1,
                           tokenizer2=tokenizer2)

with tempfile.TemporaryDirectory() as temp_dir:
    # Save processor
    processor.save_pretrained(save_directory=temp_dir, push_to_hub=False)
    # Load processor
    new_processor = OtherProcessor.from_pretrained(temp_dir)

assert processor.tokenizer1.__class__ != processor.tokenizer2.__class__ # passes
assert new_processor.tokenizer1.__class__ != new_processor.tokenizer2.__class__ # fails

Expected behavior

You should be able to use multiple processors within a processor.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions