Skip to content

Conversation

@tasercake
Copy link

Hello! I wanted to place a hard cap on the number of tokens used for reasoning (inside a <think>...</think> block from models like Qwen3) because the models would often spend way too long thinking.

So this a thing I built for my own use and figured there might be some interest in having this upstream as well. Happy to iterate on this if needed :)

Here's a summary of changes:

  • Add thinking_budget parameter to limit tokens in <think>...</think> blocks
  • Force-inserts </think> when budget is exceeded to end reasoning phase
  • Configurable think tags via thinking_start_token and thinking_end_token
  • CLI support via --thinking-budget, --thinking-start-token, --thinking-end-token

Usage

# Limit thinking to 100 tokens
mlx_vlm generate --model <model> --prompt "..." --thinking-budget 100

# Custom thinking tokens (for models with different tags)
mlx_vlm generate --model <model> --prompt "..." --thinking-budget 100 \
    --thinking-start-token "<reasoning>" --thinking-end-token "</reasoning>"

Testing

Added a TestThinkingBudget class to cover a few cases.

@Blaizzy
Copy link
Owner

Blaizzy commented Dec 20, 2025

Hey @tasercake
This is really cool, I love the idea!

One question, is there a way to find the thiking tokens automatically? If not, it's fine we can go ahead and merge.

@Blaizzy
Copy link
Owner

Blaizzy commented Dec 20, 2025

Could you add support to the batch_gen API as well?

@tasercake
Copy link
Author

tasercake commented Dec 21, 2025

Thanks for taking a look @Blaizzy! I'll see if I can get it working with batch_gen as well.

One question, is there a way to find the thiking tokens automatically?

Unfortunately it looks like there is no standardized API for reasoning token IDs. qwen3-vl defines them in its tokenizer config while deepseek does not appear to.
I suppose I could add some sort of heuristic, but I suspect that would be brittle. And I feel it might be best to do this in a separate PR anyway. Let me know your thoughts :)

@tasercake
Copy link
Author

@Blaizzy I've taken a crack at supporting this in the batch_generate API as well (and also fixed a couple of minor issues with the returned logprobs) - let me know what you think? Thanks!

@Blaizzy
Copy link
Owner

Blaizzy commented Jan 2, 2026

Unfortunately it looks like there is no standardized API for reasoning token IDs. qwen3-vl defines them in its tokenizer config while deepseek does not appear to.
I suppose I could add some sort of heuristic, but I suspect that would be brittle. And I feel it might be best to do this in a separate PR anyway. Let me know your thoughts :)

Happy new year!

Agreed, we can address it in a future PR :)

@Blaizzy
Copy link
Owner

Blaizzy commented Jan 2, 2026

Hey @tasercake,

I really like the thinking budget concept, but the current approach feels a bit verbose.
This issue reminds me of how stoppingCriteria works, which got me thinking—what if we formulated it in a similar way?
Here's the gist:

1. Create a new ThinkingBudgetCriteria class

Here we could also search the tokenizer for the thinking tokens.
https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Thinking/blob/main/tokenizer_config.json#L197-L212

# Add to mlx_vlm/utils. py after the existing StoppingCriteria class

class ThinkingBudgetCriteria(StoppingCriteria):
    """
    Stopping criteria that enforces a budget on thinking tokens.
    
    Tracks tokens within thinking blocks (between start and end tokens) and
    stops generation or forces the end token when budget is exceeded.
    """
    
    def __init__(
        self,
        eos_token_ids: List[int],
        thinking_budget: int,
        thinking_end_token_id: int,
        thinking_start_token_id: Optional[int] = None,
        tokenizer=None
    ):
        super().__init__(eos_token_ids, tokenizer)
        
        self.thinking_budget = thinking_budget
        self.thinking_start_token_id = thinking_start_token_id
        self. thinking_end_token_id = thinking_end_token_id
        
        # State tracking
        self.in_thinking = thinking_start_token_id is None  # Start immediately if no start token
        self.thinking_token_count = 0
        self.budget_exceeded = False
    
    def reset_thinking_state(self):
        """Reset thinking state between generations."""
        self.in_thinking = self.thinking_start_token_id is None
        self.thinking_token_count = 0
        self.budget_exceeded = False
    
    def should_force_end_token(self) -> bool:
        """Check if we should force the thinking end token."""
        return self. budget_exceeded
    
    def get_forced_token_id(self) -> Optional[int]:
        """Get the token ID to force (if any)."""
        if self.budget_exceeded:
            return self.thinking_end_token_id
        return None
    
    def process_token(self, token_id: int):
        """
        Process a generated token and update thinking state.
        Should be called before checking if generation should stop.
        """
        if self.thinking_start_token_id is not None and token_id == self.thinking_start_token_id:
            self.in_thinking = True
            return
        
        if token_id == self.thinking_end_token_id:
            self.in_thinking = False
            self. budget_exceeded = False  # Reset for potential next thinking block
            return
        
        if self.in_thinking:
            self.thinking_token_count += 1
            if self.thinking_token_count > self.thinking_budget:
                self.budget_exceeded = True
    
    def __call__(self, input_ids: mx.array) -> bool:
        """Check if generation should stop."""
        token_id = input_ids. item() if isinstance(input_ids, mx.array) else input_ids
        
        # Process the token to update thinking state
        self.process_token(token_id)
        
        # (Optional because preferably we should keep them separate ⚠️) Check standard EOS conditions
        return super().__call__(input_ids)

