diff --git a/Cargo.lock b/Cargo.lock index 38760318..ea2ec3ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4527,7 +4527,7 @@ dependencies = [ [[package]] name = "tvm-ffi" version = "0.1.0-alpha.0" -source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=ee8c5626e5309827b00af9e46111fe23b0307308#ee8c5626e5309827b00af9e46111fe23b0307308" +source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=cea927a#cea927a0f919fe16a1bed0feb03eefeae8dd044b" dependencies = [ "paste", "tvm-ffi-macros", @@ -4537,7 +4537,7 @@ dependencies = [ [[package]] name = "tvm-ffi-macros" version = "0.1.0-alpha.0" -source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=ee8c5626e5309827b00af9e46111fe23b0307308#ee8c5626e5309827b00af9e46111fe23b0307308" +source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=cea927a#cea927a0f919fe16a1bed0feb03eefeae8dd044b" dependencies = [ "proc-macro-error", "proc-macro2", @@ -4548,15 +4548,12 @@ dependencies = [ [[package]] name = "tvm-ffi-sys" version = "0.1.0-alpha.0" -source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=ee8c5626e5309827b00af9e46111fe23b0307308#ee8c5626e5309827b00af9e46111fe23b0307308" -dependencies = [ - "cmake", -] +source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=cea927a#cea927a0f919fe16a1bed0feb03eefeae8dd044b" [[package]] name = "tvm-runtime" version = "0.1.0" -source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=ee8c5626e5309827b00af9e46111fe23b0307308#ee8c5626e5309827b00af9e46111fe23b0307308" +source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=cea927a#cea927a0f919fe16a1bed0feb03eefeae8dd044b" dependencies = [ "anyhow", "serde", @@ -4569,7 +4566,7 @@ dependencies = [ [[package]] name = "tvm-runtime-sys" version = "0.1.0" -source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=ee8c5626e5309827b00af9e46111fe23b0307308#ee8c5626e5309827b00af9e46111fe23b0307308" +source = "git+https://github.com/brekkylab/tvm-runtime-rs?rev=cea927a#cea927a0f919fe16a1bed0feb03eefeae8dd044b" dependencies = [ "cmake", "tvm-ffi-sys", diff --git a/Cargo.toml b/Cargo.toml index b3bce2a7..754ff4fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,10 +82,10 @@ urlencoding = "2.1" version-compare = "0.2.1" [target.'cfg(target_os = "macos")'.dependencies] -tvm-runtime = { git = "https://github.com/brekkylab/tvm-runtime-rs", rev = "ee8c5626e5309827b00af9e46111fe23b0307308", features = ["metal"] } +tvm-runtime = { git = "https://github.com/brekkylab/tvm-runtime-rs", rev = "cea927a", features = ["metal"] } [target.'cfg(any(target_os = "linux", target_os = "windows"))'.dependencies] -tvm-runtime = { git = "https://github.com/brekkylab/tvm-runtime-rs", rev = "ee8c5626e5309827b00af9e46111fe23b0307308", features = ["vulkan"] } +tvm-runtime = { git = "https://github.com/brekkylab/tvm-runtime-rs", rev = "cea927a", features = ["vulkan"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] ailoy-faiss-sys = { path = "./crates/faiss-sys" } @@ -96,7 +96,7 @@ parking_lot = "0.12.4" rmcp = { version = "0.11.0", features = ["client", "reqwest", "transport-child-process", "transport-streamable-http-client", "transport-streamable-http-client-reqwest"] } tokenizers = { version = "0.22.2", default-features = false, features = ["onig"] } tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread", "sync"] } -tvm-ffi = { git = "https://github.com/brekkylab/tvm-runtime-rs", rev = "ee8c5626e5309827b00af9e46111fe23b0307308" } +tvm-ffi = { git = "https://github.com/brekkylab/tvm-runtime-rs", rev = "cea927a" } uuid = { version = "1.18.0", features = ["v4"] } [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/build.rs b/build.rs index 4381abda..87052f53 100644 --- a/build.rs +++ b/build.rs @@ -29,6 +29,26 @@ fn build_native() { if std::env::var_os("CARGO_FEATURE_NODEJS").is_some() { napi_build::setup(); } + + // Link libmlc_llm_module so its TVM_FFI_STATIC_INIT_BLOCKs run at process + // load. We deliberately link the *_module variant rather than libmlc_llm: + // mlc-llm's Python package loads libmlc_llm_module via tvm.ffi.load_module, + // so linking the same dylib keeps a single GlobalFunctionTable in process + // when both ailoy and mlc-llm are imported together. Linking libmlc_llm + // directly would put two copies of every mlc.json_ffi.* function in the + // global registry and the second registration aborts at process load. + println!("cargo:rerun-if-env-changed=MLC_LLM_LIB_DIR"); + if let Ok(mlc_dir) = std::env::var("MLC_LLM_LIB_DIR") { + println!("cargo:rustc-link-search=native={}", mlc_dir); + println!("cargo:rustc-link-lib=dylib=mlc_llm_module"); + } + + // Make the resulting cdylib portable: tell the linker to look for runtime + // dependencies right next to itself (`@loader_path`). Combined with the + // build/install step that copies libmlc_llm_module / libtvm{,_runtime,_ffi + // [_testing]} into bindings/python/ailoy/, this gives ailoy a self- + // contained dylib closure with no hard-coded venv paths. + println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path"); } fn build_wasm() { diff --git a/src/model/local/chat_template.rs b/src/model/local/chat_template.rs index 263b3f90..38307b2e 100644 --- a/src/model/local/chat_template.rs +++ b/src/model/local/chat_template.rs @@ -61,8 +61,25 @@ impl ChatTemplate { ThinkEffort::High => "high", _ => "", }; + // HuggingFace-style chat templates address message body as `message.content` + // (singular), but our Message struct serializes the field as `contents`. + // Re-serialize each message into a serde_json::Value so we can also expose + // `content` as an alias, otherwise jinja sees `content == undefined` and + // emits empty user/system bodies which collapses the model. + let messages_for_template: Vec = messages + .iter() + .map(|m| { + let mut v = serde_json::to_value(m).unwrap_or(serde_json::Value::Null); + if let Some(obj) = v.as_object_mut() { + if let Some(c) = obj.get("contents").cloned() { + obj.insert("content".to_string(), c); + } + } + v + }) + .collect(); let ctx = context!( - messages => messages, + messages => messages_for_template, tools => if !tools.is_empty() { Some(tools) } else { None::<_> }, documents => if !documents.is_empty() { Some(documents) } else { None::<_> }, add_generation_prompt => add_generation_prompt, diff --git a/src/model/local/inferencer.rs b/src/model/local/inferencer.rs index 00adfeb9..560e512f 100644 --- a/src/model/local/inferencer.rs +++ b/src/model/local/inferencer.rs @@ -5,6 +5,8 @@ pub use native::{EmbeddingModelInferencer, LanguageModelInferencer}; pub use wasm::{EmbeddingModelInferencer, LanguageModelInferencer}; use super::kv_cache::{KVCache, KVCacheConfig, KVCacheOps}; +use super::rnn_state::{RNNState, RNNStateConfig, RNNStateOps}; +use super::KVStateKind; use crate::{ cache::{Cache, CacheClaim, CacheEntry}, utils::BoxFuture, @@ -399,6 +401,8 @@ mod native { vm: Module, params: Array, kv_cache: KVCache, + rnn_state: Option, + kv_state_kind: KVStateKind, history: Vec, fembed: Function, @@ -454,9 +458,24 @@ mod native { let metadata: serde_json::Value = serde_json::from_str(&metadata) .map_err(|e| anyhow!("Failed to parse metadata json: {:?}", e))?; - let tensor_cache = TensorCache::from(tensor_cache_path, device) - .map_err(|e| anyhow!("Failed to initialize tensor cache: {:?}", e))?; - let param_names = metadata + // Load params. Two paths: + // + // 1. New `tensor-cache.json` layout (Qwen3.5 / V2 hybrid models): + // delegate to the runtime's global `vm.builtin.tensor_cache.*` + // so the resulting Tensors share the same in-vm cache instance + // that the compiled `batch_prefill` packed function expects. + // Loading via our own `TensorCache::from` produces ObjectRefs + // that look identical at the dlpack level but cause subtly + // different forward outputs in V2 hybrid because the prefill + // function correlates them with global-cache entries. + // + // 2. Legacy `ndarray-cache.json` layout (Qwen3 / V1 KvCache, + // BAAI/bge-m3 embedding, …): keep the in-tree TensorCache + // path. The compiled prefill/decode for V1 doesn't tie params + // to the global cache instance, so the simpler local loader + // is sufficient and avoids requiring the new tvm-runtime + // API surface for older artifacts. + let param_names_strs: Vec<&str> = metadata .get("params") .ok_or(anyhow!("Failed to get `params` attribute"))? .as_array() @@ -469,19 +488,94 @@ mod native { .ok_or(anyhow!("Failed to convert `name` to str")) }) .collect::>>()?; - let params = tensor_cache.get_params(param_names); + let cache_filename = tensor_cache_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(""); + let params: Array = if cache_filename == "tensor-cache.json" { + let model_dir_str: tvm_ffi::String = tvm_ffi::String::from( + tensor_cache_path + .parent() + .unwrap_or(tensor_cache_path) + .to_string_lossy() + .as_ref(), + ); + let f_tensor_cache_load = Function::get_global("vm.builtin.tensor_cache.load") + .map_err(|e| { + anyhow!("Failed to get global `vm.builtin.tensor_cache.load`: {:?}", e) + })?; + f_tensor_cache_load + .call_tuple(( + &model_dir_str, + device.device_type as i32, + device.device_id as i32, + )) + .map_err(|e| { + anyhow!("Failed to call vm.builtin.tensor_cache.load: {:?}", e) + })?; + let f_param_array_from_cache_by_name = Function::get_global( + "vm.builtin.param_array_from_cache_by_name", + ) + .map_err(|e| { + anyhow!( + "Failed to get global `vm.builtin.param_array_from_cache_by_name`: {:?}", + e + ) + })?; + let param_names_arr: Array = Array::new( + param_names_strs + .iter() + .map(|s| tvm_ffi::String::from(*s)) + .collect::>(), + ); + let result: Array = f_param_array_from_cache_by_name + .call_packed(&[AnyView::from(¶m_names_arr)]) + .map_err(|e| { + anyhow!("Failed to call param_array_from_cache_by_name: {:?}", e) + })? + .try_into() + .map_err(|e| anyhow!("Failed to convert params to Array: {:?}", e))?; + let f_tensor_cache_clear = Function::get_global("vm.builtin.tensor_cache.clear") + .map_err(|e| { + anyhow!("Failed to get global `vm.builtin.tensor_cache.clear`: {:?}", e) + })?; + f_tensor_cache_clear.call_tuple(()).map_err(|e| { + anyhow!("Failed to call vm.builtin.tensor_cache.clear: {:?}", e) + })?; + result + } else { + let tensor_cache = TensorCache::from(tensor_cache_path, device) + .map_err(|e| anyhow!("Failed to initialize tensor cache: {:?}", e))?; + tensor_cache.get_params(param_names_strs) + }; + + let kv_state_kind = KVStateKind::from_metadata(&metadata); let kv_cache = KVCache::new(&vm, kv_cache_config)?; + let rnn_state = if kv_state_kind.needs_rnn_state() { + Some(RNNState::new(&vm, RNNStateConfig::default())?) + } else { + None + }; let fembed = vm .get_function("embed") .map_err(|e| anyhow!("Failed to get `embed` function: {:?}", e))?; + + // Hybrid / RNN state models expose `batch_prefill` / `batch_decode` + // with extra arguments (logit_positions + rnn_state). Pure paged-KV + // models expose the classic `prefill` / `decode` pair. + let (prefill_name, decode_name) = if kv_state_kind.needs_rnn_state() { + ("batch_prefill", "batch_decode") + } else { + ("prefill", "decode") + }; let fprefill = vm - .get_function("prefill") - .map_err(|e| anyhow!("Failed to get `prefill` function: {:?}", e))?; + .get_function(prefill_name) + .map_err(|e| anyhow!("Failed to get `{}` function: {:?}", prefill_name, e))?; let fdecode = vm - .get_function("decode") - .map_err(|e| anyhow!("Failed to get `decode` function: {:?}", e))?; + .get_function(decode_name) + .map_err(|e| anyhow!("Failed to get `{}` function: {:?}", decode_name, e))?; let fapply_bitmask_inplace = vm .get_function("apply_bitmask_inplace") .map_err(|e| anyhow!("Failed to get `apply_bitmask_inplace` function: {:?}", e))?; @@ -498,6 +592,8 @@ mod native { vm, params, kv_cache, + rnn_state, + kv_state_kind, history: Vec::new(), fembed, @@ -544,11 +640,40 @@ mod native { pub fn clear(&mut self) -> anyhow::Result<()> { self.kv_cache.clear().map_err(|e| anyhow!("{e:?}"))?; + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.clear().map_err(|e| anyhow!("{e:?}"))?; + } self.history.clear(); Ok(()) } - pub fn prefill(&mut self, tokens: &[u32]) -> anyhow::Result<()> { + /// Build a 1-D int32 tensor on the inference device holding the last + /// token position in the prefill chunk. Mirrors mlc-llm's pattern: + /// allocate values on CPU, then copy to a device-side buffer of the + /// same shape via `Tensor::copy_from`. + fn make_logit_positions(&self, last_pos: i64) -> anyhow::Result { + use tvm_ffi::{DLDevice, DLDeviceType}; + let dtype = DLDataType { + code: DLDataTypeCode::kDLInt as u8, + bits: 32, + lanes: 1, + }; + let cpu = DLDevice { device_type: DLDeviceType::kDLCPU, device_id: 0 }; + let mut host_t = Tensor::empty(&[1i64], dtype, cpu); + // Fill via i32 slice on CPU. + host_t + .data_as_slice_mut::() + .map_err(|e| anyhow!("Failed to get i32 slice on host logit_positions: {:?}", e))? + [0] = last_pos as i32; + // Copy to device. + let mut dev_t = Tensor::empty(&[1i64], dtype, self.device); + dev_t.copy_from(&host_t).map_err(|e| { + anyhow!("Failed to copy logit_positions host→device: {:?}", e) + })?; + Ok(dev_t) + } + + pub fn prefill(&mut self, tokens: &[u32]) -> anyhow::Result> { if tokens.is_empty() { anyhow::bail!("Token must not be empty"); } @@ -566,12 +691,28 @@ mod native { .take_while(|(h, t)| h == t) .count(); - // Rewind the head of kv-cache to the LCP + // Rewind the head of kv-cache (and rnn state, when hybrid) to the LCP. + // RNN state can only roll back at most `max_history` tokens; if the + // requested rollback exceeds that, fall back to a full clear and + // re-prefill the entire prompt. This costs the LCP-rewind speedup + // but preserves correctness for hybrid (Gated DeltaNet) models. if lcp_index < self.history.len() { - self.kv_cache - .popn(0, (self.history.len() - lcp_index) as i64) - .map_err(|e| anyhow!("{e:?}"))?; - self.history.drain(lcp_index..); + let n = (self.history.len() - lcp_index) as i64; + let rnn_can_pop = self + .rnn_state + .as_ref() + .map(|rnn| n <= rnn.max_history) + .unwrap_or(true); + if rnn_can_pop { + self.kv_cache.popn(0, n).map_err(|e| anyhow!("{e:?}"))?; + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.popn(0, n).map_err(|e| anyhow!("{e:?}"))?; + } + self.history.drain(lcp_index..); + } else { + // RNN state cannot roll back this far — start over. + self.clear()?; + } } // Tokens to be added (without common prefixes) @@ -594,20 +735,70 @@ mod native { self.kv_cache .begin_forward(0, length as i64) .map_err(|e| anyhow!("{e:?}"))?; - self.fprefill - .call_packed(&[ - AnyView::from(&embedding), - AnyView::from(self.kv_cache.get_state()), - AnyView::from(&self.params), - ]) - .map_err(|e| anyhow!("{e:?}"))?; + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.begin_forward(0, length as i64) + .map_err(|e| anyhow!("{e:?}"))?; + } + + let output = match (&self.rnn_state, self.kv_state_kind) { + (Some(rnn), KVStateKind::Hybrid) => { + let last_pos = (length as i64) - 1; + let logit_positions = self.make_logit_positions(last_pos)?; + self.fprefill + .call_packed(&[ + AnyView::from(&embedding), + AnyView::from(&logit_positions), + AnyView::from(self.kv_cache.get_state()), + AnyView::from(rnn.get_state()), + AnyView::from(&self.params), + ]) + .map_err(|e| anyhow!("{e:?}"))? + } + (Some(rnn), KVStateKind::RnnState) => { + let last_pos = (length as i64) - 1; + let logit_positions = self.make_logit_positions(last_pos)?; + self.fprefill + .call_packed(&[ + AnyView::from(&embedding), + AnyView::from(&logit_positions), + AnyView::from(rnn.get_state()), + AnyView::from(&self.params), + ]) + .map_err(|e| anyhow!("{e:?}"))? + } + _ => { + self.fprefill + .call_packed(&[ + AnyView::from(&embedding), + AnyView::from(self.kv_cache.get_state()), + AnyView::from(&self.params), + ]) + .map_err(|e| anyhow!("{e:?}"))? + } + }; + + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.end_forward().map_err(|e| anyhow!("{e:?}"))?; + } self.kv_cache.end_forward().map_err(|e| anyhow!("{e:?}"))?; // Update history self.history.extend(tokens_sliced.iter().map(|&v| v as u32)); + + // On the last chunk, capture logits for the caller so the + // prompt's last token is not pushed through KV state twice + // (a separate decode pass would extend KV by one extra token + // and produce incorrect logits afterwards). + if j == new_tokens.len() { + let logits: Tensor = unsafe { + tvm_runtime::get_from_any_array(output, 0) + .map_err(|e| anyhow!("Failed to get prefill logits: {:?}", e))? + }; + return Ok(Some(logits)); + } } - Ok(()) + Ok(None) } pub fn decode(&mut self, last_token: u32) -> anyhow::Result { @@ -616,21 +807,56 @@ mod native { self.kv_cache .begin_forward(0, 1) .map_err(|e| anyhow!("Failed to begin forward: {:?}", e))?; - let output = self - .fdecode - .call_packed(&[ - AnyView::from(&embedding), - AnyView::from(self.kv_cache.get_state()), - AnyView::from(&self.params), - ]) - .map_err(|e| anyhow!("Failed to call `decode`: {:?}", e))?; + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.begin_forward(0, 1) + .map_err(|e| anyhow!("Failed to begin forward (rnn): {:?}", e))?; + } + + let output = match (&self.rnn_state, self.kv_state_kind) { + (Some(rnn), KVStateKind::Hybrid) => self + .fdecode + .call_packed(&[ + AnyView::from(&embedding), + AnyView::from(self.kv_cache.get_state()), + AnyView::from(rnn.get_state()), + AnyView::from(&self.params), + ]) + .map_err(|e| anyhow!("Failed to call `batch_decode`: {:?}", e))?, + (Some(rnn), KVStateKind::RnnState) => self + .fdecode + .call_packed(&[ + AnyView::from(&embedding), + AnyView::from(rnn.get_state()), + AnyView::from(&self.params), + ]) + .map_err(|e| anyhow!("Failed to call `batch_decode`: {:?}", e))?, + _ => self + .fdecode + .call_packed(&[ + AnyView::from(&embedding), + AnyView::from(self.kv_cache.get_state()), + AnyView::from(&self.params), + ]) + .map_err(|e| anyhow!("Failed to call `decode`: {:?}", e))?, + }; + + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.end_forward() + .map_err(|e| anyhow!("Failed to end forward (rnn): {:?}", e))?; + } self.kv_cache .end_forward() .map_err(|e| anyhow!("Failed to end forward: {:?}", e))?; - // The output of decode is an Array of 2 items: logits(Tensor) and kv cache. - let logits = unsafe { - tvm_ffi::collections::array::get_from_any_array(output, 0) + // The output is an Array: KVCache mode returns [logits, kv_cache]; + // hybrid/rnn modes return [logits, kv_cache, rnn_state] (or + // [logits, rnn_state] for pure rnn). `logits` is always at index 0. + // + // The array is heterogeneous (Tensor + state-object), so we cannot + // cast the whole thing to Array. Use the Any-level helper + // added in tvm-runtime::any_array to extract only index 0. + let logits: Tensor = unsafe { + tvm_runtime::get_from_any_array(output, 0) .map_err(|e| anyhow!("Failed to get logits from output array: {:?}", e))? }; @@ -868,6 +1094,8 @@ mod wasm { vm: tvmjs::Module, device: tvmjs::DLDevice, kv_cache: KVCache, + rnn_state: Option, + kv_state_kind: KVStateKind, params: tvmjs::TVMObject, history: Vec, @@ -886,10 +1114,24 @@ mod wasm { impl LanguageModelInferencer { fn clear(&mut self) -> anyhow::Result<()> { self.kv_cache.clear()?; + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.clear()?; + } self.history.clear(); Ok(()) } + /// Build a 1-element int32 tensor on device holding `last_pos`. + fn make_logit_positions(&self, last_pos: i64) -> anyhow::Result { + let input = self.tvm.empty( + u32_slice_to_js(&[1u32]), + "int32", + self.device.clone().into(), + ); + input.copy_from_i32array(&[last_pos as i32]); + Ok(input) + } + async fn embed(&self, tokens: &[i32]) -> anyhow::Result { let input = self.tvm.empty( u32_slice_to_js(&[tokens.len() as u32]), @@ -929,12 +1171,28 @@ mod wasm { .take_while(|(h, t)| h == t) .count(); - // Rewind the head of kv-cache to the LCP + // Rewind the head of kv-cache (and rnn state, when hybrid) to the LCP. + // RNN state can only roll back at most `max_history` tokens; if the + // requested rollback exceeds that, fall back to a full clear and + // re-prefill the entire prompt. This costs the LCP-rewind speedup + // but preserves correctness for hybrid (Gated DeltaNet) models. if lcp_index < self.history.len() { - self.kv_cache - .popn(0, (self.history.len() - lcp_index) as i64) - .map_err(|e| anyhow!("{e:?}"))?; - self.history.drain(lcp_index..); + let n = (self.history.len() - lcp_index) as i64; + let rnn_can_pop = self + .rnn_state + .as_ref() + .map(|rnn| n <= rnn.max_history) + .unwrap_or(true); + if rnn_can_pop { + self.kv_cache.popn(0, n).map_err(|e| anyhow!("{e:?}"))?; + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.popn(0, n).map_err(|e| anyhow!("{e:?}"))?; + } + self.history.drain(lcp_index..); + } else { + // RNN state cannot roll back this far — start over. + self.clear()?; + } } // Tokens to be added (without common prefixes) @@ -957,9 +1215,47 @@ mod wasm { self.kv_cache .begin_forward(0, length as i64) .map_err(|e| anyhow!("{e:?}"))?; - self.fprefill - .call3(&embedding, self.kv_cache.get_state(), &self.params) - .map_err(|e| anyhow!("{e:?}"))?; + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.begin_forward(0, length as i64) + .map_err(|e| anyhow!("{e:?}"))?; + } + + match (&self.rnn_state, self.kv_state_kind) { + (Some(rnn), KVStateKind::Hybrid) => { + let last_pos = (length as i64) - 1; + let logit_positions = self.make_logit_positions(last_pos)?; + self.fprefill + .call5( + &embedding, + &logit_positions, + self.kv_cache.get_state(), + rnn.get_state(), + &self.params, + ) + .map_err(|e| anyhow!("{e:?}"))?; + } + (Some(rnn), KVStateKind::RnnState) => { + let last_pos = (length as i64) - 1; + let logit_positions = self.make_logit_positions(last_pos)?; + self.fprefill + .call4( + &embedding, + &logit_positions, + rnn.get_state(), + &self.params, + ) + .map_err(|e| anyhow!("{e:?}"))?; + } + _ => { + self.fprefill + .call3(&embedding, self.kv_cache.get_state(), &self.params) + .map_err(|e| anyhow!("{e:?}"))?; + } + } + + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.end_forward().map_err(|e| anyhow!("{e:?}"))?; + } self.kv_cache.end_forward().map_err(|e| anyhow!("{e:?}"))?; // Update history @@ -984,11 +1280,38 @@ mod wasm { self.kv_cache .begin_forward(0, 1) .map_err(|e| anyhow!("Failed to begin forward: {:?}", e))?; - let output: tvmjs::TVMArray = self - .fdecode - .call3(&embedding, self.kv_cache.get_state(), &self.params) - .map_err(|e| anyhow!("Failed to call `decode`: {:?}", e))? - .into(); + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.begin_forward(0, 1) + .map_err(|e| anyhow!("Failed to begin forward (rnn): {:?}", e))?; + } + + let output: tvmjs::TVMArray = match (&self.rnn_state, self.kv_state_kind) { + (Some(rnn), KVStateKind::Hybrid) => self + .fdecode + .call4( + &embedding, + self.kv_cache.get_state(), + rnn.get_state(), + &self.params, + ) + .map_err(|e| anyhow!("Failed to call `batch_decode`: {:?}", e))? + .into(), + (Some(rnn), KVStateKind::RnnState) => self + .fdecode + .call3(&embedding, rnn.get_state(), &self.params) + .map_err(|e| anyhow!("Failed to call `batch_decode`: {:?}", e))? + .into(), + _ => self + .fdecode + .call3(&embedding, self.kv_cache.get_state(), &self.params) + .map_err(|e| anyhow!("Failed to call `decode`: {:?}", e))? + .into(), + }; + + if let Some(rnn) = self.rnn_state.as_mut() { + rnn.end_forward() + .map_err(|e| anyhow!("Failed to end forward (rnn): {:?}", e))?; + } self.kv_cache .end_forward() .map_err(|e| anyhow!("Failed to end forward: {:?}", e))?; @@ -1046,9 +1369,16 @@ mod wasm { let metadata = get_metadata(&vm)?; let params = initialize_params(&tvm, &device, &metadata, contents).await?; + let kv_state_kind = KVStateKind::from_metadata(&metadata); + let fembed: tvmjs::PackedFunc = tvm.detach(vm.get_function("embed")); - let fprefill: tvmjs::PackedFunc = tvm.detach(vm.get_function("prefill")); - let fdecode: tvmjs::PackedFunc = tvm.detach(vm.get_function("decode")); + let (prefill_name, decode_name) = if kv_state_kind.needs_rnn_state() { + ("batch_prefill", "batch_decode") + } else { + ("prefill", "decode") + }; + let fprefill: tvmjs::PackedFunc = tvm.detach(vm.get_function(prefill_name)); + let fdecode: tvmjs::PackedFunc = tvm.detach(vm.get_function(decode_name)); let fsample_top_p_from_logits: tvmjs::PackedFunc = tvm.detach(tvm.get_global_func("vm.builtin.sample_top_p_from_logits")); @@ -1061,6 +1391,20 @@ mod wasm { let kv_cache = KVCache::new(tvm.clone().into(), &vm, &metadata, kv_cache_config).unwrap(); + let rnn_state = if kv_state_kind.needs_rnn_state() { + Some( + RNNState::new( + tvm.clone().into(), + &vm, + &metadata, + RNNStateConfig::default(), + ) + .unwrap(), + ) + } else { + None + }; + tvm.end_scope(); Ok(LanguageModelInferencer { @@ -1068,6 +1412,8 @@ mod wasm { vm, device, kv_cache, + rnn_state, + kv_state_kind, params, history: Vec::new(), fembed, diff --git a/src/model/local/local_language_model.rs b/src/model/local/local_language_model.rs index f27e19e7..4881b405 100644 --- a/src/model/local/local_language_model.rs +++ b/src/model/local/local_language_model.rs @@ -214,25 +214,59 @@ impl LocalLangModelImpl { config: LangModelInferConfig, ) -> BoxStream<'a, anyhow::Result> { let strm = try_stream! { + let think_effort = config.think_effort.unwrap_or_default(); let prompt = if let Some(polyfill) = config.document_polyfill { let msgs = polyfill.polyfill(msgs, docs)?; - self.chat_template.apply(msgs, tools, Vec::new(), config.think_effort.unwrap_or_default(), true)? + self.chat_template.apply(msgs, tools, Vec::new(), think_effort.clone(), true)? } else { - self.chat_template.apply(msgs, tools, docs, config.think_effort.unwrap_or_default(), true)? + self.chat_template.apply(msgs, tools, docs, think_effort.clone(), true)? }; let input_tokens = self.tokenizer.encode(&prompt, true)?; - { - #[cfg(not(target_family = "wasm"))] - self.inferencer.prefill(&input_tokens).unwrap(); - #[cfg(target_family = "wasm")] + // Sample the first token from prefill's last-position logits to + // avoid pushing the prompt's last token through KV state twice + // (a separate decode pass would extend KV by one extra token and + // produce garbage logits afterwards). + #[cfg(not(target_family = "wasm"))] + let mut prefilled_first_token: Option = { + let prefill_logits = self.inferencer.prefill(&input_tokens).unwrap(); + let temperature = config.temperature.unwrap_or(0.6); + let top_p = config.top_p.unwrap_or(0.9); + prefill_logits + .map(|l| self.inferencer.sample(l, temperature, top_p).unwrap()) + }; + #[cfg(target_family = "wasm")] + let mut prefilled_first_token: Option = { self.inferencer.prefill(&input_tokens).await.unwrap(); - } + None + }; let mut last_token = *input_tokens.last().unwrap(); let mut agg_tokens = Vec::::new(); let mut count = 0; - let mut mode = "content".to_owned(); + // Decide the starting mode by scanning the rendered prompt for + // the most recent `` / `` marker. Different chat + // templates emit different trailers: + // - Qwen3.5 closes `\n\n\n\n` when thinking is + // enabled, so the model jumps straight to final content. + // - Older templates leave `\n` open and the model emits + // reasoning until it produces ``. + // We resolve marker ids through the tokenizer's vocab because + // Rust's tokenizers crate does not render special tokens through + // encode/decode even with skip_special_tokens=false. + let open_id = self.tokenizer.token_to_id(""); + let close_id = self.tokenizer.token_to_id(""); + let mut last_open = None; + let mut last_close = None; + for (i, t) in input_tokens.iter().enumerate() { + if Some(*t) == open_id { last_open = Some(i); } + if Some(*t) == close_id { last_close = Some(i); } + } + let mut mode = match (last_open, last_close) { + (Some(o), Some(c)) if o > c => "reasoning".to_owned(), + (Some(_), None) => "reasoning".to_owned(), + _ => "content".to_owned(), + }; let mut last_s = "".to_owned(); yield MessageDeltaOutput{delta: MessageDelta::new().with_role(Role::Assistant), finish_reason: None}; @@ -250,10 +284,11 @@ impl LocalLangModelImpl { let top_p = config.top_p.unwrap_or(0.9); #[cfg(not(target_family = "wasm"))] - let new_token = { + let new_token = if let Some(t) = prefilled_first_token.take() { + t + } else { let logits = self.inferencer.decode(last_token).unwrap(); - let new_token = self.inferencer.sample(logits, temperature, top_p).unwrap(); - new_token + self.inferencer.sample(logits, temperature, top_p).unwrap() }; #[cfg(target_family = "wasm")] let new_token = self.inferencer.decode(last_token, temperature, top_p).await.unwrap(); @@ -434,6 +469,500 @@ mod tests { })); } + // Local-only test: verifies that Hybrid (Qwen3.5 GatedDeltaNet) loads + // via batch_prefill/batch_decode and RNN state. Requires that + // ~/.cache/ailoy/Qwen--Qwen3.5-0.8B/ has been populated by the local + // compile script (scripts/compile_qwen35_local.sh) + chat_template.j2 + // extracted from tokenizer_config.json. + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-0.8B artifacts"] + async fn local_infer_qwen35_hybrid() { + // Qwen3.5 ships with context_window_size=262144 in metadata; that + // blows past available RAM when create_tir_paged_kv_cache sizes the + // page table. Force an 8K window for this smoke test. + let kv_cfg = crate::model::local::KVCacheConfig { + context_window_size: Some(8192), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + }; + let mut model = LocalLangModel::try_new( + "Qwen/Qwen3.5-0.8B", + Some( + LocalLangModelConfig::default() + .with_validate_checksum(false) + .with_kv_cache(&kv_cfg), + ), + ) + .await + .unwrap(); + let msgs = vec![ + Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant.".to_owned(), + }]), + Message::new(Role::User).with_contents([Part::Text { + text: "Say hi in one short sentence.".to_owned(), + }]), + ]; + let mut assistant_msg = MessageDelta::new(); + let infer_cfg = LangModelInferConfig { + max_tokens: Some(64), + ..LangModelInferConfig::default() + }; + let mut strm = model.infer_delta(msgs, Vec::new(), Vec::new(), infer_cfg); + let mut finish_reason = None; + while let Some(output_opt) = strm.next().await { + let output = output_opt.unwrap(); + assistant_msg = assistant_msg.accumulate(output.delta).unwrap(); + finish_reason = output.finish_reason; + } + // Accept either Stop or Length — a short max_tokens smoke test just + // needs the loop to terminate cleanly with at least one content part. + assert!(matches!( + finish_reason, + Some(FinishReason::Stop {}) | Some(FinishReason::Length {}) + )); + assert!(assistant_msg.finish().is_ok_and(|message| { + debug!("{:?}", message.contents.first().and_then(|c| c.as_text())); + message.contents.len() > 0 + })); + } + + /// Drives a single user turn through `model`, accumulates the assistant + /// message, and asserts the loop terminated cleanly. Returns the final + /// `MessageDelta` so the caller can fold it back into the next turn. + async fn drive_one_turn( + model: &mut LocalLangModel, + msgs: Vec, + max_tokens: Option, + ) -> MessageDelta { + let cfg = LangModelInferConfig { + max_tokens, + ..LangModelInferConfig::default() + }; + let mut strm = model.infer_delta(msgs, Vec::new(), Vec::new(), cfg); + let mut acc = MessageDelta::new(); + let mut finish_reason = None; + while let Some(out) = strm.next().await { + let out = out.unwrap(); + acc = acc.accumulate(out.delta).unwrap(); + finish_reason = out.finish_reason; + } + assert!( + matches!( + finish_reason, + Some(FinishReason::Stop {}) | Some(FinishReason::Length {}) + ), + "unexpected finish_reason: {:?}", + finish_reason + ); + acc + } + + /// Largest dense Qwen3.5 size verified on 24GB Apple Silicon. Context + /// is forced down to 4K because 9B + 8K KV pages overrun a tight RAM + /// envelope. q4f16_1 puts the weights at ~4.7GB on disk. + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-9B artifacts"] + async fn local_infer_qwen35_9b_hybrid() { + let kv_cfg = crate::model::local::KVCacheConfig { + context_window_size: Some(4096), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + }; + let mut model = LocalLangModel::try_new( + "Qwen/Qwen3.5-9B", + Some( + LocalLangModelConfig::default() + .with_validate_checksum(false) + .with_kv_cache(&kv_cfg), + ), + ) + .await + .unwrap(); + let msgs = vec![ + Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant.".to_owned(), + }]), + Message::new(Role::User).with_contents([Part::Text { + text: "Say hi in one short sentence.".to_owned(), + }]), + ]; + let cfg = LangModelInferConfig { + max_tokens: Some(64), + ..LangModelInferConfig::default() + }; + let mut strm = model.infer_delta(msgs, Vec::new(), Vec::new(), cfg); + let mut acc = MessageDelta::new(); + let mut finish_reason = None; + while let Some(out) = strm.next().await { + let out = out.unwrap(); + acc = acc.accumulate(out.delta).unwrap(); + finish_reason = out.finish_reason; + } + assert!(matches!( + finish_reason, + Some(FinishReason::Stop {}) | Some(FinishReason::Length {}) + )); + assert!(acc.finish().unwrap().contents.len() > 0); + } + + /// Same shape as `local_infer_qwen35_hybrid` but for the 4B variant. + /// Tests the upper end of the comfortable range on a 24GB Apple Silicon + /// host with q4f16_1 quantization. + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-4B artifacts"] + async fn local_infer_qwen35_4b_hybrid() { + let kv_cfg = crate::model::local::KVCacheConfig { + context_window_size: Some(8192), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + }; + let mut model = LocalLangModel::try_new( + "Qwen/Qwen3.5-4B", + Some( + LocalLangModelConfig::default() + .with_validate_checksum(false) + .with_kv_cache(&kv_cfg), + ), + ) + .await + .unwrap(); + let msgs = vec![ + Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant.".to_owned(), + }]), + Message::new(Role::User).with_contents([Part::Text { + text: "Say hi in one short sentence.".to_owned(), + }]), + ]; + let cfg = LangModelInferConfig { + max_tokens: Some(64), + ..LangModelInferConfig::default() + }; + let mut strm = model.infer_delta(msgs, Vec::new(), Vec::new(), cfg); + let mut acc = MessageDelta::new(); + let mut finish_reason = None; + while let Some(out) = strm.next().await { + let out = out.unwrap(); + acc = acc.accumulate(out.delta).unwrap(); + finish_reason = out.finish_reason; + } + assert!(matches!( + finish_reason, + Some(FinishReason::Stop {}) | Some(FinishReason::Length {}) + )); + assert!(acc.finish().unwrap().contents.len() > 0); + } + + /// Same shape as `local_infer_qwen35_hybrid` but for the 2B variant. + /// Pipeline generality check — proves the hybrid path scales beyond + /// the smallest member of the family. + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-2B artifacts"] + async fn local_infer_qwen35_2b_hybrid() { + let kv_cfg = crate::model::local::KVCacheConfig { + context_window_size: Some(8192), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + }; + let mut model = LocalLangModel::try_new( + "Qwen/Qwen3.5-2B", + Some( + LocalLangModelConfig::default() + .with_validate_checksum(false) + .with_kv_cache(&kv_cfg), + ), + ) + .await + .unwrap(); + let msgs = vec![ + Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant.".to_owned(), + }]), + Message::new(Role::User).with_contents([Part::Text { + text: "Say hi in one short sentence.".to_owned(), + }]), + ]; + let cfg = LangModelInferConfig { + max_tokens: Some(64), + ..LangModelInferConfig::default() + }; + let mut strm = model.infer_delta(msgs, Vec::new(), Vec::new(), cfg); + let mut acc = MessageDelta::new(); + let mut finish_reason = None; + while let Some(out) = strm.next().await { + let out = out.unwrap(); + acc = acc.accumulate(out.delta).unwrap(); + finish_reason = out.finish_reason; + } + assert!(matches!( + finish_reason, + Some(FinishReason::Stop {}) | Some(FinishReason::Length {}) + )); + assert!(acc.finish().unwrap().contents.len() > 0); + } + + /// Measures decode throughput (tokens/sec) and prints peak RSS so the + /// numbers can be folded into FINAL_REPORT. Both Qwen3-0.6B (KvCache) and + /// Qwen3.5-0.8B (Hybrid) are covered. Does not assert specific numbers + /// because the absolute speed depends on the host; we just exercise the + /// pipeline end-to-end and emit measurements. + #[multi_platform_test] + #[ignore = "throughput measurement, run on demand with --ignored"] + async fn local_infer_throughput_measurement() { + async fn measure(model_key: &'static str, kv_cfg: Option) { + let mut config = LocalLangModelConfig::default().with_validate_checksum(false); + if let Some(kv) = kv_cfg.as_ref() { + config = config.with_kv_cache(kv); + } + let mut model = LocalLangModel::try_new(model_key, Some(config)) + .await + .unwrap(); + let msgs = vec![ + Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant.".to_owned(), + }]), + Message::new(Role::User).with_contents([Part::Text { + text: "Count from one to twenty in english.".to_owned(), + }]), + ]; + let cfg = LangModelInferConfig { + max_tokens: Some(128), + ..LangModelInferConfig::default() + }; + let start = std::time::Instant::now(); + let mut strm = model.infer_delta(msgs, Vec::new(), Vec::new(), cfg); + let mut produced = 0u64; + while let Some(out) = strm.next().await { + let out = out.unwrap(); + // Each yielded delta corresponds to (at most) one decode step, + // ignoring the leading role-only delta. Approximate by counting + // any non-empty content/tool_call delta. + if !out.delta.contents.is_empty() || !out.delta.tool_calls.is_empty() { + produced += 1; + } + if out.finish_reason.is_some() { + break; + } + } + let elapsed = start.elapsed(); + let tps = produced as f64 / elapsed.as_secs_f64(); + eprintln!( + "[throughput] model={} produced={} elapsed={:.2}s rate={:.2} tok/s", + model_key, + produced, + elapsed.as_secs_f64(), + tps + ); + } + measure("Qwen/Qwen3-0.6B", None).await; + let kv8k = crate::model::local::KVCacheConfig { + context_window_size: Some(8192), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + }; + measure("Qwen/Qwen3.5-0.8B", Some(kv8k.clone())).await; + measure("Qwen/Qwen3.5-2B", Some(kv8k.clone())).await; + measure("Qwen/Qwen3.5-4B", Some(kv8k)).await; + let kv4k = crate::model::local::KVCacheConfig { + context_window_size: Some(4096), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + }; + measure("Qwen/Qwen3.5-9B", Some(kv4k)).await; + } + + /// Backward-compatibility regression for a larger V1 artifact (Qwen3-8B). + /// The V2 reader path must still load V1 bytecode produced by the legacy + /// brekky toolchain. Marked `#[ignore]` because Qwen3-8B is ~4GB and is + /// not always present in the local cache. + #[multi_platform_test] + #[ignore = "requires Qwen3-8B in ~/.cache/ailoy"] + async fn local_infer_qwen3_8b_v1_compat() { + let mut model = LocalLangModel::try_new( + "Qwen/Qwen3-8B", + Some(LocalLangModelConfig::default().with_validate_checksum(false)), + ) + .await + .unwrap(); + let msgs = vec![ + Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant.".to_owned(), + }]), + Message::new(Role::User).with_contents([Part::Text { + text: "Say hi in one short sentence.".to_owned(), + }]), + ]; + let cfg = LangModelInferConfig { + max_tokens: Some(64), + ..LangModelInferConfig::default() + }; + let mut strm = model.infer_delta(msgs, Vec::new(), Vec::new(), cfg); + let mut acc = MessageDelta::new(); + let mut finish_reason = None; + while let Some(out) = strm.next().await { + let out = out.unwrap(); + acc = acc.accumulate(out.delta).unwrap(); + finish_reason = out.finish_reason; + } + assert!(matches!( + finish_reason, + Some(FinishReason::Stop {}) | Some(FinishReason::Length {}) + )); + assert!(acc.finish().unwrap().contents.len() > 0); + } + + /// Multi-turn LCP rewind regression on the existing KvCache path + /// (Qwen3-0.6B). Two consecutive user turns share a system prefix; the + /// second `infer_delta` should hit `popn(0, ...)` exactly once because the + /// LCP equals the assistant reply length appended after turn 1. + #[multi_platform_test] + async fn local_infer_multi_turn_kvcache() { + let mut model = LocalLangModel::try_new( + "Qwen/Qwen3-0.6B", + Some(LocalLangModelConfig::default().with_validate_checksum(false)), + ) + .await + .unwrap(); + let system = Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant. Answer in one short sentence.".to_owned(), + }]); + // Turn 1 + let user1 = Message::new(Role::User) + .with_contents([Part::Text { text: "Say hi.".to_owned() }]); + let acc1 = drive_one_turn( + &mut model, + vec![system.clone(), user1.clone()], + Some(64), + ) + .await; + let assistant1 = acc1.finish().unwrap(); + // Turn 2: same system + same user1 + assistant1 + user2. + // History inside the runtime ends with assistant1's tokens; the next + // prefill computes LCP up to that point and `popn` should be a no-op + // because tokens[..lcp_index] covers the full history. + let user2 = Message::new(Role::User) + .with_contents([Part::Text { text: "And again, please.".to_owned() }]); + let acc2 = drive_one_turn( + &mut model, + vec![system, user1, assistant1, user2], + Some(64), + ) + .await; + let assistant2 = acc2.finish().unwrap(); + assert!(assistant2.contents.len() > 0); + } + + /// Generic multi-turn driver. Same shape as + /// `local_infer_multi_turn_kvcache` but with a configurable model and KV + /// config so the same pattern can be reused for every Qwen3.5 size. + async fn run_multi_turn(model_key: &'static str, kv_cfg: crate::model::local::KVCacheConfig) { + let mut model = LocalLangModel::try_new( + model_key, + Some( + LocalLangModelConfig::default() + .with_validate_checksum(false) + .with_kv_cache(&kv_cfg), + ), + ) + .await + .unwrap(); + let system = Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant. Answer in one short sentence.".to_owned(), + }]); + let user1 = Message::new(Role::User) + .with_contents([Part::Text { text: "Say hi.".to_owned() }]); + let acc1 = drive_one_turn( + &mut model, + vec![system.clone(), user1.clone()], + Some(48), + ) + .await; + let assistant1 = acc1.finish().unwrap(); + let user2 = Message::new(Role::User) + .with_contents([Part::Text { text: "And again, please.".to_owned() }]); + let acc2 = drive_one_turn( + &mut model, + vec![system, user1, assistant1, user2], + Some(48), + ) + .await; + let assistant2 = acc2.finish().unwrap(); + assert!(assistant2.contents.len() > 0); + } + + fn hybrid_kv(ctx: u32) -> crate::model::local::KVCacheConfig { + crate::model::local::KVCacheConfig { + context_window_size: Some(ctx), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + } + } + + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-2B artifacts"] + async fn local_infer_multi_turn_qwen35_2b() { + run_multi_turn("Qwen/Qwen3.5-2B", hybrid_kv(8192)).await; + } + + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-4B artifacts"] + async fn local_infer_multi_turn_qwen35_4b() { + run_multi_turn("Qwen/Qwen3.5-4B", hybrid_kv(8192)).await; + } + + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-9B artifacts; very slow on tight RAM"] + async fn local_infer_multi_turn_qwen35_9b() { + run_multi_turn("Qwen/Qwen3.5-9B", hybrid_kv(4096)).await; + } + + /// Multi-turn LCP rewind regression on the Hybrid path (Qwen3.5-0.8B). + /// Same shape as `local_infer_multi_turn_kvcache` but additionally + /// exercises `RNNState::popn` and the rnn_state begin/end_forward pair + /// across two consecutive turns. + #[multi_platform_test] + #[ignore = "requires locally-compiled Qwen3.5-0.8B artifacts"] + async fn local_infer_multi_turn_hybrid() { + let kv_cfg = crate::model::local::KVCacheConfig { + context_window_size: Some(8192), + prefill_chunk_size: Some(2048), + sliding_window_size: None, + }; + let mut model = LocalLangModel::try_new( + "Qwen/Qwen3.5-0.8B", + Some( + LocalLangModelConfig::default() + .with_validate_checksum(false) + .with_kv_cache(&kv_cfg), + ), + ) + .await + .unwrap(); + let system = Message::new(Role::System).with_contents([Part::Text { + text: "You are an assistant. Answer in one short sentence.".to_owned(), + }]); + let user1 = Message::new(Role::User) + .with_contents([Part::Text { text: "Say hi.".to_owned() }]); + let acc1 = drive_one_turn( + &mut model, + vec![system.clone(), user1.clone()], + Some(64), + ) + .await; + let assistant1 = acc1.finish().unwrap(); + let user2 = Message::new(Role::User) + .with_contents([Part::Text { text: "And again, please.".to_owned() }]); + let acc2 = drive_one_turn( + &mut model, + vec![system, user1, assistant1, user2], + Some(64), + ) + .await; + let assistant2 = acc2.finish().unwrap(); + assert!(assistant2.contents.len() > 0); + } + #[multi_platform_test] async fn local_infer_tool_call() { let mut model = LocalLangModel::try_new( diff --git a/src/model/local/tokenizer.rs b/src/model/local/tokenizer.rs index 7b42bda3..50da6e0f 100644 --- a/src/model/local/tokenizer.rs +++ b/src/model/local/tokenizer.rs @@ -33,6 +33,11 @@ impl Tokenizer { .decode(ids, skip_special_tokens) .map_err(|e| anyhow!("Tokenizer::decode failed: {}", e)) } + + /// Resolve a (special or regular) token literal to its id. + pub fn token_to_id(&self, token: &str) -> Option { + self.inner.token_to_id(token) + } } impl<'this> TryFromCache<'this> for Tokenizer {