|
1 | 1 | use crate::flash_attn::flash_attn_varlen;
|
2 | 2 | use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
|
3 | 3 | use crate::models::{Model, Qwen3Config};
|
4 |
| -use candle::{DType, Device, IndexOp, Result, Tensor}; |
| 4 | +use candle::{DType, Device, IndexOp, Result, Tensor, D}; |
5 | 5 | use candle_nn::{Embedding, Module, VarBuilder};
|
6 | 6 | use candle_rotary::apply_rotary_inplace;
|
7 | 7 | use text_embeddings_backend_core::{Batch, ModelType, Pool};
|
@@ -592,10 +592,13 @@ impl Model for FlashQwen3Model {
|
592 | 592 |
|
593 | 593 | let h_last = Tensor::stack(&last_hidden_states, 0)?; // [bs, hidden_size]
|
594 | 594 |
|
595 |
| - let true_id = 9693u32; |
596 |
| - let false_id = 2152u32; |
| 595 | + // Correct token IDs for Qwen3 (verified from tokenizer) |
| 596 | + let yes_id = 9454u32; // "yes" token ID |
| 597 | + let no_id = 2901u32; // "no" token ID |
597 | 598 |
|
598 |
| - let ids = Tensor::from_vec(vec![false_id, true_id], 2, &self.device)?; |
| 599 | + tracing::debug!("Using Qwen3 token IDs - yes: {}, no: {}", yes_id, no_id); |
| 600 | + |
| 601 | + let ids = Tensor::from_vec(vec![no_id, yes_id], 2, &self.device)?; |
599 | 602 | let w = self.lm_head_weight.index_select(&ids, 0)?; // [2, hidden_size]
|
600 | 603 | let logits = h_last.matmul(&w.t()?)?; // [bs, 2] (no, yes)
|
601 | 604 | let log_probs = candle_nn::ops::log_softmax(&logits, D::Minus1)?;
|
|
0 commit comments