2. Modify generate_step to use the criteria

# In mlx_vlm/generate. py, modify generate_step function

def generate_step(
    prompt: mx.array,
    model: nn.Module,
    pixel_values: Optional[mx.array] = None,
    mask: Optional[mx.array] = None,
    temp: float = 0.0,
    top_p: float = 1.0,
    logit_bias: Optional[Dict[int, float]] = None,
    repetition_penalty: Optional[float] = None,
    repetition_context_size: Optional[int] = 20,
    max_tokens: int = 100,
    prefill_step_size: int = 512,
    kv_bits: Optional[int] = None,
    kv_group_size: int = 64,
    quantized_kv_start: int = 0,
    stopping_criteria: Optional[StoppingCriteria] = None,  # Add this parameter
    **kwargs,
) -> Generator[Tuple[mx.array, mx. array], None, None]:
    """
    ...  existing docstring ...
    """
    
    # ...  existing setup code ...
    
    n = 0
    while True:
        if n != max_tokens:
            next_y, next_logprobs = _step(y, **kwargs)
            
            # Check if we need to force a token (e.g., thinking end token)
            if stopping_criteria and hasattr(stopping_criteria, 'should_force_end_token'):
                if stopping_criteria.should_force_end_token():
                    forced_token_id = stopping_criteria.get_forced_token_id()
                    if forced_token_id is not None:
                        next_y = mx.array([forced_token_id])
                        # Update logprobs to reflect forced token
                        next_logprobs = mx.full(
                            next_logprobs. shape,
                            -float("inf"),
                            dtype=next_logprobs.dtype,
                        )
                        next_logprobs[forced_token_id] = 0.0
            
            mx.async_eval(next_y)
            # ... rest of the existing code ... 

3. Update stream_generate and generate to use the new criteria

# In mlx_vlm/generate.py

def stream_generate(
    model: nn.Module,
    processor,
    prompt: str,
    image:  Union[str, List[str]] = None,
    audio: Union[str, List[str]] = None,
    thinking_budget: Optional[int] = None,
    thinking_start_token:  Optional[str] = None,
    thinking_end_token: str = "</think>",
    **kwargs,
) -> Union[str, Generator[str, None, None]]:
    
    tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
    
    # Create thinking budget criteria if needed
    if thinking_budget is not None:
        thinking_start_token_id = None
        if thinking_start_token is not None:
            thinking_start_token_id = tokenizer.encode(
                thinking_start_token, add_special_tokens=False
            )[-1]
        
        thinking_end_token_id = tokenizer.encode(
            thinking_end_token, add_special_tokens=False
        )[-1]
        
        # Create the special criteria
        from mlx_vlm.utils import ThinkingBudgetCriteria
        
        # Get existing EOS token IDs
        eos_token_ids = tokenizer. stopping_criteria.eos_token_ids
        
        # Replace with thinking budget criteria
        thinking_criteria = ThinkingBudgetCriteria(
            eos_token_ids=eos_token_ids,
            thinking_budget=thinking_budget,
            thinking_end_token_id=thinking_end_token_id,
            thinking_start_token_id=thinking_start_token_id,
            tokenizer=tokenizer
        )
        
        # Pass it to kwargs for generate_step
        kwargs['stopping_criteria'] = thinking_criteria
    
    # ... rest of existing code ...

4. Update the CLI arguments handling

# In main() function, the args are already added in your PR, just pass them: 

def main():
    args = parse_arguments()
    # ... existing code ...
    
    for chunk in stream_generate(
        model,
        processor,
        args.prompt,
        args.image,
        args.audio,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        thinking_budget=args.thinking_budget,  # Already added in your PR
        thinking_start_token=args.thinking_start_token,  # Already added
        thinking_end_token=args.thinking_end_token,  # Already added
        **kwargs,
    ):
        response += chunk. text

5. Add tests

# In mlx_vlm/tests/test_utils.py

def test_thinking_budget_criteria():
    """Test thinking budget enforcement."""
    class MockProcessor:
        def __init__(self):
            self.tokenizer = type(
                "DummyTokenizer", (), {"pad_token": None, "eos_token": "</s>"}
            )()
    
    processor = MockProcessor()
    
    # Test with explicit start token
    criteria = ThinkingBudgetCriteria(
        eos_token_ids=[2],
        thinking_budget=5,
        thinking_end_token_id=100,
        thinking_start_token_id=99,
        tokenizer=processor
    )
    
    # Should not be in thinking initially
    assert criteria.in_thinking == False
    assert criteria.thinking_token_count == 0
    
    # Process start token
    criteria.process_token(99)
    assert criteria. in_thinking == True
    
    # Process thinking tokens
    for i in range(5):
        criteria.process_token(50 + i)
    
    assert criteria.thinking_token_count == 5
    assert criteria.should_force_end_token() == False
    
    # One more token should exceed budget
    criteria.process_token(60)
    assert criteria.should_force_end_token() == True
    assert criteria.get_forced_token_id() == 100


def test_thinking_budget_criteria_implicit_start():
    """Test thinking budget with implicit start (no start token)."""
    criteria = ThinkingBudgetCriteria(
        eos_token_ids=[2],
        thinking_budget=3,
        thinking_end_token_id=100,
        thinking_start_token_id=None,  # Implicit start
        tokenizer=None
    )
    
    # Should be in thinking immediately
    assert criteria.in_thinking == True
    
    # Process 3 tokens
    for i in range(3):
        criteria.process_token(50 + i)
    
    assert criteria. thinking_token_count == 3
    
    # Next token exceeds budget
    criteria.process_token(60)
    assert criteria.budget_exceeded == True

Benefits of this approach:

  1. ✅ No changes to Batch class - No new fields needed
  2. ✅ Encapsulated logic - All thinking budget logic in one class
  3. ✅ Leverages existing pattern - Uses the same stopping_criteria mechanism
  4. ✅ Easier to test - State and logic isolated in the criteria class
  5. ✅ Composable - Can be combined with other stopping criteria if needed
  6. ✅ Minimal footprint - Requires significantly fewer code changes

tasercake and others added 3 commits January 3, 2026 17:15
This adds a `thinking_budget` parameter to the generate functions that
limits the number of tokens a model can generate inside `<think>...</think>`
blocks. When the budget is exceeded, the `</think>` token is force-inserted
to end the thinking phase.

Features:
- New parameters: `thinking_budget`, `thinking_start_token`, `thinking_end_token`
- Configurable think tags (default: `<think>`/`</think>`)
- CLI support via `--thinking-budget`, `--thinking-start-token`, `--thinking-end-token`
- Works with any reasoning model that uses think tags (Qwen3, DeepSeek-R1, etc.)

The implementation tracks thinking state in the generation loop and enforces
the budget by replacing the sampled token with the end token when exceeded.
* Add thinking budget support to batch generation

Extend the thinking budget enforcement to BatchGenerator to support streaming batch inference with token limits on thinking blocks. Tracks thinking state per sample and forces </think> tokens when budget is exceeded.

* Update logprobs when forcing thinking end token

When the thinking budget is exceeded and the </think> token is forced,
the logprobs are now updated to reflect this forced selection (setting
the forced token's log probability to 0 and all others to -inf).
* Add test and fix mlx array syntax for thinking budget

- Add TestBatchGeneratorThinking test class with thinking budget test
- Fix mlx array syntax: use direct indexing instead of .at[].set()
- Fix Batch.filter/extend to handle None thinking state
- Skip counting the start token in thinking token count
@tasercake
Copy link
Author

@Blaizzy agree that something of this sort would be nice! I'd been thinking of doing something like this, but extending the StoppingCriteria class instead – I think your suggestion is much cleaner though.
I'll take a crack at building this out over the next few days.

@Blaizzy
Copy link
Owner

Blaizzy commented Jan 3, 2026

That's awesome, extending the stop criteria is a great idea too and is simpler!

Looking forward to your changes 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants