From 4db111927050d96d05408995af8312c8adc8b01d Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 16 May 2026 17:45:46 +0530 Subject: [PATCH 01/93] feat(litert): add LiteRT-LM as second on-device inference engine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add LiteRTModule.kt — native Android module managing Engine/Conversation lifecycle with NPU→GPU→CPU fallback chain and image decode pipeline - Add LiteRTPackage.kt and register in MainApplication - Add LiteRTService.ts — JS bridge with streaming token events - Wire generation routing in generationServiceHelpers (litert vs llama.cpp) - Add doLoadLiteRTModel in activeModelService loaders - Add .litertlm import support with per-model vision toggle dialog - Add liteRTVision and engine fields to DownloadedModel type - Add persistent debug logs store (AsyncStorage-backed, survives crashes) - Add DebugLogsScreen modal accessible from ChatHeader terminal icon - Upgrade litertlm-android 0.10.0→0.11.0, Kotlin 2.1.20→2.2.0, kapt→ksp - Fix SIGSEGV: gate visionBackend=GPU behind supportsVision flag - Fix double load: check liteRTService.isModelLoaded() before triggering load - Fix reload loop: skip hasPendingSettings and handleReloadTextModel for litert - Add LITERT_TODO.md with full production readiness backlog - Fix lint errors and update modelManager tests for .litertlm support Co-Authored-By: Dishit Karia --- __tests__/unit/services/modelManager.test.ts | 48 +- android/app/build.gradle | 11 +- .../java/ai/offgridmobile/MainApplication.kt | 2 + .../ai/offgridmobile/litert/LiteRTModule.kt | 406 ++++++++++++++ .../ai/offgridmobile/litert/LiteRTPackage.kt | 16 + android/build.gradle | 7 +- android/gradle.properties | 1 + docs/LITERT_TODO.md | 52 ++ src/components/DebugLogsScreen/index.tsx | 124 ++--- src/navigation/AppNavigator.tsx | 2 + src/navigation/types.ts | 1 + src/screens/ChatScreen/ChatModalSection.tsx | 7 +- .../ChatScreen/ChatScreenComponents.tsx | 6 +- src/screens/ChatScreen/index.tsx | 3 + .../ChatScreen/useChatGenerationActions.ts | 6 + src/screens/ChatScreen/useChatModelActions.ts | 26 +- src/screens/ChatScreen/useChatScreen.ts | 6 + src/screens/DebugLiteRTScreen.tsx | 499 ++++++++++++++++++ src/screens/ModelsScreen/importHelpers.ts | 19 + src/screens/ModelsScreen/useModelsScreen.ts | 7 +- src/screens/SettingsScreen.tsx | 7 + src/screens/index.ts | 1 + src/services/activeModelService/loaders.ts | 74 +++ src/services/generationServiceHelpers.ts | 91 ++++ src/services/litert.ts | 213 ++++++++ src/services/modelManager/copyFile.ts | 52 ++ src/services/modelManager/scan.ts | 61 +-- src/stores/debugLogsStore.ts | 42 +- src/types/index.ts | 6 + 29 files changed, 1660 insertions(+), 136 deletions(-) create mode 100644 android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt create mode 100644 android/app/src/main/java/ai/offgridmobile/litert/LiteRTPackage.kt create mode 100644 docs/LITERT_TODO.md create mode 100644 src/screens/DebugLiteRTScreen.tsx create mode 100644 src/services/litert.ts create mode 100644 src/services/modelManager/copyFile.ts diff --git a/__tests__/unit/services/modelManager.test.ts b/__tests__/unit/services/modelManager.test.ts index 3a9aeaa04..f74493ead 100644 --- a/__tests__/unit/services/modelManager.test.ts +++ b/__tests__/unit/services/modelManager.test.ts @@ -1149,7 +1149,7 @@ describe('ModelManager', () => { it('rejects non-.gguf files', async () => { await expect( modelManager.importLocalModel({ sourceUri: '/path/to/model.bin', fileName: 'model.bin' }) - ).rejects.toThrow('Only .gguf files can be imported'); + ).rejects.toThrow('Only .gguf and .litertlm files can be imported'); }); it('rejects when destination already exists', async () => { @@ -2705,4 +2705,50 @@ describe('ModelManager', () => { expect(result[0].backend).toBe('coreml'); }); }); + + describe('importLocalModel — LiteRT branches', () => { + function setupImportMocks() { + mockedRNFS.exists + .mockResolvedValueOnce(true) // modelsDir + .mockResolvedValueOnce(true) // imageModelsDir + .mockResolvedValueOnce(false); // destExists = false + mockedRNFS.stat.mockResolvedValue({ size: 500000000, isFile: () => true } as any); + (mockedRNFS as any).copyFile.mockResolvedValue(undefined); + mockedAsyncStorage.getItem.mockResolvedValue('[]'); + } + + it('imports a .litertlm file with engine=litert and liteRTVision=false', async () => { + setupImportMocks(); + const result = await modelManager.importLocalModel({ + sourceUri: '/path/to/gemma.litertlm', + fileName: 'gemma-4-E2B-it.litertlm', + engine: 'litert', + liteRTVision: false, + }); + expect(result.engine).toBe('litert'); + expect(result.liteRTVision).toBe(false); + expect(result.id).toBe('local_import/gemma-4-E2B-it.litertlm'); + }); + + it('imports a .litertlm file with liteRTVision=true', async () => { + setupImportMocks(); + const result = await modelManager.importLocalModel({ + sourceUri: '/path/to/gemma-vision.litertlm', + fileName: 'gemma-vision.litertlm', + engine: 'litert', + liteRTVision: true, + }); + expect(result.liteRTVision).toBe(true); + }); + + it('omits engine and liteRTVision when not provided', async () => { + setupImportMocks(); + const result = await modelManager.importLocalModel({ + sourceUri: '/path/to/model.gguf', + fileName: 'model.gguf', + }); + expect(result.engine).toBeUndefined(); + expect(result.liteRTVision).toBeUndefined(); + }); + }); }); diff --git a/android/app/build.gradle b/android/app/build.gradle index f0e24f770..e831ffcec 100644 --- a/android/app/build.gradle +++ b/android/app/build.gradle @@ -1,6 +1,6 @@ apply plugin: "com.android.application" apply plugin: "org.jetbrains.kotlin.android" -apply plugin: "org.jetbrains.kotlin.kapt" +apply plugin: "com.google.devtools.ksp" apply plugin: "com.facebook.react" apply from: file("../../node_modules/react-native-vector-icons/fonts.gradle") @@ -84,8 +84,8 @@ android { applicationId "ai.offgridmobile" minSdkVersion rootProject.ext.minSdkVersion targetSdkVersion rootProject.ext.targetSdkVersion - versionCode 1776434971 - versionName "0.0.89" + versionCode 1778048025 + versionName "0.1.00" } signingConfigs { debug { @@ -157,11 +157,14 @@ dependencies { // PDF text extraction (used by PDFExtractorModule) implementation("io.legere:pdfiumandroid:1.0.35") + // LiteRT-LM on-device LLM inference (pinned — do not use latest.release) + implementation("com.google.ai.edge.litertlm:litertlm-android:0.11.0") + // Download layer — Room + WorkManager + OkHttp def room_version = "2.8.2" implementation("androidx.room:room-runtime:$room_version") implementation("androidx.room:room-ktx:$room_version") - kapt("androidx.room:room-compiler:$room_version") + ksp("androidx.room:room-compiler:$room_version") implementation("androidx.work:work-runtime-ktx:2.10.0") implementation("androidx.lifecycle:lifecycle-livedata-ktx:2.8.7") implementation("com.squareup.okhttp3:okhttp:4.12.0") diff --git a/android/app/src/main/java/ai/offgridmobile/MainApplication.kt b/android/app/src/main/java/ai/offgridmobile/MainApplication.kt index 67669a074..3d7b961ec 100644 --- a/android/app/src/main/java/ai/offgridmobile/MainApplication.kt +++ b/android/app/src/main/java/ai/offgridmobile/MainApplication.kt @@ -9,6 +9,7 @@ import com.facebook.react.defaults.DefaultReactHost.getDefaultReactHost import ai.offgridmobile.download.DownloadManagerPackage import ai.offgridmobile.localdream.LocalDreamPackage import ai.offgridmobile.pdf.PDFExtractorPackage +import ai.offgridmobile.litert.LiteRTPackage class MainApplication : Application(), ReactApplication { @@ -21,6 +22,7 @@ class MainApplication : Application(), ReactApplication { add(DownloadManagerPackage()) add(LocalDreamPackage()) add(PDFExtractorPackage()) + add(LiteRTPackage()) }, ) } diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt new file mode 100644 index 000000000..eee459d3f --- /dev/null +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -0,0 +1,406 @@ +package ai.offgridmobile.litert + +import android.util.Log +import com.facebook.react.bridge.* +import com.facebook.react.modules.core.DeviceEventManagerModule +import ai.offgridmobile.SafePromise +import com.google.ai.edge.litertlm.Backend +import com.google.ai.edge.litertlm.ConversationConfig +import com.google.ai.edge.litertlm.Engine +import com.google.ai.edge.litertlm.EngineConfig +import com.google.ai.edge.litertlm.Content +import com.google.ai.edge.litertlm.Contents +import com.google.ai.edge.litertlm.SamplerConfig +import kotlinx.coroutines.* +import java.io.File +import java.io.InputStream +import java.io.ByteArrayOutputStream +import android.net.Uri +import android.graphics.Bitmap +import android.graphics.BitmapFactory + +class LiteRTModule(private val reactContext: ReactApplicationContext) : + ReactContextBaseJavaModule(reactContext) { + + companion object { + private const val TAG = "LiteRTModule" + + // Streaming events sent to JS + const val EVENT_TOKEN = "litert_token" + const val EVENT_THINKING = "litert_thinking" + const val EVENT_COMPLETE = "litert_complete" + const val EVENT_ERROR = "litert_error" + + // Timeouts per backend tier + private const val NPU_TIMEOUT_MS = 45_000L + private const val GPU_TIMEOUT_MS = 20_000L + private const val CPU_TIMEOUT_MS = 15_000L + } + + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + + private var engine: Engine? = null + private var conversation: com.google.ai.edge.litertlm.Conversation? = null + private var activeBackend: String = "cpu" + private var supportsVision: Boolean = false + private var currentJob: Job? = null + + override fun getName(): String = "LiteRTModule" + + // ------------------------------------------------------------------------- + // loadModel + // ------------------------------------------------------------------------- + + @ReactMethod + fun loadModel(modelPath: String, backendStr: String, visionEnabled: Boolean, promise: Promise) { + val safe = SafePromise(promise, TAG) + Log.i(TAG, "loadModel — path=$modelPath backend=$backendStr vision=$visionEnabled") + + scope.launch { + try { + // Unload any existing engine first + cleanupEngine() + + val requestedBackend = parseBackend(backendStr) + Log.i(TAG, "loadModel — attempting backend chain from $backendStr") + + val resolvedBackend = initializeWithFallback(modelPath, requestedBackend, visionEnabled) + activeBackend = backendName(resolvedBackend) + supportsVision = visionEnabled + + Log.i(TAG, "loadModel — success on backend=$activeBackend vision=$supportsVision") + safe.resolve(activeBackend) + } catch (e: Exception) { + Log.e(TAG, "loadModel — all backends failed: ${e.message}", e) + safe.reject("LITERT_LOAD_ERROR", "Failed to load model: ${e.message}", e) + } + } + } + + // 3-tier fallback: NPU → GPU → CPU + private suspend fun initializeWithFallback(modelPath: String, requested: Backend, visionEnabled: Boolean): Backend { + val chain = when (requested) { + is Backend.NPU -> listOf( + Backend.NPU(nativeLibraryDir = reactContext.applicationInfo.nativeLibraryDir), + Backend.GPU(), + Backend.CPU(), + ) + is Backend.GPU -> listOf(Backend.GPU(), Backend.CPU()) + else -> listOf(Backend.CPU()) + } + + var lastError: Exception? = null + for (backend in chain) { + val name = backendName(backend) + Log.i(TAG, "initializeWithFallback — trying $name vision=$visionEnabled") + try { + val cfg = EngineConfig( + modelPath = modelPath, + backend = backend, + cacheDir = null, + visionBackend = if (visionEnabled) Backend.GPU() else null, + ) + val eng = Engine(cfg) + val timeoutMs = when (backend) { + is Backend.NPU -> NPU_TIMEOUT_MS + is Backend.GPU -> GPU_TIMEOUT_MS + else -> CPU_TIMEOUT_MS + } + withTimeout(timeoutMs) { + eng.initialize() + } + engine = eng + Log.i(TAG, "initializeWithFallback — $name succeeded") + return backend + } catch (e: Exception) { + Log.w(TAG, "initializeWithFallback — $name failed: ${e.message}") + engine?.close() + engine = null + lastError = e + if (backend == chain.last()) break + Log.i(TAG, "initializeWithFallback — falling back to next tier") + } + } + throw lastError ?: IllegalStateException("All backends failed") + } + + // ------------------------------------------------------------------------- + // resetConversation — closes and recreates Conversation only, Engine stays + // ------------------------------------------------------------------------- + + @ReactMethod + fun resetConversation(systemPrompt: String, promise: Promise) { + val safe = SafePromise(promise, TAG) + Log.i(TAG, "resetConversation — systemPrompt length=${systemPrompt.length}") + + scope.launch { + try { + val eng = engine + if (eng == null) { + Log.w(TAG, "resetConversation — no engine loaded") + safe.reject("LITERT_NOT_LOADED", "No model loaded", null) + return@launch + } + + // Close existing conversation first + closeConversation() + + // SamplerConfig is not supported on NPU + val samplerConfig = if (activeBackend == "npu") { + Log.i(TAG, "resetConversation — NPU backend, skipping SamplerConfig") + null + } else { + SamplerConfig(topK = 40, topP = 0.95, temperature = 0.8) + } + + val convConfig = ConversationConfig( + systemInstruction = if (systemPrompt.isNotEmpty()) + Contents.of(systemPrompt) else null, + samplerConfig = samplerConfig, + ) + + conversation = eng.createConversation(convConfig) + Log.i(TAG, "resetConversation — new conversation created") + safe.resolve(null) + } catch (e: Exception) { + Log.e(TAG, "resetConversation — error: ${e.message}", e) + safe.reject("LITERT_CONV_ERROR", "Failed to reset conversation: ${e.message}", e) + } + } + } + + // ------------------------------------------------------------------------- + // sendMessage — sends only the current user turn, library holds history + // ------------------------------------------------------------------------- + + @ReactMethod + fun sendMessage(text: String, imageUri: String?, promise: Promise) { + val safe = SafePromise(promise, TAG) + Log.i(TAG, "sendMessage — text length=${text.length} hasImage=${imageUri != null}") + + scope.launch { + // Wait for any in-flight generation to finish + currentJob?.join() + + val conv = conversation + if (conv == null) { + Log.w(TAG, "sendMessage — no conversation, call resetConversation first") + safe.reject("LITERT_NO_CONV", "No conversation. Call resetConversation first.", null) + return@launch + } + + if (imageUri != null && !supportsVision) { + Log.w(TAG, "sendMessage — image provided but model was not loaded with vision support, ignoring image") + } + + currentJob = launch { + try { + Log.i(TAG, "sendMessage — starting generation") + + val contents = if (imageUri != null && supportsVision) { + Log.i(TAG, "sendMessage — reading image from URI: $imageUri") + val pngBytes = try { + readImageAsPngBytes(imageUri) + } catch (e: Exception) { + Log.e(TAG, "sendMessage — failed to read/decode image: ${e.message}", e) + sendEvent(EVENT_ERROR, "Failed to read image: ${e.message}") + safe.reject("LITERT_IMG_ERROR", "Failed to read image: ${e.message}", e) + return@launch + } + Log.i(TAG, "sendMessage — image decoded to PNG, bytes=${pngBytes.size}") + // Image before text — matches reference implementation order + Contents.of(Content.ImageBytes(pngBytes), Content.Text(text)) + } else { + Contents.of(text) + } + + Log.i(TAG, "sendMessage — calling sendMessageAsync") + conv.sendMessageAsync(contents) + .collect { message -> + val thought = message.channels["thought"] + if (thought != null && thought.isNotEmpty()) { + Log.d(TAG, "sendMessage — thinking token") + sendEvent(EVENT_THINKING, thought) + } else { + val token = message.contents.contents + .filterIsInstance() + .joinToString("") { it.text } + Log.d(TAG, "sendMessage — token: '$token'") + if (token.isNotEmpty()) sendEvent(EVENT_TOKEN, token) + } + } + Log.i(TAG, "sendMessage — generation complete") + sendEvent(EVENT_COMPLETE, "") + safe.resolve(null) + } catch (e: CancellationException) { + Log.i(TAG, "sendMessage — job cancelled") + sendEvent(EVENT_COMPLETE, "") + safe.resolve(null) + } catch (e: OutOfMemoryError) { + Log.e(TAG, "sendMessage — OOM: ${e.message}") + sendEvent(EVENT_ERROR, "Out of memory processing image") + safe.reject("LITERT_OOM", "Out of memory processing image", null) + } catch (e: Exception) { + Log.e(TAG, "sendMessage — error: ${e.message}", e) + sendEvent(EVENT_ERROR, e.message ?: "Unknown error") + safe.reject("LITERT_GEN_ERROR", "Generation failed: ${e.message}", e) + } finally { + currentJob = null + } + } + } + } + + // ------------------------------------------------------------------------- + // stopGeneration + // ------------------------------------------------------------------------- + + @ReactMethod + fun stopGeneration(promise: Promise) { + val safe = SafePromise(promise, TAG) + Log.i(TAG, "stopGeneration — cancelling current job") + + scope.launch { + try { + currentJob?.cancel() + currentJob?.join() + Log.i(TAG, "stopGeneration — done") + safe.resolve(null) + } catch (e: Exception) { + Log.w(TAG, "stopGeneration — error during cancel: ${e.message}") + safe.resolve(null) // resolve anyway — stop is best-effort + } + } + } + + // ------------------------------------------------------------------------- + // unloadModel — conversation first, then engine (order is critical) + // ------------------------------------------------------------------------- + + @ReactMethod + fun unloadModel(promise: Promise) { + val safe = SafePromise(promise, TAG) + Log.i(TAG, "unloadModel — starting cleanup") + + scope.launch { + try { + currentJob?.cancel() + currentJob?.join() + cleanupEngine() + activeBackend = "cpu" + supportsVision = false + Log.i(TAG, "unloadModel — done") + safe.resolve(null) + } catch (e: Exception) { + Log.e(TAG, "unloadModel — error: ${e.message}", e) + safe.resolve(null) // resolve anyway + } + } + } + + // ------------------------------------------------------------------------- + // getActiveBackend — returns which backend is actually running + // ------------------------------------------------------------------------- + + @ReactMethod + fun getActiveBackend(promise: Promise) { + SafePromise(promise, TAG).resolve(activeBackend) + } + + // ------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------- + + private fun closeConversation() { + try { + conversation?.close() + Log.d(TAG, "closeConversation — closed") + } catch (e: Exception) { + Log.w(TAG, "closeConversation — error (ignored): ${e.message}") + } finally { + conversation = null + } + } + + private fun cleanupEngine() { + // conversation MUST be closed before engine + closeConversation() + try { + engine?.close() + Log.d(TAG, "cleanupEngine — engine closed") + } catch (e: Exception) { + Log.w(TAG, "cleanupEngine — engine close error (ignored): ${e.message}") + } finally { + engine = null + } + } + + private fun parseBackend(s: String): Backend = when (s.lowercase()) { + "npu", "htp" -> Backend.NPU( + nativeLibraryDir = reactContext.applicationInfo.nativeLibraryDir + ) + "gpu", "opencl", "metal" -> Backend.GPU() + else -> Backend.CPU() + } + + private fun backendName(b: Backend): String = when (b) { + is Backend.NPU -> "npu" + is Backend.GPU -> "gpu" + else -> "cpu" + } + + /** + * Decode image URI → Bitmap → PNG bytes. + * Handles content:// (gallery picker) and file:// (filesystem) URIs. + * Converting to PNG ensures the model receives a well-formed image format + * regardless of the source (JPEG, WebP, HEIC, etc.). + * Max dimension is capped at 1024px to avoid OOM on large photos. + */ + private fun readImageAsPngBytes(uri: String): ByteArray { + val inputStream: InputStream = if (uri.startsWith("content://")) { + reactContext.contentResolver.openInputStream(Uri.parse(uri)) + ?: throw IllegalArgumentException("Cannot open content URI: $uri") + } else { + File(uri.removePrefix("file://")).inputStream() + } + + val bitmap = inputStream.use { BitmapFactory.decodeStream(it) } + ?: throw IllegalArgumentException("Failed to decode image from URI: $uri") + + // Scale down if either dimension exceeds 1024px to avoid OOM + val scaled = scaleBitmapIfNeeded(bitmap, maxDim = 1024) + + val out = ByteArrayOutputStream() + scaled.compress(Bitmap.CompressFormat.PNG, 100, out) + if (scaled !== bitmap) scaled.recycle() + bitmap.recycle() + return out.toByteArray() + } + + private fun scaleBitmapIfNeeded(src: Bitmap, maxDim: Int): Bitmap { + val w = src.width + val h = src.height + if (w <= maxDim && h <= maxDim) return src + val scale = maxDim.toFloat() / maxOf(w, h) + val newW = (w * scale).toInt() + val newH = (h * scale).toInt() + return Bitmap.createScaledBitmap(src, newW, newH, true) + } + + private fun sendEvent(eventName: String, data: String) { + try { + reactContext + .getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter::class.java) + .emit(eventName, data) + } catch (e: Exception) { + Log.w(TAG, "sendEvent — failed to emit $eventName: ${e.message}") + } + } + + override fun onCatalystInstanceDestroy() { + Log.i(TAG, "onCatalystInstanceDestroy — cleaning up") + scope.cancel() + cleanupEngine() + super.onCatalystInstanceDestroy() + } +} diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTPackage.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTPackage.kt new file mode 100644 index 000000000..155073c93 --- /dev/null +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTPackage.kt @@ -0,0 +1,16 @@ +package ai.offgridmobile.litert + +import com.facebook.react.ReactPackage +import com.facebook.react.bridge.NativeModule +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.uimanager.ViewManager + +class LiteRTPackage : ReactPackage { + override fun createNativeModules(reactContext: ReactApplicationContext): List { + return listOf(LiteRTModule(reactContext)) + } + + override fun createViewManagers(reactContext: ReactApplicationContext): List> { + return emptyList() + } +} diff --git a/android/build.gradle b/android/build.gradle index dad99b022..ed39de050 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -5,16 +5,17 @@ buildscript { compileSdkVersion = 36 targetSdkVersion = 36 ndkVersion = "27.1.12297006" - kotlinVersion = "2.1.20" + kotlinVersion = "2.2.0" } repositories { google() mavenCentral() } dependencies { - classpath("com.android.tools.build:gradle") + classpath("com.android.tools.build:gradle:8.8.2") classpath("com.facebook.react:react-native-gradle-plugin") - classpath("org.jetbrains.kotlin:kotlin-gradle-plugin") + classpath("org.jetbrains.kotlin:kotlin-gradle-plugin:${kotlinVersion}") + classpath("com.google.devtools.ksp:com.google.devtools.ksp.gradle.plugin:2.2.0-2.0.2") } } diff --git a/android/gradle.properties b/android/gradle.properties index 9afe61598..c9f8970b5 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -11,6 +11,7 @@ # The setting is particularly useful for tweaking memory settings. # Default value: -Xmx512m -XX:MaxMetaspaceSize=256m org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=512m +org.gradle.java.home=/Library/Java/JavaVirtualMachines/temurin-21.jdk/Contents/Home # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit diff --git a/docs/LITERT_TODO.md b/docs/LITERT_TODO.md new file mode 100644 index 000000000..4f0860778 --- /dev/null +++ b/docs/LITERT_TODO.md @@ -0,0 +1,52 @@ +# LiteRT Production TODO + +## P0 — Correctness blockers +1. **Fix stopGeneration()** — wire `liteRTService.stopGeneration()` into `generationService.ts`. Currently stop button does nothing for LiteRT, native thread keeps running. +2. **Fix multi-turn conversation history** — remove `resetConversation()` before every message. Only reset on new chat / model switch. For follow-up turns just call `sendMessage()` directly — native Conversation object already holds history. For app-resume cases use `ConversationConfig.initialMessages` to replay. +3. **Fix memory budget** — `memory.ts` lines 69 and 103: replace `llmService.isModelLoaded()` with `llmService.isModelLoaded() || liteRTService.isModelLoaded()`. LiteRT RAM not counted so image gen model can OOM when loaded alongside LiteRT. + +## P1 — User-facing correctness +4. **Fix silent image drop** — when `liteRTVision=false` and user attaches image, show error toast instead of silently sending text-only. User thinks model saw the image, it didn't. +5. **Wire sampler settings** — `SamplerConfig` in `LiteRTModule.kt` `resetConversation()` is hardcoded to `topK=40, topP=0.95, temperature=0.8`. Read from `store.settings` instead. +6. **iOS platform guard** — use `liteRTService.isAvailable()` (already checks `Platform.OS === 'android'`) as the single gate. Never write `Platform.OS` inline anywhere else. Hide all LiteRT UI on iOS from this one method. +7. **Hide irrelevant settings for LiteRT** — when LiteRT model active, hide: KV cache type, GPU layers, flash attention, nThreads, nBatch, repeat penalty. Show only: backend (cpu/gpu/npu), temperature, topK, topP. + +## P2 — Future-proofing +8. **Fix syncWithNativeState** — `utils.ts` only checks `llmService.isModelLoaded()`, not `liteRTService.isModelLoaded()`. App resume after background kill can show stale loaded state. +9. **Backend change reload for LiteRT** — `hasPendingSettings` returns false for LiteRT so CPU→GPU switch never reloads. Need to track `loadedBackend` and trigger reload when it changes. + +## P3 — Tests +10. **Unit tests** — generation flow, stopGeneration, error paths, vision=false guard, memory budget with LiteRT loaded +11. **Integration tests** — load→generate→stop cycle, model switch llama↔LiteRT, image with vision disabled + +## P4 — UI (lowest priority) +12. **Downloadable model catalog** — curated list of .litertlm models (Gemma variants), download with vision flag pre-set, Android only + +--- + +## Build cleanup (do before merging litertsupport → main) +- Remove `org.gradle.java.home` from `android/gradle.properties` — machine-specific, breaks CI +- Rebuild gesture handler patch cleanly — current patch captured CMake build artefacts, not just the source fix. Check if `react-native-gesture-handler` has released a fix natively first. +- Test full main branch build after merge — Kotlin 2.2.0 upgrade is the highest risk item + +--- + +## Key files +- `android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt` — native Android module +- `src/services/litert.ts` — JS bridge service +- `src/services/activeModelService/loaders.ts` — load routing (litert vs llama, line 171) +- `src/services/activeModelService/memory.ts` — memory budget (bug at lines 69, 103) +- `src/services/activeModelService/utils.ts` — syncWithNativeState (missing liteRT check) +- `src/services/generationServiceHelpers.ts` — generation routing, conversation reset +- `src/screens/ModelsScreen/importHelpers.ts` — engine tag set at import, vision dialog +- `src/types/index.ts` — DownloadedModel type, ModelEngine = 'llama' | 'litert', liteRTVision flag + +## Architecture facts to remember +- Engine decided by `DownloadedModel.engine` field — set once at import, never changes +- LiteRT only on Android — `liteRTService.isAvailable()` is the single platform gate +- LiteRT loads model into RAM same as llama — no streaming from disk +- Vision requires `liteRTVision=true` on the model record AND `visionBackend=Backend.GPU()` in EngineConfig +- SamplerConfig (topK/topP/temp) not supported on NPU backend — skip it there +- Library: `com.google.ai.edge.litertlm:litertlm-android:0.11.0` +- iOS Swift SDK: not released yet, Coming Soon per Google +- Gemma 4 E2B on GPU: TTFT ~7s, ~38-41 chars/sec diff --git a/src/components/DebugLogsScreen/index.tsx b/src/components/DebugLogsScreen/index.tsx index c0adc7296..94c4d13cd 100644 --- a/src/components/DebugLogsScreen/index.tsx +++ b/src/components/DebugLogsScreen/index.tsx @@ -3,7 +3,7 @@ * Simple modal showing captured debug logs with copy and clear options */ -import React from 'react'; +import React, { useEffect } from 'react'; import { View, Text, @@ -12,6 +12,7 @@ import { Clipboard, Share, SafeAreaView, + Modal, } from 'react-native'; import Icon from 'react-native-vector-icons/Feather'; import { useTheme, useThemedStyles } from '../../theme'; @@ -26,7 +27,11 @@ interface DebugLogsScreenProps { export const DebugLogsScreen: React.FC = ({ visible, onClose }) => { const theme = useTheme(); const styles = useThemedStyles(createStyles); - const { logs, clearLogs } = useDebugLogsStore() as any; // NOSONAR - zustand store + const { logs, clearLogs, loadFromStorage } = useDebugLogsStore() as any; // NOSONAR - zustand store + + useEffect(() => { + loadFromStorage(); + }, []); const formatTime = (timestamp: number) => { const date = new Date(timestamp); @@ -83,69 +88,64 @@ export const DebugLogsScreen: React.FC = ({ visible, onClo } }; - if (!visible) return null; - return ( - - {/* Header */} - - - Debug Logs - {logs.length} entries + + + {/* Header */} + + + Debug Logs + {logs.length} entries + + + + - - - - - - {/* Action Buttons */} - - - - Copy All - - - - - Share - - - - - Clear - - - - {/* Logs List */} - {logs.length === 0 ? ( - - No logs yet + + {/* Action Buttons */} + + + + Copy All + + + + + Share + + + + + Clear + - ) : ( - `${index}`} - contentContainerStyle={styles.listContent} - renderItem={({ item }: { item: any }) => ( - - {formatTime(item.timestamp)} - - {item.level.toUpperCase()} - - - {item.message} - - - )} - inverted - /> - )} - + + {/* Logs List */} + {logs.length === 0 ? ( + + No logs yet + + ) : ( + `${index}`} + contentContainerStyle={styles.listContent} + renderItem={({ item }: { item: any }) => ( + + {formatTime(item.timestamp)} + + {item.level.toUpperCase()} + + + {item.message} + + + )} + inverted + /> + )} + + ); }; diff --git a/src/navigation/AppNavigator.tsx b/src/navigation/AppNavigator.tsx index 1d15b73a0..e71da39ba 100644 --- a/src/navigation/AppNavigator.tsx +++ b/src/navigation/AppNavigator.tsx @@ -37,6 +37,7 @@ import { SecuritySettingsScreen, GalleryScreen, RemoteServersScreen, + DebugLiteRTScreen, } from '../screens'; import { RootStackParamList, @@ -242,6 +243,7 @@ export const AppNavigator: React.FC = () => { component={GalleryScreen} options={{ presentation: 'modal', animation: 'slide_from_bottom' }} /> + ); diff --git a/src/navigation/types.ts b/src/navigation/types.ts index 21b876daa..6cca68185 100644 --- a/src/navigation/types.ts +++ b/src/navigation/types.ts @@ -22,6 +22,7 @@ export type RootStackParamList = { // Already in RootStack DownloadManager: undefined; Gallery: { conversationId?: string } | undefined; + DebugLiteRT: undefined; }; // Tab navigator — simple, no sub-stacks diff --git a/src/screens/ChatScreen/ChatModalSection.tsx b/src/screens/ChatScreen/ChatModalSection.tsx index 301b3bdc0..2fb38fdcb 100644 --- a/src/screens/ChatScreen/ChatModalSection.tsx +++ b/src/screens/ChatScreen/ChatModalSection.tsx @@ -3,6 +3,7 @@ import { ModelSelectorModal, GenerationSettingsModal, ProjectSelectorSheet, DebugSheet, } from '../../components'; +import { DebugLogsScreen } from '../../components/DebugLogsScreen'; import { llmService } from '../../services'; import { createStyles } from './styles'; import { useTheme } from '../../theme'; @@ -38,6 +39,8 @@ type ChatModalSectionProps = { viewerImageUri: string | null; setViewerImageUri: (v: string | null) => void; handleSaveImage: () => void; + showLogsPanel: boolean; + setShowLogsPanel: (v: boolean) => void; isRemote?: boolean; }; @@ -50,9 +53,11 @@ export const ChatModalSection: React.FC = ({ debugInfo, activeProject, activeConversation, settings, projects, handleSelectProject, handleModelSelect, handleUnloadModel, handleDeleteConversation, isModelLoading, imageCount, activeConversationId, navigation, - viewerImageUri, setViewerImageUri, handleSaveImage, isRemote, + viewerImageUri, setViewerImageUri, handleSaveImage, + showLogsPanel, setShowLogsPanel, isRemote, }) => ( <> + setShowLogsPanel(false)} /> setShowProjectSelector(false)} diff --git a/src/screens/ChatScreen/ChatScreenComponents.tsx b/src/screens/ChatScreen/ChatScreenComponents.tsx index 080027aa5..2cf4f4695 100644 --- a/src/screens/ChatScreen/ChatScreenComponents.tsx +++ b/src/screens/ChatScreen/ChatScreenComponents.tsx @@ -113,8 +113,9 @@ export const ChatHeader: React.FC<{ setShowModelSelector: (v: boolean) => void; setShowSettingsPanel: (v: boolean) => void; setShowProjectSelector: (v: boolean) => void; + setShowLogsPanel: (v: boolean) => void; isRemote?: boolean; -}> = ({ styles, colors, activeConversation, activeModel, activeModelName, activeImageModel, activeProject, navigation, setShowModelSelector, setShowSettingsPanel, setShowProjectSelector, isRemote }) => ( +}> = ({ styles, colors, activeConversation, activeModel, activeModelName, activeImageModel, activeProject, navigation, setShowModelSelector, setShowSettingsPanel, setShowProjectSelector, setShowLogsPanel, isRemote }) => ( navigation.goBack()}> @@ -149,6 +150,9 @@ export const ChatHeader: React.FC<{ + setShowLogsPanel(true)}> + + setShowSettingsPanel(true)} testID="chat-settings-icon"> diff --git a/src/screens/ChatScreen/index.tsx b/src/screens/ChatScreen/index.tsx index 27c5475bc..df93d9202 100644 --- a/src/screens/ChatScreen/index.tsx +++ b/src/screens/ChatScreen/index.tsx @@ -184,6 +184,7 @@ export const ChatScreen: React.FC = () => { setShowModelSelector={chat.setShowModelSelector} setShowSettingsPanel={chat.setShowSettingsPanel} setShowProjectSelector={chat.setShowProjectSelector} + setShowLogsPanel={chat.setShowLogsPanel} isRemote={chat.activeModelInfo?.isRemote} /> { viewerImageUri={chat.viewerImageUri} setViewerImageUri={chat.setViewerImageUri} handleSaveImage={chat.handleSaveImage} + showLogsPanel={chat.showLogsPanel} + setShowLogsPanel={chat.setShowLogsPanel} isRemote={chat.activeModelInfo?.isRemote} /> diff --git a/src/screens/ChatScreen/useChatGenerationActions.ts b/src/screens/ChatScreen/useChatGenerationActions.ts index c6b9a755e..ab2534f1f 100644 --- a/src/screens/ChatScreen/useChatGenerationActions.ts +++ b/src/screens/ChatScreen/useChatGenerationActions.ts @@ -18,6 +18,7 @@ import { ragService, retrievalService, } from '../../services'; +import { liteRTService } from '../../services/litert'; import { embeddingService } from '../../services/rag/embedding'; import { useChatStore, useProjectStore, useRemoteServerStore } from '../../stores'; import { Message, MediaAttachment, Project, DownloadedModel, RemoteModel, ModelLoadingStrategy, CacheType } from '../../types'; @@ -153,6 +154,11 @@ export async function handleImageGenerationFn( export type StartGenerationCall = { setDebugInfo: SetState; targetConversationId: string; messageText: string }; async function ensureModelReady(deps: GenerationDeps): Promise { if (deps.activeModelInfo?.isRemote) return true; + if (deps.activeModel?.engine === 'litert') { + if (liteRTService.isModelLoaded()) return true; + await deps.ensureModelLoaded(); + return liteRTService.isModelLoaded(); + } const loadedPath = llmService.getLoadedModelPath(); if (loadedPath && loadedPath === deps.activeModel!.filePath) return true; await deps.ensureModelLoaded(); diff --git a/src/screens/ChatScreen/useChatModelActions.ts b/src/screens/ChatScreen/useChatModelActions.ts index 9ed497fb7..dafb7b308 100644 --- a/src/screens/ChatScreen/useChatModelActions.ts +++ b/src/screens/ChatScreen/useChatModelActions.ts @@ -5,8 +5,10 @@ import { hideAlert, } from '../../components'; import { llmService, activeModelService, modelManager } from '../../services'; +import { liteRTService } from '../../services/litert'; import { DownloadedModel, RemoteModel, ONNXImageModel } from '../../types'; import logger from '../../utils/logger'; +import { useDebugLogsStore } from '../../stores/debugLogsStore'; type SetState = Dispatch>; @@ -63,7 +65,7 @@ async function doLoadTextModel(deps: ModelActionDeps): Promise { try { await activeModelService.loadTextModel(activeModelId); const multimodalSupport = llmService.getMultimodalSupport(); - deps.setSupportsVision(multimodalSupport?.vision || false); + deps.setSupportsVision(activeModel.engine === 'litert' ? true : (multimodalSupport?.vision || false)); if (deps.modelLoadStartTimeRef.current && deps.settings.showGenerationDetails) { const loadTime = ((Date.now() - deps.modelLoadStartTimeRef.current) / 1000).toFixed(1); addSystemMsg(deps, `Model loaded: ${activeModel.name} (${loadTime}s)`); @@ -111,15 +113,19 @@ export async function initiateModelLoad( await waitForRenderFrame(); } + const dbg = useDebugLogsStore.getState().addLog; + dbg('log', `[LiteRT] initiateModelLoad — model=${activeModel.name} engine=${activeModel.engine ?? 'llama'}`); try { await activeModelService.loadTextModel(activeModelId); const multimodalSupport = llmService.getMultimodalSupport(); - deps.setSupportsVision(multimodalSupport?.vision || false); + deps.setSupportsVision(activeModel.engine === 'litert' ? true : (multimodalSupport?.vision || false)); + dbg('log', `[LiteRT] loadTextModel success — engine=${activeModel.engine ?? 'llama'}`); if (!alreadyLoading && deps.modelLoadStartTimeRef.current && deps.settings.showGenerationDetails) { const loadTime = ((Date.now() - deps.modelLoadStartTimeRef.current) / 1000).toFixed(1); addSystemMsg(deps, `Model loaded: ${activeModel.name} (${loadTime}s)`); } } catch (error: any) { + dbg('error', `[LiteRT] loadTextModel failed — ${error?.message || 'Unknown error'}`); if (!alreadyLoading) { deps.setAlertState(showAlert('Error', `Failed to load model: ${error?.message || 'Unknown error'}`)); } @@ -137,6 +143,18 @@ export async function ensureModelLoadedFn( ): Promise { const { activeModel, activeModelId } = deps; if (!activeModel || !activeModelId) return; + const dbg = useDebugLogsStore.getState().addLog; + if (activeModel.engine === 'litert') { + if (liteRTService.isModelLoaded()) { + dbg('log', `[LiteRT] ensureModelLoaded — already loaded, skipping`); + deps.setSupportsVision(true); + return; + } + dbg('log', `[LiteRT] ensureModelLoaded — model=${activeModel.name}, triggering load`); + deps.setSupportsVision(true); + if (deps.activeModelId) await initiateModelLoad(deps, activeModelService.getActiveModels().text.isLoading); + return; + } const loadedPath = llmService.getLoadedModelPath(); const currentVisionSupport = llmService.getMultimodalSupport()?.vision || false; const needsReload = loadedPath !== activeModel.filePath || @@ -160,7 +178,7 @@ export async function proceedWithModelLoadFn( try { await activeModelService.loadTextModel(model.id); const multimodalSupport = llmService.getMultimodalSupport(); - deps.setSupportsVision(multimodalSupport?.vision || false); + deps.setSupportsVision(model.engine === 'litert' ? true : (multimodalSupport?.vision || false)); if (deps.modelLoadStartTimeRef.current && deps.settings.showGenerationDetails && deps.activeConversationId) { const loadTime = ((Date.now() - deps.modelLoadStartTimeRef.current) / 1000).toFixed(1); deps.addMessage(deps.activeConversationId, { @@ -308,6 +326,8 @@ export function useChatModelStateSync(deps: ModelStateSyncDeps): void { useEffect(() => { if (activeModelInfo.isRemote) { setSupportsVision(activeRemoteModel?.capabilities?.supportsVision ?? false); + } else if (activeModel?.engine === 'litert') { + setSupportsVision(true); } else if (activeModel?.mmProjPath && llmService.isModelLoaded()) { setSupportsVision(llmService.getMultimodalSupport()?.vision ?? false); } else { diff --git a/src/screens/ChatScreen/useChatScreen.ts b/src/screens/ChatScreen/useChatScreen.ts index ac46fdce0..a392b4140 100644 --- a/src/screens/ChatScreen/useChatScreen.ts +++ b/src/screens/ChatScreen/useChatScreen.ts @@ -40,6 +40,7 @@ export const useChatScreen = () => { const [supportsVision, setSupportsVision] = useState(false); const [showProjectSelector, setShowProjectSelector] = useState(false); const [showDebugPanel, setShowDebugPanel] = useState(false); + const [showLogsPanel, setShowLogsPanel] = useState(false); const [showModelSelector, setShowModelSelector] = useState(false); const [showSettingsPanel, setShowSettingsPanel] = useState(false); const [debugInfo, setDebugInfo] = useState(null); @@ -229,6 +230,8 @@ export const useChatScreen = () => { }; // Check if there are pending settings that require model reload const hasPendingSettings = (() => { + // LiteRT manages its own backend internally — llama.cpp settings don't apply + if (activeModel?.engine === 'litert') return false; if (!loadedSettings) return false; return ( settings.nThreads !== loadedSettings.nThreads || @@ -245,6 +248,8 @@ export const useChatScreen = () => { const handleReloadTextModel = useCallback(async () => { if (!activeModelInfo.modelId || activeModelInfo.isRemote) return; + // LiteRT manages its own backend — reloading via llama.cpp path is wrong + if (activeModel?.engine === 'litert') return; // Open the model selector bottom sheet before unloading so the user sees the // loading state inside it rather than the NoModelScreen ("Select Model"). setShowModelSelector(true); @@ -261,6 +266,7 @@ export const useChatScreen = () => { isModelLoading, loadingModel, supportsVision, showProjectSelector, setShowProjectSelector, showDebugPanel, setShowDebugPanel, + showLogsPanel, setShowLogsPanel, showModelSelector, setShowModelSelector, showSettingsPanel, setShowSettingsPanel, showToolPicker, setShowToolPicker, supportsToolCalling, supportsThinking, diff --git a/src/screens/DebugLiteRTScreen.tsx b/src/screens/DebugLiteRTScreen.tsx new file mode 100644 index 000000000..e451ba11b --- /dev/null +++ b/src/screens/DebugLiteRTScreen.tsx @@ -0,0 +1,499 @@ +/** + * DebugLiteRTScreen + * + * Development screen to test LiteRT-LM inference end-to-end. + * Accessible from Settings. Not shown in production nav. + * + * Usage: + * 1. adb push model.litertlm /data/data/ai.offgridmobile.localdream/files/model.litertlm + * 2. Open this screen, enter the path, pick backend, tap Load + * 3. Enter a message, tap Send, watch tokens stream + * 4. Use the Logs panel to copy/clear debug output + */ + +import React, { useState, useRef, useCallback } from 'react'; +import { + View, + Text, + TextInput, + TouchableOpacity, + ScrollView, + Alert, + Clipboard, + SafeAreaView, +} from 'react-native'; +import { useNavigation } from '@react-navigation/native'; +import Icon from 'react-native-vector-icons/Feather'; +import { useTheme, useThemedStyles } from '../theme'; +import type { ThemeColors, ThemeShadows } from '../theme'; +import { TYPOGRAPHY, SPACING } from '../constants'; +import { liteRTService, LiteRTBackend } from '../services/litert'; +import RNFS from 'react-native-fs'; + +const DEFAULT_MODEL_PATH = `${RNFS.DocumentDirectoryPath}/models/gemma-4-E2B-it.litertlm`; +const DEFAULT_SYSTEM_PROMPT = 'You are a helpful assistant.'; + +type LogLevel = 'info' | 'warn' | 'error' | 'success'; +interface LogEntry { + id: number; + level: LogLevel; + message: string; + time: string; +} + +let logIdCounter = 0; + +function logColor(level: LogLevel, colors: ThemeColors): string { + switch (level) { + case 'error': return colors.error ?? '#FF453A'; + case 'warn': return '#FF9F0A'; + case 'success': return '#30D158'; + default: return colors.textSecondary; + } +} + +export const DebugLiteRTScreen: React.FC = () => { + const navigation = useNavigation(); + const { colors } = useTheme(); + const styles = useThemedStyles(createStyles); + + // Model config + const [modelPath, setModelPath] = useState(DEFAULT_MODEL_PATH); + const [backend, setBackend] = useState('gpu'); + const [systemPrompt, setSystemPrompt] = useState(DEFAULT_SYSTEM_PROMPT); + + // State + const [isLoaded, setIsLoaded] = useState(false); + const [isLoading, setIsLoading] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [activeBackend, setActiveBackend] = useState(null); + + // Chat + const [messageInput, setMessageInput] = useState(''); + const [response, setResponse] = useState(''); + + // Logs + const [logs, setLogs] = useState([]); + const logsScrollRef = useRef(null); + const responseScrollRef = useRef(null); + + const addLog = useCallback((level: LogLevel, message: string) => { + const entry: LogEntry = { + id: logIdCounter++, + level, + message, + time: new Date().toISOString().substring(11, 23), + }; + setLogs(prev => [...prev, entry]); + setTimeout(() => logsScrollRef.current?.scrollToEnd({ animated: true }), 50); + }, []); + + // --------------------------------------------------------------------------- + // Load model + // --------------------------------------------------------------------------- + + const handleLoad = async () => { + if (isLoading) return; + addLog('info', `Loading model: ${modelPath}`); + addLog('info', `Requested backend: ${backend}`); + setIsLoading(true); + setIsLoaded(false); + setActiveBackend(null); + setResponse(''); + + try { + const exists = await RNFS.exists(modelPath); + if (!exists) { + addLog('error', `File not found: ${modelPath}`); + addLog('warn', `Push via ADB:\nadb push model.litertlm ${modelPath}`); + Alert.alert('File not found', `No .litertlm file at:\n${modelPath}`); + return; + } + + addLog('info', 'File exists on disk, initializing engine...'); + await liteRTService.loadModel(modelPath, backend); + + const actual = liteRTService.getActiveBackend(); + setActiveBackend(actual); + setIsLoaded(true); + + if (actual !== backend) { + addLog('warn', `Requested ${backend} but fell back to ${actual}`); + } else { + addLog('success', `Model loaded on ${actual?.toUpperCase()}`); + } + + if (liteRTService.isNPU()) { + addLog('warn', 'NPU: temperature/topK/topP sampling settings are inactive on this backend'); + } + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + addLog('error', `Load failed: ${msg}`); + Alert.alert('Load failed', msg); + } finally { + setIsLoading(false); + } + }; + + // --------------------------------------------------------------------------- + // Send message + // --------------------------------------------------------------------------- + + const handleSend = async () => { + if (!isLoaded || isGenerating || !messageInput.trim()) return; + const text = messageInput.trim(); + setMessageInput(''); + setResponse(''); + setIsGenerating(true); + addLog('info', `Resetting conversation with system prompt...`); + + try { + await liteRTService.resetConversation(systemPrompt); + addLog('info', `Sending: "${text.substring(0, 80)}${text.length > 80 ? '...' : ''}"`); + + let tokenCount = 0; + const startMs = Date.now(); + + await liteRTService.sendMessage(text, { + onToken: (token) => { + tokenCount++; + setResponse(prev => prev + token); + if (tokenCount % 20 === 0) { + responseScrollRef.current?.scrollToEnd({ animated: false }); + } + }, + onReasoning: (token) => { + addLog('info', `[thinking] ${token.substring(0, 60)}`); + }, + onComplete: (fullContent, fullReasoning) => { + const elapsed = ((Date.now() - startMs) / 1000).toFixed(1); + const tps = (tokenCount / parseFloat(elapsed)).toFixed(1); + addLog('success', `Done — ${tokenCount} tokens in ${elapsed}s (${tps} tok/s)`); + if (fullReasoning) { + addLog('info', `Reasoning: ${fullReasoning.length} chars`); + } + setIsGenerating(false); + }, + onError: (err) => { + addLog('error', `Generation error: ${err.message}`); + setIsGenerating(false); + }, + }); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + addLog('error', `Send failed: ${msg}`); + setIsGenerating(false); + } + }; + + // --------------------------------------------------------------------------- + // Stop + // --------------------------------------------------------------------------- + + const handleStop = async () => { + addLog('info', 'Stopping generation...'); + await liteRTService.stopGeneration(); + setIsGenerating(false); + addLog('info', 'Stopped'); + }; + + // --------------------------------------------------------------------------- + // Unload + // --------------------------------------------------------------------------- + + const handleUnload = async () => { + addLog('info', 'Unloading model...'); + await liteRTService.unloadModel(); + setIsLoaded(false); + setActiveBackend(null); + setResponse(''); + addLog('info', 'Model unloaded'); + }; + + // --------------------------------------------------------------------------- + // Log controls + // --------------------------------------------------------------------------- + + const handleCopyLogs = () => { + const text = logs.map(l => `[${l.time}][${l.level.toUpperCase()}] ${l.message}`).join('\n'); + Clipboard.setString(text); + addLog('info', 'Logs copied to clipboard'); + }; + + const handleClearLogs = () => { + setLogs([]); + }; + + // --------------------------------------------------------------------------- + // Render + // --------------------------------------------------------------------------- + + return ( + + {/* Header */} + + navigation.goBack()} style={styles.backBtn}> + + + LiteRT Debug + {activeBackend && ( + + {activeBackend.toUpperCase()} + + )} + + + + + {/* Model Path */} + + Model path (.litertlm) + + + Push via ADB:{'\n'} + adb push model.litertlm {RNFS.DocumentDirectoryPath}/models/model.litertlm + + + + {/* Backend selector */} + + Backend + + {(['cpu', 'gpu', 'npu'] as LiteRTBackend[]).map(b => ( + setBackend(b)} + > + + {b.toUpperCase()} + + + ))} + + {backend === 'npu' && ( + NPU: Snapdragon 8 Gen 2+ only. Falls back to GPU then CPU. + )} + + + {/* System prompt */} + + System prompt + + + + {/* Load / Unload */} + + + + {isLoading ? 'Loading...' : isLoaded ? 'Loaded' : 'Load Model'} + + + {isLoaded && ( + + Unload + + )} + + + {/* Message input + Send/Stop */} + {isLoaded && ( + + Message + + + + {isGenerating ? 'Generating...' : 'Send'} + + {isGenerating && ( + + Stop + + )} + + + )} + + {/* Response output */} + {response.length > 0 && ( + + Response + + {response} + + + )} + + {/* Logs */} + + + Logs ({logs.length}) + + + + Copy + + + + Clear + + + + + {logs.length === 0 && ( + No logs yet. Load a model to start. + )} + {logs.map(entry => ( + + {entry.time} {entry.message} + + ))} + + + + + + ); +}; + +const createStyles = (colors: ThemeColors, shadows: ThemeShadows) => ({ + container: { flex: 1, backgroundColor: colors.background }, + header: { + flexDirection: 'row' as const, + alignItems: 'center' as const, + paddingHorizontal: SPACING.md, + paddingVertical: SPACING.md, + borderBottomWidth: 1, + borderBottomColor: colors.border, + backgroundColor: colors.surface, + ...shadows.small, + }, + backBtn: { padding: SPACING.sm, marginRight: SPACING.sm }, + headerTitle: { ...TYPOGRAPHY.h3, color: colors.text, flex: 1 }, + backendBadge: { + backgroundColor: colors.primary, + borderRadius: 4, + paddingHorizontal: SPACING.sm, + paddingVertical: 2, + }, + backendBadgeText: { ...TYPOGRAPHY.meta, color: colors.background, fontWeight: '600' as const }, + scroll: { flex: 1 }, + section: { paddingHorizontal: SPACING.md, marginTop: SPACING.lg }, + label: { ...TYPOGRAPHY.bodySmall, color: colors.textSecondary, marginBottom: SPACING.sm }, + hint: { ...TYPOGRAPHY.meta, color: colors.textMuted, marginTop: SPACING.sm, lineHeight: 16 }, + input: { + backgroundColor: colors.surface, + borderWidth: 1, + borderColor: colors.border, + borderRadius: 8, + padding: SPACING.md, + ...TYPOGRAPHY.body, + color: colors.text, + }, + multilineInput: { minHeight: 80, textAlignVertical: 'top' as const }, + segmented: { + flexDirection: 'row' as const, + backgroundColor: colors.surfaceLight ?? colors.surface, + borderRadius: 8, + padding: 3, + gap: 2, + }, + segmentBtn: { + flex: 1, + paddingVertical: SPACING.sm, + alignItems: 'center' as const, + borderRadius: 6, + }, + segmentBtnActive: { backgroundColor: colors.primary }, + segmentText: { ...TYPOGRAPHY.bodySmall, color: colors.textMuted }, + segmentTextActive: { color: colors.background, fontWeight: '600' as const }, + row: { + flexDirection: 'row' as const, + gap: SPACING.sm, + paddingHorizontal: SPACING.md, + marginTop: SPACING.md, + }, + btn: { + flex: 1, + paddingVertical: SPACING.md, + borderRadius: 8, + alignItems: 'center' as const, + }, + btnPrimary: { backgroundColor: colors.primary }, + btnDanger: { backgroundColor: colors.error ?? '#FF453A', flex: 0, paddingHorizontal: SPACING.xl }, + btnDisabled: { opacity: 0.4 }, + btnText: { ...TYPOGRAPHY.body, color: colors.background, fontWeight: '600' as const }, + responseBox: { + backgroundColor: colors.surface, + borderRadius: 8, + borderWidth: 1, + borderColor: colors.border, + padding: SPACING.md, + maxHeight: 240, + }, + responseText: { ...TYPOGRAPHY.body, color: colors.text, lineHeight: 22 }, + logsHeader: { + flexDirection: 'row' as const, + justifyContent: 'space-between' as const, + alignItems: 'center' as const, + marginBottom: SPACING.sm, + }, + logsBox: { + backgroundColor: colors.surface, + borderRadius: 8, + borderWidth: 1, + borderColor: colors.border, + padding: SPACING.md, + height: 260, + }, + logsEmpty: { ...TYPOGRAPHY.bodySmall, color: colors.textMuted, fontStyle: 'italic' as const }, + logEntry: { ...TYPOGRAPHY.meta, lineHeight: 18, marginBottom: 2, fontFamily: 'Courier' }, + logBtn: { + flexDirection: 'row' as const, + alignItems: 'center' as const, + gap: 4, + paddingHorizontal: SPACING.sm, + paddingVertical: 4, + }, + logBtnText: { ...TYPOGRAPHY.meta, color: colors.textMuted }, +}); diff --git a/src/screens/ModelsScreen/importHelpers.ts b/src/screens/ModelsScreen/importHelpers.ts index c41368c0b..3ef4b4224 100644 --- a/src/screens/ModelsScreen/importHelpers.ts +++ b/src/screens/ModelsScreen/importHelpers.ts @@ -47,10 +47,29 @@ export async function importGgufFiles( if (files.length === 1) { const resolvedFileName = files[0].name ?? 'unknown'; + const isLitert = resolvedFileName.toLowerCase().endsWith('.litertlm'); + + let liteRTVision = false; + if (isLitert) { + liteRTVision = await new Promise(resolve => { + Alert.alert( + 'Vision Support', + 'Does this model support image/vision input?\n\nEnable this only for multimodal models (e.g. Gemma 3n). Enabling it on a text-only model will cause a load error.', + [ + { text: 'Text Only', style: 'cancel', onPress: () => resolve(false) }, + { text: 'Vision', style: 'default', onPress: () => resolve(true) }, + ], + { cancelable: false }, + ); + }); + } + const model = await modelManager.importLocalModel({ sourceUri: files[0].uri, fileName: resolvedFileName, sourceSize: files[0].size, + engine: isLitert ? 'litert' : undefined, + liteRTVision: isLitert ? liteRTVision : undefined, onProgress: p => { setImportProgress(p); }, diff --git a/src/screens/ModelsScreen/useModelsScreen.ts b/src/screens/ModelsScreen/useModelsScreen.ts index 2c0915874..c830756d5 100644 --- a/src/screens/ModelsScreen/useModelsScreen.ts +++ b/src/screens/ModelsScreen/useModelsScreen.ts @@ -130,19 +130,20 @@ export function useModelsScreen() { const allGguf = resolvedFiles.every(f => f.name.toLowerCase().endsWith('.gguf')); const singleZip = resolvedFiles.length === 1 && resolvedFiles[0].name.toLowerCase().endsWith('.zip'); + const singleLitert = resolvedFiles.length === 1 && resolvedFiles[0].name.toLowerCase().endsWith('.litertlm'); - if (!allGguf && !singleZip) { + if (!allGguf && !singleZip && !singleLitert) { setAlertState(showAlert( 'Invalid File', resolvedFiles.length > 1 ? 'When selecting multiple files, all must be .gguf files (main model + mmproj projector).' - : 'Supported formats: .gguf (text models) and .zip (image models).', + : 'Supported formats: .gguf (text models), .litertlm (LiteRT models), and .zip (image models).', )); return; } if (resolvedFiles.length > 2) { - setAlertState(showAlert('Too Many Files', 'Select 1 file (text/zip) or 2 .gguf files (vision model + mmproj projector).')); + setAlertState(showAlert('Too Many Files', 'Select 1 file (text/zip/litertlm) or 2 .gguf files (vision model + mmproj projector).')); return; } diff --git a/src/screens/SettingsScreen.tsx b/src/screens/SettingsScreen.tsx index 9b543e0c9..8221e62f3 100644 --- a/src/screens/SettingsScreen.tsx +++ b/src/screens/SettingsScreen.tsx @@ -250,6 +250,13 @@ export const SettingsScreen: React.FC = () => { Reset Onboarding Checklist + navigation.getParent()?.navigate('DebugLiteRT')} + > + + LiteRT Debug + diff --git a/src/screens/index.ts b/src/screens/index.ts index 9d5ec9b8b..412d5e59f 100644 --- a/src/screens/index.ts +++ b/src/screens/index.ts @@ -21,3 +21,4 @@ export { DeviceInfoScreen } from './DeviceInfoScreen'; export { StorageSettingsScreen } from './StorageSettingsScreen'; export { SecuritySettingsScreen } from './SecuritySettingsScreen'; export { RemoteServersScreen } from './RemoteServersScreen'; +export { DebugLiteRTScreen } from './DebugLiteRTScreen'; diff --git a/src/services/activeModelService/loaders.ts b/src/services/activeModelService/loaders.ts index 1a51773e2..5640c9755 100644 --- a/src/services/activeModelService/loaders.ts +++ b/src/services/activeModelService/loaders.ts @@ -7,6 +7,7 @@ import { useAppStore } from '../../stores'; import { useDebugLogsStore } from '../../stores/debugLogsStore'; import { DownloadedModel, ONNXImageModel, INFERENCE_BACKENDS } from '../../types'; import { llmService } from '../llm'; +import { liteRTService } from '../litert'; import { localDreamGeneratorService as onnxImageGeneratorService } from '../localDreamGenerator'; import { modelManager } from '../modelManager'; import logger from '../../utils/logger'; @@ -97,7 +98,80 @@ export interface TextLoadContext { onFinally: () => void; } +function inferenceBackendToLiteRT(backend: string | undefined): 'cpu' | 'gpu' | 'npu' { + switch (backend) { + case INFERENCE_BACKENDS.HTP: return 'npu'; + case INFERENCE_BACKENDS.OPENCL: return 'gpu'; + case INFERENCE_BACKENDS.METAL: return 'gpu'; + default: return 'cpu'; + } +} + +async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { + const addDebugLog = useDebugLogsStore.getState().addLog; + try { + addDebugLog('log', `[LiteRT] Starting model load: ${ctx.model.fileName}`); + + if (ctx.loadedTextModelId && ctx.loadedTextModelId !== ctx.modelId) { + addDebugLog('log', '[LiteRT] Unloading previous LiteRT model before load.'); + try { + await liteRTService.unloadModel(); + } catch (unloadErr) { + logger.warn('[LiteRT] Error unloading previous model, continuing:', unloadErr); + addDebugLog('warn', `[LiteRT] Previous model unload warning: ${String(unloadErr)}`); + } + ctx.onError(); + } + + const preferredBackend = inferenceBackendToLiteRT(ctx.store.settings.inferenceBackend); + addDebugLog('log', `[LiteRT] Preferred backend: ${preferredBackend}`); + + const timeoutMs = preferredBackend === 'npu' ? 45_000 + : preferredBackend === 'gpu' ? 20_000 + : 15_000; + + let timeoutId: ReturnType | null = null; + const timeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout( + () => reject(new Error(`LiteRT model load timed out after ${timeoutMs / 1000}s.`)), + timeoutMs, + ); + }); + + try { + addDebugLog('log', `[LiteRT] Calling liteRTService.loadModel (timeout ${timeoutMs / 1000}s, vision=${ctx.model.liteRTVision ?? false}).`); + await Promise.race([ + liteRTService.loadModel(ctx.model.filePath, preferredBackend, ctx.model.liteRTVision ?? false), + timeoutPromise, + ]); + } finally { + if (timeoutId !== null) clearTimeout(timeoutId); + } + + const actualBackend = liteRTService.getActiveBackend(); + addDebugLog('log', `[LiteRT] Load complete — actual backend: ${actualBackend}`); + if (actualBackend !== preferredBackend) { + addDebugLog('warn', `[LiteRT] Requested ${preferredBackend}, fell back to ${actualBackend}`); + } + + ctx.onLoaded(ctx.modelId); + ctx.store.setActiveModelId(ctx.modelId); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + addDebugLog('error', `[LiteRT] Model load failed: ${message}`); + ctx.onError(); + throw error; + } finally { + ctx.onFinally(); + } +} + export async function doLoadTextModel(ctx: TextLoadContext): Promise { + // Route LiteRT models to the LiteRT loader — existing llama path is untouched below + if (ctx.model.engine === 'litert') { + return doLoadLiteRTModel(ctx); + } + const addDebugLog = useDebugLogsStore.getState().addLog; try { addDebugLog('log', `[Reload] Starting text model load: ${ctx.model.fileName}`); diff --git a/src/services/generationServiceHelpers.ts b/src/services/generationServiceHelpers.ts index 4a090c399..63c67291f 100644 --- a/src/services/generationServiceHelpers.ts +++ b/src/services/generationServiceHelpers.ts @@ -3,7 +3,9 @@ * All functions receive the GenerationService instance as `svc: any` and mutate its internal state. */ import { llmService } from './llm'; +import { liteRTService } from './litert'; import { useAppStore, useChatStore, useRemoteServerStore } from '../stores'; +import { useDebugLogsStore } from '../stores/debugLogsStore'; import type { Message, GenerationMeta } from '../types'; import { runToolLoop } from './generationToolLoop'; import type { ToolResult } from './tools/types'; @@ -13,6 +15,13 @@ import logger from '../utils/logger'; export const FLUSH_INTERVAL_MS = 50; // ~20 updates/sec type StreamChunk = string | { content?: string; reasoningContent?: string }; +/** Returns true when the currently active model uses LiteRT engine. */ +function isLiteRTActive(): boolean { + const { downloadedModels, activeModelId } = useAppStore.getState(); + const activeModel = downloadedModels.find((m: any) => m.id === activeModelId); + return activeModel?.engine === 'litert'; +} + export interface GenerationRequest { conversationId: string; messages: Message[]; @@ -114,6 +123,8 @@ export async function prepareGenerationImpl(svc: any, conversationId: string): P if (!svc.state.isGenerating) return false; // stop called during drain svc.abortRequested = false; + const dbg = useDebugLogsStore.getState().addLog; + // Check provider readiness const failPrepare = (msg: string) => { svc.resetState(); @@ -125,6 +136,10 @@ export async function prepareGenerationImpl(svc: any, conversationId: string): P if (!provider) failPrepare('Remote provider not found'); const ready = await provider.isReady(); if (!ready) failPrepare('Remote provider not ready'); + } else if (isLiteRTActive()) { + const loaded = liteRTService.isModelLoaded(); + dbg('log', `[LiteRT] prepareGeneration — isLoaded=${loaded}`); + if (!loaded) { dbg('error', '[LiteRT] prepareGeneration failed: no model loaded'); failPrepare('No LiteRT model loaded'); } } else { if (!llmService.isModelLoaded()) failPrepare('No model loaded'); if (llmService.isCurrentlyGenerating()) failPrepare('LLM service busy'); @@ -146,6 +161,82 @@ export async function generateResponseImpl( const chatStore = useChatStore.getState(); let firstTokenReceived = false; + // LiteRT path — stateful engine, send only current user message + if (isLiteRTActive()) { + const dbg = useDebugLogsStore.getState().addLog; + const lastUser = [...messages].reverse().find(m => m.role === 'user'); + if (!lastUser) { + dbg('warn', '[LiteRT] generateResponse: no user message found'); + chatStore.clearStreamingMessage(); + svc.resetState(); + return; + } + const systemMsg = messages.find(m => m.role === 'system'); + const systemPrompt = typeof systemMsg?.content === 'string' ? systemMsg.content : ''; + const imageAttachment = lastUser.attachments?.find((a: any) => a.type === 'image'); + const imageUri = imageAttachment?.uri as string | undefined; + dbg('log', `[LiteRT] generateResponse — hasImage=${!!imageUri}, systemPrompt length=${systemPrompt.length}, userText length=${typeof lastUser.content === 'string' ? lastUser.content.length : 0}`); + + try { + // MVP: always reset conversation before each generation (safe + correct for all flows) + dbg('log', '[LiteRT] resetting conversation'); + await liteRTService.resetConversation(systemPrompt); + dbg('log', `[LiteRT] sendMessage start — imageUri=${imageUri ?? 'none'}`); + + await liteRTService.sendMessage( + typeof lastUser.content === 'string' ? lastUser.content : '', + { + onToken: (token: string) => { + if (svc.abortRequested) return; + if (!firstTokenReceived) { + firstTokenReceived = true; + svc.updateState({ isThinking: false }); + onFirstToken?.(); + dbg('log', '[LiteRT] first token received'); + } + svc.state.streamingContent += token; + svc.tokenBuffer += token; + if (!svc.flushTimer) { + svc.flushTimer = setTimeout(() => svc.flushTokenBuffer(), FLUSH_INTERVAL_MS); + } + }, + onReasoning: (token: string) => { + if (svc.abortRequested) return; + svc.reasoningBuffer += token; + }, + onComplete: (_content: string, _reasoning: string) => { + if (svc.abortRequested) return; + svc.forceFlushTokens(); + const generationTime = svc.state.startTime ? Date.now() - svc.state.startTime : undefined; + dbg('log', `[LiteRT] generation complete — ${generationTime}ms, tokens=${svc.state.streamingContent.length} chars`); + chatStore.finalizeStreamingMessage(conversationId, generationTime, buildGenerationMetaImpl(svc)); + svc.checkSharePrompt(); + svc.resetState(); + }, + onError: (err: Error) => { + if (svc.abortRequested) return; + dbg('error', `[LiteRT] sendMessage error: ${err.message}`); + if (svc.flushTimer) { clearTimeout(svc.flushTimer); svc.flushTimer = null; } + svc.tokenBuffer = ''; + chatStore.clearStreamingMessage(); + svc.resetState(); + }, + }, + imageUri, + ); + } catch (error: any) { + if (svc.abortRequested) return; + dbg('error', `[LiteRT] generateResponse caught: ${error?.message ?? error}`); + if (svc.flushTimer) { clearTimeout(svc.flushTimer); svc.flushTimer = null; } + svc.tokenBuffer = ''; + chatStore.clearStreamingMessage(); + svc.resetState(); + throw error; + } + return; + } + + // llama.cpp path — unchanged try { await llmService.generateResponse( messages, diff --git a/src/services/litert.ts b/src/services/litert.ts new file mode 100644 index 000000000..e49b40877 --- /dev/null +++ b/src/services/litert.ts @@ -0,0 +1,213 @@ +/** + * LiteRTService — JS bridge to the native LiteRTModule (Android). + * + * Architecture notes: + * - The native Conversation object holds turn history internally. + * JS sends only the current user message via sendMessage(). + * - Call resetConversation() before each generation (MVP approach). + * This is safe and correct for all flows including retry/edit/switch. + * - onComplete receives fully accumulated content, not an empty string. + */ + +import { NativeModules, NativeEventEmitter, Platform, EmitterSubscription } from 'react-native'; +import logger from '../utils/logger'; + +const TAG = '[LiteRTService]'; + +const { LiteRTModule } = NativeModules; + +// Events emitted by the native module +const EVENT_TOKEN = 'litert_token'; +const EVENT_THINKING = 'litert_thinking'; +const EVENT_COMPLETE = 'litert_complete'; +const EVENT_ERROR = 'litert_error'; + +export type LiteRTBackend = 'cpu' | 'gpu' | 'npu'; + +export interface LiteRTGenerationCallbacks { + onToken: (token: string) => void; + onReasoning: (token: string) => void; + onComplete: (fullContent: string, fullReasoning: string) => void; + onError: (error: Error) => void; +} + +class LiteRTService { + private loaded = false; + private activeBackend: LiteRTBackend | null = null; + private emitter: NativeEventEmitter | null = null; + private subscriptions: EmitterSubscription[] = []; + + // Accumulated content for current generation + private currentContent = ''; + private currentReasoning = ''; + private currentCallbacks: LiteRTGenerationCallbacks | null = null; + + constructor() { + if (Platform.OS === 'android' && LiteRTModule) { + this.emitter = new NativeEventEmitter(LiteRTModule); + logger.log(TAG, 'initialized — native module available'); + } else { + logger.log(TAG, 'native module not available on this platform'); + } + } + + // --------------------------------------------------------------------------- + // loadModel + // --------------------------------------------------------------------------- + + async loadModel(modelPath: string, preferredBackend: LiteRTBackend, supportsVision = false): Promise { + if (!this.isAvailable()) { + throw new Error('LiteRT is not available on this platform'); + } + + logger.log(TAG, `loadModel — path=${modelPath} backend=${preferredBackend} supportsVision=${supportsVision}`); + + try { + const actualBackend: string = await LiteRTModule.loadModel(modelPath, preferredBackend, supportsVision); + this.activeBackend = actualBackend as LiteRTBackend; + this.loaded = true; + logger.log(TAG, `loadModel — loaded on ${this.activeBackend}`); + } catch (e) { + this.loaded = false; + this.activeBackend = null; + logger.log(TAG, `loadModel — failed: ${String(e)}`); + throw e; + } + } + + // --------------------------------------------------------------------------- + // resetConversation — cheap: closes + recreates Conversation, Engine stays + // --------------------------------------------------------------------------- + + async resetConversation(systemPrompt: string): Promise { + if (!this.isAvailable() || !this.loaded) { + throw new Error('No LiteRT model loaded'); + } + logger.log(TAG, `resetConversation — systemPrompt length=${systemPrompt.length}`); + await LiteRTModule.resetConversation(systemPrompt); + logger.log(TAG, 'resetConversation — done'); + } + + // --------------------------------------------------------------------------- + // sendMessage — sends current turn only, library holds history + // --------------------------------------------------------------------------- + + async sendMessage( + text: string, + callbacks: LiteRTGenerationCallbacks, + imageUri?: string, + ): Promise { + if (!this.isAvailable() || !this.loaded) { + callbacks.onError(new Error('No LiteRT model loaded')); + return; + } + + logger.log(TAG, `sendMessage — text length=${text.length}`); + + // Reset accumulators + this.currentContent = ''; + this.currentReasoning = ''; + this.currentCallbacks = callbacks; + + // Register event listeners for this generation + this.clearSubscriptions(); + this.subscriptions = [ + this.emitter!.addListener(EVENT_TOKEN, (token: string) => { + this.currentContent += token; + callbacks.onToken(token); + }), + this.emitter!.addListener(EVENT_THINKING, (token: string) => { + this.currentReasoning += token; + callbacks.onReasoning(token); + }), + this.emitter!.addListener(EVENT_COMPLETE, () => { + logger.log(TAG, `sendMessage — complete, content=${this.currentContent.length} chars`); + this.clearSubscriptions(); + this.currentCallbacks = null; + callbacks.onComplete(this.currentContent, this.currentReasoning); + }), + this.emitter!.addListener(EVENT_ERROR, (message: string) => { + logger.log(TAG, `sendMessage — error: ${message}`); + this.clearSubscriptions(); + this.currentCallbacks = null; + callbacks.onError(new Error(message)); + }), + ]; + + try { + await LiteRTModule.sendMessage(text, imageUri ?? null); + } catch (e) { + this.clearSubscriptions(); + this.currentCallbacks = null; + const err = e instanceof Error ? e : new Error(String(e)); + logger.log(TAG, `sendMessage — native error: ${err.message}`); + callbacks.onError(err); + } + } + + // --------------------------------------------------------------------------- + // stopGeneration + // --------------------------------------------------------------------------- + + async stopGeneration(): Promise { + if (!this.isAvailable()) return; + logger.log(TAG, 'stopGeneration'); + this.clearSubscriptions(); + this.currentCallbacks = null; + try { + await LiteRTModule.stopGeneration(); + } catch (e) { + logger.log(TAG, `stopGeneration — error (ignored): ${String(e)}`); + } + } + + // --------------------------------------------------------------------------- + // unloadModel — expensive: closes Conversation + Engine + // --------------------------------------------------------------------------- + + async unloadModel(): Promise { + if (!this.isAvailable()) return; + logger.log(TAG, 'unloadModel'); + this.clearSubscriptions(); + this.currentCallbacks = null; + try { + await LiteRTModule.unloadModel(); + } catch (e) { + logger.log(TAG, `unloadModel — error (ignored): ${String(e)}`); + } finally { + this.loaded = false; + this.activeBackend = null; + } + } + + // --------------------------------------------------------------------------- + // State queries + // --------------------------------------------------------------------------- + + isModelLoaded(): boolean { + return this.loaded; + } + + isNPU(): boolean { + return this.activeBackend === 'npu'; + } + + getActiveBackend(): LiteRTBackend | null { + return this.activeBackend; + } + + isAvailable(): boolean { + return Platform.OS === 'android' && !!LiteRTModule; + } + + // --------------------------------------------------------------------------- + // Internal helpers + // --------------------------------------------------------------------------- + + private clearSubscriptions(): void { + this.subscriptions.forEach(s => s.remove()); + this.subscriptions = []; + } +} + +export const liteRTService = new LiteRTService(); diff --git a/src/services/modelManager/copyFile.ts b/src/services/modelManager/copyFile.ts new file mode 100644 index 000000000..03d5236d0 --- /dev/null +++ b/src/services/modelManager/copyFile.ts @@ -0,0 +1,52 @@ +import RNFS from 'react-native-fs'; + +type CopyProgressOpts = { knownTotalBytes: number | null; onProgress?: (fraction: number) => void }; + +function parseSizeInt(size: string | number): number { + return typeof size === 'string' ? Number.parseInt(size, 10) : size; +} + +export async function copyFileWithProgress( + source: string, + dest: string, + { knownTotalBytes, onProgress }: CopyProgressOpts, +): Promise { + let totalBytes = knownTotalBytes ?? 0; + if (totalBytes === 0) { + try { + const sourceStat = await RNFS.stat(source); + totalBytes = parseSizeInt(sourceStat.size); + } catch { + // stat failed — progress will be indeterminate (stuck at 0%), non-fatal + } + } + + let polling = true; + + const pollInterval = setInterval(async () => { + if (!polling) return; + try { + const exists = await RNFS.exists(dest); + if (exists && totalBytes > 0) { + const stat = await RNFS.stat(dest); + const written = parseSizeInt(stat.size); + const pct = Math.min(written / totalBytes, 0.99); + onProgress?.(pct); + } + } catch { + // poll errors are non-fatal + } + }, 500); + + try { + await RNFS.copyFile(source, dest); + polling = false; + clearInterval(pollInterval); + onProgress?.(1); + } catch (error) { + polling = false; + clearInterval(pollInterval); + await RNFS.unlink(dest).catch(() => {}); + throw error; + } +} diff --git a/src/services/modelManager/scan.ts b/src/services/modelManager/scan.ts index b0763d630..2791a13a3 100644 --- a/src/services/modelManager/scan.ts +++ b/src/services/modelManager/scan.ts @@ -2,6 +2,7 @@ import RNFS from 'react-native-fs'; import { unzip } from 'react-native-zip-archive'; import { DownloadedModel, ModelFile, ONNXImageModel } from '../../types'; import { buildDownloadedModel, persistDownloadedModel, loadDownloadedModels, saveModelsList } from './storage'; +import { copyFileWithProgress } from './copyFile'; import { resolveCoreMLModelDir } from '../../utils/coreMLModelUtils'; export function isMMProjFile(fileName: string): boolean { @@ -376,6 +377,8 @@ export interface ImportLocalModelOpts { fileName: string; modelsDir: string; sourceSize?: number | null; + engine?: import('../../types').ModelEngine; + liteRTVision?: boolean; onProgress?: (progress: { fraction: number; fileName: string }) => void; mmProjSourceUri?: string; mmProjFileName?: string; @@ -393,10 +396,11 @@ function resolveUri(uri: string): string { export async function importLocalModel(opts: ImportLocalModelOpts): Promise { // NOSONAR - const { sourceUri, fileName, modelsDir, sourceSize, onProgress, mmProjSourceUri, mmProjFileName, mmProjSourceSize } = opts; + const { sourceUri, fileName, modelsDir, sourceSize, engine, liteRTVision, onProgress, mmProjSourceUri, mmProjFileName, mmProjSourceSize } = opts; - if (!fileName.toLowerCase().endsWith('.gguf')) { - throw new Error('Only .gguf files can be imported'); + const isLitert = fileName.toLowerCase().endsWith('.litertlm'); + if (!fileName.toLowerCase().endsWith('.gguf') && !isLitert) { + throw new Error('Only .gguf and .litertlm files can be imported'); } const resolvedSource = resolveUri(sourceUri); @@ -418,7 +422,7 @@ export async function importLocalModel(opts: ImportLocalModelOpts): Promise void }; - -async function copyFileWithProgress( - source: string, - dest: string, - { knownTotalBytes, onProgress }: CopyProgressOpts, -): Promise { - let totalBytes = knownTotalBytes ?? 0; - if (totalBytes === 0) { - try { - const sourceStat = await RNFS.stat(source); - totalBytes = parseSizeInt(sourceStat.size); - } catch { - // stat failed — progress will be indeterminate (stuck at 0%), non-fatal - } - } - - let polling = true; - - const pollInterval = setInterval(async () => { - if (!polling) return; - try { - const exists = await RNFS.exists(dest); - if (exists && totalBytes > 0) { - const stat = await RNFS.stat(dest); - const written = parseSizeInt(stat.size); - const pct = Math.min(written / totalBytes, 0.99); - onProgress?.(pct); - } - } catch { - // poll errors are non-fatal - } - }, 500); - - try { - await RNFS.copyFile(source, dest); - polling = false; - clearInterval(pollInterval); - onProgress?.(1); - } catch (error) { - polling = false; - clearInterval(pollInterval); - await RNFS.unlink(dest).catch(() => {}); - throw error; - } -} diff --git a/src/stores/debugLogsStore.ts b/src/stores/debugLogsStore.ts index cb7c00f16..d0b30b621 100644 --- a/src/stores/debugLogsStore.ts +++ b/src/stores/debugLogsStore.ts @@ -1,4 +1,8 @@ import { create } from 'zustand'; +import AsyncStorage from '@react-native-async-storage/async-storage'; + +const STORAGE_KEY = '@debug_logs'; +const MAX_LOGS = 200; export interface DebugLogEntry { timestamp: number; @@ -8,16 +12,40 @@ export interface DebugLogEntry { interface DebugLogsState { logs: DebugLogEntry[]; + loaded: boolean; addLog: (level: 'log' | 'warn' | 'error', message: string) => void; clearLogs: () => void; + loadFromStorage: () => Promise; } -export const useDebugLogsStore = create((set) => ({ +export const useDebugLogsStore = create((set, get) => ({ logs: [], - addLog: (level, message) => - set((state) => ({ - // Keep last 200 logs for memory efficiency - logs: [...state.logs, { timestamp: Date.now(), level, message }].slice(-200), - })), - clearLogs: () => set({ logs: [] }), + loaded: false, + addLog: (level, message) => { + const entry: DebugLogEntry = { timestamp: Date.now(), level, message }; + set((state) => { + const updated = [...state.logs, entry].slice(-MAX_LOGS); + // Fire-and-forget persist — don't await so addLog stays synchronous + AsyncStorage.setItem(STORAGE_KEY, JSON.stringify(updated)).catch(() => {}); + return { logs: updated }; + }); + }, + clearLogs: () => { + set({ logs: [] }); + AsyncStorage.removeItem(STORAGE_KEY).catch(() => {}); + }, + loadFromStorage: async () => { + if (get().loaded) return; + try { + const raw = await AsyncStorage.getItem(STORAGE_KEY); + if (raw) { + const logs: DebugLogEntry[] = JSON.parse(raw); + set({ logs, loaded: true }); + } else { + set({ loaded: true }); + } + } catch { + set({ loaded: true }); + } + }, })); diff --git a/src/types/index.ts b/src/types/index.ts index ffcc52b57..dfe1040c7 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -41,6 +41,8 @@ export interface ModelFile { }; } +export type ModelEngine = 'llama' | 'litert'; + export interface DownloadedModel { id: string; name: string; @@ -56,6 +58,10 @@ export interface DownloadedModel { mmProjPath?: string; mmProjFileName?: string; mmProjFileSize?: number; + // Inference engine — undefined means 'llama' (backwards compatible) + engine?: ModelEngine; + // LiteRT-specific: whether this model's vision encoder is compatible with visionBackend=GPU + liteRTVision?: boolean; } export interface PersistedDownloadInfo { From dcf40c813fab363b697c2482ce6c85d401983719 Mon Sep 17 00:00:00 2001 From: Dishit Date: Tue, 19 May 2026 15:08:21 +0530 Subject: [PATCH 02/93] fix(android): add BenchmarkInfo stats, getMemoryInfo, and sampler params to LiteRTModule Co-Authored-By: Dishit Karia --- .../ai/offgridmobile/litert/LiteRTModule.kt | 62 +++++++++++++++++-- android/gradle.properties | 1 - 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt index eee459d3f..79f88d86f 100644 --- a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -1,15 +1,20 @@ package ai.offgridmobile.litert import android.util.Log +import android.app.ActivityManager +import android.content.Context +import android.os.Debug import com.facebook.react.bridge.* import com.facebook.react.modules.core.DeviceEventManagerModule import ai.offgridmobile.SafePromise import com.google.ai.edge.litertlm.Backend +import com.google.ai.edge.litertlm.BenchmarkInfo import com.google.ai.edge.litertlm.ConversationConfig import com.google.ai.edge.litertlm.Engine import com.google.ai.edge.litertlm.EngineConfig import com.google.ai.edge.litertlm.Content import com.google.ai.edge.litertlm.Contents +import com.google.ai.edge.litertlm.ExperimentalApi import com.google.ai.edge.litertlm.SamplerConfig import kotlinx.coroutines.* import java.io.File @@ -129,9 +134,9 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : // ------------------------------------------------------------------------- @ReactMethod - fun resetConversation(systemPrompt: String, promise: Promise) { + fun resetConversation(systemPrompt: String, temperature: Double, topK: Int, topP: Double, promise: Promise) { val safe = SafePromise(promise, TAG) - Log.i(TAG, "resetConversation — systemPrompt length=${systemPrompt.length}") + Log.i(TAG, "resetConversation — systemPrompt length=${systemPrompt.length} temperature=$temperature topK=$topK topP=$topP") scope.launch { try { @@ -150,7 +155,11 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : Log.i(TAG, "resetConversation — NPU backend, skipping SamplerConfig") null } else { - SamplerConfig(topK = 40, topP = 0.95, temperature = 0.8) + SamplerConfig( + topK = topK, + topP = topP, + temperature = temperature, + ) } val convConfig = ConversationConfig( @@ -230,7 +239,15 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : } } Log.i(TAG, "sendMessage — generation complete") - sendEvent(EVENT_COMPLETE, "") + @OptIn(ExperimentalApi::class) + val benchmarkJson = try { + val b = conv.getBenchmarkInfo() + """{"ttft":${b.timeToFirstTokenInSecond},"decodeTokensPerSecond":${b.lastDecodeTokensPerSecond},"prefillTokensPerSecond":${b.lastPrefillTokensPerSecond},"prefillTokenCount":${b.lastPrefillTokenCount}}""" + } catch (e: Exception) { + Log.w(TAG, "getBenchmarkInfo failed: ${e.message}") + "" + } + sendEvent(EVENT_COMPLETE, benchmarkJson) safe.resolve(null) } catch (e: CancellationException) { Log.i(TAG, "sendMessage — job cancelled") @@ -335,6 +352,43 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : } } + // ------------------------------------------------------------------------- + // getMemoryInfo — live RAM usage + process GPU memory + // ------------------------------------------------------------------------- + + @ReactMethod + fun getMemoryInfo(promise: Promise) { + val safe = SafePromise(promise, TAG) + try { + val am = reactContext.getSystemService(Context.ACTIVITY_SERVICE) as ActivityManager + + // System RAM + val ramInfo = ActivityManager.MemoryInfo() + am.getMemoryInfo(ramInfo) + val totalRamMb = ramInfo.totalMem / (1024 * 1024) + val availRamMb = ramInfo.availMem / (1024 * 1024) + val usedRamMb = totalRamMb - availRamMb + + // Process GPU memory (graphics + GL textures) via Debug.MemoryInfo + val memInfo = Debug.MemoryInfo() + Debug.getMemoryInfo(memInfo) + val gpuPrivateMb = try { + (memInfo.getMemoryStat("summary.graphics") ?: "0").toLong() / 1024 + } catch (e: Exception) { 0L } + + val result = Arguments.createMap().apply { + putDouble("totalRamMb", totalRamMb.toDouble()) + putDouble("usedRamMb", usedRamMb.toDouble()) + putDouble("availRamMb", availRamMb.toDouble()) + putDouble("gpuPrivateMb", gpuPrivateMb.toDouble()) + putBoolean("lowMemory", ramInfo.lowMemory) + } + safe.resolve(result) + } catch (e: Exception) { + safe.reject("MEM_ERROR", "Failed to get memory info: ${e.message}", e) + } + } + private fun parseBackend(s: String): Backend = when (s.lowercase()) { "npu", "htp" -> Backend.NPU( nativeLibraryDir = reactContext.applicationInfo.nativeLibraryDir diff --git a/android/gradle.properties b/android/gradle.properties index c9f8970b5..9afe61598 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -11,7 +11,6 @@ # The setting is particularly useful for tweaking memory settings. # Default value: -Xmx512m -XX:MaxMetaspaceSize=256m org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=512m -org.gradle.java.home=/Library/Java/JavaVirtualMachines/temurin-21.jdk/Contents/Home # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit From a8aa88bd0ff1bb66296c3ab7b760ab6da8673f39 Mon Sep 17 00:00:00 2001 From: Dishit Date: Tue, 19 May 2026 15:08:24 +0530 Subject: [PATCH 03/93] fix(litert): production fixes - stopGeneration, multi-turn tracking, memory budget, BenchmarkInfo wiring Co-Authored-By: Dishit Karia --- src/services/activeModelService/memory.ts | 5 +- src/services/activeModelService/utils.ts | 3 +- src/services/generationService.ts | 18 ++++-- src/services/generationServiceHelpers.ts | 59 ++++++++++++++--- src/services/litert.ts | 78 +++++++++++++++++++++-- 5 files changed, 142 insertions(+), 21 deletions(-) diff --git a/src/services/activeModelService/memory.ts b/src/services/activeModelService/memory.ts index ea56500fe..0fc29abd5 100644 --- a/src/services/activeModelService/memory.ts +++ b/src/services/activeModelService/memory.ts @@ -6,6 +6,7 @@ import { DownloadedModel, ONNXImageModel } from '../../types'; import { hardwareService } from '../hardware'; import { llmService } from '../llm'; +import { liteRTService } from '../litert'; import { ModelType, MemoryCheckResult, @@ -66,7 +67,7 @@ export function getCurrentlyLoadedMemoryGB( ): number { let totalGB = 0; - if (ids.loadedTextModelId && llmService.isModelLoaded()) { + if (ids.loadedTextModelId && (llmService.isModelLoaded() || liteRTService.isModelLoaded())) { const textModel = lists.downloadedModels.find(m => m.id === ids.loadedTextModelId); if (textModel) { totalGB += estimateModelMemoryGB(textModel, 'text'); @@ -100,7 +101,7 @@ export function getOtherLoadedMemoryGB( totalGB += estimateModelMemoryGB(imageModel, 'image'); } } - if (modelType === 'image' && ids.loadedTextModelId && llmService.isModelLoaded()) { + if (modelType === 'image' && ids.loadedTextModelId && (llmService.isModelLoaded() || liteRTService.isModelLoaded())) { const textModel = lists.downloadedModels.find(m => m.id === ids.loadedTextModelId); if (textModel) { totalGB += estimateModelMemoryGB(textModel, 'text'); diff --git a/src/services/activeModelService/utils.ts b/src/services/activeModelService/utils.ts index 114ae049a..22733b7d3 100644 --- a/src/services/activeModelService/utils.ts +++ b/src/services/activeModelService/utils.ts @@ -5,6 +5,7 @@ import { useAppStore } from '../../stores'; import { hardwareService } from '../hardware'; import { llmService } from '../llm'; +import { liteRTService } from '../litert'; import { localDreamGeneratorService as onnxImageGeneratorService } from '../localDreamGenerator'; import { ResourceUsage } from './types'; @@ -46,7 +47,7 @@ export interface SyncStateTarget { export async function syncWithNativeState(target: SyncStateTarget): Promise { const store = useAppStore.getState(); - const textModelLoaded = llmService.isModelLoaded(); + const textModelLoaded = llmService.isModelLoaded() || liteRTService.isModelLoaded(); if (!textModelLoaded) { target.setLoadedTextModelId(null); } else if (!target.loadedTextModelId && store.activeModelId) { diff --git a/src/services/generationService.ts b/src/services/generationService.ts index a732b370d..7faec9b15 100644 --- a/src/services/generationService.ts +++ b/src/services/generationService.ts @@ -1,5 +1,6 @@ /** GenerationService - Handles LLM generation independently of UI lifecycle */ import { llmService } from './llm'; +import { liteRTService } from './litert'; import { useAppStore, useChatStore, useRemoteServerStore } from '../stores'; import { Message, GenerationMeta, MediaAttachment } from '../types'; import { runToolLoop } from './generationToolLoop'; @@ -205,8 +206,9 @@ class GenerationService { /** Stop the current generation. Returns partial content if any was generated. */ async stopGeneration(): Promise { if (!this.state.isGenerating) { - // Stop both local and remote + // Stop all engines and remote await llmService.stopGeneration().catch(() => { }); + await liteRTService.stopGeneration().catch(() => { }); const provider = this.getCurrentProvider(); if (provider) provider.stopGeneration().catch(() => { }); if (this.currentRemoteAbortController) { @@ -252,9 +254,17 @@ class GenerationService { // Stop the native completion after we've already updated UI state, // so the user sees immediate feedback. Store the promise so new // generations can drain it before starting. - this.pendingStop = llmService.stopGeneration().catch(() => { }).finally(() => { - this.pendingStop = null; - }); + const { downloadedModels, activeModelId } = useAppStore.getState(); + const activeModel = downloadedModels.find((m: any) => m.id === activeModelId); + if (activeModel?.engine === 'litert') { + this.pendingStop = liteRTService.stopGeneration().catch(() => { }).finally(() => { + this.pendingStop = null; + }); + } else { + this.pendingStop = llmService.stopGeneration().catch(() => { }).finally(() => { + this.pendingStop = null; + }); + } return streamingContent; } diff --git a/src/services/generationServiceHelpers.ts b/src/services/generationServiceHelpers.ts index 63c67291f..0faed2375 100644 --- a/src/services/generationServiceHelpers.ts +++ b/src/services/generationServiceHelpers.ts @@ -60,13 +60,41 @@ export function buildGenerationMetaImpl(svc: any): GenerationMeta { }; } - // Local provider metadata + const { downloadedModels, activeModelId, settings } = useAppStore.getState(); + const modelName = downloadedModels.find((m: any) => m.id === activeModelId)?.name; + + // LiteRT path — use real BenchmarkInfo stats if available, else estimate + if (isLiteRTActive()) { + const backend = liteRTService.getActiveBackend() ?? 'cpu'; + const stats = svc.liteRTBenchmarkStats; + if (stats) { + return { + gpu: backend !== 'cpu', + gpuBackend: backend.toUpperCase(), + modelName, + tokensPerSecond: stats.decodeTokensPerSecond, + timeToFirstToken: stats.ttft * 1000, + tokenCount: stats.prefillTokenCount, + }; + } + // Fallback estimate if BenchmarkInfo unavailable + const generationTime = svc.state.startTime ? (Date.now() - svc.state.startTime) / 1000 : 0; + const estimatedTokens = Math.ceil(svc.state.streamingContent.length / 4); + return { + gpu: backend !== 'cpu', + gpuBackend: backend.toUpperCase(), + modelName, + tokenCount: estimatedTokens, + tokensPerSecond: generationTime > 0 ? estimatedTokens / generationTime : undefined, + }; + } + + // llama.cpp path — real perf data from native engine const { gpu, gpuBackend, gpuLayers } = llmService.getGpuInfo(); const perf = llmService.getPerformanceStats(); - const { downloadedModels, activeModelId, settings } = useAppStore.getState(); return { gpu, gpuBackend, gpuLayers, - modelName: downloadedModels.find((m: any) => m.id === activeModelId)?.name, + modelName, tokensPerSecond: perf.lastTokensPerSecond, decodeTokensPerSecond: perf.lastDecodeTokensPerSecond, timeToFirstToken: perf.lastTimeToFirstToken, @@ -177,10 +205,24 @@ export async function generateResponseImpl( const imageUri = imageAttachment?.uri as string | undefined; dbg('log', `[LiteRT] generateResponse — hasImage=${!!imageUri}, systemPrompt length=${systemPrompt.length}, userText length=${typeof lastUser.content === 'string' ? lastUser.content.length : 0}`); + // Guard: image attached but model was not imported with vision support + if (imageUri) { + const { downloadedModels, activeModelId } = useAppStore.getState(); + const activeModel = downloadedModels.find((m: any) => m.id === activeModelId); + if (!activeModel?.liteRTVision) { + dbg('warn', '[LiteRT] Image attached but model does not support vision — aborting'); + chatStore.clearStreamingMessage(); + svc.resetState(); + throw new Error('This model does not support images. Import it with vision enabled, or remove the image.'); + } + } + try { - // MVP: always reset conversation before each generation (safe + correct for all flows) - dbg('log', '[LiteRT] resetting conversation'); - await liteRTService.resetConversation(systemPrompt); + const { settings } = useAppStore.getState(); + await liteRTService.prepareConversation(conversationId, systemPrompt, { + temperature: settings.temperature, + topP: settings.topP, + }); dbg('log', `[LiteRT] sendMessage start — imageUri=${imageUri ?? 'none'}`); await liteRTService.sendMessage( @@ -204,11 +246,12 @@ export async function generateResponseImpl( if (svc.abortRequested) return; svc.reasoningBuffer += token; }, - onComplete: (_content: string, _reasoning: string) => { + onComplete: (_content: string, _reasoning: string, stats) => { if (svc.abortRequested) return; svc.forceFlushTokens(); + svc.liteRTBenchmarkStats = stats; const generationTime = svc.state.startTime ? Date.now() - svc.state.startTime : undefined; - dbg('log', `[LiteRT] generation complete — ${generationTime}ms, tokens=${svc.state.streamingContent.length} chars`); + dbg('log', `[LiteRT] generation complete — ${generationTime}ms, tok/s=${stats?.decodeTokensPerSecond?.toFixed(1) ?? 'n/a'}`); chatStore.finalizeStreamingMessage(conversationId, generationTime, buildGenerationMetaImpl(svc)); svc.checkSharePrompt(); svc.resetState(); diff --git a/src/services/litert.ts b/src/services/litert.ts index e49b40877..d9cbc0450 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -24,10 +24,25 @@ const EVENT_ERROR = 'litert_error'; export type LiteRTBackend = 'cpu' | 'gpu' | 'npu'; +export interface LiteRTBenchmarkStats { + ttft: number; + decodeTokensPerSecond: number; + prefillTokensPerSecond: number; + prefillTokenCount: number; +} + +export interface LiteRTMemoryInfo { + totalRamMb: number; + usedRamMb: number; + availRamMb: number; + gpuPrivateMb: number; + lowMemory: boolean; +} + export interface LiteRTGenerationCallbacks { onToken: (token: string) => void; onReasoning: (token: string) => void; - onComplete: (fullContent: string, fullReasoning: string) => void; + onComplete: (fullContent: string, fullReasoning: string, stats?: LiteRTBenchmarkStats) => void; onError: (error: Error) => void; } @@ -42,6 +57,10 @@ class LiteRTService { private currentReasoning = ''; private currentCallbacks: LiteRTGenerationCallbacks | null = null; + // Multi-turn tracking — reset conversation only when context changes + private activeConversationId: string | null = null; + private activeSystemPrompt: string | null = null; + constructor() { if (Platform.OS === 'android' && LiteRTModule) { this.emitter = new NativeEventEmitter(LiteRTModule); @@ -79,15 +98,44 @@ class LiteRTService { // resetConversation — cheap: closes + recreates Conversation, Engine stays // --------------------------------------------------------------------------- - async resetConversation(systemPrompt: string): Promise { + async resetConversation( + systemPrompt: string, + samplerConfig?: { temperature?: number; topK?: number; topP?: number }, + ): Promise { if (!this.isAvailable() || !this.loaded) { throw new Error('No LiteRT model loaded'); } - logger.log(TAG, `resetConversation — systemPrompt length=${systemPrompt.length}`); - await LiteRTModule.resetConversation(systemPrompt); + const temperature = samplerConfig?.temperature ?? 0.8; + const topK = samplerConfig?.topK ?? 40; + const topP = samplerConfig?.topP ?? 0.95; + logger.log(TAG, `resetConversation — systemPrompt length=${systemPrompt.length} temperature=${temperature} topK=${topK} topP=${topP}`); + await LiteRTModule.resetConversation(systemPrompt, temperature, topK, topP); + this.activeSystemPrompt = systemPrompt; logger.log(TAG, 'resetConversation — done'); } + /** + * Ensure conversation is ready for the given context. + * Resets only when conversationId or systemPrompt has changed — preserves + * native turn history for follow-up messages in the same conversation. + */ + async prepareConversation( + conversationId: string, + systemPrompt: string, + samplerConfig?: { temperature?: number; topK?: number; topP?: number }, + ): Promise { + const needsReset = + this.activeConversationId !== conversationId || + this.activeSystemPrompt !== systemPrompt; + if (needsReset) { + logger.log(TAG, `prepareConversation — reset (convId changed=${this.activeConversationId !== conversationId}, sysPrompt changed=${this.activeSystemPrompt !== systemPrompt})`); + await this.resetConversation(systemPrompt, samplerConfig); + this.activeConversationId = conversationId; + } else { + logger.log(TAG, 'prepareConversation — reusing existing conversation (multi-turn)'); + } + } + // --------------------------------------------------------------------------- // sendMessage — sends current turn only, library holds history // --------------------------------------------------------------------------- @@ -120,11 +168,15 @@ class LiteRTService { this.currentReasoning += token; callbacks.onReasoning(token); }), - this.emitter!.addListener(EVENT_COMPLETE, () => { + this.emitter!.addListener(EVENT_COMPLETE, (benchmarkJson: string) => { logger.log(TAG, `sendMessage — complete, content=${this.currentContent.length} chars`); this.clearSubscriptions(); this.currentCallbacks = null; - callbacks.onComplete(this.currentContent, this.currentReasoning); + let stats: LiteRTBenchmarkStats | undefined; + if (benchmarkJson) { + try { stats = JSON.parse(benchmarkJson); } catch { /* ignore parse errors */ } + } + callbacks.onComplete(this.currentContent, this.currentReasoning, stats); }), this.emitter!.addListener(EVENT_ERROR, (message: string) => { logger.log(TAG, `sendMessage — error: ${message}`); @@ -154,6 +206,8 @@ class LiteRTService { logger.log(TAG, 'stopGeneration'); this.clearSubscriptions(); this.currentCallbacks = null; + // After a stop the native conversation state is indeterminate — force reset on next turn + this.activeConversationId = null; try { await LiteRTModule.stopGeneration(); } catch (e) { @@ -170,6 +224,8 @@ class LiteRTService { logger.log(TAG, 'unloadModel'); this.clearSubscriptions(); this.currentCallbacks = null; + this.activeConversationId = null; + this.activeSystemPrompt = null; try { await LiteRTModule.unloadModel(); } catch (e) { @@ -200,6 +256,16 @@ class LiteRTService { return Platform.OS === 'android' && !!LiteRTModule; } + async getMemoryInfo(): Promise { + if (!this.isAvailable()) return null; + try { + return await LiteRTModule.getMemoryInfo(); + } catch (e) { + logger.log(TAG, `getMemoryInfo — error: ${String(e)}`); + return null; + } + } + // --------------------------------------------------------------------------- // Internal helpers // --------------------------------------------------------------------------- From 85f887df99949cea829b51096e08f4f0f5015491 Mon Sep 17 00:00:00 2001 From: Dishit Date: Tue, 19 May 2026 15:08:28 +0530 Subject: [PATCH 04/93] feat(ui): DeviceStatsChip, hide irrelevant LiteRT settings, backend reload trigger, iOS guard Co-Authored-By: Dishit Karia --- src/components/DeviceStatsChip.tsx | 104 ++++++++++++++++ src/screens/ChatScreen/index.tsx | 6 + src/screens/ChatScreen/useChatScreen.ts | 28 +++-- .../TextGenerationAdvanced.tsx | 116 +++++++++--------- .../TextGenerationSection.tsx | 86 +++++++------ src/screens/ModelsScreen/useModelsScreen.ts | 6 + src/screens/SettingsScreen.tsx | 21 ++-- 7 files changed, 252 insertions(+), 115 deletions(-) create mode 100644 src/components/DeviceStatsChip.tsx diff --git a/src/components/DeviceStatsChip.tsx b/src/components/DeviceStatsChip.tsx new file mode 100644 index 000000000..d85ace2b3 --- /dev/null +++ b/src/components/DeviceStatsChip.tsx @@ -0,0 +1,104 @@ +import React, { useEffect, useRef, useState } from 'react'; +import { View, Text, TouchableOpacity, StyleSheet } from 'react-native'; +import { liteRTService, LiteRTMemoryInfo } from '../services/litert'; +import { useTheme } from '../theme'; +import { SPACING, TYPOGRAPHY } from '../constants'; + +interface Props { + visible: boolean; + onPress?: () => void; +} + +export const DeviceStatsChip: React.FC = ({ visible, onPress }) => { + const { colors } = useTheme(); + const [mem, setMem] = useState(null); + const intervalRef = useRef | null>(null); + + useEffect(() => { + if (!visible || !liteRTService.isAvailable()) return; + + const refresh = async () => { + const info = await liteRTService.getMemoryInfo(); + if (info) setMem(info); + }; + + refresh(); + intervalRef.current = setInterval(refresh, 2000); + return () => { + if (intervalRef.current) clearInterval(intervalRef.current); + }; + }, [visible]); + + if (!visible || !liteRTService.isAvailable() || !mem) return null; + + const usedPct = Math.round((mem.usedRamMb / mem.totalRamMb) * 100); + const barColor = mem.lowMemory ? colors.error : usedPct > 80 ? '#FF9F0A' : colors.primary; + + return ( + + + RAM + + {Math.round(mem.usedRamMb / 1024 * 10) / 10} + / + {Math.round(mem.totalRamMb / 1024 * 10) / 10} + GB + + + {mem.gpuPrivateMb > 0 && ( + + GPU + + {mem.gpuPrivateMb} + MB + + + )} + {mem.lowMemory && ( + low mem + )} + + ); +}; + +const styles = StyleSheet.create({ + chip: { + position: 'absolute', + top: SPACING.xs, + right: SPACING.sm, + borderRadius: 8, + borderWidth: 1, + paddingHorizontal: SPACING.sm, + paddingVertical: SPACING.xs, + zIndex: 100, + minWidth: 90, + }, + row: { + flexDirection: 'row', + justifyContent: 'space-between', + alignItems: 'center', + gap: SPACING.xs, + }, + label: { + ...TYPOGRAPHY.meta, + fontSize: 10, + }, + value: { + ...TYPOGRAPHY.meta, + fontSize: 11, + fontVariant: ['tabular-nums'], + }, + unit: { + fontSize: 9, + }, + warn: { + ...TYPOGRAPHY.meta, + fontSize: 9, + textAlign: 'center', + marginTop: 1, + }, +}); diff --git a/src/screens/ChatScreen/index.tsx b/src/screens/ChatScreen/index.tsx index df93d9202..c2c9d6c92 100644 --- a/src/screens/ChatScreen/index.tsx +++ b/src/screens/ChatScreen/index.tsx @@ -4,6 +4,7 @@ import { SafeAreaView } from 'react-native-safe-area-context'; import { useFocusEffect } from '@react-navigation/native'; import { useSpotlightTour } from 'react-native-spotlight-tour'; import { CustomAlert, hideAlert, SharePromptSheet } from '../../components'; +import { DeviceStatsChip } from '../../components/DeviceStatsChip'; import { consumePendingSpotlight } from '../../components/onboarding/spotlightState'; import { subscribeSharePrompt } from '../../utils/sharePrompt'; import { VOICE_HINT_STEP_INDEX, IMAGE_SETTINGS_STEP_INDEX } from '../../components/onboarding/spotlightConfig'; @@ -31,6 +32,7 @@ export const ChatScreen: React.FC = () => { const pendingNextRef = useRef(null); const [sharePromptVisible, setSharePromptVisible] = useState(false); + const [showStatsChip, setShowStatsChip] = useState(true); useEffect(() => subscribeSharePrompt(() => setSharePromptVisible(true)), []); // Only ONE AttachStep mounted at a time to avoid waypoint dots/lines. // chatSpotlight controls which index is active (3, 12, 15, or 16). @@ -187,6 +189,10 @@ export const ChatScreen: React.FC = () => { setShowLogsPanel={chat.setShowLogsPanel} isRemote={chat.activeModelInfo?.isRemote} /> + setShowStatsChip(v => !v)} + /> { }; // Check if there are pending settings that require model reload const hasPendingSettings = (() => { - // LiteRT manages its own backend internally — llama.cpp settings don't apply - if (activeModel?.engine === 'litert') return false; if (!loadedSettings) return false; + // LiteRT only reloads when the inference backend changes (cpu/gpu/npu) + if (activeModel?.engine === 'litert') { + return settings.inferenceBackend !== loadedSettings.inferenceBackend; + } return ( settings.nThreads !== loadedSettings.nThreads || settings.nBatch !== loadedSettings.nBatch || @@ -248,19 +251,22 @@ export const useChatScreen = () => { const handleReloadTextModel = useCallback(async () => { if (!activeModelInfo.modelId || activeModelInfo.isRemote) return; - // LiteRT manages its own backend — reloading via llama.cpp path is wrong - if (activeModel?.engine === 'litert') return; - // Open the model selector bottom sheet before unloading so the user sees the - // loading state inside it rather than the NoModelScreen ("Select Model"). setShowModelSelector(true); - // Must unload first — loadTextModel skips if the same model ID is already loaded, - // which means setLoadedSettings would never run and the banner would persist. - if (llmService.isModelLoaded()) { - await activeModelService.unloadTextModel(); + if (activeModel?.engine === 'litert') { + // Unload LiteRT engine before reloading with the new backend + if (liteRTService.isModelLoaded()) { + await liteRTService.unloadModel().catch(() => { }); + } + } else { + // Must unload first — loadTextModel skips if the same model ID is already loaded, + // which means setLoadedSettings would never run and the banner would persist. + if (llmService.isModelLoaded()) { + await activeModelService.unloadTextModel(); + } } await initiateModelLoad(modelDeps, false); // eslint-disable-next-line react-hooks/exhaustive-deps - }, [activeModelInfo.modelId, activeModelInfo.isRemote, settings]); + }, [activeModelInfo.modelId, activeModelInfo.isRemote, settings, activeModel?.engine]); return { isModelLoading, loadingModel, supportsVision, diff --git a/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx b/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx index c3d14393a..09bcad695 100644 --- a/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx +++ b/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx @@ -217,7 +217,7 @@ const ModelLoadingStrategySection: React.FC = () => { // ─── Main Advanced Component ───────────────────────────────────────────────── -export const TextGenerationAdvanced: React.FC = () => { +export const TextGenerationAdvanced: React.FC<{ isLiteRT?: boolean }> = ({ isLiteRT = false }) => { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); const { settings, updateSettings } = useAppStore(); @@ -246,66 +246,72 @@ export const TextGenerationAdvanced: React.FC = () => { /> - - - Repeat Penalty - {(settings?.repeatPenalty || 1.1).toFixed(2)} + {!isLiteRT && ( + + + Repeat Penalty + {(settings?.repeatPenalty || 1.1).toFixed(2)} + + Penalize repeated tokens + updateSettings({ repeatPenalty: value })} + minimumTrackTintColor={colors.primary} + maximumTrackTintColor={colors.surface} + thumbTintColor={colors.primary} + /> - Penalize repeated tokens - updateSettings({ repeatPenalty: value })} - minimumTrackTintColor={colors.primary} - maximumTrackTintColor={colors.surface} - thumbTintColor={colors.primary} - /> - + )} - - - CPU Threads - {cpuThreadsDisplayValue} + {!isLiteRT && ( + + + CPU Threads + {cpuThreadsDisplayValue} + + Parallel threads for inference + updateSettings({ nThreads: value })} + minimumTrackTintColor={colors.primary} + maximumTrackTintColor={colors.surface} + thumbTintColor={colors.primary} + /> - Parallel threads for inference - updateSettings({ nThreads: value })} - minimumTrackTintColor={colors.primary} - maximumTrackTintColor={colors.surface} - thumbTintColor={colors.primary} - /> - + )} - - - Batch Size - {settings?.nBatch || 256} + {!isLiteRT && ( + + + Batch Size + {settings?.nBatch || 256} + + Tokens processed per batch + updateSettings({ nBatch: value })} + minimumTrackTintColor={colors.primary} + maximumTrackTintColor={colors.surface} + thumbTintColor={colors.primary} + /> - Tokens processed per batch - updateSettings({ nBatch: value })} - minimumTrackTintColor={colors.primary} - maximumTrackTintColor={colors.surface} - thumbTintColor={colors.primary} - /> - + )} - - - + {!isLiteRT && } + {!isLiteRT && } + {!isLiteRT && } ); diff --git a/src/screens/ModelSettingsScreen/TextGenerationSection.tsx b/src/screens/ModelSettingsScreen/TextGenerationSection.tsx index 5b1d9099f..f13cf6c56 100644 --- a/src/screens/ModelSettingsScreen/TextGenerationSection.tsx +++ b/src/screens/ModelSettingsScreen/TextGenerationSection.tsx @@ -16,6 +16,8 @@ export const TextGenerationSection: React.FC = () => { const { settings, updateSettings } = useAppStore(); const modelMaxContext = useAppStore((s) => s.modelMaxContext); const [showAdvanced, setShowAdvanced] = useState(false); + const { downloadedModels, activeModelId } = useAppStore.getState(); + const isLiteRT = downloadedModels.find(m => m.id === activeModelId)?.engine === 'litert'; const trackColor = { false: colors.surfaceLight, true: `${colors.primary}80` }; const maxTokens = settings?.maxTokens || 512; @@ -53,48 +55,52 @@ export const TextGenerationSection: React.FC = () => { /> - - - Max Tokens - {maxTokensLabel} + {!isLiteRT && ( + + + Max Tokens + {maxTokensLabel} + + Maximum response length + updateSettings({ maxTokens: value })} + minimumTrackTintColor={colors.primary} + maximumTrackTintColor={colors.surface} + thumbTintColor={colors.primary} + /> - Maximum response length - updateSettings({ maxTokens: value })} - minimumTrackTintColor={colors.primary} - maximumTrackTintColor={colors.surface} - thumbTintColor={colors.primary} - /> - + )} - - - Context Length - {contextLengthLabel} + {!isLiteRT && ( + + + Context Length + {contextLengthLabel} + + KV cache size — larger uses more RAM (requires reload) + {contextLength > HIGH_CONTEXT_THRESHOLD && ( + + High context uses significant RAM and may crash on some devices + + )} + updateSettings({ contextLength: value })} + minimumTrackTintColor={colors.primary} + maximumTrackTintColor={colors.surface} + thumbTintColor={colors.primary} + /> - KV cache size — larger uses more RAM (requires reload) - {contextLength > HIGH_CONTEXT_THRESHOLD && ( - - High context uses significant RAM and may crash on some devices - - )} - updateSettings({ contextLength: value })} - minimumTrackTintColor={colors.primary} - maximumTrackTintColor={colors.surface} - thumbTintColor={colors.primary} - /> - + )} @@ -113,7 +119,7 @@ export const TextGenerationSection: React.FC = () => { setShowAdvanced(!showAdvanced)} testID="text-advanced-toggle" /> - {showAdvanced && } + {showAdvanced && } ); }; diff --git a/src/screens/ModelsScreen/useModelsScreen.ts b/src/screens/ModelsScreen/useModelsScreen.ts index c830756d5..a1e5a8087 100644 --- a/src/screens/ModelsScreen/useModelsScreen.ts +++ b/src/screens/ModelsScreen/useModelsScreen.ts @@ -9,6 +9,7 @@ import { useFocusTrigger } from '../../hooks/useFocusTrigger'; import { useAppStore } from '../../stores'; import { useDownloadStore, isActiveStatus } from '../../stores/downloadStore'; import { modelManager } from '../../services'; +import { liteRTService } from '../../services/litert'; import { resolveCoreMLModelDir } from '../../utils/coreMLModelUtils'; import { ONNXImageModel } from '../../types'; import { ModelTab, NavigationProp } from './types'; @@ -132,6 +133,11 @@ export function useModelsScreen() { const singleZip = resolvedFiles.length === 1 && resolvedFiles[0].name.toLowerCase().endsWith('.zip'); const singleLitert = resolvedFiles.length === 1 && resolvedFiles[0].name.toLowerCase().endsWith('.litertlm'); + if (singleLitert && !liteRTService.isAvailable()) { + setAlertState(showAlert('Not Supported', 'LiteRT models are only supported on Android.')); + return; + } + if (!allGguf && !singleZip && !singleLitert) { setAlertState(showAlert( 'Invalid File', diff --git a/src/screens/SettingsScreen.tsx b/src/screens/SettingsScreen.tsx index 8221e62f3..e28ac8a36 100644 --- a/src/screens/SettingsScreen.tsx +++ b/src/screens/SettingsScreen.tsx @@ -25,6 +25,7 @@ import DeviceInfo from 'react-native-device-info'; import RNFS from 'react-native-fs'; import { useAppStore, useRemoteServerStore } from '../stores'; import { hardwareService } from '../services'; +import { liteRTService } from '../services/litert'; import { RootStackParamList, MainTabParamList } from '../navigation/types'; import { GITHUB_URL, SHARE_ON_X_URL } from '../utils/sharePrompt'; import packageJson from '../../package.json'; @@ -217,10 +218,10 @@ export const SettingsScreen: React.FC = () => { Version - {packageJson.version} + {packageJson.version} (litertsupport) - Off Grid brings AI to your device without compromising your privacy. + LiteRT engine: multi-turn memory, stop generation, sampler settings, iOS guard, memory budget. @@ -250,13 +251,15 @@ export const SettingsScreen: React.FC = () => { Reset Onboarding Checklist - navigation.getParent()?.navigate('DebugLiteRT')} - > - - LiteRT Debug - + {liteRTService.isAvailable() && ( + navigation.getParent()?.navigate('DebugLiteRT')} + > + + LiteRT Debug + + )} From d560373a930669ef05aa2d9415aa9de4c5deebbf Mon Sep 17 00:00:00 2001 From: Dishit Date: Tue, 19 May 2026 15:08:32 +0530 Subject: [PATCH 05/93] test: add LiteRT service tests and improve existing coverage Co-Authored-By: Dishit Karia --- .../rntl/screens/SettingsScreen.test.tsx | 4 +- .../activeModelService.loaders.test.ts | 175 ++++++++++ .../activeModelService.memory.test.ts | 146 ++++++++ .../services/activeModelService.utils.test.ts | 99 ++++++ .../services/generationServiceHelpers.test.ts | 254 ++++++++++++++ __tests__/unit/services/httpClient.test.ts | 133 +++++++ __tests__/unit/services/litert.test.ts | 170 +++++++++ .../unit/services/llmSafetyChecks.test.ts | 59 +++- .../providers/openAICompatibleStream.test.ts | 325 ++++++++++++++++++ __tests__/unit/services/rag/embedding.test.ts | 50 +++ .../OffgridMobileTests.swift | 27 +- 11 files changed, 1430 insertions(+), 12 deletions(-) create mode 100644 __tests__/unit/services/activeModelService.loaders.test.ts create mode 100644 __tests__/unit/services/activeModelService.memory.test.ts create mode 100644 __tests__/unit/services/activeModelService.utils.test.ts create mode 100644 __tests__/unit/services/generationServiceHelpers.test.ts create mode 100644 __tests__/unit/services/litert.test.ts create mode 100644 __tests__/unit/services/providers/openAICompatibleStream.test.ts diff --git a/__tests__/rntl/screens/SettingsScreen.test.tsx b/__tests__/rntl/screens/SettingsScreen.test.tsx index 25456e80a..9ac8e9afc 100644 --- a/__tests__/rntl/screens/SettingsScreen.test.tsx +++ b/__tests__/rntl/screens/SettingsScreen.test.tsx @@ -90,7 +90,7 @@ describe('SettingsScreen', () => { it('renders version number', () => { const { getByText } = render(); - expect(getByText('1.0.0')).toBeTruthy(); + expect(getByText(/1\.0\.0/)).toBeTruthy(); }); it('renders navigation items', () => { @@ -156,7 +156,7 @@ describe('SettingsScreen', () => { it('renders about section text', () => { const { getByText } = render(); expect(getByText('Version')).toBeTruthy(); - expect(getByText(/Off Grid brings AI/)).toBeTruthy(); + expect(getByText(/LiteRT engine/)).toBeTruthy(); }); it('renders Reset Onboarding button in __DEV__ mode', () => { diff --git a/__tests__/unit/services/activeModelService.loaders.test.ts b/__tests__/unit/services/activeModelService.loaders.test.ts new file mode 100644 index 000000000..c9cfdf2bd --- /dev/null +++ b/__tests__/unit/services/activeModelService.loaders.test.ts @@ -0,0 +1,175 @@ +/** + * Unit tests for activeModelService/loaders.ts + * Covers inferenceBackendToLiteRT switch and isMMProjFile branches. + */ + +jest.mock('../../../src/stores', () => ({ + useAppStore: { getState: jest.fn() }, +})); +jest.mock('../../../src/stores/debugLogsStore', () => ({ + useDebugLogsStore: { getState: jest.fn(() => ({ addLog: jest.fn() })) }, +})); +jest.mock('../../../src/services/llm', () => ({ + llmService: { loadModel: jest.fn(), unloadModel: jest.fn(), getMultimodalSupport: jest.fn(() => null) }, +})); +jest.mock('../../../src/services/litert', () => ({ + liteRTService: { loadModel: jest.fn(), unloadModel: jest.fn(), getActiveBackend: jest.fn(() => 'cpu') }, +})); +jest.mock('../../../src/services/localDreamGenerator', () => ({ + localDreamGeneratorService: { loadModel: jest.fn(), unloadModel: jest.fn() }, +})); +jest.mock('../../../src/services/modelManager', () => ({ + modelManager: { saveModelWithMmproj: jest.fn(), clearMmProjLink: jest.fn() }, +})); +jest.mock('react-native-fs', () => ({ + exists: jest.fn(() => Promise.resolve(false)), + readDir: jest.fn(() => Promise.resolve([])), +})); +jest.mock('../../../src/utils/logger', () => ({ + __esModule: true, + default: { log: jest.fn(), warn: jest.fn(), error: jest.fn() }, +})); + +import RNFS from 'react-native-fs'; +import { doLoadTextModel, resolveMmProjPath } from '../../../src/services/activeModelService/loaders'; +import { liteRTService } from '../../../src/services/litert'; +import { llmService } from '../../../src/services/llm'; +import { useAppStore } from '../../../src/stores'; + +const mockedRNFS = RNFS as jest.Mocked; +const mockedLiteRT = liteRTService as jest.Mocked; +const mockedLlm = llmService as jest.Mocked; +const mockedGetState = useAppStore.getState as jest.Mock; + +function makeStore(overrides: any = {}) { + return { + settings: { inferenceBackend: undefined, enableGpu: false, gpuLayers: 0, nThreads: 4, nBatch: 512, contextLength: 2048, flashAttn: false, cacheType: 'ram' }, + downloadedModels: [], + setDownloadedModels: jest.fn(), + setActiveModelId: jest.fn(), + setLoadedSettings: jest.fn(), + ...overrides, + }; +} + +function makeCtx(overrides: any = {}) { + return { + model: { id: 'model-1', fileName: 'model.gguf', filePath: '/models/model.gguf', engine: 'ggml', ...overrides.model }, + modelId: 'model-1', + store: makeStore(overrides.store), + timeoutMs: 30000, + loadedTextModelId: null, + onLoaded: jest.fn(), + onError: jest.fn(), + onFinally: jest.fn(), + ...overrides, + }; +} + +describe('resolveMmProjPath', () => { + beforeEach(() => jest.clearAllMocks()); + + it('returns mmProjPath from model when file exists on disk', async () => { + mockedRNFS.exists.mockResolvedValue(true); + const model = { filePath: '/models/m.gguf', mmProjPath: '/models/mmproj.gguf' } as any; + const result = await resolveMmProjPath(model, 'model-1'); + expect(result).toBe('/models/mmproj.gguf'); + }); + + it('returns undefined when no mmproj file found in directory', async () => { + mockedRNFS.exists.mockResolvedValue(false); + mockedRNFS.readDir.mockResolvedValue([]); + const model = { filePath: '/models/m.gguf' } as any; + const result = await resolveMmProjPath(model, 'model-1'); + expect(result).toBeUndefined(); + }); + + it('finds mmproj file via directory scan when stored path is stale', async () => { + mockedRNFS.exists.mockResolvedValue(false); + mockedRNFS.readDir.mockResolvedValue([ + { name: 'mmproj-model-f16.gguf', path: '/models/mmproj-model-f16.gguf', isFile: () => true, size: 500 } as any, + ]); + mockedGetState.mockReturnValue({ + downloadedModels: [{ id: 'model-1' }], + setDownloadedModels: jest.fn(), + }); + const { modelManager } = require('../../../src/services/modelManager'); + modelManager.saveModelWithMmproj.mockResolvedValue(undefined); + + const model = { filePath: '/models/m.gguf', mmProjPath: '/stale/path.gguf' } as any; + const result = await resolveMmProjPath(model, 'model-1'); + expect(result).toBe('/models/mmproj-model-f16.gguf'); + }); +}); + +describe('doLoadTextModel — llama.cpp path', () => { + beforeEach(() => jest.clearAllMocks()); + + it('calls llmService.loadModel and onLoaded on success', async () => { + mockedLlm.loadModel.mockResolvedValue(undefined); + mockedRNFS.exists.mockResolvedValue(false); + mockedRNFS.readDir.mockResolvedValue([]); + const ctx = makeCtx(); + mockedGetState.mockReturnValue(ctx.store); + + await doLoadTextModel(ctx); + + expect(mockedLlm.loadModel).toHaveBeenCalled(); + expect(ctx.onLoaded).toHaveBeenCalledWith('model-1'); + expect(ctx.onFinally).toHaveBeenCalled(); + }); + + it('calls onError and rethrows when llmService.loadModel fails', async () => { + mockedLlm.loadModel.mockRejectedValue(new Error('load failed')); + mockedRNFS.exists.mockResolvedValue(false); + mockedRNFS.readDir.mockResolvedValue([]); + const ctx = makeCtx(); + mockedGetState.mockReturnValue(ctx.store); + + await expect(doLoadTextModel(ctx)).rejects.toThrow('load failed'); + expect(ctx.onError).toHaveBeenCalled(); + expect(ctx.onFinally).toHaveBeenCalled(); + }); + + it('unloads previous model when loadedTextModelId differs', async () => { + mockedLlm.loadModel.mockResolvedValue(undefined); + mockedLlm.unloadModel.mockResolvedValue(undefined); + mockedRNFS.exists.mockResolvedValue(false); + mockedRNFS.readDir.mockResolvedValue([]); + const ctx = makeCtx({ loadedTextModelId: 'old-model' }); + mockedGetState.mockReturnValue(ctx.store); + + await doLoadTextModel(ctx); + expect(mockedLlm.unloadModel).toHaveBeenCalled(); + }); +}); + +describe('doLoadTextModel — LiteRT path', () => { + beforeEach(() => jest.clearAllMocks()); + + it('routes to liteRTService when engine=litert', async () => { + mockedLiteRT.loadModel.mockResolvedValue(undefined); + mockedLiteRT.getActiveBackend.mockReturnValue('cpu'); + const ctx = makeCtx({ model: { id: 'model-1', fileName: 'model.litertlm', filePath: '/models/model.litertlm', engine: 'litert' } }); + const { useDebugLogsStore } = require('../../../src/stores/debugLogsStore'); + useDebugLogsStore.getState.mockReturnValue({ addLog: jest.fn() }); + + await doLoadTextModel(ctx); + + expect(mockedLiteRT.loadModel).toHaveBeenCalled(); + expect(mockedLlm.loadModel).not.toHaveBeenCalled(); + expect(ctx.onLoaded).toHaveBeenCalledWith('model-1'); + }); + + it('calls onError and rethrows when liteRTService.loadModel fails', async () => { + mockedLiteRT.loadModel.mockRejectedValue(new Error('litert failed')); + mockedLiteRT.getActiveBackend.mockReturnValue('cpu'); + const ctx = makeCtx({ model: { id: 'model-1', fileName: 'model.litertlm', filePath: '/models/model.litertlm', engine: 'litert' } }); + const { useDebugLogsStore } = require('../../../src/stores/debugLogsStore'); + useDebugLogsStore.getState.mockReturnValue({ addLog: jest.fn() }); + + await expect(doLoadTextModel(ctx)).rejects.toThrow('litert failed'); + expect(ctx.onError).toHaveBeenCalled(); + expect(ctx.onFinally).toHaveBeenCalled(); + }); +}); diff --git a/__tests__/unit/services/activeModelService.memory.test.ts b/__tests__/unit/services/activeModelService.memory.test.ts new file mode 100644 index 000000000..e8b3e48f3 --- /dev/null +++ b/__tests__/unit/services/activeModelService.memory.test.ts @@ -0,0 +1,146 @@ +/** + * Unit tests for activeModelService/memory.ts + * Focuses on LiteRT-only branches (liteRTService loaded, llmService not loaded). + */ + +import { getCurrentlyLoadedMemoryGB, getOtherLoadedMemoryGB } from '../../../src/services/activeModelService/memory'; + +jest.mock('../../../src/services/llm', () => ({ + llmService: { isModelLoaded: jest.fn() }, +})); + +jest.mock('../../../src/services/litert', () => ({ + liteRTService: { isModelLoaded: jest.fn() }, +})); + +jest.mock('../../../src/services/hardware', () => ({ + hardwareService: { + getDeviceInfo: jest.fn(() => Promise.resolve({ totalMemory: 8 * 1024 * 1024 * 1024 })), + }, +})); + +import { llmService } from '../../../src/services/llm'; +import { liteRTService } from '../../../src/services/litert'; + +const mockedLlm = llmService as jest.Mocked; +const mockedLiteRT = liteRTService as jest.Mocked; + +const TEXT_MODEL = { id: 'model-1', name: 'Test Model', fileSize: 4 * 1024 * 1024 * 1024 } as any; +const IMAGE_MODEL = { id: 'img-1', name: 'Image Model', size: 2 * 1024 * 1024 * 1024 } as any; + +const LISTS = { + downloadedModels: [TEXT_MODEL], + downloadedImageModels: [IMAGE_MODEL], +}; + +describe('getCurrentlyLoadedMemoryGB', () => { + beforeEach(() => jest.clearAllMocks()); + + it('counts text model memory when only liteRTService is loaded', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(true); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const result = getCurrentlyLoadedMemoryGB( + { loadedTextModelId: 'model-1', loadedImageModelId: null }, + LISTS, + ); + + // 4 GB * TEXT_MODEL_OVERHEAD_MULTIPLIER (1.5) + expect(result).toBeCloseTo(6, 1); + }); + + it('returns 0 for text model when both services report not loaded', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(false); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const result = getCurrentlyLoadedMemoryGB( + { loadedTextModelId: 'model-1', loadedImageModelId: null }, + LISTS, + ); + + expect(result).toBe(0); + }); + + it('counts text model memory when only llmService is loaded', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(false); + mockedLlm.isModelLoaded.mockReturnValue(true); + + const result = getCurrentlyLoadedMemoryGB( + { loadedTextModelId: 'model-1', loadedImageModelId: null }, + LISTS, + ); + + expect(result).toBeGreaterThan(0); + }); + + it('includes image model memory regardless of text model loaded state', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(false); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const result = getCurrentlyLoadedMemoryGB( + { loadedTextModelId: null, loadedImageModelId: 'img-1' }, + LISTS, + ); + + // 2 GB * IMAGE_MODEL_OVERHEAD_MULTIPLIER (1.5 on iOS, 1.8 on Android) + expect(result).toBeGreaterThan(2.9); + }); + + it('sums both models when liteRT loaded and image model also loaded', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(true); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const result = getCurrentlyLoadedMemoryGB( + { loadedTextModelId: 'model-1', loadedImageModelId: 'img-1' }, + LISTS, + ); + + // text(6) + image(3 or 3.6) - just verify it's greater than text alone + expect(result).toBeGreaterThan(6); + }); +}); + +describe('getOtherLoadedMemoryGB', () => { + beforeEach(() => jest.clearAllMocks()); + + it('counts text model memory (LiteRT only loaded) when loading an image model', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(true); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const result = getOtherLoadedMemoryGB( + 'image', + { loadedTextModelId: 'model-1', loadedImageModelId: null }, + LISTS, + ); + + // 4 GB * TEXT_MODEL_OVERHEAD_MULTIPLIER (1.5) + expect(result).toBeCloseTo(6, 1); + }); + + it('returns 0 for image model loading when neither service is loaded', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(false); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const result = getOtherLoadedMemoryGB( + 'image', + { loadedTextModelId: 'model-1', loadedImageModelId: null }, + LISTS, + ); + + expect(result).toBe(0); + }); + + it('counts image model memory when loading a text model (no service check needed)', () => { + mockedLiteRT.isModelLoaded.mockReturnValue(false); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const result = getOtherLoadedMemoryGB( + 'text', + { loadedTextModelId: null, loadedImageModelId: 'img-1' }, + LISTS, + ); + + // 2 GB * IMAGE_MODEL_OVERHEAD_MULTIPLIER + expect(result).toBeGreaterThan(2.9); + }); +}); diff --git a/__tests__/unit/services/activeModelService.utils.test.ts b/__tests__/unit/services/activeModelService.utils.test.ts new file mode 100644 index 000000000..d6d7ef5da --- /dev/null +++ b/__tests__/unit/services/activeModelService.utils.test.ts @@ -0,0 +1,99 @@ +/** + * Unit tests for activeModelService/utils.ts + * Focuses on syncWithNativeState LiteRT branch. + */ + +import { syncWithNativeState } from '../../../src/services/activeModelService/utils'; + +jest.mock('../../../src/services/llm', () => ({ + llmService: { isModelLoaded: jest.fn() }, +})); + +jest.mock('../../../src/services/litert', () => ({ + liteRTService: { isModelLoaded: jest.fn() }, +})); + +jest.mock('../../../src/services/localDreamGenerator', () => ({ + localDreamGeneratorService: { isModelLoaded: jest.fn(() => Promise.resolve(false)) }, +})); + +jest.mock('../../../src/stores', () => ({ + useAppStore: { + getState: jest.fn(() => ({ + activeModelId: 'model-abc', + activeImageModelId: null, + downloadedModels: [], + downloadedImageModels: [], + })), + }, +})); + +jest.mock('../../../src/services/hardware', () => ({ + hardwareService: { + refreshMemoryInfo: jest.fn(() => + Promise.resolve({ usedMemory: 0, totalMemory: 0, availableMemory: 0 }), + ), + }, +})); + +import { llmService } from '../../../src/services/llm'; +import { liteRTService } from '../../../src/services/litert'; +import { useAppStore } from '../../../src/stores'; + +const mockedLlm = llmService as jest.Mocked; +const mockedLiteRT = liteRTService as jest.Mocked; + +function makeTarget(overrides: Partial<{ loadedTextModelId: string | null; loadedImageModelId: string | null }> = {}) { + const t = { + loadedTextModelId: overrides.loadedTextModelId ?? null, + loadedImageModelId: overrides.loadedImageModelId ?? null, + setLoadedTextModelId: jest.fn((id: string | null) => { t.loadedTextModelId = id; }), + setLoadedImageModelId: jest.fn((id: string | null) => { t.loadedImageModelId = id; }), + setLoadedImageModelThreads: jest.fn(), + }; + return t; +} + +describe('syncWithNativeState', () => { + beforeEach(() => jest.clearAllMocks()); + + it('sets loadedTextModelId when only liteRTService is loaded and target has no id', async () => { + mockedLiteRT.isModelLoaded.mockReturnValue(true); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const target = makeTarget({ loadedTextModelId: null }); + await syncWithNativeState(target); + + expect(target.setLoadedTextModelId).toHaveBeenCalledWith('model-abc'); + }); + + it('clears loadedTextModelId when both services are not loaded', async () => { + mockedLiteRT.isModelLoaded.mockReturnValue(false); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const target = makeTarget({ loadedTextModelId: 'old-model' }); + await syncWithNativeState(target); + + expect(target.setLoadedTextModelId).toHaveBeenCalledWith(null); + }); + + it('does not overwrite existing loadedTextModelId when liteRTService is loaded', async () => { + mockedLiteRT.isModelLoaded.mockReturnValue(true); + mockedLlm.isModelLoaded.mockReturnValue(false); + + const target = makeTarget({ loadedTextModelId: 'already-set' }); + await syncWithNativeState(target); + + expect(target.setLoadedTextModelId).not.toHaveBeenCalled(); + }); + + it('sets loadedTextModelId when only llmService is loaded and target has no id', async () => { + mockedLiteRT.isModelLoaded.mockReturnValue(false); + mockedLlm.isModelLoaded.mockReturnValue(true); + + const target = makeTarget({ loadedTextModelId: null }); + await syncWithNativeState(target); + + expect(target.setLoadedTextModelId).toHaveBeenCalledWith('model-abc'); + }); +}); diff --git a/__tests__/unit/services/generationServiceHelpers.test.ts b/__tests__/unit/services/generationServiceHelpers.test.ts new file mode 100644 index 000000000..d283662c3 --- /dev/null +++ b/__tests__/unit/services/generationServiceHelpers.test.ts @@ -0,0 +1,254 @@ +/** + * Unit tests for generationServiceHelpers.ts + * Focuses on vision guard and buildGenerationMetaImpl LiteRT branches. + */ + +import { buildGenerationMetaImpl, FLUSH_INTERVAL_MS } from '../../../src/services/generationServiceHelpers'; + +jest.mock('../../../src/services/llm', () => ({ + llmService: { + isModelLoaded: jest.fn(() => false), + isCurrentlyGenerating: jest.fn(() => false), + getGpuInfo: jest.fn(() => ({ gpu: false, gpuBackend: 'CPU', gpuLayers: 0 })), + getPerformanceStats: jest.fn(() => ({ + lastTokensPerSecond: 10, + lastDecodeTokensPerSecond: 12, + lastTimeToFirstToken: 0.4, + lastGenerationTime: 2, + lastTokenCount: 40, + })), + }, +})); + +jest.mock('../../../src/services/litert', () => ({ + liteRTService: { + isModelLoaded: jest.fn(() => false), + getActiveBackend: jest.fn(() => 'cpu'), + prepareConversation: jest.fn(() => Promise.resolve()), + sendMessage: jest.fn(() => Promise.resolve()), + }, +})); + +jest.mock('../../../src/stores', () => ({ + useAppStore: { + getState: jest.fn(), + }, + useChatStore: { + getState: jest.fn(() => ({ + startStreaming: jest.fn(), + clearStreamingMessage: jest.fn(), + appendToStreamingMessage: jest.fn(), + finalizeStreamingMessage: jest.fn(), + })), + }, + useRemoteServerStore: { + getState: jest.fn(() => ({ + getActiveServer: jest.fn(() => null), + activeServerId: null, + updateServerHealth: jest.fn(), + })), + }, +})); + +jest.mock('../../../src/stores/debugLogsStore', () => ({ + useDebugLogsStore: { + getState: jest.fn(() => ({ addLog: jest.fn() })), + }, +})); + +jest.mock('../../../src/services/generationToolLoop', () => ({ + runToolLoop: jest.fn(() => Promise.resolve()), +})); + +jest.mock('../../../src/utils/logger', () => ({ + default: { log: jest.fn(), error: jest.fn(), warn: jest.fn() }, +})); + +import { useAppStore } from '../../../src/stores'; +import { liteRTService } from '../../../src/services/litert'; + +const mockedGetState = useAppStore.getState as jest.Mock; +const mockedLiteRT = liteRTService as jest.Mocked; + +function makeLiteRTAppState(overrides: any = {}) { + return { + downloadedModels: [{ id: 'litert-1', name: 'LiteRT Model', engine: 'litert', ...overrides.modelProps }], + activeModelId: 'litert-1', + downloadedImageModels: [], + activeImageModelId: null, + settings: { temperature: 0.7, topP: 0.9, cacheType: 'ram', maxTokens: 512, thinkingEnabled: false }, + ...overrides.storeProps, + }; +} + +describe('buildGenerationMetaImpl — remote provider path', () => { + beforeEach(() => jest.clearAllMocks()); + + it('returns remote meta with estimated token count', () => { + const { useRemoteServerStore } = require('../../../src/stores'); + useRemoteServerStore.getState.mockReturnValue({ + getActiveServer: () => ({ name: 'My Server' }), + activeServerId: 'srv-1', + updateServerHealth: jest.fn(), + }); + + const svc = { + isUsingRemoteProvider: () => true, + state: { streamingContent: 'hello world test', startTime: Date.now() - 2000 }, + totalReasoningLength: 8, + remoteTimeToFirstToken: 0.3, + }; + + const meta = buildGenerationMetaImpl(svc); + expect(meta.gpuBackend).toBe('Remote'); + expect(meta.modelName).toBe('My Server'); + expect(meta.gpu).toBe(false); + expect(meta.tokenCount).toBeGreaterThan(0); + expect(meta.timeToFirstToken).toBe(0.3); + }); + + it('uses fallback name when no active server', () => { + const { useRemoteServerStore } = require('../../../src/stores'); + useRemoteServerStore.getState.mockReturnValue({ + getActiveServer: () => null, + activeServerId: null, + updateServerHealth: jest.fn(), + }); + + const svc = { + isUsingRemoteProvider: () => true, + state: { streamingContent: 'tokens', startTime: null }, + totalReasoningLength: 0, + remoteTimeToFirstToken: undefined, + }; + + const meta = buildGenerationMetaImpl(svc); + expect(meta.modelName).toBe('Remote Model'); + expect(meta.tokensPerSecond).toBeUndefined(); + }); +}); + +describe('buildGenerationMetaImpl — llama.cpp path', () => { + beforeEach(() => jest.clearAllMocks()); + + it('returns llama.cpp perf stats when model engine is not litert', () => { + const { llmService } = require('../../../src/services/llm'); + llmService.getGpuInfo.mockReturnValue({ gpu: true, gpuBackend: 'Metal', gpuLayers: 32 }); + llmService.getPerformanceStats.mockReturnValue({ + lastTokensPerSecond: 25, + lastDecodeTokensPerSecond: 28, + lastTimeToFirstToken: 0.6, + lastGenerationTime: 3, + lastTokenCount: 75, + }); + + mockedGetState.mockReturnValue({ + downloadedModels: [{ id: 'llm-1', name: 'Llama-3', engine: 'ggml' }], + activeModelId: 'llm-1', + downloadedImageModels: [], + activeImageModelId: null, + settings: { cacheType: 'flash', temperature: 0.7, topP: 0.9, maxTokens: 512, thinkingEnabled: false }, + }); + + const svc = { + isUsingRemoteProvider: () => false, + liteRTBenchmarkStats: null, + state: { streamingContent: '', startTime: Date.now() }, + }; + + const meta = buildGenerationMetaImpl(svc); + expect(meta.gpu).toBe(true); + expect(meta.gpuBackend).toBe('Metal'); + expect(meta.tokensPerSecond).toBe(25); + expect(meta.tokenCount).toBe(75); + expect(meta.cacheType).toBe('flash'); + }); +}); + +describe('buildGenerationMetaImpl — LiteRT path', () => { + beforeEach(() => jest.clearAllMocks()); + + it('returns real benchmark stats when liteRTBenchmarkStats is set', () => { + mockedGetState.mockReturnValue(makeLiteRTAppState()); + mockedLiteRT.getActiveBackend.mockReturnValue('gpu'); + + const svc = { + isUsingRemoteProvider: () => false, + liteRTBenchmarkStats: { + decodeTokensPerSecond: 42, + ttft: 0.12, + prefillTokenCount: 128, + }, + state: { streamingContent: 'hello world', startTime: Date.now() - 2000 }, + }; + + const meta = buildGenerationMetaImpl(svc); + + expect(meta.tokensPerSecond).toBe(42); + expect(meta.timeToFirstToken).toBeCloseTo(120, 0); + expect(meta.tokenCount).toBe(128); + expect(meta.gpu).toBe(true); + expect(meta.gpuBackend).toBe('GPU'); + }); + + it('falls back to estimate when liteRTBenchmarkStats is null', () => { + mockedGetState.mockReturnValue(makeLiteRTAppState()); + mockedLiteRT.getActiveBackend.mockReturnValue('cpu'); + + const startTime = Date.now() - 4000; + const svc = { + isUsingRemoteProvider: () => false, + liteRTBenchmarkStats: null, + state: { streamingContent: 'abcd'.repeat(50), startTime }, + }; + + const meta = buildGenerationMetaImpl(svc); + + expect(meta.tokenCount).toBe(Math.ceil(svc.state.streamingContent.length / 4)); + expect(meta.tokensPerSecond).toBeGreaterThan(0); + expect(meta.gpu).toBe(false); + }); + + it('sets gpu=true when backend is npu', () => { + mockedGetState.mockReturnValue(makeLiteRTAppState()); + mockedLiteRT.getActiveBackend.mockReturnValue('npu'); + + const svc = { + isUsingRemoteProvider: () => false, + liteRTBenchmarkStats: { decodeTokensPerSecond: 30, ttft: 0.2, prefillTokenCount: 64 }, + state: { streamingContent: '', startTime: Date.now() }, + }; + + const meta = buildGenerationMetaImpl(svc); + expect(meta.gpu).toBe(true); + expect(meta.gpuBackend).toBe('NPU'); + }); + + it('returns model name from downloadedModels', () => { + mockedGetState.mockReturnValue(makeLiteRTAppState({ modelProps: { name: 'Gemma-3' } })); + mockedLiteRT.getActiveBackend.mockReturnValue('cpu'); + + const svc = { + isUsingRemoteProvider: () => false, + liteRTBenchmarkStats: { decodeTokensPerSecond: 20, ttft: 0.1, prefillTokenCount: 64 }, + state: { streamingContent: '', startTime: Date.now() }, + }; + + const meta = buildGenerationMetaImpl(svc); + expect(meta.modelName).toBe('Gemma-3'); + }); + + it('returns undefined tokensPerSecond when startTime is null (fallback path)', () => { + mockedGetState.mockReturnValue(makeLiteRTAppState()); + mockedLiteRT.getActiveBackend.mockReturnValue('cpu'); + + const svc = { + isUsingRemoteProvider: () => false, + liteRTBenchmarkStats: null, + state: { streamingContent: 'some text', startTime: null }, + }; + + const meta = buildGenerationMetaImpl(svc); + expect(meta.tokensPerSecond).toBeUndefined(); + }); +}); diff --git a/__tests__/unit/services/httpClient.test.ts b/__tests__/unit/services/httpClient.test.ts index f713d315e..c2a340813 100644 --- a/__tests__/unit/services/httpClient.test.ts +++ b/__tests__/unit/services/httpClient.test.ts @@ -15,6 +15,7 @@ import { imageToBase64DataUrl, detectServerType, createStreamingRequest, + createNDJSONStreamingRequest, } from '../../../src/services/httpClient'; // Mock React Native FS @@ -1107,4 +1108,136 @@ describe('httpClient', () => { expect(result).toBeNull(); }); }); + + // ─── createNDJSONStreamingRequest ───────────────────────────────────────── + + describe('createNDJSONStreamingRequest', () => { + let mockXHR: any; + + beforeEach(() => { + mockXHR = { + open: jest.fn(), + setRequestHeader: jest.fn(), + send: jest.fn(), + abort: jest.fn(), + addEventListener: jest.fn((event: string, cb: () => void) => { + if (event === 'abort') mockXHR._abortCb = cb; + }), + readyState: 0, + status: 200, + responseText: '', + onreadystatechange: null as any, + onprogress: null as any, + onerror: null as any, + ontimeout: null as any, + }; + (global as any).XMLHttpRequest = jest.fn(() => mockXHR); + }); + + function simulateSuccess(responseText = '') { + mockXHR.responseText = responseText; + mockXHR.readyState = 4; + mockXHR.status = 200; + mockXHR.onreadystatechange?.(); + } + + it('resolves and calls onLine for each complete NDJSON line', async () => { + const onLine = jest.fn(); + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, onLine); + simulateSuccess('{"done":false}\n{"done":true}\n'); + await promise; + expect(onLine).toHaveBeenCalledTimes(2); + expect(onLine).toHaveBeenCalledWith({ done: false }); + expect(onLine).toHaveBeenCalledWith({ done: true }); + }); + + it('flushes partial buffered line on readyState=4', async () => { + const onLine = jest.fn(); + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, onLine); + // No trailing newline — sits in lineBuffer until completion + simulateSuccess('{"done":true}'); + await promise; + expect(onLine).toHaveBeenCalledWith({ done: true }); + }); + + it('rejects on HTTP error status', async () => { + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, jest.fn()); + mockXHR.responseText = 'Internal Server Error'; + mockXHR.readyState = 4; + mockXHR.status = 500; + mockXHR.onreadystatechange?.(); + await expect(promise).rejects.toThrow('HTTP 500'); + }); + + it('rejects on network error', async () => { + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, jest.fn()); + mockXHR.onerror?.(); + await expect(promise).rejects.toThrow('Network error'); + }); + + it('rejects on timeout', async () => { + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, jest.fn()); + mockXHR.ontimeout?.(); + await expect(promise).rejects.toThrow('Request timeout'); + }); + + it('skips empty/blank lines', async () => { + const onLine = jest.fn(); + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, onLine); + simulateSuccess('\n\n{"done":true}\n\n'); + await promise; + expect(onLine).toHaveBeenCalledTimes(1); + }); + + it('warns and skips invalid JSON lines', async () => { + const onLine = jest.fn(); + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, onLine); + simulateSuccess('not-json\n{"ok":true}\n'); + await promise; + expect(onLine).toHaveBeenCalledTimes(1); + expect(onLine).toHaveBeenCalledWith({ ok: true }); + }); + + it('sets custom headers', async () => { + const promise = createNDJSONStreamingRequest( + 'http://localhost/api/chat', + { body: {}, headers: { Authorization: 'Bearer token' } }, + jest.fn(), + ); + simulateSuccess(''); + await promise; + expect(mockXHR.setRequestHeader).toHaveBeenCalledWith('Authorization', 'Bearer token'); + }); + + it('processes onprogress chunks and merges partial lines', async () => { + const onLine = jest.fn(); + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, onLine); + // First progress event delivers half a line + mockXHR.responseText = '{"a":1}\n{"b":'; + mockXHR.onprogress?.(); + // Second progress delivers rest + mockXHR.responseText = '{"a":1}\n{"b":2}\n'; + mockXHR.onprogress?.(); + simulateSuccess('{"a":1}\n{"b":2}\n'); + await promise; + expect(onLine).toHaveBeenCalledWith({ a: 1 }); + expect(onLine).toHaveBeenCalledWith({ b: 2 }); + }); + + it('warns and skips invalid JSON in buffered final line', async () => { + const onLine = jest.fn(); + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, onLine); + // No trailing newline so it ends up in lineBuffer — invalid JSON + simulateSuccess('not-valid-json'); + await promise; + expect(onLine).not.toHaveBeenCalled(); + }); + + it('rejects when xhr.send throws', async () => { + mockXHR.send = jest.fn(() => { throw new Error('send failed'); }); + const promise = createNDJSONStreamingRequest('http://localhost/api/chat', { body: {} }, jest.fn()); + await expect(promise).rejects.toThrow('send failed'); + }); + + }); }); \ No newline at end of file diff --git a/__tests__/unit/services/litert.test.ts b/__tests__/unit/services/litert.test.ts new file mode 100644 index 000000000..7fa5ff262 --- /dev/null +++ b/__tests__/unit/services/litert.test.ts @@ -0,0 +1,170 @@ +/** + * Unit tests for litert.ts + * Targets the state-machine branches that don't require native hardware. + */ + +// Mock NativeModules BEFORE importing the service +const mockLiteRTModule = { + loadModel: jest.fn(), + resetConversation: jest.fn(), + sendMessage: jest.fn(), + stopGeneration: jest.fn(), + unloadModel: jest.fn(), + getMemoryInfo: jest.fn(), +}; + +const mockAddListener = jest.fn(() => ({ remove: jest.fn() })); +const mockEmitter = { addListener: mockAddListener }; + +jest.mock('react-native', () => ({ + NativeModules: { LiteRTModule: mockLiteRTModule }, + NativeEventEmitter: jest.fn(() => mockEmitter), + Platform: { OS: 'android' }, +})); + +jest.mock('../../../src/utils/logger', () => { + const log = jest.fn(); + return { __esModule: true, default: { log, error: log, warn: log } }; +}); + +// Import after mocks are set up +import { liteRTService } from '../../../src/services/litert'; + +describe('LiteRTService', () => { + beforeEach(() => { + jest.clearAllMocks(); + // Reset internal state to unloaded + (liteRTService as any).loaded = false; + (liteRTService as any).activeBackend = null; + (liteRTService as any).activeConversationId = null; + (liteRTService as any).activeSystemPrompt = null; + (liteRTService as any).subscriptions = []; + (liteRTService as any).currentCallbacks = null; + // Ensure emitter is available for tests that need it + (liteRTService as any).emitter = mockEmitter; + // Make isAvailable return true by default so state-machine methods run + jest.spyOn(liteRTService, 'isAvailable').mockReturnValue(true); + }); + + describe('isModelLoaded', () => { + it('returns false when not loaded', () => { + expect(liteRTService.isModelLoaded()).toBe(false); + }); + + it('returns true when loaded flag is set', () => { + (liteRTService as any).loaded = true; + expect(liteRTService.isModelLoaded()).toBe(true); + }); + }); + + describe('getActiveBackend', () => { + it('returns null when no model loaded', () => { + expect(liteRTService.getActiveBackend()).toBeNull(); + }); + + it('returns backend when set', () => { + (liteRTService as any).activeBackend = 'npu'; + expect(liteRTService.getActiveBackend()).toBe('npu'); + }); + }); + + describe('isNPU', () => { + it('returns false when backend is cpu', () => { + (liteRTService as any).activeBackend = 'cpu'; + expect(liteRTService.isNPU()).toBe(false); + }); + + it('returns true when backend is npu', () => { + (liteRTService as any).activeBackend = 'npu'; + expect(liteRTService.isNPU()).toBe(true); + }); + }); + + describe('loadModel', () => { + it('calls onError when model not loaded (sendMessage guard)', async () => { + // loadModel uses module-level LiteRTModule const captured at import — hard to mock via NativeModules. + // Instead verify the isAvailable guard indirectly via sendMessage which rejects when not loaded. + (liteRTService as any).loaded = false; + const onError = jest.fn(); + await liteRTService.sendMessage('test', { onToken: jest.fn(), onReasoning: jest.fn(), onComplete: jest.fn(), onError }); + expect(onError).toHaveBeenCalled(); + }); + }); + + describe('sendMessage', () => { + it('calls onError immediately when model is not loaded', async () => { + const onError = jest.fn(); + const callbacks = { onToken: jest.fn(), onReasoning: jest.fn(), onComplete: jest.fn(), onError }; + await liteRTService.sendMessage('hello', callbacks); + expect(onError).toHaveBeenCalledWith(expect.any(Error)); + expect(mockLiteRTModule.sendMessage).not.toHaveBeenCalled(); + }); + }); + + describe('prepareConversation', () => { + it('skips reset when conversationId and systemPrompt are unchanged', async () => { + (liteRTService as any).loaded = true; + (liteRTService as any).activeConversationId = 'conv-1'; + (liteRTService as any).activeSystemPrompt = 'You are helpful.'; + mockLiteRTModule.resetConversation.mockResolvedValue(undefined); + + await liteRTService.prepareConversation('conv-1', 'You are helpful.'); + + expect(mockLiteRTModule.resetConversation).not.toHaveBeenCalled(); + }); + + it('calls resetConversation when systemPrompt changes', async () => { + // Spy on resetConversation directly since LiteRTModule const is captured at import + const resetSpy = jest.spyOn(liteRTService as any, 'resetConversation').mockResolvedValue(undefined); + (liteRTService as any).loaded = true; + (liteRTService as any).activeConversationId = 'conv-1'; + (liteRTService as any).activeSystemPrompt = 'Old prompt'; + + await liteRTService.prepareConversation('conv-1', 'New prompt'); + + expect(resetSpy).toHaveBeenCalledWith('New prompt', undefined); + expect((liteRTService as any).activeConversationId).toBe('conv-1'); + resetSpy.mockRestore(); + }); + }); + + describe('stopGeneration', () => { + it('clears activeConversationId to force reset on next turn', async () => { + (liteRTService as any).activeConversationId = 'conv-1'; + mockLiteRTModule.stopGeneration.mockResolvedValue(undefined); + + await liteRTService.stopGeneration(); + + expect((liteRTService as any).activeConversationId).toBeNull(); + }); + + it('swallows errors from native stopGeneration', async () => { + mockLiteRTModule.stopGeneration.mockRejectedValue(new Error('native error')); + await expect(liteRTService.stopGeneration()).resolves.not.toThrow(); + }); + }); + + describe('unloadModel', () => { + it('sets loaded=false and clears backend in finally block', async () => { + (liteRTService as any).loaded = true; + (liteRTService as any).activeBackend = 'gpu'; + mockLiteRTModule.unloadModel.mockResolvedValue(undefined); + + await liteRTService.unloadModel(); + + expect(liteRTService.isModelLoaded()).toBe(false); + expect(liteRTService.getActiveBackend()).toBeNull(); + }); + + it('still clears state even when native unloadModel throws', async () => { + (liteRTService as any).loaded = true; + (liteRTService as any).activeBackend = 'npu'; + mockLiteRTModule.unloadModel.mockRejectedValue(new Error('unload failed')); + + await liteRTService.unloadModel(); + + expect(liteRTService.isModelLoaded()).toBe(false); + expect(liteRTService.getActiveBackend()).toBeNull(); + }); + }); +}); diff --git a/__tests__/unit/services/llmSafetyChecks.test.ts b/__tests__/unit/services/llmSafetyChecks.test.ts index a3b373f13..a094b3377 100644 --- a/__tests__/unit/services/llmSafetyChecks.test.ts +++ b/__tests__/unit/services/llmSafetyChecks.test.ts @@ -1,5 +1,5 @@ import RNFS from 'react-native-fs'; -import { validateModelFile, checkMemoryForModel } from '../../../src/services/llmSafetyChecks'; +import { validateModelFile, checkMemoryForModel, safeCompletion } from '../../../src/services/llmSafetyChecks'; const mockedRNFS = RNFS as jest.Mocked; @@ -101,3 +101,60 @@ describe('checkMemoryForModel', () => { expect(result.safe).toBe(true); }); }); + +describe('safeCompletion', () => { + it('returns result of completionFn on success', async () => { + const mockContext = { clearCache: jest.fn() }; + const result = await safeCompletion(mockContext as any, async () => 'ok'); + expect(result).toBe('ok'); + }); + + it('throws wrapped error and clears KV cache on native crash (ggml)', async () => { + const mockContext = { clearCache: jest.fn().mockResolvedValue(undefined) }; + await expect( + safeCompletion(mockContext as any, async () => { + throw new Error('ggml alloc failed'); + }), + ).rejects.toThrow('Model inference failed (native error)'); + expect(mockContext.clearCache).toHaveBeenCalledWith(true); + }); + + it('throws wrapped error even when clearCache also fails', async () => { + const mockContext = { clearCache: jest.fn().mockRejectedValue(new Error('cache clear failed')) }; + await expect( + safeCompletion(mockContext as any, async () => { + throw new Error('abort detected'); + }), + ).rejects.toThrow('Model inference failed (native error)'); + }); + + it('re-throws non-native errors unchanged', async () => { + const mockContext = { clearCache: jest.fn() }; + await expect( + safeCompletion(mockContext as any, async () => { + throw new Error('unknown error'); + }), + ).rejects.toThrow('unknown error'); + expect(mockContext.clearCache).not.toHaveBeenCalled(); + }); + + it('recognises OOM as native crash keyword', async () => { + const mockContext = { clearCache: jest.fn().mockResolvedValue(undefined) }; + await expect( + safeCompletion(mockContext as any, async () => { + throw new Error('OOM: out of memory'); + }), + ).rejects.toThrow('Model inference failed (native error)'); + expect(mockContext.clearCache).toHaveBeenCalled(); + }); + + it('uses String(error) when thrown value has no message', async () => { + const mockContext = { clearCache: jest.fn().mockResolvedValue(undefined) }; + await expect( + safeCompletion(mockContext as any, async () => { + // eslint-disable-next-line @typescript-eslint/no-throw-literal + throw 'tensor error string'; + }), + ).rejects.toThrow('Model inference failed (native error)'); + }); +}); diff --git a/__tests__/unit/services/providers/openAICompatibleStream.test.ts b/__tests__/unit/services/providers/openAICompatibleStream.test.ts new file mode 100644 index 000000000..e9245026d --- /dev/null +++ b/__tests__/unit/services/providers/openAICompatibleStream.test.ts @@ -0,0 +1,325 @@ +/** + * Unit tests for openAICompatibleStream.ts + * Covers ThinkTagParser and processDelta branch paths. + */ + +jest.mock('../../../../src/services/httpClient', () => ({ + createNDJSONStreamingRequest: jest.fn(), +})); + +jest.mock('../../../../src/utils/logger', () => ({ + __esModule: true, + default: { log: jest.fn(), warn: jest.fn(), error: jest.fn() }, +})); + +import { ThinkTagParser, processDelta, generateOllamaChatImpl } from '../../../../src/services/providers/openAICompatibleStream'; +import { createNDJSONStreamingRequest } from '../../../../src/services/httpClient'; + +const mockedNDJSON = createNDJSONStreamingRequest as jest.Mock; +import type { OpenAIStreamState } from '../../../../src/services/providers/openAICompatibleTypes'; + +function makeState(overrides: Partial = {}): OpenAIStreamState { + return { + fullContent: '', + fullReasoningContent: '', + toolCalls: [], + currentToolCall: null, + ...overrides, + }; +} + +function makeCtx(thinkingEnabled = true) { + const onToken = jest.fn(); + const onReasoning = jest.fn(); + const callbacks = { onToken, onReasoning, onError: jest.fn(), onComplete: jest.fn() }; + const thinkTagParser = new ThinkTagParser(); + return { thinkingEnabled, callbacks, thinkTagParser, onToken, onReasoning }; +} + +// --------------------------------------------------------------------------- +// ThinkTagParser +// --------------------------------------------------------------------------- + +describe('ThinkTagParser', () => { + it('routes plain text to onToken', () => { + const parser = new ThinkTagParser(); + const onToken = jest.fn(); + const onReasoning = jest.fn(); + parser.process('hello world', onToken, onReasoning); + expect(onToken).toHaveBeenCalledWith('hello world'); + expect(onReasoning).not.toHaveBeenCalled(); + }); + + it('routes ... content to onReasoning', () => { + const parser = new ThinkTagParser(); + const onToken = jest.fn(); + const onReasoning = jest.fn(); + parser.process('reasoning here', onToken, onReasoning); + expect(onReasoning).toHaveBeenCalledWith('reasoning here'); + expect(onToken).not.toHaveBeenCalled(); + }); + + it('splits content before and after think block', () => { + const parser = new ThinkTagParser(); + const tokens: string[] = []; + const reasoning: string[] = []; + parser.process('beforeinsideafter', t => tokens.push(t), r => reasoning.push(r)); + expect(tokens.join('')).toBe('beforeafter'); + expect(reasoning.join('')).toBe('inside'); + }); + + it('handles think tag split across two chunks', () => { + const parser = new ThinkTagParser(); + const tokens: string[] = []; + const reasoning: string[] = []; + const cb = (t: string) => tokens.push(t); + const rc = (r: string) => reasoning.push(r); + // First chunk ends mid-tag + parser.process('hithoughtdone', cb, rc); + expect(tokens.join('')).toBe('hidone'); + expect(reasoning.join('')).toBe('thought'); + }); + + it('handles close tag split across two chunks', () => { + const parser = new ThinkTagParser(); + const tokens: string[] = []; + const reasoning: string[] = []; + parser.process('partial tokens.push(t), r => reasoning.push(r)); + parser.process('nk>rest', t => tokens.push(t), r => reasoning.push(r)); + // reasoning gets 'partial', 'rest' goes to onToken after close tag + expect(reasoning.join('')).toBe('partial'); + expect(tokens.join('')).toBe('rest'); + }); + + it('emits text before think tag via onToken', () => { + const parser = new ThinkTagParser(); + const onToken = jest.fn(); + parser.process('prefixx', onToken, jest.fn()); + expect(onToken).toHaveBeenCalledWith('prefix'); + }); +}); + +// --------------------------------------------------------------------------- +// processDelta +// --------------------------------------------------------------------------- + +describe('processDelta', () => { + it('calls onToken for delta.content (no think tags)', () => { + const state = makeState(); + const { thinkTagParser, callbacks, onToken } = makeCtx(); + processDelta({ content: 'hello' }, state, { thinkingEnabled: true, callbacks, thinkTagParser }); + expect(onToken).toHaveBeenCalledWith('hello'); + expect(state.fullContent).toBe('hello'); + }); + + it('does not call onReasoning when thinkingEnabled=false and reasoning_content present', () => { + const state = makeState(); + const { thinkTagParser, callbacks, onReasoning } = makeCtx(false); + processDelta({ reasoning_content: 'private thought' }, state, { thinkingEnabled: false, callbacks, thinkTagParser }); + expect(onReasoning).not.toHaveBeenCalled(); + expect(state.fullReasoningContent).toBe(''); + }); + + it('calls onReasoning for reasoning_content when thinkingEnabled=true', () => { + const state = makeState(); + const { thinkTagParser, callbacks, onReasoning } = makeCtx(true); + processDelta({ reasoning_content: 'thought' }, state, { thinkingEnabled: true, callbacks, thinkTagParser }); + expect(onReasoning).toHaveBeenCalledWith('thought'); + expect(state.fullReasoningContent).toBe('thought'); + }); + + it('falls back to delta.reasoning field', () => { + const state = makeState(); + const { thinkTagParser, callbacks, onReasoning } = makeCtx(true); + processDelta({ reasoning: 'ollama thought' }, state, { thinkingEnabled: true, callbacks, thinkTagParser }); + expect(onReasoning).toHaveBeenCalledWith('ollama thought'); + }); + + it('falls back to delta.thinking field', () => { + const state = makeState(); + const { thinkTagParser, callbacks, onReasoning } = makeCtx(true); + processDelta({ thinking: 'anthropic thought' }, state, { thinkingEnabled: true, callbacks, thinkTagParser }); + expect(onReasoning).toHaveBeenCalledWith('anthropic thought'); + }); + + it('accumulates tool_calls with id', () => { + const state = makeState(); + const { thinkTagParser, callbacks } = makeCtx(); + processDelta({ + tool_calls: [{ id: 'call-1', type: 'function', function: { name: 'get_weather', arguments: '{"city"' } }], + }, state, { thinkingEnabled: true, callbacks, thinkTagParser }); + expect(state.toolCalls).toHaveLength(1); + expect(state.toolCalls[0].id).toBe('call-1'); + expect(state.toolCalls[0].function.name).toBe('get_weather'); + }); + + it('appends arguments to existing tool call (no new id)', () => { + const state = makeState(); + state.toolCalls = [{ id: 'call-1', type: 'function', function: { name: 'get_weather', arguments: '{"city"' } }]; + state.currentToolCall = state.toolCalls[0]; + const { thinkTagParser, callbacks } = makeCtx(); + processDelta({ + tool_calls: [{ function: { arguments: ':"NY"}' } }], + }, state, { thinkingEnabled: true, callbacks, thinkTagParser }); + expect(state.currentToolCall!.function.arguments).toBe('{"city":"NY"}'); + }); + + it('suppresses think-tag reasoning when thinkingEnabled=false', () => { + const state = makeState(); + const { thinkTagParser, callbacks, onReasoning, onToken } = makeCtx(false); + processDelta({ content: 'hiddenvisible' }, state, { thinkingEnabled: false, callbacks, thinkTagParser }); + // reasoning suppressed, visible text goes to onToken + expect(onReasoning).not.toHaveBeenCalled(); + expect(onToken).toHaveBeenCalledWith('visible'); + }); + + it('ignores delta with no content, no reasoning, no tool_calls', () => { + const state = makeState(); + const { thinkTagParser, callbacks, onToken, onReasoning } = makeCtx(); + processDelta({}, state, { thinkingEnabled: true, callbacks, thinkTagParser }); + expect(onToken).not.toHaveBeenCalled(); + expect(onReasoning).not.toHaveBeenCalled(); + }); +}); + +// --------------------------------------------------------------------------- +// generateOllamaChatImpl — tests handleOllamaChatLine branches indirectly +// --------------------------------------------------------------------------- + +function makeOllamaReq(overrides: any = {}) { + const callbacks = { + onToken: jest.fn(), + onReasoning: jest.fn(), + onError: jest.fn(), + onComplete: jest.fn(), + }; + const controller = new AbortController(); + return { + options: { enableThinking: true }, + callbacks, + signal: controller.signal, + endpoint: 'http://localhost:11434', + modelId: 'llama3', + abort: jest.fn(), + controller, + ...overrides, + }; +} + +describe('generateOllamaChatImpl', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('calls onComplete with content when done=true line received', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockImplementation(async (_url: string, _opts: any, handler: (line: any) => void) => { + handler({ message: { role: 'assistant', content: 'hello' }, done: false }); + handler({ done: true }); + }); + + await generateOllamaChatImpl([], req); + expect(req.callbacks.onToken).toHaveBeenCalledWith('hello'); + expect(req.callbacks.onComplete).toHaveBeenCalledWith(expect.objectContaining({ content: 'hello' })); + }); + + it('calls onError when error field present in line', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockImplementation(async (_url: string, _opts: any, handler: (line: any) => void) => { + handler({ error: 'model not found' }); + }); + + await generateOllamaChatImpl([], req); + expect(req.callbacks.onError).toHaveBeenCalledWith(expect.any(Error)); + expect(req.abort).toHaveBeenCalled(); + }); + + it('calls onComplete with empty content when signal is aborted on throw', async () => { + const req = makeOllamaReq(); + req.controller.abort(); + mockedNDJSON.mockRejectedValue(new Error('aborted')); + + await generateOllamaChatImpl([], req); + expect(req.callbacks.onComplete).toHaveBeenCalledWith(expect.objectContaining({ content: '' })); + expect(req.callbacks.onError).not.toHaveBeenCalled(); + }); + + it('calls onError on non-abort throw', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockRejectedValue(new Error('network error')); + + await generateOllamaChatImpl([], req); + expect(req.callbacks.onError).toHaveBeenCalledWith(expect.any(Error)); + }); + + it('calls onComplete after stream if completeCalled is false', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockImplementation(async (_url: string, _opts: any, handler: (line: any) => void) => { + handler({ message: { content: 'partial' }, done: false }); + // no done:true line + }); + + await generateOllamaChatImpl([], req); + expect(req.callbacks.onComplete).toHaveBeenCalledWith(expect.objectContaining({ content: 'partial' })); + }); + + it('accumulates tool_calls from Ollama message chunks', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockImplementation(async (_url: string, _opts: any, handler: (line: any) => void) => { + handler({ message: { tool_calls: [{ function: { name: 'search', arguments: { query: 'test' } } }] }, done: false }); + handler({ done: true }); + }); + + await generateOllamaChatImpl([], req); + const result = req.callbacks.onComplete.mock.calls[0][0]; + expect(result.toolCalls).toHaveLength(1); + expect(result.toolCalls[0].name).toBe('search'); + }); + + it('strips base64 prefix from image_url content parts', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockImplementation(async (_url: string, _opts: any, handler: (line: any) => void) => { + handler({ done: true }); + }); + + const messages = [{ + role: 'user', + content: [ + { type: 'text', text: 'describe this' }, + { type: 'image_url', image_url: { url: 'data:image/png;base64,abc123' } }, + ], + }] as any; + + await generateOllamaChatImpl(messages, req); + // If we got here without error the conversion ran — check onComplete was called + expect(req.callbacks.onComplete).toHaveBeenCalled(); + }); + + it('converts tool_call arguments from JSON string to object', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockImplementation(async (_url: string, _opts: any, handler: (line: any) => void) => { + handler({ done: true }); + }); + + const messages = [{ + role: 'assistant', + content: '', + tool_calls: [{ id: 'c1', type: 'function', function: { name: 'fn', arguments: '{"k":"v"}' } }], + }] as any; + + await generateOllamaChatImpl(messages, req); + expect(req.callbacks.onComplete).toHaveBeenCalled(); + }); + + it('routes thinking content to onReasoning', async () => { + const req = makeOllamaReq(); + mockedNDJSON.mockImplementation(async (_url: string, _opts: any, handler: (line: any) => void) => { + handler({ message: { thinking: 'internal thought', content: '' }, done: false }); + handler({ done: true }); + }); + + await generateOllamaChatImpl([], req); + expect(req.callbacks.onReasoning).toHaveBeenCalledWith('internal thought'); + }); +}); diff --git a/__tests__/unit/services/rag/embedding.test.ts b/__tests__/unit/services/rag/embedding.test.ts index 08c265eb8..0d3db696b 100644 --- a/__tests__/unit/services/rag/embedding.test.ts +++ b/__tests__/unit/services/rag/embedding.test.ts @@ -121,4 +121,54 @@ describe('EmbeddingService', () => { expect(embeddingService.getDimension()).toBe(384); }); }); + + describe('load — Android copyFileAssets branch', () => { + it('copies from assets on Android when file does not exist', async () => { + const { Platform } = require('react-native'); + Platform.OS = 'android'; + mockExists.mockResolvedValue(false); + mockCopyFileAssets.mockResolvedValue(undefined); + + await embeddingService.load(); + + expect(mockCopyFileAssets).toHaveBeenCalled(); + Platform.OS = 'ios'; // restore + }); + }); + + describe('embed — error recovery branches', () => { + it('unloads model and throws wrapped error on ggml native error', async () => { + await embeddingService.load(); + mockEmbedding.mockRejectedValue(new Error('ggml alloc failed')); + mockRelease.mockResolvedValue(undefined); + + await expect(embeddingService.embed('test')).rejects.toThrow('Embedding failed (native error)'); + expect(embeddingService.isLoaded()).toBe(false); + }); + + it('uses String(error) fallback when error has no message property', async () => { + await embeddingService.load(); + // Throw a plain string, not an Error object — error?.message is undefined + mockEmbedding.mockRejectedValue('OOM string error'); + + await expect(embeddingService.embed('test')).rejects.toThrow('Embedding failed (native error)'); + }); + + it('re-throws non-recovery errors unchanged', async () => { + await embeddingService.load(); + mockEmbedding.mockRejectedValue(new Error('unexpected error')); + + await expect(embeddingService.embed('test')).rejects.toThrow('unexpected error'); + }); + }); + + describe('unload — release error is swallowed', () => { + it('sets context to null even when release throws', async () => { + await embeddingService.load(); + mockRelease.mockRejectedValue(new Error('bridge torn down')); + + await embeddingService.unload(); + expect(embeddingService.isLoaded()).toBe(false); + }); + }); }); diff --git a/ios/OffgridMobileTests/OffgridMobileTests.swift b/ios/OffgridMobileTests/OffgridMobileTests.swift index 6667f4a4d..9a409f910 100644 --- a/ios/OffgridMobileTests/OffgridMobileTests.swift +++ b/ios/OffgridMobileTests/OffgridMobileTests.swift @@ -680,12 +680,15 @@ final class DownloadManagerModuleTests: XCTestCase { bytesDownloaded: 1_000_000, status: "completed", startedAt: Date().timeIntervalSince1970 * 1000, - task: nil, + modelKey: nil, + modelType: "text", + combinedTotalBytes: 1_000_000, + metadataJson: nil, taskIdentifier: nil, localUri: TestPaths.tmpTestModelGGUF, - fileTasks: [:], multiFileDestDir: nil, - isMultiFile: false + isMultiFile: false, + fileTasks: [] ) module.queue.sync(flags: .barrier) { self.module.downloads["100"] = info @@ -728,12 +731,15 @@ final class DownloadManagerModuleTests: XCTestCase { bytesDownloaded: 256, status: "completed", startedAt: Date().timeIntervalSince1970 * 1000, - task: nil, + modelKey: nil, + modelType: "text", + combinedTotalBytes: 256, + metadataJson: nil, taskIdentifier: nil, localUri: sourceFile, - fileTasks: [:], multiFileDestDir: nil, - isMultiFile: false + isMultiFile: false, + fileTasks: [] ) module.queue.sync(flags: .barrier) { self.module.downloads["200"] = info @@ -781,12 +787,15 @@ final class DownloadManagerModuleTests: XCTestCase { bytesDownloaded: 500_000, status: "running", startedAt: Date().timeIntervalSince1970 * 1000, - task: nil, + modelKey: nil, + modelType: "text", + combinedTotalBytes: 0, + metadataJson: nil, taskIdentifier: nil, localUri: nil, - fileTasks: [:], multiFileDestDir: nil, - isMultiFile: false + isMultiFile: false, + fileTasks: [] ) module.queue.sync(flags: .barrier) { self.module.downloads["300"] = info From 626501129abc524460980f1fcb478854c61cbdf5 Mon Sep 17 00:00:00 2001 From: Dishit Date: Tue, 19 May 2026 20:02:33 +0530 Subject: [PATCH 06/93] feat(litert): expose full BenchmarkInfo - prefill speed, TTFT, decode tps, init time in generation details Co-Authored-By: Dishit Karia --- .../src/main/java/ai/offgridmobile/litert/LiteRTModule.kt | 2 +- src/components/ChatMessage/components/GenerationMeta.tsx | 4 +++- src/services/generationServiceHelpers.ts | 6 ++++-- src/services/litert.ts | 1 + src/types/index.ts | 6 +++++- 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt index 79f88d86f..ca16b318d 100644 --- a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -242,7 +242,7 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : @OptIn(ExperimentalApi::class) val benchmarkJson = try { val b = conv.getBenchmarkInfo() - """{"ttft":${b.timeToFirstTokenInSecond},"decodeTokensPerSecond":${b.lastDecodeTokensPerSecond},"prefillTokensPerSecond":${b.lastPrefillTokensPerSecond},"prefillTokenCount":${b.lastPrefillTokenCount}}""" + """{"ttft":${b.timeToFirstTokenInSecond},"decodeTokensPerSecond":${b.lastDecodeTokensPerSecond},"prefillTokensPerSecond":${b.lastPrefillTokensPerSecond},"prefillTokenCount":${b.lastPrefillTokenCount},"initTimeSeconds":${b.initTimeInSecond}}""" } catch (e: Exception) { Log.w(TAG, "getBenchmarkInfo failed: ${e.message}") "" diff --git a/src/components/ChatMessage/components/GenerationMeta.tsx b/src/components/ChatMessage/components/GenerationMeta.tsx index 59087f2b3..880e5e276 100644 --- a/src/components/ChatMessage/components/GenerationMeta.tsx +++ b/src/components/ChatMessage/components/GenerationMeta.tsx @@ -14,8 +14,10 @@ function formatOptionalMeta(meta: NonNullable, tps: n const m = meta; const entries: Array<[string, string | undefined, number?]> = [ ['model', m.modelName, 1], + ['load', m.modelLoadTimeSeconds != null && m.modelLoadTimeSeconds > 0 ? `load ${m.modelLoadTimeSeconds.toFixed(1)}s` : undefined], + ['prefill', m.prefillTokensPerSecond != null && m.prefillTokensPerSecond > 0 ? `prefill ${m.prefillTokensPerSecond.toFixed(0)} tok/s` : undefined], ['tps', tps != null && tps > 0 ? `${tps.toFixed(1)} tok/s` : undefined], - ['ttft', m.timeToFirstToken != null && m.timeToFirstToken > 0 ? `TTFT ${m.timeToFirstToken.toFixed(1)}s` : undefined], + ['ttft', m.timeToFirstToken != null && m.timeToFirstToken > 0 ? `TTFT ${m.timeToFirstToken.toFixed(2)}s` : undefined], ['tokens', m.tokenCount != null && m.tokenCount > 0 ? `${m.tokenCount} tokens` : undefined], ['steps', m.steps == null ? undefined : `${m.steps} steps`], ['cfg', m.guidanceScale == null ? undefined : `cfg ${m.guidanceScale}`], diff --git a/src/services/generationServiceHelpers.ts b/src/services/generationServiceHelpers.ts index 0faed2375..882094f66 100644 --- a/src/services/generationServiceHelpers.ts +++ b/src/services/generationServiceHelpers.ts @@ -72,9 +72,11 @@ export function buildGenerationMetaImpl(svc: any): GenerationMeta { gpu: backend !== 'cpu', gpuBackend: backend.toUpperCase(), modelName, - tokensPerSecond: stats.decodeTokensPerSecond, - timeToFirstToken: stats.ttft * 1000, + decodeTokensPerSecond: stats.decodeTokensPerSecond, + prefillTokensPerSecond: stats.prefillTokensPerSecond, + timeToFirstToken: stats.ttft, tokenCount: stats.prefillTokenCount, + modelLoadTimeSeconds: stats.initTimeSeconds > 0 ? stats.initTimeSeconds : undefined, }; } // Fallback estimate if BenchmarkInfo unavailable diff --git a/src/services/litert.ts b/src/services/litert.ts index d9cbc0450..da19b7bba 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -29,6 +29,7 @@ export interface LiteRTBenchmarkStats { decodeTokensPerSecond: number; prefillTokensPerSecond: number; prefillTokenCount: number; + initTimeSeconds: number; } export interface LiteRTMemoryInfo { diff --git a/src/types/index.ts b/src/types/index.ts index dfe1040c7..840e8e22d 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -160,10 +160,14 @@ export interface GenerationMeta { tokensPerSecond?: number; /** Tokens per second — decode only, excluding prefill (text generation only) */ decodeTokensPerSecond?: number; - /** Time to first token in seconds (text generation only) */ + /** Tokens per second — prefill/prompt processing speed (LiteRT only) */ + prefillTokensPerSecond?: number; + /** Time to first token in milliseconds (text generation only) */ timeToFirstToken?: number; /** Token count (text generation only) */ tokenCount?: number; + /** Model load/init time in seconds */ + modelLoadTimeSeconds?: number; /** Image generation steps */ steps?: number; /** Image guidance scale */ From ed234389806e2969ca58211f6484a2ba59bd9a82 Mon Sep 17 00:00:00 2001 From: Dishit Date: Wed, 20 May 2026 12:26:51 +0530 Subject: [PATCH 07/93] feat(litert): replace SDK benchmark with wall-clock stats and add GPU/NPU warmup getBenchmarkInfo() requires internal BenchmarkParams not exposed in the public API. Track TTFT, decode tok/s, and token count via wall-clock timers in JS instead. Add model warmup after GPU/NPU load to prime shader caches. Co-Authored-By: Dishit Karia --- .../ai/offgridmobile/litert/LiteRTModule.kt | 1 + .../ChatMessage/components/GenerationMeta.tsx | 3 +- src/services/activeModelService/loaders.ts | 8 +++ src/services/generationServiceHelpers.ts | 6 +- src/services/litert.ts | 59 +++++++++++++++++-- 5 files changed, 66 insertions(+), 11 deletions(-) diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt index ca16b318d..9ee876aff 100644 --- a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -242,6 +242,7 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : @OptIn(ExperimentalApi::class) val benchmarkJson = try { val b = conv.getBenchmarkInfo() + Log.i(TAG, "getBenchmarkInfo — ttft=${b.timeToFirstTokenInSecond} decode=${b.lastDecodeTokensPerSecond} prefill=${b.lastPrefillTokensPerSecond} prefillCount=${b.lastPrefillTokenCount} init=${b.initTimeInSecond}") """{"ttft":${b.timeToFirstTokenInSecond},"decodeTokensPerSecond":${b.lastDecodeTokensPerSecond},"prefillTokensPerSecond":${b.lastPrefillTokensPerSecond},"prefillTokenCount":${b.lastPrefillTokenCount},"initTimeSeconds":${b.initTimeInSecond}}""" } catch (e: Exception) { Log.w(TAG, "getBenchmarkInfo failed: ${e.message}") diff --git a/src/components/ChatMessage/components/GenerationMeta.tsx b/src/components/ChatMessage/components/GenerationMeta.tsx index 880e5e276..eeeddffaa 100644 --- a/src/components/ChatMessage/components/GenerationMeta.tsx +++ b/src/components/ChatMessage/components/GenerationMeta.tsx @@ -42,7 +42,8 @@ function buildMetaItems( } export function GenerationMeta({ generationMeta, styles }: Readonly) { - const tps = generationMeta.decodeTokensPerSecond ?? generationMeta.tokensPerSecond; + const rawTps = generationMeta.decodeTokensPerSecond ?? generationMeta.tokensPerSecond; + const tps = rawTps && rawTps > 0 ? rawTps : undefined; const items = buildMetaItems(generationMeta, tps); return ( diff --git a/src/services/activeModelService/loaders.ts b/src/services/activeModelService/loaders.ts index 5640c9755..83db5941f 100644 --- a/src/services/activeModelService/loaders.ts +++ b/src/services/activeModelService/loaders.ts @@ -154,6 +154,14 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { addDebugLog('warn', `[LiteRT] Requested ${preferredBackend}, fell back to ${actualBackend}`); } + // Warmup on GPU/NPU only — primes shader/kernel caches so first real prompt runs at full speed + if (actualBackend === 'gpu' || actualBackend === 'npu') { + addDebugLog('log', `[LiteRT] Starting warmup on ${actualBackend}...`); + const warmupStart = Date.now(); + await liteRTService.warmup(); + addDebugLog('log', `[LiteRT] Warmup complete in ${((Date.now() - warmupStart) / 1000).toFixed(1)}s`); + } + ctx.onLoaded(ctx.modelId); ctx.store.setActiveModelId(ctx.modelId); } catch (error) { diff --git a/src/services/generationServiceHelpers.ts b/src/services/generationServiceHelpers.ts index 882094f66..f4166e63e 100644 --- a/src/services/generationServiceHelpers.ts +++ b/src/services/generationServiceHelpers.ts @@ -79,15 +79,11 @@ export function buildGenerationMetaImpl(svc: any): GenerationMeta { modelLoadTimeSeconds: stats.initTimeSeconds > 0 ? stats.initTimeSeconds : undefined, }; } - // Fallback estimate if BenchmarkInfo unavailable - const generationTime = svc.state.startTime ? (Date.now() - svc.state.startTime) / 1000 : 0; - const estimatedTokens = Math.ceil(svc.state.streamingContent.length / 4); + // BenchmarkInfo unavailable — return backend info only, no estimated tps return { gpu: backend !== 'cpu', gpuBackend: backend.toUpperCase(), modelName, - tokenCount: estimatedTokens, - tokensPerSecond: generationTime > 0 ? estimatedTokens / generationTime : undefined, }; } diff --git a/src/services/litert.ts b/src/services/litert.ts index da19b7bba..523c979ab 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -11,6 +11,7 @@ import { NativeModules, NativeEventEmitter, Platform, EmitterSubscription } from 'react-native'; import logger from '../utils/logger'; +import { useDebugLogsStore } from '../stores/debugLogsStore'; const TAG = '[LiteRTService]'; @@ -137,6 +138,32 @@ class LiteRTService { } } + // --------------------------------------------------------------------------- + // warmup — send a throwaway prompt to prime GPU/NPU shader caches + // --------------------------------------------------------------------------- + + async warmup(): Promise { + if (!this.isAvailable() || !this.loaded) return; + logger.log(TAG, 'warmup — starting'); + try { + await this.resetConversation(''); + await new Promise((resolve) => { + this.sendMessage('Hi', { + onToken: () => {}, + onReasoning: () => {}, + onComplete: () => resolve(), + onError: () => resolve(), + }); + }); + // Clear warmup state so first real message gets a fresh conversation + this.activeConversationId = null; + this.activeSystemPrompt = null; + logger.log(TAG, 'warmup — done'); + } catch (e) { + logger.log(TAG, `warmup — error (ignored): ${String(e)}`); + } + } + // --------------------------------------------------------------------------- // sendMessage — sends current turn only, library holds history // --------------------------------------------------------------------------- @@ -158,10 +185,17 @@ class LiteRTService { this.currentReasoning = ''; this.currentCallbacks = callbacks; + // Wall-clock tracking + const sendStart = Date.now(); + let firstTokenTime: number | undefined; + let decodeTokenCount = 0; + // Register event listeners for this generation this.clearSubscriptions(); this.subscriptions = [ this.emitter!.addListener(EVENT_TOKEN, (token: string) => { + if (firstTokenTime === undefined) firstTokenTime = Date.now(); + decodeTokenCount++; this.currentContent += token; callbacks.onToken(token); }), @@ -173,11 +207,26 @@ class LiteRTService { logger.log(TAG, `sendMessage — complete, content=${this.currentContent.length} chars`); this.clearSubscriptions(); this.currentCallbacks = null; - let stats: LiteRTBenchmarkStats | undefined; - if (benchmarkJson) { - try { stats = JSON.parse(benchmarkJson); } catch { /* ignore parse errors */ } - } - callbacks.onComplete(this.currentContent, this.currentReasoning, stats); + const addLog = useDebugLogsStore.getState().addLog; + + // Build wall-clock stats + const completeTime = Date.now(); + const ttft = firstTokenTime !== undefined ? (firstTokenTime - sendStart) / 1000 : undefined; + const decodeElapsed = firstTokenTime !== undefined ? (completeTime - firstTokenTime) / 1000 : undefined; + const decodeTokensPerSecond = decodeElapsed && decodeElapsed > 0 && decodeTokenCount > 1 + ? decodeTokenCount / decodeElapsed + : undefined; + + const wallClockStats: LiteRTBenchmarkStats = { + ttft: ttft ?? 0, + decodeTokensPerSecond: decodeTokensPerSecond ?? 0, + prefillTokensPerSecond: 0, + prefillTokenCount: decodeTokenCount, + initTimeSeconds: 0, + }; + + addLog('log', `[LiteRTService] wall-clock stats — ttft=${ttft?.toFixed(3)}s decode=${decodeTokensPerSecond?.toFixed(1)}tok/s tokens=${decodeTokenCount}`); + callbacks.onComplete(this.currentContent, this.currentReasoning, wallClockStats); }), this.emitter!.addListener(EVENT_ERROR, (message: string) => { logger.log(TAG, `sendMessage — error: ${message}`); From 6b8703cfc8021b9a59a57f6725dbf1eb9cb33b08 Mon Sep 17 00:00:00 2001 From: Dishit Date: Thu, 21 May 2026 15:04:44 +0530 Subject: [PATCH 08/93] fix(litert): tool call parsing, regeneration, and context stability - Fix regeneration for LiteRT: use ensureModelReady instead of bare llmService.isModelLoaded() check which always returns false for LiteRT - Invalidate native conversation before regenerate/edit so native history is correctly rewound to match the JS message array - Fix context loss after stopGeneration: remove activeConversationId=null which was wiping native turn history on every stop - Add invalidateConversation() to LiteRTService for explicit resets - Extend tool call parser to handle: no-args calls, Gemma function-call style args NAME({"k":"v"}), and closing tag variant - Fix Gemma native parser regex to accept both and as closing tags - GPU retry logic in LiteRTModule: retry non-CPU backends up to 3 times with 600ms backoff before falling back, handles transient VRAM pressure after model switches - Capture benchmark stats from generateRaw path for generation meta display - Raise debug log capacity from 200 to 2000 entries Co-Authored-By: Dishit Karia --- android/app/build.gradle | 2 +- .../ai/offgridmobile/litert/LiteRTModule.kt | 67 ++++--- android/app/src/main/res/values/strings.xml | 2 +- docs/LITERT_TODO.md | 33 ++++ ios/Podfile.lock | 2 +- .../ChatScreen/useChatGenerationActions.ts | 36 ++-- src/screens/ChatScreen/useChatModelActions.ts | 5 +- src/services/generationService.ts | 6 +- src/services/generationServiceHelpers.ts | 4 +- src/services/generationToolLoop.ts | 187 ++++++++++++++++-- src/services/litert.ts | 43 +++- src/stores/debugLogsStore.ts | 2 +- 12 files changed, 323 insertions(+), 66 deletions(-) diff --git a/android/app/build.gradle b/android/app/build.gradle index e831ffcec..b84724219 100644 --- a/android/app/build.gradle +++ b/android/app/build.gradle @@ -81,7 +81,7 @@ android { namespace "ai.offgridmobile" defaultConfig { - applicationId "ai.offgridmobile" + applicationId "ai.offgridmobile.litert" minSdkVersion rootProject.ext.minSdkVersion targetSdkVersion rootProject.ext.targetSdkVersion versionCode 1778048025 diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt index 9ee876aff..d6c656e2c 100644 --- a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -94,36 +94,55 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : else -> listOf(Backend.CPU()) } + // GPU/NPU failures can be transient (e.g. VRAM not yet released after a model switch). + // Retry up to 2 extra times with backoff before giving up on a non-CPU backend. + val GPU_RETRIES = 2 + val GPU_RETRY_DELAY_MS = 600L + var lastError: Exception? = null for (backend in chain) { val name = backendName(backend) - Log.i(TAG, "initializeWithFallback — trying $name vision=$visionEnabled") - try { - val cfg = EngineConfig( - modelPath = modelPath, - backend = backend, - cacheDir = null, - visionBackend = if (visionEnabled) Backend.GPU() else null, - ) - val eng = Engine(cfg) - val timeoutMs = when (backend) { - is Backend.NPU -> NPU_TIMEOUT_MS - is Backend.GPU -> GPU_TIMEOUT_MS - else -> CPU_TIMEOUT_MS + val maxAttempts = if (backend is Backend.CPU) 1 else GPU_RETRIES + 1 + var succeeded = false + + for (attempt in 1..maxAttempts) { + if (attempt > 1) { + Log.i(TAG, "initializeWithFallback — $name retry $attempt/$maxAttempts after ${GPU_RETRY_DELAY_MS}ms") + delay(GPU_RETRY_DELAY_MS) + } else { + Log.i(TAG, "initializeWithFallback — trying $name vision=$visionEnabled") } - withTimeout(timeoutMs) { - eng.initialize() + try { + val cfg = EngineConfig( + modelPath = modelPath, + backend = backend, + cacheDir = null, + visionBackend = if (visionEnabled) Backend.GPU() else null, + ) + val eng = Engine(cfg) + val timeoutMs = when (backend) { + is Backend.NPU -> NPU_TIMEOUT_MS + is Backend.GPU -> GPU_TIMEOUT_MS + else -> CPU_TIMEOUT_MS + } + withTimeout(timeoutMs) { + eng.initialize() + } + engine = eng + Log.i(TAG, "initializeWithFallback — $name succeeded (attempt $attempt)") + succeeded = true + return backend + } catch (e: Exception) { + Log.w(TAG, "initializeWithFallback — $name attempt $attempt failed: ${e.message}") + engine?.close() + engine = null + lastError = e } - engine = eng - Log.i(TAG, "initializeWithFallback — $name succeeded") - return backend - } catch (e: Exception) { - Log.w(TAG, "initializeWithFallback — $name failed: ${e.message}") - engine?.close() - engine = null - lastError = e + } + + if (!succeeded) { if (backend == chain.last()) break - Log.i(TAG, "initializeWithFallback — falling back to next tier") + Log.i(TAG, "initializeWithFallback — $name exhausted retries, falling back to next tier") } } throw lastError ?: IllegalStateException("All backends failed") diff --git a/android/app/src/main/res/values/strings.xml b/android/app/src/main/res/values/strings.xml index 9fba9024d..48e3bd995 100644 --- a/android/app/src/main/res/values/strings.xml +++ b/android/app/src/main/res/values/strings.xml @@ -1,3 +1,3 @@ - Off Grid + Off Grid LiteRT diff --git a/docs/LITERT_TODO.md b/docs/LITERT_TODO.md index 4f0860778..edbb038e7 100644 --- a/docs/LITERT_TODO.md +++ b/docs/LITERT_TODO.md @@ -50,3 +50,36 @@ - Library: `com.google.ai.edge.litertlm:litertlm-android:0.11.0` - iOS Swift SDK: not released yet, Coming Soon per Google - Gemma 4 E2B on GPU: TTFT ~7s, ~38-41 chars/sec + + +i want to cleanup some branches haev got too many branches locally +add-mailfeedback-and-lovewithWednesday +addmailto +0.0.84 +85 +dev/integration -dont know what this is about dont delete +fdroid/wrapper-and-fastlane +feat/remote-server-api-keys can do + fix-file-delete-from-source-error + fix-file-import + fix-google-play-services + fix-threads-resetting-auto-error + fix-threads-resetting-auto-error2 + fix-whisper-downlaod + fix/ios-document-picker-issue2 + fix/issue-201-view-old-chats-without-model + fix/react-native-zip-archive-upgrade + fix/local-model-import + fix/remove-notification-permission + fix/vision-retry-mmproj-restore + fix/whisper-release-thread-join + react-native-zip-patch + remove-phi-list + ui-ux/improvements/0.0.88 + make-tool-heuristic-based + + + + + + \ No newline at end of file diff --git a/ios/Podfile.lock b/ios/Podfile.lock index 7cedf2d7f..cac5f375c 100644 --- a/ios/Podfile.lock +++ b/ios/Podfile.lock @@ -3603,7 +3603,7 @@ SPEC CHECKSUMS: FBLazyVector: 309703e71d3f2f1ed7dc7889d58309c9d77a95a4 fmt: a40bb5bd0294ea969aaaba240a927bd33d878cdd glog: 5683914934d5b6e4240e497e0f4a3b42d1854183 - hermes-engine: 8c6be38f94b3bf8b864981980e64e55f08e467ec + hermes-engine: 4d40ce008f57c9348a66614a3167581aa861e379 llama-rn: 796fa53f37f89e2c77cd6c462ad1172ee96d4c80 lottie-ios: a881093fab623c467d3bce374367755c272bdd59 lottie-react-native: 691b8363e8c591fb78a78254ff2517258891456b diff --git a/src/screens/ChatScreen/useChatGenerationActions.ts b/src/screens/ChatScreen/useChatGenerationActions.ts index ab2534f1f..0c6f35b05 100644 --- a/src/screens/ChatScreen/useChatGenerationActions.ts +++ b/src/screens/ChatScreen/useChatGenerationActions.ts @@ -8,7 +8,6 @@ import { APP_CONFIG } from '../../constants'; import { llmService, intentClassifier, - classifyToolsNeeded, generationService, imageGenerationService, onnxImageGeneratorService, @@ -224,32 +223,23 @@ async function injectRagContext(projectId: string | undefined, query: string, pr /** Gemma 4 E2B/E4B need <|think|> prepended to activate thinking mode. */ const applyGemma4ThinkToken = (prompt: string, isRemote: boolean): string => (!isRemote && llmService.isGemma4Model() && llmService.isThinkingEnabled()) ? `<|think|>\n${prompt}` : prompt; -function resolveToolsAndPrompt(deps: GenerationDeps, conversation: any, messageText: string): { enabledTools: string[]; rawPrompt: string } { +function resolveToolsAndPrompt(deps: GenerationDeps, conversation: any, _messageText: string): { enabledTools: string[]; rawPrompt: string } { const project = conversation?.projectId ? useProjectStore.getState().getProject(conversation.projectId) : null; const { activeServerId, activeRemoteTextModelId } = useRemoteServerStore.getState(); const localToolCalling = llmService.supportsToolCalling(); const isRemoteActive = !!(activeServerId && activeRemoteTextModelId); - const canUseTools = localToolCalling || isRemoteActive; + const isLiteRT = deps.activeModel?.engine === 'litert' && liteRTService.isModelLoaded(); + const canUseTools = localToolCalling || isRemoteActive || isLiteRT; let enabledTools = canUseTools ? (deps.settings.enabledTools || []) : []; - if (enabledTools.length > 0) { - // Heuristic filter: only pass tools relevant to this message (local regex, ~0.1ms) - const heuristicTools = classifyToolsNeeded(messageText); - - // Always keep search_knowledge_base for project conversations regardless of heuristic - const alwaysKeep = new Set(); - if (conversation?.projectId) alwaysKeep.add('search_knowledge_base'); - - enabledTools = enabledTools.filter(t => heuristicTools.includes(t) || alwaysKeep.has(t)); - - // Auto-add search_knowledge_base for project chats even if not in user's enabled list - if (conversation?.projectId && !enabledTools.includes('search_knowledge_base')) { - enabledTools = [...enabledTools, 'search_knowledge_base']; - } + // Auto-add search_knowledge_base for project chats even if not in user's enabled list + if (conversation?.projectId && !enabledTools.includes('search_knowledge_base')) { + enabledTools = [...enabledTools, 'search_knowledge_base']; } const rawPrompt = project?.systemPrompt || deps.settings.systemPrompt || APP_CONFIG.defaultSystemPrompt; + logger.log(`[ChatGen][resolveTools] isLiteRT=${isLiteRT}, canUseTools=${canUseTools}, enabledTools=[${enabledTools.join(', ')}]`); return { enabledTools, rawPrompt }; } export async function startGenerationFn(deps: GenerationDeps, call: StartGenerationCall): Promise { @@ -283,7 +273,8 @@ export async function startGenerationFn(deps: GenerationDeps, call: StartGenerat useTextHint ? `${basePrompt}${buildToolSystemPromptHint(activeTools)}` : basePrompt, isRemote, ); - logger.log(`[ChatGen][DEBUG] isRemote=${isRemote}, tools=[${activeTools.join(', ')}], path=${activeTools.length > 0 ? 'withTools' : 'generate'}`); + logger.log(`[ChatGen][DEBUG] isRemote=${isRemote}, useTextHint=${useTextHint}, tools=[${activeTools.join(', ')}], path=${activeTools.length > 0 ? 'withTools' : 'generate'}`); + logger.log(`[ChatGen][PROMPT] systemPrompt (${systemPrompt.length}ch): "${systemPrompt.substring(0, 800)}"`); const messagesForContext = buildMessagesForContext(targetConversationId, messageText, systemPrompt); await prepareContext(setDebugInfo, systemPrompt, messagesForContext); try { @@ -354,8 +345,15 @@ export async function regenerateResponseFn(deps: GenerationDeps, call: Regenerat await handleImageGenerationFn(deps, { prompt: userMessage.content, conversationId: targetConversationId, skipUserMessage: true }); return; } - if (!deps.activeModelInfo?.isRemote && !llmService.isModelLoaded()) return; + if (!deps.activeModelInfo?.isRemote && deps.activeModel) { + if (!(await ensureModelReady(deps))) { + deps.setAlertState(showAlert('Error', 'Failed to load model. Please try again.')); + return; + } + } deps.generatingForConversationRef.current = targetConversationId; + // LiteRT: native history must be rewound to match the JS messages we're about to replay. + if (deps.activeModel?.engine === 'litert') liteRTService.invalidateConversation(); const conversation = useChatStore.getState().conversations.find(c => c.id === targetConversationId); const messages = (conversation?.messages || []).filter((m: Message) => !m.isSystemInfo); const messagesUpToUser = messages.slice(0, messages.findIndex((m: Message) => m.id === userMessage.id) + 1) diff --git a/src/screens/ChatScreen/useChatModelActions.ts b/src/screens/ChatScreen/useChatModelActions.ts index dafb7b308..65b33af3b 100644 --- a/src/screens/ChatScreen/useChatModelActions.ts +++ b/src/screens/ChatScreen/useChatModelActions.ts @@ -339,6 +339,9 @@ export function useChatModelStateSync(deps: ModelStateSyncDeps): void { if (activeRemoteTextModelId) { setSupportsToolCalling(activeRemoteModel?.capabilities?.supportsToolCalling ?? false); setSupportsThinking(activeRemoteModel?.capabilities?.supportsThinking ?? false); + } else if (activeModel?.engine === 'litert' && liteRTService.isModelLoaded()) { + setSupportsToolCalling(true); + setSupportsThinking(false); } else if (llmService.isModelLoaded()) { setSupportsToolCalling(llmService.supportsToolCalling()); setSupportsThinking(llmService.supportsThinking()); @@ -347,5 +350,5 @@ export function useChatModelStateSync(deps: ModelStateSyncDeps): void { setSupportsThinking(false); } - }, [activeModelId, isModelLoading, activeRemoteTextModelId, activeRemoteModel?.capabilities?.supportsToolCalling, activeRemoteModel?.capabilities?.supportsThinking]); + }, [activeModelId, activeModel?.engine, isModelLoading, activeRemoteTextModelId, activeRemoteModel?.capabilities?.supportsToolCalling, activeRemoteModel?.capabilities?.supportsThinking]); } diff --git a/src/services/generationService.ts b/src/services/generationService.ts index 7faec9b15..7eebdc302 100644 --- a/src/services/generationService.ts +++ b/src/services/generationService.ts @@ -182,10 +182,14 @@ class GenerationService { }); // If aborted, stopGeneration() already handled cleanup. + logger.log(`[GenService][ToolLoop] runToolLoop done — aborted=${this.abortRequested}, streamingContent=${this.state.streamingContent?.length ?? 0}ch, tokenBuffer=${this.tokenBuffer?.length ?? 0}ch`); if (!this.abortRequested) { this.forceFlushTokens(); + const store = useChatStore.getState(); + logger.log(`[GenService][ToolLoop] pre-finalize — streamingForConvId=${store.streamingForConversationId}, targetConvId=${conversationId}, streamingMsg=${store.streamingMessage?.length ?? 0}ch`); const generationTime = this.state.startTime ? Date.now() - this.state.startTime : undefined; - useChatStore.getState().finalizeStreamingMessage(conversationId, generationTime, this.buildGenerationMeta()); + store.finalizeStreamingMessage(conversationId, generationTime, this.buildGenerationMeta()); + logger.log(`[GenService][ToolLoop] finalizeStreamingMessage called — convId=${conversationId}`); this.checkSharePrompt(); this.resetState(); } diff --git a/src/services/generationServiceHelpers.ts b/src/services/generationServiceHelpers.ts index f4166e63e..58cb30311 100644 --- a/src/services/generationServiceHelpers.ts +++ b/src/services/generationServiceHelpers.ts @@ -66,7 +66,9 @@ export function buildGenerationMetaImpl(svc: any): GenerationMeta { // LiteRT path — use real BenchmarkInfo stats if available, else estimate if (isLiteRTActive()) { const backend = liteRTService.getActiveBackend() ?? 'cpu'; - const stats = svc.liteRTBenchmarkStats; + // svc.liteRTBenchmarkStats is set on the direct (non-tool) path; + // liteRTService.getLastBenchmarkStats() covers the generateRaw (tool loop) path + const stats = svc.liteRTBenchmarkStats ?? liteRTService.getLastBenchmarkStats(); if (stats) { return { gpu: backend !== 'cpu', diff --git a/src/services/generationToolLoop.ts b/src/services/generationToolLoop.ts index e5b78f2c8..6b80befda 100644 --- a/src/services/generationToolLoop.ts +++ b/src/services/generationToolLoop.ts @@ -1,6 +1,7 @@ /** Tool-calling generation loop. Extracted to keep generationService.ts under the max-lines limit. */ import { llmService } from './llm'; import type { StreamToken } from './llm'; +import { liteRTService } from './litert'; import { useChatStore, useRemoteServerStore, useAppStore } from '../stores'; import { Message } from '../types'; import { getToolsAsOpenAISchema, executeToolCall } from './tools'; @@ -23,12 +24,78 @@ function parseXmlStyleToolCall(body: string, idSuffix: number): ToolCall | null } function parseToolCallBody(body: string, idSuffix: number): ToolCall | null { + const makeCall = (name: string, args: Record): ToolCall => + ({ id: `text-tc-${Date.now()}-${idSuffix}`, name, arguments: args }); + + // Standard JSON: {"name": "tool", "arguments": {...}} try { const parsed = JSON.parse(body); - if (parsed.name) return { id: `text-tc-${Date.now()}-${idSuffix}`, name: parsed.name, arguments: parsed.arguments || parsed.parameters || {} }; - } catch { /* Not JSON — fall through to XML */ } + if (parsed.name) return makeCall(parsed.name, parsed.arguments || parsed.parameters || {}); + } catch { /* fall through */ } + + // Function-call style: tool_name({"key": "value"}) + const funcMatch = body.match(/^(\w+)\s*\((\{[\s\S]*\})\)$/); + if (funcMatch) { + try { return makeCall(funcMatch[1], JSON.parse(funcMatch[2])); } catch { /* fall through */ } + } + + // Bare style: tool_name{"key": "value"} + const bareMatch = body.match(/^(\w+)\s*(\{[\s\S]*\})$/); + if (bareMatch) { + try { return makeCall(bareMatch[1], JSON.parse(bareMatch[2])); } catch { /* fall through */ } + } + + // No-args style: just a tool name with no arguments + const noArgsMatch = body.match(/^(\w+)$/); + if (noArgsMatch) return makeCall(noArgsMatch[1], {}); + return parseXmlStyleToolCall(body, idSuffix); } +/** Parse Gemma 4's native tool call format: <|tool_call>call:NAME{...} and */ +function parseGemmaNativeToolCalls(text: string): { cleanText: string; toolCalls: ToolCall[] } { + const toolCalls: ToolCall[] = []; + // Matches <|tool_call>..., , and + const pattern = /(?:<\|tool_call>||<\/tool_call>)/g; + let match; + const matchedRanges: [number, number][] = []; + + while ((match = pattern.exec(text)) !== null) { + matchedRanges.push([match.index, match.index + match[0].length]); + const raw = match[1].trim().replace(/<\|"\|>/g, '"'); + // Extract tool name — handles: call:NAME, NAME, call:NAME({...}), call:NAME{...} + const nameMatch = raw.match(/^(?:call:)?(\w+)/); + if (nameMatch) { + const name = nameMatch[1]; + const rest = raw.slice(nameMatch[0].length).trim(); + let args: Record = {}; + // Args may be: ({"key":"val"}) function-call style, or {"key":"val"} bare object + const argsStr = rest.match(/^\((\{[\s\S]*\})\)$/)?.[1] ?? rest.match(/^(\{[\s\S]*\})$/)?.[1] ?? null; + if (argsStr) { + try { + // Gemma often emits unquoted keys like {queries:["x"]} — fix to valid JSON + const fixedJson = argsStr.replace(/([{,]\s*)([a-zA-Z_]\w*)(\s*):/g, '$1"$2"$3:'); + args = JSON.parse(fixedJson); + } catch { + logger.warn(`[ToolLoop] Failed to parse Gemma tool args: ${argsStr.substring(0, 100)}`); + } + } + // Normalize Gemma's "queries" array to the single "query" string our tools expect + if (name === 'web_search' && !args.query && args.queries) { + args = { ...args, query: Array.isArray(args.queries) ? args.queries[0] : args.queries }; + } + toolCalls.push({ id: `gemma-tc-${Date.now()}-${toolCalls.length}`, name, arguments: args }); + logger.log(`[ToolLoop] Parsed Gemma native tool call: ${name}(${JSON.stringify(args).substring(0, 100)})`); + } else { + logger.warn(`[ToolLoop] Gemma tool call body did not match expected format: "${raw.substring(0, 100)}"`); + } + } + + matchedRanges.sort((a, b) => b[0] - a[0]); + let cleanText = text; + for (const [start, end] of matchedRanges) { cleanText = cleanText.slice(0, start) + cleanText.slice(end); } + return { cleanText: cleanText.trim(), toolCalls }; +} + /** Parse tool calls from text output (fallback for small models). Supports JSON and XML-like formats. */ export function parseToolCallsFromText(text: string): { cleanText: string; toolCalls: ToolCall[] } { const toolCalls: ToolCall[] = []; @@ -187,14 +254,98 @@ async function callLocalWithRetry( throw new Error(lastError?.message || String(lastError) || 'Unknown LLM error after tool execution'); } -interface CallLLMOptions { onStream?: (data: StreamToken) => void; forceRemote?: boolean; disableThinking?: boolean; } +function isLiteRTActive(): boolean { + const { downloadedModels, activeModelId } = useAppStore.getState(); + const m = downloadedModels.find((model: any) => model.id === activeModelId); + return m?.engine === 'litert' && liteRTService.isModelLoaded(); +} + +/** On first iteration: last user message. On tool-result iterations: formatted tool results. */ +function buildLiteRTSendText(messages: Message[]): string { + const toolResults: Message[] = []; + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'tool') toolResults.unshift(messages[i]); + else break; + } + if (toolResults.length > 0) { + const parts = toolResults.map(m => `${m.toolName || 'tool'}: ${m.content}`); + return `Tool results:\n${parts.join('\n\n')}\n\nPlease continue based on these results.`; + } + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'user') { + const c = messages[i].content; + return typeof c === 'string' ? c : ''; + } + } + return ''; +} + +function buildLiteRTToolSystemPrompt(basePrompt: string, tools: any[]): string { + if (tools.length === 0) return basePrompt; + const toolList = tools.map((t: any) => { + const fn = t.function || t; + const props = fn.parameters?.properties || {}; + const paramStr = Object.entries(props) + .map(([k, v]: [string, any]) => `"${k}": ${JSON.stringify((v as any).description || (v as any).type || k)}`) + .join(', '); + return ` ${fn.name}({${paramStr}}): ${fn.description}`; + }).join('\n'); + return `${basePrompt} + +TOOL USE RULES — follow strictly: +- To call a tool output ONLY this on its own line, nothing else before or after: +{"name": "tool_name", "arguments": {"param": "value"}} +- Call tools IMMEDIATELY. Never say "I'll search" or "let me look that up" — just call the tool. +- NEVER ask for clarification before calling a tool. Make your best guess at the arguments and call it. +- NEVER say "I cannot browse", "I don't have access", or "I need more information" — you have tools, use them. +- If a query is ambiguous, pick the most likely interpretation and call the tool right away. +- Only respond in plain text when no tool is needed (e.g. simple math in your head, greetings). +Tools:\n${toolList}`; +} + +async function callLiteRTForLoop( + conversationId: string, + messages: Message[], + tools: any[], + onStream?: (data: StreamToken) => void, +): Promise<{ fullResponse: string; toolCalls: ToolCall[] }> { + const systemMsg = messages.find(m => m.role === 'system'); + const basePrompt = typeof systemMsg?.content === 'string' ? systemMsg.content : ''; + const systemPrompt = buildLiteRTToolSystemPrompt(basePrompt, tools); + const text = buildLiteRTSendText(messages); + + logger.log(`[ToolLoop][LiteRT] callLiteRTForLoop — convId=${conversationId}, text=${text.length}ch, sysPrompt=${systemPrompt.length}ch`); + logger.log(`[ToolLoop][LiteRT] sysPrompt content: "${systemPrompt.substring(0, 500)}"`); + logger.log(`[ToolLoop][LiteRT] sending text: "${text.substring(0, 300)}"`); + + if (!text) { + logger.warn('[ToolLoop][LiteRT] no message text — aborting'); + return { fullResponse: '', toolCalls: [] }; + } + + // prepareConversation is idempotent — only resets native conversation when context changes + await liteRTService.prepareConversation(conversationId, systemPrompt); + + const fullResponse = await liteRTService.generateRaw(text, token => { + onStream?.({ content: token }); + }); + + logger.log(`[ToolLoop][LiteRT] raw response (${fullResponse.length}ch): "${fullResponse.substring(0, 400)}"`); + logger.log(`[ToolLoop][LiteRT] contains : ${fullResponse.includes('')} | contains <|tool_call>: ${fullResponse.includes('<|tool_call>')} | contains void; forceRemote?: boolean; disableThinking?: boolean; conversationId?: string; } /** Call LLM with retry+backoff for transient native context errors. */ async function callLLMWithRetry( messages: Message[], tools: any[], - { onStream, forceRemote, disableThinking }: CallLLMOptions = {}, + { onStream, forceRemote, disableThinking, conversationId }: CallLLMOptions = {}, ): Promise<{ fullResponse: string; toolCalls: ToolCall[] }> { + if (isLiteRTActive() && conversationId) { + return callLiteRTForLoop(conversationId, messages, tools, onStream); + } const activeServerId = useRemoteServerStore.getState().activeServerId; const useRemote = forceRemote || (!!activeServerId && providerRegistry.hasProvider(activeServerId) && !llmService.isModelLoaded()); logger.log(`[ToolLoop] callLLM — remote=${useRemote}, tools=${tools.length}`); @@ -207,14 +358,22 @@ async function callLLMWithRetry( return callLocalWithRetry(messages, tools, onStream); } -/** If no structured tool calls, try parsing tags from text. */ +/** If no structured tool calls, try parsing tags or Gemma's native format from text. */ function resolveToolCalls(fullResponse: string, toolCalls: ToolCall[]) { - if (toolCalls.length > 0 || !fullResponse.includes('')) - return { effectiveToolCalls: toolCalls, displayResponse: fullResponse }; - const parsed = parseToolCallsFromText(fullResponse); - if (parsed.toolCalls.length > 0) { - logger.log(`[ToolLoop] Parsed ${parsed.toolCalls.length} tool call(s) from text output`); - return { effectiveToolCalls: parsed.toolCalls, displayResponse: parsed.cleanText }; + if (toolCalls.length > 0) return { effectiveToolCalls: toolCalls, displayResponse: fullResponse }; + if (fullResponse.includes('')) { + const parsed = parseToolCallsFromText(fullResponse); + if (parsed.toolCalls.length > 0) { + logger.log(`[ToolLoop] Parsed ${parsed.toolCalls.length} tool call(s) from tags`); + return { effectiveToolCalls: parsed.toolCalls, displayResponse: parsed.cleanText }; + } + } + if (fullResponse.includes('<|tool_call>') || fullResponse.includes(' 0) { + logger.log(`[ToolLoop] Parsed ${parsed.toolCalls.length} tool call(s) from Gemma native format`); + return { effectiveToolCalls: parsed.toolCalls, displayResponse: parsed.cleanText }; + } } return { effectiveToolCalls: toolCalls, displayResponse: fullResponse }; } @@ -264,7 +423,7 @@ async function forceFinalTextResponse(ctx: ToolLoopContext, state: ToolLoopState state.firstTokenFired = false; const forcedOnStream = buildStreamHandler(ctx, state); // Disable thinking so the model spends all tokens on actual content - const { fullResponse: forcedResponse } = await callLLMWithRetry(loopMessages, [], { onStream: forcedOnStream, forceRemote: ctx.forceRemote, disableThinking: true }); + const { fullResponse: forcedResponse } = await callLLMWithRetry(loopMessages, [], { onStream: forcedOnStream, forceRemote: ctx.forceRemote, disableThinking: true, conversationId: ctx.conversationId }); logger.log(`[ToolLoop][DEBUG] Forced response — length=${forcedResponse.length}, streamedContent=${state.streamedContent.length}, reasoning=${state.reasoningContent.length}`); emitFinalResponse(ctx, state, forcedResponse); } @@ -299,7 +458,7 @@ export async function runToolLoop(ctx: ToolLoopContext): Promise { logger.log(`[ToolLoop][DEBUG] === Iteration ${iteration} === messages=${loopMessages.length}, tools=${toolSchemas.length}, totalCalls=${totalToolCalls}`); const onStream = buildStreamHandler(ctx, state); - const { fullResponse, toolCalls } = await callLLMWithRetry(loopMessages, toolSchemas, { onStream, forceRemote: ctx.forceRemote }); + const { fullResponse, toolCalls } = await callLLMWithRetry(loopMessages, toolSchemas, { onStream, forceRemote: ctx.forceRemote, conversationId: ctx.conversationId }); logger.log(`[ToolLoop][DEBUG] LLM returned — response=${fullResponse.length}, toolCalls=${toolCalls.length}, streamed=${state.streamedContent.length}, reasoning=${state.reasoningContent.length}`); if (fullResponse.length === 0 && state.streamedContent.length === 0) { @@ -320,7 +479,7 @@ export async function runToolLoop(ctx: ToolLoopContext): Promise { state.firstTokenFired = false; const fallbackOnStream = buildStreamHandler(ctx, state); const { fullResponse: fallbackResp } = await callLLMWithRetry( - loopMessages, [], { onStream: fallbackOnStream, forceRemote: ctx.forceRemote, disableThinking: true }, + loopMessages, [], { onStream: fallbackOnStream, forceRemote: ctx.forceRemote, disableThinking: true, conversationId: ctx.conversationId }, ); emitFinalResponse(ctx, state, fallbackResp); return; diff --git a/src/services/litert.ts b/src/services/litert.ts index 523c979ab..e75a03830 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -62,6 +62,7 @@ class LiteRTService { // Multi-turn tracking — reset conversation only when context changes private activeConversationId: string | null = null; private activeSystemPrompt: string | null = null; + private _lastBenchmarkStats: LiteRTBenchmarkStats | undefined = undefined; constructor() { if (Platform.OS === 'android' && LiteRTModule) { @@ -247,6 +248,33 @@ class LiteRTService { } } + // --------------------------------------------------------------------------- + // generateRaw — used by the tool loop only. + // Wraps sendMessage into a Promise. No chat store interaction. + // --------------------------------------------------------------------------- + + async generateRaw( + text: string, + onToken?: (token: string) => void, + ): Promise { + logger.log(TAG, `generateRaw — text=${text.length}ch, first100="${text.substring(0, 100)}"`); + return new Promise((resolve, reject) => { + this.sendMessage(text, { + onToken: t => onToken?.(t), + onReasoning: () => {}, + onComplete: (fullContent, _reasoning, stats) => { + logger.log(TAG, `generateRaw — complete, response=${fullContent.length}ch, first200="${fullContent.substring(0, 200)}"`); + this._lastBenchmarkStats = stats; + resolve(fullContent); + }, + onError: (err) => { + logger.log(TAG, `generateRaw — error: ${err.message}`); + reject(err); + }, + }).catch(reject); + }); + } + // --------------------------------------------------------------------------- // stopGeneration // --------------------------------------------------------------------------- @@ -256,8 +284,6 @@ class LiteRTService { logger.log(TAG, 'stopGeneration'); this.clearSubscriptions(); this.currentCallbacks = null; - // After a stop the native conversation state is indeterminate — force reset on next turn - this.activeConversationId = null; try { await LiteRTModule.stopGeneration(); } catch (e) { @@ -302,10 +328,23 @@ class LiteRTService { return this.activeBackend; } + getLastBenchmarkStats(): LiteRTBenchmarkStats | undefined { + return this._lastBenchmarkStats; + } + isAvailable(): boolean { return Platform.OS === 'android' && !!LiteRTModule; } + /** + * Force the next prepareConversation call to reset native history. + * Call before regeneration or edit — the JS message array is being rewound, + * so the native conversation must start fresh from that point. + */ + invalidateConversation(): void { + this.activeConversationId = null; + } + async getMemoryInfo(): Promise { if (!this.isAvailable()) return null; try { diff --git a/src/stores/debugLogsStore.ts b/src/stores/debugLogsStore.ts index d0b30b621..a17ce6a57 100644 --- a/src/stores/debugLogsStore.ts +++ b/src/stores/debugLogsStore.ts @@ -2,7 +2,7 @@ import { create } from 'zustand'; import AsyncStorage from '@react-native-async-storage/async-storage'; const STORAGE_KEY = '@debug_logs'; -const MAX_LOGS = 200; +const MAX_LOGS = 2000; export interface DebugLogEntry { timestamp: number; From 7fa4f29106c3e85d45649daae65f8363d6ecf065 Mon Sep 17 00:00:00 2001 From: Dishit Date: Fri, 22 May 2026 09:51:53 +0530 Subject: [PATCH 09/93] feat(litert): pass maxNumTokens to engine, scale init timeout by context size, and wire tool call event bridge Co-Authored-By: Dishit Karia --- .../ai/offgridmobile/litert/LiteRTModule.kt | 169 +++++++++++++++--- android/gradle.properties | 2 +- src/services/activeModelService/loaders.ts | 27 ++- src/services/generationServiceHelpers.ts | 3 +- src/services/litert.ts | 138 +++++++++++--- src/stores/appStore.ts | 3 + src/stores/index.ts | 2 +- 7 files changed, 286 insertions(+), 58 deletions(-) diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt index d6c656e2c..11048bdc3 100644 --- a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -16,10 +16,17 @@ import com.google.ai.edge.litertlm.Content import com.google.ai.edge.litertlm.Contents import com.google.ai.edge.litertlm.ExperimentalApi import com.google.ai.edge.litertlm.SamplerConfig +import com.google.ai.edge.litertlm.OpenApiTool +import com.google.ai.edge.litertlm.ToolProvider +import com.google.ai.edge.litertlm.tool +import com.google.ai.edge.litertlm.Message as LiteRTMessage +import com.google.gson.JsonParser import kotlinx.coroutines.* import java.io.File import java.io.InputStream import java.io.ByteArrayOutputStream +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import android.net.Uri import android.graphics.Bitmap import android.graphics.BitmapFactory @@ -31,15 +38,30 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : private const val TAG = "LiteRTModule" // Streaming events sent to JS - const val EVENT_TOKEN = "litert_token" - const val EVENT_THINKING = "litert_thinking" - const val EVENT_COMPLETE = "litert_complete" - const val EVENT_ERROR = "litert_error" - - // Timeouts per backend tier - private const val NPU_TIMEOUT_MS = 45_000L - private const val GPU_TIMEOUT_MS = 20_000L - private const val CPU_TIMEOUT_MS = 15_000L + const val EVENT_TOKEN = "litert_token" + const val EVENT_THINKING = "litert_thinking" + const val EVENT_COMPLETE = "litert_complete" + const val EVENT_ERROR = "litert_error" + const val EVENT_TOOL_CALL = "litert_tool_call" + + // Base timeouts per backend tier (for default 4096-token context). + // Actual timeout scales up proportionally for larger context windows + // because KV-cache allocation takes longer at higher token counts. + private const val NPU_BASE_TIMEOUT_MS = 45_000L + private const val GPU_BASE_TIMEOUT_MS = 20_000L + private const val CPU_BASE_TIMEOUT_MS = 15_000L + private const val DEFAULT_CONTEXT_TOKENS = 4096 + + fun initTimeoutMs(backend: Backend, maxNumTokens: Int): Long { + val base = when (backend) { + is Backend.NPU -> NPU_BASE_TIMEOUT_MS + is Backend.GPU -> GPU_BASE_TIMEOUT_MS + else -> CPU_BASE_TIMEOUT_MS + } + // Scale linearly above the default context size, capped at 3 minutes. + val scalar = maxOf(1.0, maxNumTokens.toDouble() / DEFAULT_CONTEXT_TOKENS) + return minOf((base * scalar).toLong(), 180_000L) + } } private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) @@ -50,6 +72,10 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : private var supportsVision: Boolean = false private var currentJob: Job? = null + // Pending tool calls waiting for JS to respond via respondToToolCall() + private val pendingToolCalls = ConcurrentHashMap>() + private var configuredMaxTokens: Int = 4096 + override fun getName(): String = "LiteRTModule" // ------------------------------------------------------------------------- @@ -57,12 +83,13 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : // ------------------------------------------------------------------------- @ReactMethod - fun loadModel(modelPath: String, backendStr: String, visionEnabled: Boolean, promise: Promise) { + fun loadModel(modelPath: String, backendStr: String, visionEnabled: Boolean, maxNumTokens: Int, promise: Promise) { val safe = SafePromise(promise, TAG) - Log.i(TAG, "loadModel — path=$modelPath backend=$backendStr vision=$visionEnabled") + Log.i(TAG, "loadModel — path=$modelPath backend=$backendStr vision=$visionEnabled maxNumTokens=$maxNumTokens") scope.launch { try { + configuredMaxTokens = maxNumTokens // Unload any existing engine first cleanupEngine() @@ -116,15 +143,12 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : val cfg = EngineConfig( modelPath = modelPath, backend = backend, + maxNumTokens = configuredMaxTokens, cacheDir = null, visionBackend = if (visionEnabled) Backend.GPU() else null, ) val eng = Engine(cfg) - val timeoutMs = when (backend) { - is Backend.NPU -> NPU_TIMEOUT_MS - is Backend.GPU -> GPU_TIMEOUT_MS - else -> CPU_TIMEOUT_MS - } + val timeoutMs = initTimeoutMs(backend, configuredMaxTokens) withTimeout(timeoutMs) { eng.initialize() } @@ -153,9 +177,9 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : // ------------------------------------------------------------------------- @ReactMethod - fun resetConversation(systemPrompt: String, temperature: Double, topK: Int, topP: Double, promise: Promise) { + fun resetConversation(systemPrompt: String, temperature: Double, topK: Int, topP: Double, toolsJson: String, historyJson: String, promise: Promise) { val safe = SafePromise(promise, TAG) - Log.i(TAG, "resetConversation — systemPrompt length=${systemPrompt.length} temperature=$temperature topK=$topK topP=$topP") + Log.i(TAG, "resetConversation — systemPrompt length=${systemPrompt.length} temperature=$temperature topK=$topK topP=$topP tools=${toolsJson.length}ch history=${historyJson.length}ch") scope.launch { try { @@ -181,14 +205,19 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : ) } + val toolProviders = buildToolProviders(toolsJson) + val initialMessages = parseHistoryMessages(historyJson) val convConfig = ConversationConfig( systemInstruction = if (systemPrompt.isNotEmpty()) Contents.of(systemPrompt) else null, + initialMessages = initialMessages, + tools = toolProviders, samplerConfig = samplerConfig, + automaticToolCalling = toolProviders.isNotEmpty(), ) conversation = eng.createConversation(convConfig) - Log.i(TAG, "resetConversation — new conversation created") + Log.i(TAG, "resetConversation — new conversation created (tools=${toolProviders.size}, history=${initialMessages.size})") safe.resolve(null) } catch (e: Exception) { Log.e(TAG, "resetConversation — error: ${e.message}", e) @@ -261,8 +290,8 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : @OptIn(ExperimentalApi::class) val benchmarkJson = try { val b = conv.getBenchmarkInfo() - Log.i(TAG, "getBenchmarkInfo — ttft=${b.timeToFirstTokenInSecond} decode=${b.lastDecodeTokensPerSecond} prefill=${b.lastPrefillTokensPerSecond} prefillCount=${b.lastPrefillTokenCount} init=${b.initTimeInSecond}") - """{"ttft":${b.timeToFirstTokenInSecond},"decodeTokensPerSecond":${b.lastDecodeTokensPerSecond},"prefillTokensPerSecond":${b.lastPrefillTokensPerSecond},"prefillTokenCount":${b.lastPrefillTokenCount},"initTimeSeconds":${b.initTimeInSecond}}""" + Log.i(TAG, "getBenchmarkInfo — ttft=${b.timeToFirstTokenInSecond} decode=${b.lastDecodeTokensPerSecond} prefill=${b.lastPrefillTokensPerSecond} prefillCount=${b.lastPrefillTokenCount} decodeCount=${b.lastDecodeTokenCount} init=${b.initTimeInSecond}") + """{"ttft":${b.timeToFirstTokenInSecond},"decodeTokensPerSecond":${b.lastDecodeTokensPerSecond},"prefillTokensPerSecond":${b.lastPrefillTokensPerSecond},"prefillTokenCount":${b.lastPrefillTokenCount},"decodeTokenCount":${b.lastDecodeTokenCount},"maxNumTokens":$configuredMaxTokens,"initTimeSeconds":${b.initTimeInSecond}}""" } catch (e: Exception) { Log.w(TAG, "getBenchmarkInfo failed: ${e.message}") "" @@ -288,6 +317,16 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : } } + // ------------------------------------------------------------------------- + // respondToToolCall — called from JS to unblock a pending tool execute() + // ------------------------------------------------------------------------- + + @ReactMethod + fun respondToToolCall(callId: String, result: String) { + Log.d(TAG, "respondToToolCall — callId=$callId resultLen=${result.length}") + pendingToolCalls.remove(callId)?.complete(result) + } + // ------------------------------------------------------------------------- // stopGeneration // ------------------------------------------------------------------------- @@ -349,6 +388,13 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : // ------------------------------------------------------------------------- private fun closeConversation() { + // Cancel any tool calls blocked in execute() so their threads can unblock + pendingToolCalls.forEach { (callId, deferred) -> + Log.d(TAG, "closeConversation — cancelling pending tool call $callId") + deferred.cancel(CancellationException("Conversation closed")) + } + pendingToolCalls.clear() + try { conversation?.close() Log.d(TAG, "closeConversation — closed") @@ -461,6 +507,87 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : return Bitmap.createScaledBitmap(src, newW, newH, true) } + /** + * Convert a JSON array of prior turns into LiteRT Message objects for ConversationConfig.initialMessages. + * Only user/assistant text turns are replayed — tool call bridge messages are skipped because + * the native SDK doesn't need them when automaticToolCalling handles the cycle. + * Format: [{"role":"user"|"assistant","content":"..."}] + */ + private fun parseHistoryMessages(historyJson: String): List { + if (historyJson.isBlank()) return emptyList() + return try { + val arr = JsonParser.parseString(historyJson).asJsonArray + arr.mapNotNull { element -> + val obj = element.asJsonObject + val content = obj.get("content")?.asString?.trim() ?: return@mapNotNull null + if (content.isEmpty()) return@mapNotNull null + when (obj.get("role")?.asString) { + "user" -> LiteRTMessage.user(content) + "assistant" -> LiteRTMessage.model(Contents.of(content)) + else -> return@mapNotNull null + } + }.also { Log.i(TAG, "parseHistoryMessages — replaying ${it.size} turns") } + } catch (e: Exception) { + Log.w(TAG, "parseHistoryMessages — failed: ${e.message}") + emptyList() + } + } + + /** + * Parse toolsJson (OpenAPI-format array) into a list of ToolProviders. + * Each ToolProvider wraps one OpenApiTool whose execute() bridges the synchronous SDK + * callback to async JS via CompletableDeferred: + * 1. Emit litert_tool_call event to JS with a unique callId and the raw args JSON string + * 2. Block on a CompletableDeferred until JS calls respondToToolCall(callId, result) + * 3. Return the result string to the SDK + */ + private fun buildToolProviders(toolsJson: String): List { + if (toolsJson.isBlank()) return emptyList() + return try { + val toolsArray = JsonParser.parseString(toolsJson).asJsonArray + if (toolsArray.size() == 0) return emptyList() + + val providers = toolsArray.map { element -> + val toolObj = element.asJsonObject + val toolName = toolObj.get("name").asString + val toolDescriptionJson = toolObj.toString() + + val openApiTool = object : OpenApiTool { + override fun getToolDescriptionJsonString(): String = toolDescriptionJson + + override fun execute(argsJson: String): String { + val callId = UUID.randomUUID().toString() + val deferred = CompletableDeferred() + pendingToolCalls[callId] = deferred + + val eventJson = """{"id":"$callId","name":"$toolName","arguments":$argsJson}""" + Log.d(TAG, "buildToolProviders — emitting tool call callId=$callId name=$toolName") + sendEvent(EVENT_TOOL_CALL, eventJson) + + return try { + runBlocking { withTimeout(30_000L) { deferred.await() } } + } catch (e: TimeoutCancellationException) { + Log.w(TAG, "buildToolProviders — tool call $callId timed out") + pendingToolCalls.remove(callId) + "Error: Tool call timed out" + } catch (e: CancellationException) { + Log.w(TAG, "buildToolProviders — tool call $callId cancelled") + pendingToolCalls.remove(callId) + "Error: Tool call cancelled" + } + } + } + tool(openApiTool) + } + + Log.i(TAG, "buildToolProviders — registered ${providers.size} tools") + providers + } catch (e: Exception) { + Log.w(TAG, "buildToolProviders — failed to parse toolsJson: ${e.message}") + emptyList() + } + } + private fun sendEvent(eventName: String, data: String) { try { reactContext diff --git a/android/gradle.properties b/android/gradle.properties index 9afe61598..6da9b0a50 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -10,7 +10,7 @@ # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. # Default value: -Xmx512m -XX:MaxMetaspaceSize=256m -org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=512m +org.gradle.jvmargs=-Xmx4096m -XX:MaxMetaspaceSize=1024m # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit diff --git a/src/services/activeModelService/loaders.ts b/src/services/activeModelService/loaders.ts index 83db5941f..97e83ee15 100644 --- a/src/services/activeModelService/loaders.ts +++ b/src/services/activeModelService/loaders.ts @@ -126,9 +126,12 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { const preferredBackend = inferenceBackendToLiteRT(ctx.store.settings.inferenceBackend); addDebugLog('log', `[LiteRT] Preferred backend: ${preferredBackend}`); - const timeoutMs = preferredBackend === 'npu' ? 45_000 - : preferredBackend === 'gpu' ? 20_000 - : 15_000; + const maxTokens = ctx.store.settings.contextLength ?? 4096; + const contextScalar = Math.max(1, maxTokens / 4096); + const baseTimeoutMs = preferredBackend === 'npu' ? 45_000 + : preferredBackend === 'gpu' ? 20_000 + : 15_000; + const timeoutMs = Math.min(Math.ceil(baseTimeoutMs * contextScalar), 180_000); let timeoutId: ReturnType | null = null; const timeoutPromise = new Promise((_, reject) => { @@ -139,9 +142,9 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { }); try { - addDebugLog('log', `[LiteRT] Calling liteRTService.loadModel (timeout ${timeoutMs / 1000}s, vision=${ctx.model.liteRTVision ?? false}).`); + addDebugLog('log', `[LiteRT] Calling liteRTService.loadModel (timeout ${timeoutMs / 1000}s, vision=${ctx.model.liteRTVision ?? false}, maxNumTokens=${maxTokens}).`); await Promise.race([ - liteRTService.loadModel(ctx.model.filePath, preferredBackend, ctx.model.liteRTVision ?? false), + liteRTService.loadModel(ctx.model.filePath, preferredBackend, ctx.model.liteRTVision ?? false, maxTokens), timeoutPromise, ]); } finally { @@ -162,6 +165,20 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { addDebugLog('log', `[LiteRT] Warmup complete in ${((Date.now() - warmupStart) / 1000).toFixed(1)}s`); } + // Snapshot the settings that require a full engine reload so the pending-settings + // banner appears if the user changes them while the model is loaded. + ctx.store.setLoadedSettings({ + inferenceBackend: ctx.store.settings.inferenceBackend, + contextLength: maxTokens, + // Fields not used by LiteRT — set to current values so llama checks don't misfire + enableGpu: ctx.store.settings.enableGpu, + gpuLayers: ctx.store.settings.gpuLayers, + nThreads: ctx.store.settings.nThreads, + nBatch: ctx.store.settings.nBatch, + flashAttn: ctx.store.settings.flashAttn, + cacheType: ctx.store.settings.cacheType, + }); + ctx.onLoaded(ctx.modelId); ctx.store.setActiveModelId(ctx.modelId); } catch (error) { diff --git a/src/services/generationServiceHelpers.ts b/src/services/generationServiceHelpers.ts index 58cb30311..d62447444 100644 --- a/src/services/generationServiceHelpers.ts +++ b/src/services/generationServiceHelpers.ts @@ -220,8 +220,7 @@ export async function generateResponseImpl( try { const { settings } = useAppStore.getState(); await liteRTService.prepareConversation(conversationId, systemPrompt, { - temperature: settings.temperature, - topP: settings.topP, + samplerConfig: { temperature: settings.temperature, topP: settings.topP }, }); dbg('log', `[LiteRT] sendMessage start — imageUri=${imageUri ?? 'none'}`); diff --git a/src/services/litert.ts b/src/services/litert.ts index e75a03830..ab1a84859 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -18,10 +18,11 @@ const TAG = '[LiteRTService]'; const { LiteRTModule } = NativeModules; // Events emitted by the native module -const EVENT_TOKEN = 'litert_token'; -const EVENT_THINKING = 'litert_thinking'; -const EVENT_COMPLETE = 'litert_complete'; -const EVENT_ERROR = 'litert_error'; +const EVENT_TOKEN = 'litert_token'; +const EVENT_THINKING = 'litert_thinking'; +const EVENT_COMPLETE = 'litert_complete'; +const EVENT_ERROR = 'litert_error'; +const EVENT_TOOL_CALL = 'litert_tool_call'; export type LiteRTBackend = 'cpu' | 'gpu' | 'npu'; @@ -30,6 +31,8 @@ export interface LiteRTBenchmarkStats { decodeTokensPerSecond: number; prefillTokensPerSecond: number; prefillTokenCount: number; + decodeTokenCount: number; + maxNumTokens?: number; initTimeSeconds: number; } @@ -57,13 +60,18 @@ class LiteRTService { // Accumulated content for current generation private currentContent = ''; private currentReasoning = ''; - private currentCallbacks: LiteRTGenerationCallbacks | null = null; + private currentToolCallHandler: ((name: string, args: Record) => Promise) | null = null; // Multi-turn tracking — reset conversation only when context changes private activeConversationId: string | null = null; private activeSystemPrompt: string | null = null; + private activeToolsJson: string | null = null; private _lastBenchmarkStats: LiteRTBenchmarkStats | undefined = undefined; + // Context usage tracking — cumulative tokens across turns, reset on conversation reset + private cumulativeTokens = 0; + private configuredMaxTokens = 4096; + constructor() { if (Platform.OS === 'android' && LiteRTModule) { this.emitter = new NativeEventEmitter(LiteRTModule); @@ -77,15 +85,16 @@ class LiteRTService { // loadModel // --------------------------------------------------------------------------- - async loadModel(modelPath: string, preferredBackend: LiteRTBackend, supportsVision = false): Promise { + async loadModel(modelPath: string, preferredBackend: LiteRTBackend, supportsVision = false, maxNumTokens = 4096): Promise { if (!this.isAvailable()) { throw new Error('LiteRT is not available on this platform'); } - logger.log(TAG, `loadModel — path=${modelPath} backend=${preferredBackend} supportsVision=${supportsVision}`); + this.configuredMaxTokens = maxNumTokens; + logger.log(TAG, `loadModel — path=${modelPath} backend=${preferredBackend} supportsVision=${supportsVision} maxNumTokens=${maxNumTokens}`); try { - const actualBackend: string = await LiteRTModule.loadModel(modelPath, preferredBackend, supportsVision); + const actualBackend: string = await LiteRTModule.loadModel(modelPath, preferredBackend, supportsVision, maxNumTokens); this.activeBackend = actualBackend as LiteRTBackend; this.loaded = true; logger.log(TAG, `loadModel — loaded on ${this.activeBackend}`); @@ -103,17 +112,26 @@ class LiteRTService { async resetConversation( systemPrompt: string, - samplerConfig?: { temperature?: number; topK?: number; topP?: number }, + opts?: { + samplerConfig?: { temperature?: number; topK?: number; topP?: number }; + tools?: any[]; + history?: Array<{ role: 'user' | 'assistant'; content: string }>; + }, ): Promise { if (!this.isAvailable() || !this.loaded) { throw new Error('No LiteRT model loaded'); } + const { samplerConfig, tools, history } = opts ?? {}; const temperature = samplerConfig?.temperature ?? 0.8; const topK = samplerConfig?.topK ?? 40; const topP = samplerConfig?.topP ?? 0.95; - logger.log(TAG, `resetConversation — systemPrompt length=${systemPrompt.length} temperature=${temperature} topK=${topK} topP=${topP}`); - await LiteRTModule.resetConversation(systemPrompt, temperature, topK, topP); + const toolsJson = tools && tools.length > 0 ? JSON.stringify(tools) : ''; + const historyJson = history && history.length > 0 ? JSON.stringify(history) : ''; + logger.log(TAG, `resetConversation — systemPrompt length=${systemPrompt.length} temperature=${temperature} topK=${topK} topP=${topP} tools=${tools?.length ?? 0} history=${history?.length ?? 0}`); + await LiteRTModule.resetConversation(systemPrompt, temperature, topK, topP, toolsJson, historyJson); this.activeSystemPrompt = systemPrompt; + this.activeToolsJson = toolsJson; + this.cumulativeTokens = 0; logger.log(TAG, 'resetConversation — done'); } @@ -121,18 +139,40 @@ class LiteRTService { * Ensure conversation is ready for the given context. * Resets only when conversationId or systemPrompt has changed — preserves * native turn history for follow-up messages in the same conversation. + * Auto-trims history when cumulative token usage exceeds 80% of the context limit. */ async prepareConversation( conversationId: string, systemPrompt: string, - samplerConfig?: { temperature?: number; topK?: number; topP?: number }, + opts?: { + samplerConfig?: { temperature?: number; topK?: number; topP?: number }; + tools?: any[]; + history?: Array<{ role: 'user' | 'assistant'; content: string }>; + }, ): Promise { + const toolsJson = opts?.tools && opts.tools.length > 0 ? JSON.stringify(opts.tools) : ''; + + // Auto-compact: trim oldest turns when nearing context limit + const maxTokens = this.configuredMaxTokens; + let history = opts?.history; + if (maxTokens > 0 && this.cumulativeTokens > maxTokens * 0.8 && history && history.length > 2) { + const trimmedHistory = history.slice(Math.floor(history.length / 2)); + logger.log(TAG, `prepareConversation — auto-compact: cumulativeTokens=${this.cumulativeTokens} > ${Math.floor(maxTokens * 0.8)} (80% of ${maxTokens}), trimming history ${history.length} → ${trimmedHistory.length} turns`); + this.cumulativeTokens = Math.floor(this.cumulativeTokens * 0.5); + await this.resetConversation(systemPrompt, { samplerConfig: opts?.samplerConfig, tools: opts?.tools, history: trimmedHistory }); + this.activeConversationId = conversationId; + this.activeSystemPrompt = systemPrompt; + this.activeToolsJson = toolsJson; + return; + } + const needsReset = this.activeConversationId !== conversationId || - this.activeSystemPrompt !== systemPrompt; + this.activeSystemPrompt !== systemPrompt || + this.activeToolsJson !== toolsJson; if (needsReset) { - logger.log(TAG, `prepareConversation — reset (convId changed=${this.activeConversationId !== conversationId}, sysPrompt changed=${this.activeSystemPrompt !== systemPrompt})`); - await this.resetConversation(systemPrompt, samplerConfig); + logger.log(TAG, `prepareConversation — reset (convId changed=${this.activeConversationId !== conversationId}, sysPrompt changed=${this.activeSystemPrompt !== systemPrompt}, tools changed=${this.activeToolsJson !== toolsJson}, history=${opts?.history?.length ?? 0})`); + await this.resetConversation(systemPrompt, { samplerConfig: opts?.samplerConfig, tools: opts?.tools, history: opts?.history }); this.activeConversationId = conversationId; } else { logger.log(TAG, 'prepareConversation — reusing existing conversation (multi-turn)'); @@ -184,19 +224,19 @@ class LiteRTService { // Reset accumulators this.currentContent = ''; this.currentReasoning = ''; - this.currentCallbacks = callbacks; + // currentToolCallHandler is set by generateRaw before sendMessage is called // Wall-clock tracking const sendStart = Date.now(); let firstTokenTime: number | undefined; - let decodeTokenCount = 0; + let jsDecodeTokenCount = 0; // Register event listeners for this generation this.clearSubscriptions(); this.subscriptions = [ this.emitter!.addListener(EVENT_TOKEN, (token: string) => { if (firstTokenTime === undefined) firstTokenTime = Date.now(); - decodeTokenCount++; + jsDecodeTokenCount++; this.currentContent += token; callbacks.onToken(token); }), @@ -207,41 +247,74 @@ class LiteRTService { this.emitter!.addListener(EVENT_COMPLETE, (benchmarkJson: string) => { logger.log(TAG, `sendMessage — complete, content=${this.currentContent.length} chars`); this.clearSubscriptions(); - this.currentCallbacks = null; + + this.currentToolCallHandler = null; const addLog = useDebugLogsStore.getState().addLog; + // Parse native benchmark stats for accurate token counts + let nativePrefillCount = 0; + let nativeDecodeCount = jsDecodeTokenCount; + if (benchmarkJson) { + try { + const native = JSON.parse(benchmarkJson) as Record; + nativePrefillCount = native.prefillTokenCount ?? 0; + nativeDecodeCount = native.decodeTokenCount ?? jsDecodeTokenCount; + } catch { /* use JS fallback counts */ } + } + + // Accumulate into cumulative context usage + this.cumulativeTokens += nativePrefillCount + nativeDecodeCount; + // Build wall-clock stats const completeTime = Date.now(); const ttft = firstTokenTime !== undefined ? (firstTokenTime - sendStart) / 1000 : undefined; const decodeElapsed = firstTokenTime !== undefined ? (completeTime - firstTokenTime) / 1000 : undefined; - const decodeTokensPerSecond = decodeElapsed && decodeElapsed > 0 && decodeTokenCount > 1 - ? decodeTokenCount / decodeElapsed + const decodeTokensPerSecond = decodeElapsed && decodeElapsed > 0 && jsDecodeTokenCount > 1 + ? jsDecodeTokenCount / decodeElapsed : undefined; const wallClockStats: LiteRTBenchmarkStats = { ttft: ttft ?? 0, decodeTokensPerSecond: decodeTokensPerSecond ?? 0, prefillTokensPerSecond: 0, - prefillTokenCount: decodeTokenCount, + prefillTokenCount: nativePrefillCount || jsDecodeTokenCount, + decodeTokenCount: nativeDecodeCount, + maxNumTokens: this.configuredMaxTokens, initTimeSeconds: 0, }; - addLog('log', `[LiteRTService] wall-clock stats — ttft=${ttft?.toFixed(3)}s decode=${decodeTokensPerSecond?.toFixed(1)}tok/s tokens=${decodeTokenCount}`); + addLog('log', `[LiteRTService] wall-clock stats — ttft=${ttft?.toFixed(3)}s decode=${decodeTokensPerSecond?.toFixed(1)}tok/s tokens=${jsDecodeTokenCount} cumulative=${this.cumulativeTokens}/${this.configuredMaxTokens}`); callbacks.onComplete(this.currentContent, this.currentReasoning, wallClockStats); }), this.emitter!.addListener(EVENT_ERROR, (message: string) => { logger.log(TAG, `sendMessage — error: ${message}`); this.clearSubscriptions(); - this.currentCallbacks = null; + + this.currentToolCallHandler = null; callbacks.onError(new Error(message)); }), + this.emitter!.addListener(EVENT_TOOL_CALL, async (json: string) => { + logger.log(TAG, `sendMessage — tool call received: ${json.substring(0, 200)}`); + try { + const { id, name, arguments: args } = JSON.parse(json) as { + id: string; + name: string; + arguments: Record; + }; + const handler = this.currentToolCallHandler; + const result = handler ? await handler(name, args) : ''; + logger.log(TAG, `sendMessage — responding to tool call id=${id} name=${name} resultLen=${result.length}`); + await LiteRTModule.respondToToolCall(id, result); + } catch (e) { + logger.log(TAG, `sendMessage — tool call handling error: ${String(e)}`); + } + }), ]; try { await LiteRTModule.sendMessage(text, imageUri ?? null); } catch (e) { this.clearSubscriptions(); - this.currentCallbacks = null; const err = e instanceof Error ? e : new Error(String(e)); logger.log(TAG, `sendMessage — native error: ${err.message}`); callbacks.onError(err); @@ -256,8 +329,10 @@ class LiteRTService { async generateRaw( text: string, onToken?: (token: string) => void, + onToolCall?: (name: string, args: Record) => Promise, ): Promise { - logger.log(TAG, `generateRaw — text=${text.length}ch, first100="${text.substring(0, 100)}"`); + logger.log(TAG, `generateRaw — text=${text.length}ch, hasToolHandler=${!!onToolCall}, first100="${text.substring(0, 100)}"`); + this.currentToolCallHandler = onToolCall ?? null; return new Promise((resolve, reject) => { this.sendMessage(text, { onToken: t => onToken?.(t), @@ -269,6 +344,7 @@ class LiteRTService { }, onError: (err) => { logger.log(TAG, `generateRaw — error: ${err.message}`); + this.currentToolCallHandler = null; reject(err); }, }).catch(reject); @@ -283,7 +359,6 @@ class LiteRTService { if (!this.isAvailable()) return; logger.log(TAG, 'stopGeneration'); this.clearSubscriptions(); - this.currentCallbacks = null; try { await LiteRTModule.stopGeneration(); } catch (e) { @@ -299,9 +374,12 @@ class LiteRTService { if (!this.isAvailable()) return; logger.log(TAG, 'unloadModel'); this.clearSubscriptions(); - this.currentCallbacks = null; + this.currentToolCallHandler = null; this.activeConversationId = null; this.activeSystemPrompt = null; + this.activeToolsJson = null; + this.cumulativeTokens = 0; + this.configuredMaxTokens = 4096; try { await LiteRTModule.unloadModel(); } catch (e) { @@ -332,6 +410,10 @@ class LiteRTService { return this._lastBenchmarkStats; } + getContextUsage(): { used: number; max: number } { + return { used: this.cumulativeTokens, max: this.configuredMaxTokens }; + } + isAvailable(): boolean { return Platform.OS === 'android' && !!LiteRTModule; } diff --git a/src/stores/appStore.ts b/src/stores/appStore.ts index 396a4963f..dd04aaa54 100644 --- a/src/stores/appStore.ts +++ b/src/stores/appStore.ts @@ -172,6 +172,9 @@ function migratePersistedState(persistedState: any, currentState: AppState): App return merged as AppState; } +export const selectIsLiteRT = (state: AppState): boolean => + state.downloadedModels.find(m => m.id === state.activeModelId)?.engine === 'litert' ?? false; + export const useAppStore = create()( persist( (set, get) => ({ diff --git a/src/stores/index.ts b/src/stores/index.ts index fd14cb482..a5b2c1e96 100644 --- a/src/stores/index.ts +++ b/src/stores/index.ts @@ -1,4 +1,4 @@ -export { useAppStore } from './appStore'; +export { useAppStore, selectIsLiteRT } from './appStore'; export { useChatStore } from './chatStore'; export { useProjectStore } from './projectStore'; export { useAuthStore } from './authStore'; From 491746e3bbfe70d48ae49ee6dde43c287d19cb4f Mon Sep 17 00:00:00 2001 From: Dishit Date: Fri, 22 May 2026 09:51:58 +0530 Subject: [PATCH 10/93] feat(litert): filter settings UI per engine and add selectIsLiteRT store selector Co-Authored-By: Dishit Karia --- .../TextGenerationAdvanced.tsx | 4 +- .../TextGenerationSection.tsx | 34 +++++--- .../TextGenerationAdvanced.tsx | 6 +- .../TextGenerationSection.tsx | 31 ++++++- src/screens/ModelsScreen/TextModelsTab.tsx | 75 ++++++++++++++++- src/screens/ModelsScreen/styles.ts | 81 +++++++++++++++++++ 6 files changed, 207 insertions(+), 24 deletions(-) diff --git a/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx b/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx index 587673b09..66e8bb0fb 100644 --- a/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx +++ b/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx @@ -36,7 +36,7 @@ const HTP_BACKEND: BackendOption = { id: INFERENCE_BACKENDS.HTP, label: 'HTP', desc: 'Offload layers to Hexagon NPU on Snapdragon devices. Best for large models. Requires model reload.', }; -export const BackendSelector: React.FC = () => { +export const BackendSelector: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpuLayers = false }) => { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); const { settings, updateSettings } = useAppStore(); @@ -80,7 +80,7 @@ export const BackendSelector: React.FC = () => { ))} - {showLayers && ( + {showLayers && !hideGpuLayers && ( {layersLabel} diff --git a/src/components/GenerationSettingsModal/TextGenerationSection.tsx b/src/components/GenerationSettingsModal/TextGenerationSection.tsx index b5bdb7186..c0e29cf54 100644 --- a/src/components/GenerationSettingsModal/TextGenerationSection.tsx +++ b/src/components/GenerationSettingsModal/TextGenerationSection.tsx @@ -3,7 +3,7 @@ import { View, Text, TouchableOpacity } from 'react-native'; import Slider from '@react-native-community/slider'; import { AdvancedToggle } from '../AdvancedToggle'; import { useTheme, useThemedStyles } from '../../theme'; -import { useAppStore } from '../../stores'; +import { useAppStore, selectIsLiteRT } from '../../stores'; import { createStyles } from './styles'; import { CpuThreadsSlider, @@ -42,8 +42,10 @@ const contextWarning = (v: number): string | null => v > HIGH_CONTEXT_THRESHOLD ? 'High context uses significant RAM and may crash on some devices' : null; const BASIC_KEYS = ['temperature', 'maxTokens', 'contextLength']; +const LITERT_BASIC_KEYS = ['temperature', 'contextLength']; +const LITERT_ADVANCED_KEYS = ['topP']; -const buildSettingsConfig = (modelMaxContext: number | null): SettingConfig[] => [ +const buildSettingsConfig = (modelMaxContext: number | null, isLiteRT: boolean): SettingConfig[] => [ { key: 'temperature', label: 'Temperature', @@ -82,12 +84,14 @@ const buildSettingsConfig = (modelMaxContext: number | null): SettingConfig[] => }, { key: 'contextLength', - label: 'Context Length', + label: isLiteRT ? 'Max Tokens' : 'Context Length', min: 512, max: modelMaxContext || FALLBACK_MAX_CONTEXT, step: 1024, format: formatContext, - description: 'KV cache size — larger uses more RAM (requires reload)', + description: isLiteRT + ? 'Total context window — input + history + output combined (requires reload)' + : 'KV cache size — larger uses more RAM (requires reload)', warning: contextWarning, }, ]; @@ -174,11 +178,15 @@ const ShowGenerationDetailsToggle: React.FC = () => { export const TextGenerationSection: React.FC = () => { const styles = useThemedStyles(createStyles); const modelMaxContext = useAppStore((s) => s.modelMaxContext); - const settingsConfig = buildSettingsConfig(modelMaxContext); + const isLiteRT = useAppStore(selectIsLiteRT); + const settingsConfig = buildSettingsConfig(modelMaxContext, isLiteRT); const [showAdvanced, setShowAdvanced] = useState(false); - const basicSettings = settingsConfig.filter(c => BASIC_KEYS.includes(c.key)); - const advancedSettings = settingsConfig.filter(c => !BASIC_KEYS.includes(c.key)); + const basicKeys = isLiteRT ? LITERT_BASIC_KEYS : BASIC_KEYS; + const advancedKeys = isLiteRT ? LITERT_ADVANCED_KEYS : settingsConfig.filter(c => !BASIC_KEYS.includes(c.key)).map(c => c.key); + + const basicSettings = settingsConfig.filter(c => basicKeys.includes(c.key)); + const advancedSettings = settingsConfig.filter(c => advancedKeys.includes(c.key)); return ( @@ -194,12 +202,12 @@ export const TextGenerationSection: React.FC = () => { {advancedSettings.map((config) => ( ))} - - - - - - + {!isLiteRT && } + {!isLiteRT && } + + {!isLiteRT && } + {!isLiteRT && } + {!isLiteRT && } )} diff --git a/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx b/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx index 09bcad695..3e8d0ec97 100644 --- a/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx +++ b/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx @@ -33,7 +33,7 @@ const ANDROID_BASE_BACKENDS: BackendOption[] = [ const HTP_BACKEND: BackendOption = { id: INFERENCE_BACKENDS.HTP, label: 'HTP' }; -const BackendSelectorSection: React.FC = () => { +const BackendSelectorSection: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpuLayers = false }) => { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); const { settings, updateSettings } = useAppStore(); @@ -82,7 +82,7 @@ const BackendSelectorSection: React.FC = () => { ))} - {showLayers && ( + {showLayers && !hideGpuLayers && ( @@ -309,7 +309,7 @@ export const TextGenerationAdvanced: React.FC<{ isLiteRT?: boolean }> = ({ isLit )} - {!isLiteRT && } + {!isLiteRT && } {!isLiteRT && } diff --git a/src/screens/ModelSettingsScreen/TextGenerationSection.tsx b/src/screens/ModelSettingsScreen/TextGenerationSection.tsx index f13cf6c56..ee084c066 100644 --- a/src/screens/ModelSettingsScreen/TextGenerationSection.tsx +++ b/src/screens/ModelSettingsScreen/TextGenerationSection.tsx @@ -3,7 +3,7 @@ import { View, Text, Switch } from 'react-native'; import Slider from '@react-native-community/slider'; import { AdvancedToggle, Card } from '../../components'; import { useTheme, useThemedStyles } from '../../theme'; -import { useAppStore } from '../../stores'; +import { useAppStore, selectIsLiteRT } from '../../stores'; import { createStyles } from './styles'; import { TextGenerationAdvanced } from './TextGenerationAdvanced'; @@ -15,9 +15,8 @@ export const TextGenerationSection: React.FC = () => { const styles = useThemedStyles(createStyles); const { settings, updateSettings } = useAppStore(); const modelMaxContext = useAppStore((s) => s.modelMaxContext); + const isLiteRT = useAppStore(selectIsLiteRT); const [showAdvanced, setShowAdvanced] = useState(false); - const { downloadedModels, activeModelId } = useAppStore.getState(); - const isLiteRT = downloadedModels.find(m => m.id === activeModelId)?.engine === 'litert'; const trackColor = { false: colors.surfaceLight, true: `${colors.primary}80` }; const maxTokens = settings?.maxTokens || 512; @@ -76,7 +75,31 @@ export const TextGenerationSection: React.FC = () => { )} - {!isLiteRT && ( + {isLiteRT ? ( + + + Max Tokens + {contextLengthLabel} + + Total context window — input + history + output combined (requires reload) + {contextLength > HIGH_CONTEXT_THRESHOLD && ( + + High values use significant RAM and may fail on some devices + + )} + updateSettings({ contextLength: value })} + minimumTrackTintColor={colors.primary} + maximumTrackTintColor={colors.surface} + thumbTintColor={colors.primary} + /> + + ) : ( Context Length diff --git a/src/screens/ModelsScreen/TextModelsTab.tsx b/src/screens/ModelsScreen/TextModelsTab.tsx index 6abafb8e3..2586a30c0 100644 --- a/src/screens/ModelsScreen/TextModelsTab.tsx +++ b/src/screens/ModelsScreen/TextModelsTab.tsx @@ -1,5 +1,5 @@ import React, { useEffect } from 'react'; -import { View, Text, FlatList, TextInput, ActivityIndicator, RefreshControl, TouchableOpacity, InteractionManager } from 'react-native'; +import { View, Text, FlatList, TextInput, ActivityIndicator, RefreshControl, TouchableOpacity, InteractionManager, Linking } from 'react-native'; import Icon from 'react-native-vector-icons/Feather'; import { AttachStep, useSpotlightTour } from 'react-native-spotlight-tour'; import { Card, ModelCard } from '../../components'; @@ -188,6 +188,72 @@ const ModelDetailView: React.FC = ({ ); }; +const LITERT_FEATURED = [ + { + id: 'gemma4-e2b-litert', + name: 'Gemma 4 E2B', + author: 'google', + description: 'Google\'s latest, thinking mode + vision, MoE architecture', + chips: ['NPU / GPU', 'Vision', 'Thinking', '~1.5 GB'], + highlight: 'Up to 2x faster than CPU via NPU hardware acceleration', + url: 'https://huggingface.co/google/gemma-4-E2B-it-litert-lm', + }, + { + id: 'gemma4-e4b-litert', + name: 'Gemma 4 E4B', + author: 'google', + description: 'Stronger reasoning + vision, MoE fits more in less RAM', + chips: ['NPU / GPU', 'Vision', 'Thinking', '~3.5 GB'], + highlight: 'Higher quality, same hardware efficiency as E2B', + url: 'https://huggingface.co/google/gemma-4-E4B-it-litert-lm', + }, +] as const; + +const FeaturedLiteRTCard: React.FC<{ model: typeof LITERT_FEATURED[number]; styles: any; colors: any }> = ({ model, styles, colors }) => ( + Linking.openURL(model.url)} + activeOpacity={0.85} + > + + + {model.name} + + {model.author} + + + + LiteRT + + + + {model.description} + + {model.chips.map(chip => ( + + {chip} + + ))} + + + {model.highlight} + + + Get + + + +); + +const FeaturedLiteRTSection: React.FC<{ styles: any; colors: any }> = ({ styles, colors }) => ( + + Hardware-accelerated + {LITERT_FEATURED.map(model => ( + + ))} + +); + const DeviceBanner: React.FC<{ ramGB: number; rec: { maxParameters: number; recommendedQuantization: string }; showTitle: boolean; styles: any }> = ({ ramGB, rec, showTitle, styles }) => ( {Math.round(ramGB)}GB RAM — models up to {rec.maxParameters}B recommended ({rec.recommendedQuantization}) @@ -345,7 +411,12 @@ export const TextModelsTab: React.FC = (props) => { contentContainerStyle={styles.listContent} testID="models-list" refreshControl={} - ListHeaderComponent={hasSearched ? null : 0} styles={styles} />} + ListHeaderComponent={hasSearched ? null : ( + <> + + 0} styles={styles} /> + + )} ListEmptyComponent={ {getEmptyText(hasSearched, hasActiveFilters)} diff --git a/src/screens/ModelsScreen/styles.ts b/src/screens/ModelsScreen/styles.ts index d5a846f73..ff9d493d0 100644 --- a/src/screens/ModelsScreen/styles.ts +++ b/src/screens/ModelsScreen/styles.ts @@ -218,6 +218,87 @@ const createTextModelsStyles = (colors: ThemeColors, _shadows: ThemeShadows) => marginBottom: 16, }, recommendedTitle: { ...TYPOGRAPHY.meta, color: colors.textMuted, marginBottom: SPACING.md }, + featuredSection: { marginBottom: SPACING.md }, + featuredSectionLabel: { + ...TYPOGRAPHY.meta, + color: colors.primary, + marginBottom: SPACING.sm, + letterSpacing: 0.5, + }, + featuredCard: { + backgroundColor: colors.surface, + borderRadius: 12, + padding: 12, + marginBottom: 12, + borderWidth: 1, + borderColor: `${colors.primary}50`, + borderLeftWidth: 3, + borderLeftColor: colors.primary, + }, + featuredCardTopRow: { + flexDirection: 'row' as const, + alignItems: 'center' as const, + marginBottom: SPACING.xs, + }, + featuredCardNameGroup: { + flexDirection: 'row' as const, + alignItems: 'center' as const, + flex: 1, + gap: 6, + }, + featuredCardName: { ...TYPOGRAPHY.h3, color: colors.text }, + featuredAuthorTag: { + backgroundColor: colors.surfaceLight, + paddingHorizontal: 6, + paddingVertical: 2, + borderRadius: 6, + }, + featuredAuthorTagText: { ...TYPOGRAPHY.metaSmall, color: colors.textSecondary }, + featuredBadge: { + flexDirection: 'row' as const, + alignItems: 'center' as const, + gap: 3, + backgroundColor: `${colors.primary}18`, + borderRadius: 5, + paddingHorizontal: 5, + paddingVertical: 2, + }, + featuredBadgeText: { ...TYPOGRAPHY.labelSmall, color: colors.primary }, + featuredDescription: { + ...TYPOGRAPHY.bodySmall, + color: colors.textSecondary, + marginBottom: SPACING.sm, + }, + featuredChipsRow: { + flexDirection: 'row' as const, + flexWrap: 'wrap' as const, + gap: 4, + marginBottom: SPACING.sm, + }, + featuredChip: { + backgroundColor: `${colors.primary}12`, + borderRadius: 4, + paddingHorizontal: 6, + paddingVertical: 2, + }, + featuredChipText: { ...TYPOGRAPHY.labelSmall, color: colors.primary }, + featuredFooter: { + flexDirection: 'row' as const, + alignItems: 'center' as const, + justifyContent: 'space-between' as const, + marginTop: SPACING.xs, + }, + featuredHighlight: { ...TYPOGRAPHY.meta, color: colors.textMuted, flex: 1 }, + featuredGetButton: { + flexDirection: 'row' as const, + alignItems: 'center' as const, + gap: 4, + backgroundColor: `${colors.primary}18`, + borderRadius: 8, + paddingHorizontal: SPACING.md, + paddingVertical: SPACING.sm, + }, + featuredGetText: { ...TYPOGRAPHY.bodySmall, color: colors.primary }, }); export const createStyles = (colors: ThemeColors, shadows: ThemeShadows) => ({ From 5b0169c3b3caa76970df5d771f999fba2d4cac08 Mon Sep 17 00:00:00 2001 From: Dishit Date: Fri, 22 May 2026 09:52:05 +0530 Subject: [PATCH 11/93] fix(litert): fix read_url colon-arg parsing, track context fill, and extend reload detection to context length Co-Authored-By: Dishit Karia --- .../services/activeModelService.utils.test.ts | 2 +- .../services/generationServiceHelpers.test.ts | 2 +- .../unit/services/llmSafetyChecks.test.ts | 1 - src/components/ChatInput/index.tsx | 9 + src/components/ContextRing.tsx | 45 ++++ src/screens/ChatScreen/ChatMessageArea.tsx | 14 +- src/screens/ChatScreen/useChatScreen.ts | 5 +- src/services/generationToolLoop.ts | 213 ++++++++++-------- 8 files changed, 194 insertions(+), 97 deletions(-) create mode 100644 src/components/ContextRing.tsx diff --git a/__tests__/unit/services/activeModelService.utils.test.ts b/__tests__/unit/services/activeModelService.utils.test.ts index d6d7ef5da..d46f9a657 100644 --- a/__tests__/unit/services/activeModelService.utils.test.ts +++ b/__tests__/unit/services/activeModelService.utils.test.ts @@ -38,7 +38,7 @@ jest.mock('../../../src/services/hardware', () => ({ import { llmService } from '../../../src/services/llm'; import { liteRTService } from '../../../src/services/litert'; -import { useAppStore } from '../../../src/stores'; +import { useAppStore as _useAppStore } from '../../../src/stores'; const mockedLlm = llmService as jest.Mocked; const mockedLiteRT = liteRTService as jest.Mocked; diff --git a/__tests__/unit/services/generationServiceHelpers.test.ts b/__tests__/unit/services/generationServiceHelpers.test.ts index d283662c3..e2707fe0e 100644 --- a/__tests__/unit/services/generationServiceHelpers.test.ts +++ b/__tests__/unit/services/generationServiceHelpers.test.ts @@ -3,7 +3,7 @@ * Focuses on vision guard and buildGenerationMetaImpl LiteRT branches. */ -import { buildGenerationMetaImpl, FLUSH_INTERVAL_MS } from '../../../src/services/generationServiceHelpers'; +import { buildGenerationMetaImpl, FLUSH_INTERVAL_MS as _FLUSH_INTERVAL_MS } from '../../../src/services/generationServiceHelpers'; jest.mock('../../../src/services/llm', () => ({ llmService: { diff --git a/__tests__/unit/services/llmSafetyChecks.test.ts b/__tests__/unit/services/llmSafetyChecks.test.ts index a094b3377..5e25fc142 100644 --- a/__tests__/unit/services/llmSafetyChecks.test.ts +++ b/__tests__/unit/services/llmSafetyChecks.test.ts @@ -152,7 +152,6 @@ describe('safeCompletion', () => { const mockContext = { clearCache: jest.fn().mockResolvedValue(undefined) }; await expect( safeCompletion(mockContext as any, async () => { - // eslint-disable-next-line @typescript-eslint/no-throw-literal throw 'tensor error string'; }), ).rejects.toThrow('Model inference failed (native error)'); diff --git a/src/components/ChatInput/index.tsx b/src/components/ChatInput/index.tsx index b0abe6c35..0bd89b600 100644 --- a/src/components/ChatInput/index.tsx +++ b/src/components/ChatInput/index.tsx @@ -8,6 +8,7 @@ import { AttachStep } from 'react-native-spotlight-tour'; import { triggerHaptic } from '../../utils/haptics'; import { CustomAlert, showAlert, hideAlert, AlertState, initialAlertState } from '../CustomAlert'; import { createStyles, PILL_ICONS_WIDTH, ANIM_DURATION_IN, ANIM_DURATION_OUT } from './styles'; +import { ContextRing } from '../ContextRing'; import { QueueRow } from './Toolbar'; import { AttachmentPreview, useAttachments } from './Attachments'; import { useVoiceInput } from './Voice'; @@ -35,6 +36,8 @@ interface ChatInputProps { onRepairVision?: () => void; /** When set, mounts a single AttachStep for that index. Only one at a time to avoid waypoint dots. */ activeSpotlight?: number | null; + /** LiteRT context fill level. Renders a small ring indicator inside the pill when non-zero. */ + contextUsage?: { used: number; max: number }; } const IMAGE_MODE_CYCLE: ImageModeState[] = ['auto', 'force', 'disabled']; @@ -61,6 +64,7 @@ export const ChatInput: React.FC = ({ supportsThinking = false, onRepairVision, activeSpotlight = null, + contextUsage, }) => { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); @@ -217,6 +221,11 @@ export const ChatInput: React.FC = ({ blurOnSubmit={false} returnKeyType="default" /> + {contextUsage && contextUsage.used > 0 && contextUsage.max > 0 && ( + + + + )} {/* Icons collapse when user starts typing, reappear when input is empty */} = ({ used, max, size = 16, thickness = 2 }) => { + const { colors } = useTheme(); + + if (!max || !used) return null; + + const pct = Math.min(used / max, 1); + const fillColor = pct < 0.7 ? colors.primary : pct < 0.85 ? AMBER : colors.error; + const emptyColor = colors.border; + + // Each border segment covers one 90-degree arc of the circle. + // Fill order: top (0-25%) → right (25-50%) → bottom (50-75%) → left (75-100%). + const top = pct > 0 ? fillColor : emptyColor; + const right = pct >= 0.25 ? fillColor : emptyColor; + const bottom = pct >= 0.5 ? fillColor : emptyColor; + const left = pct >= 0.75 ? fillColor : emptyColor; + + return ( + + ); +}; diff --git a/src/screens/ChatScreen/ChatMessageArea.tsx b/src/screens/ChatScreen/ChatMessageArea.tsx index c1941e87d..e244a4dad 100644 --- a/src/screens/ChatScreen/ChatMessageArea.tsx +++ b/src/screens/ChatScreen/ChatMessageArea.tsx @@ -1,4 +1,4 @@ -import React, { useState, useMemo } from 'react'; +import React, { useState, useMemo, useEffect, useRef } from 'react'; import { View, FlatList, Text, Keyboard, ActivityIndicator, Platform } from 'react-native'; import Icon from 'react-native-vector-icons/Feather'; import Animated, { FadeIn } from 'react-native-reanimated'; @@ -6,6 +6,7 @@ import { AttachStep } from 'react-native-spotlight-tour'; import { ChatInput, ToolPickerSheet, ThinkingIndicator } from '../../components'; import { AnimatedPressable } from '../../components/AnimatedPressable'; import { generationService } from '../../services'; +import { liteRTService } from '../../services/litert'; import { EmptyChat, ImageProgressIndicator } from './ChatScreenComponents'; import { getPlaceholderText, useChatScreen } from './useChatScreen'; import { createStyles } from './styles'; @@ -30,6 +31,16 @@ export const ChatMessageArea: React.FC = ({ }) => { const tabNav = useNavigation>(); const [inputHeight, setInputHeight] = useState(84); + const [contextUsage, setContextUsage] = useState<{ used: number; max: number } | undefined>(undefined); + const isStreaming = chat.isStreaming || chat.isThinking; + const prevIsStreamingRef = useRef(isStreaming); + useEffect(() => { + if (prevIsStreamingRef.current && !isStreaming) { + const usage = liteRTService.getContextUsage(); + setContextUsage(usage.used > 0 && usage.max > 0 ? usage : undefined); + } + prevIsStreamingRef.current = isStreaming; + }, [isStreaming]); const activeModelRepoId = chat.activeModelId?.split('/').slice(0, 2).join('/'); const handleRepairVision = activeModelRepoId ? () => tabNav.navigate('DownloadManager') @@ -133,6 +144,7 @@ export const ChatMessageArea: React.FC = ({ supportsThinking={chat.supportsThinking} onRepairVision={handleRepairVision} activeSpotlight={chatSpotlight === 12 ? chatSpotlight : null} + contextUsage={contextUsage} /> diff --git a/src/screens/ChatScreen/useChatScreen.ts b/src/screens/ChatScreen/useChatScreen.ts index 5a1f9e76f..ce99b67b1 100644 --- a/src/screens/ChatScreen/useChatScreen.ts +++ b/src/screens/ChatScreen/useChatScreen.ts @@ -232,9 +232,10 @@ export const useChatScreen = () => { // Check if there are pending settings that require model reload const hasPendingSettings = (() => { if (!loadedSettings) return false; - // LiteRT only reloads when the inference backend changes (cpu/gpu/npu) + // LiteRT reloads when backend or context length changes — both are baked into the engine at load time if (activeModel?.engine === 'litert') { - return settings.inferenceBackend !== loadedSettings.inferenceBackend; + return settings.inferenceBackend !== loadedSettings.inferenceBackend || + settings.contextLength !== loadedSettings.contextLength; } return ( settings.nThreads !== loadedSettings.nThreads || diff --git a/src/services/generationToolLoop.ts b/src/services/generationToolLoop.ts index 6b80befda..4bb2fa433 100644 --- a/src/services/generationToolLoop.ts +++ b/src/services/generationToolLoop.ts @@ -1,3 +1,4 @@ +/* eslint-disable max-lines */ /** Tool-calling generation loop. Extracted to keep generationService.ts under the max-lines limit. */ import { llmService } from './llm'; import type { StreamToken } from './llm'; @@ -51,42 +52,65 @@ function parseToolCallBody(body: string, idSuffix: number): ToolCall | null { return parseXmlStyleToolCall(body, idSuffix); } +function parseGemmaToolCallBody(raw: string, toolCalls: ToolCall[]): void { + const nameMatch = raw.match(/^(?:call:)?(\w+)/); + if (!nameMatch) { + logger.warn(`[ToolLoop] Gemma tool call body did not match expected format: "${raw.substring(0, 100)}"`); + return; + } + const name = nameMatch[1]; + const rest = raw.slice(nameMatch[0].length).trim(); + let args: Record = {}; + + // Args may be: ({"key":"val"}) function-call style, {"key":"val"} bare object + const argsStr = rest.match(/^\((\{[\s\S]*\})\)$/)?.[1] ?? rest.match(/^(\{[\s\S]*\})$/)?.[1] ?? null; + if (argsStr) { + try { + // Gemma often emits unquoted keys like {queries:["x"]} — fix to valid JSON + const fixedJson = argsStr.replace(/([{,]\s*)([a-zA-Z_]\w*)(\s*):/g, '$1"$2"$3:'); + args = JSON.parse(fixedJson); + } catch { + logger.warn(`[ToolLoop] Failed to parse Gemma tool args: ${argsStr.substring(0, 100)}`); + } + } else if (rest.startsWith(':')) { + // Colon-separated key:value format — e.g. read_url emits :url:https://... + const colonArgs = rest.slice(1); + const firstColon = colonArgs.indexOf(':'); + if (firstColon !== -1) { + const key = colonArgs.slice(0, firstColon); + const value = colonArgs.slice(firstColon + 1).trim(); + args = { [key]: value }; + } + } + + // Normalize Gemma's "queries" array to the single "query" string our tools expect + if (name === 'web_search' && !args.query && args.queries) { + args = { ...args, query: Array.isArray(args.queries) ? args.queries[0] : args.queries }; + } + toolCalls.push({ id: `gemma-tc-${Date.now()}-${toolCalls.length}`, name, arguments: args }); + logger.log(`[ToolLoop] Parsed Gemma native tool call: ${name}(${JSON.stringify(args).substring(0, 100)})`); +} + /** Parse Gemma 4's native tool call format: <|tool_call>call:NAME{...} and */ function parseGemmaNativeToolCalls(text: string): { cleanText: string; toolCalls: ToolCall[] } { const toolCalls: ToolCall[] = []; - // Matches <|tool_call>..., , and const pattern = /(?:<\|tool_call>||<\/tool_call>)/g; let match; const matchedRanges: [number, number][] = []; while ((match = pattern.exec(text)) !== null) { matchedRanges.push([match.index, match.index + match[0].length]); - const raw = match[1].trim().replace(/<\|"\|>/g, '"'); - // Extract tool name — handles: call:NAME, NAME, call:NAME({...}), call:NAME{...} - const nameMatch = raw.match(/^(?:call:)?(\w+)/); - if (nameMatch) { - const name = nameMatch[1]; - const rest = raw.slice(nameMatch[0].length).trim(); - let args: Record = {}; - // Args may be: ({"key":"val"}) function-call style, or {"key":"val"} bare object - const argsStr = rest.match(/^\((\{[\s\S]*\})\)$/)?.[1] ?? rest.match(/^(\{[\s\S]*\})$/)?.[1] ?? null; - if (argsStr) { - try { - // Gemma often emits unquoted keys like {queries:["x"]} — fix to valid JSON - const fixedJson = argsStr.replace(/([{,]\s*)([a-zA-Z_]\w*)(\s*):/g, '$1"$2"$3:'); - args = JSON.parse(fixedJson); - } catch { - logger.warn(`[ToolLoop] Failed to parse Gemma tool args: ${argsStr.substring(0, 100)}`); - } - } - // Normalize Gemma's "queries" array to the single "query" string our tools expect - if (name === 'web_search' && !args.query && args.queries) { - args = { ...args, query: Array.isArray(args.queries) ? args.queries[0] : args.queries }; + parseGemmaToolCallBody(match[1].trim().replace(/<\|"\|>/g, '"'), toolCalls); + } + + // Fallback: unclosed <|tool_call> at end of text (model hit EOS without closing tag) + if (toolCalls.length === 0) { + const unclosedMatch = /(?:<\|tool_call>|/g, '"'), toolCalls); + if (toolCalls.length > 0) { + matchedRanges.push([unclosedMatch.index, text.length]); } - toolCalls.push({ id: `gemma-tc-${Date.now()}-${toolCalls.length}`, name, arguments: args }); - logger.log(`[ToolLoop] Parsed Gemma native tool call: ${name}(${JSON.stringify(args).substring(0, 100)})`); - } else { - logger.warn(`[ToolLoop] Gemma tool call body did not match expected format: "${raw.substring(0, 100)}"`); } } @@ -195,13 +219,9 @@ async function callRemoteLLMWithTools( if (!provider) throw new Error('Remote provider not found'); const settings = useAppStore.getState().settings; const thinkingEnabled = !opts?.disableThinking && settings.thinkingEnabled && provider.capabilities.supportsThinking; - const options: GenerationOptions = { - temperature: settings.temperature, maxTokens: settings.maxTokens, topP: settings.topP, - tools, enableThinking: thinkingEnabled, - }; + const options: GenerationOptions = { temperature: settings.temperature, maxTokens: settings.maxTokens, topP: settings.topP, tools, enableThinking: thinkingEnabled }; logger.log(`[ToolLoop] callRemoteLLM — server=${activeServerId}, tools=${tools.length}, thinking=${thinkingEnabled}`); - let _fullContent = ''; - let toolCalls: ToolCall[] = []; + let _fullContent = '', toolCalls: ToolCall[] = []; const onStream = opts?.onStream; return new Promise((resolve, reject) => { provider.generate(messages, options, { @@ -256,8 +276,7 @@ async function callLocalWithRetry( function isLiteRTActive(): boolean { const { downloadedModels, activeModelId } = useAppStore.getState(); - const m = downloadedModels.find((model: any) => model.id === activeModelId); - return m?.engine === 'litert' && liteRTService.isModelLoaded(); + return downloadedModels.find((m: any) => m.id === activeModelId)?.engine === 'litert' && liteRTService.isModelLoaded(); } /** On first iteration: last user message. On tool-result iterations: formatted tool results. */ @@ -280,82 +299,101 @@ function buildLiteRTSendText(messages: Message[]): string { return ''; } -function buildLiteRTToolSystemPrompt(basePrompt: string, tools: any[]): string { - if (tools.length === 0) return basePrompt; - const toolList = tools.map((t: any) => { - const fn = t.function || t; - const props = fn.parameters?.properties || {}; - const paramStr = Object.entries(props) - .map(([k, v]: [string, any]) => `"${k}": ${JSON.stringify((v as any).description || (v as any).type || k)}`) - .join(', '); - return ` ${fn.name}({${paramStr}}): ${fn.description}`; - }).join('\n'); - return `${basePrompt} - -TOOL USE RULES — follow strictly: -- To call a tool output ONLY this on its own line, nothing else before or after: -{"name": "tool_name", "arguments": {"param": "value"}} -- Call tools IMMEDIATELY. Never say "I'll search" or "let me look that up" — just call the tool. -- NEVER ask for clarification before calling a tool. Make your best guess at the arguments and call it. -- NEVER say "I cannot browse", "I don't have access", or "I need more information" — you have tools, use them. -- If a query is ambiguous, pick the most likely interpretation and call the tool right away. -- Only respond in plain text when no tool is needed (e.g. simple math in your head, greetings). -Tools:\n${toolList}`; +function buildLiteRTHistory(messages: Message[]): Array<{ role: 'user' | 'assistant'; content: string }> { + let lastUserIdx = -1; + for (let i = messages.length - 1; i >= 0; i--) { if (messages[i].role === 'user') { lastUserIdx = i; break; } } + if (lastUserIdx <= 0) return []; + return messages.slice(0, lastUserIdx) + .filter(m => (m.role === 'user' || m.role === 'assistant') && m.content) + .map(m => ({ role: m.role as 'user' | 'assistant', content: typeof m.content === 'string' ? m.content : '' })) + .filter(h => h.content.trim() !== ''); +} + +function buildLiteRTToolCallHandler(ctx: ToolLoopContext, conversationId: string) { + return async (name: string, args: Record): Promise => { + if (ctx.isAborted()) return 'Aborted'; + logger.log(`[ToolLoop][LiteRT] native tool call — name=${name}, args=${JSON.stringify(args).substring(0, 200)}`); + ctx.callbacks?.onToolCallStart?.(name, args as Record); + const toolCall: ToolCall = { id: `native-tc-${Date.now()}`, name, arguments: args as Record }; + if (ctx.projectId) (toolCall as any).context = { projectId: ctx.projectId }; + const result = await executeToolCall(toolCall); + ctx.callbacks?.onToolCallComplete?.(name, result); + const resultContent = result.error ? `Error: ${result.error}` : result.content; + const toolCallMsg: Message = { id: `tc-${Date.now()}-${name}`, role: 'assistant', content: '', + toolCalls: [{ id: toolCall.id, name, arguments: JSON.stringify(toolCall.arguments) }], timestamp: Date.now() }; + const toolResultMsg: Message = { id: `tr-${Date.now()}-${name}`, role: 'tool', content: resultContent, + toolCallId: toolCall.id, toolName: name, timestamp: Date.now() }; + useChatStore.getState().addMessage(conversationId, toolCallMsg); + useChatStore.getState().addMessage(conversationId, toolResultMsg); + logger.log(`[ToolLoop][LiteRT] tool ${name} completed — resultLen=${resultContent.length}, first200="${resultContent.substring(0, 200)}"`); + return resultContent; + }; } async function callLiteRTForLoop( conversationId: string, messages: Message[], - tools: any[], - onStream?: (data: StreamToken) => void, + opts: { tools: any[]; onStream?: (data: StreamToken) => void; ctx?: ToolLoopContext }, ): Promise<{ fullResponse: string; toolCalls: ToolCall[] }> { + const { tools, onStream, ctx } = opts; const systemMsg = messages.find(m => m.role === 'system'); - const basePrompt = typeof systemMsg?.content === 'string' ? systemMsg.content : ''; - const systemPrompt = buildLiteRTToolSystemPrompt(basePrompt, tools); + const systemPrompt = typeof systemMsg?.content === 'string' ? systemMsg.content : ''; const text = buildLiteRTSendText(messages); - - logger.log(`[ToolLoop][LiteRT] callLiteRTForLoop — convId=${conversationId}, text=${text.length}ch, sysPrompt=${systemPrompt.length}ch`); - logger.log(`[ToolLoop][LiteRT] sysPrompt content: "${systemPrompt.substring(0, 500)}"`); + const history = buildLiteRTHistory(messages); + logger.log(`[ToolLoop][LiteRT] callLiteRTForLoop — convId=${conversationId}, text=${text.length}ch, sysPrompt=${systemPrompt.length}ch, tools=${tools.length}, history=${history.length}`); + logger.log(`[ToolLoop][LiteRT] sysPrompt first500: "${systemPrompt.substring(0, 500)}"`); logger.log(`[ToolLoop][LiteRT] sending text: "${text.substring(0, 300)}"`); - if (!text) { logger.warn('[ToolLoop][LiteRT] no message text — aborting'); return { fullResponse: '', toolCalls: [] }; } - - // prepareConversation is idempotent — only resets native conversation when context changes - await liteRTService.prepareConversation(conversationId, systemPrompt); - - const fullResponse = await liteRTService.generateRaw(text, token => { - onStream?.({ content: token }); - }); - + await liteRTService.prepareConversation(conversationId, systemPrompt, { tools, history }); + const onToolCall = ctx ? buildLiteRTToolCallHandler(ctx, conversationId) : undefined; + const fullResponse = await liteRTService.generateRaw( + text, + token => onStream?.({ content: token }), + onToolCall, + ); logger.log(`[ToolLoop][LiteRT] raw response (${fullResponse.length}ch): "${fullResponse.substring(0, 400)}"`); - logger.log(`[ToolLoop][LiteRT] contains : ${fullResponse.includes('')} | contains <|tool_call>: ${fullResponse.includes('<|tool_call>')} | contains void; forceRemote?: boolean; disableThinking?: boolean; conversationId?: string; } +const TOOL_BEHAVIOR_GUIDANCE = '\n\nMake good use of the tools available to you. If you are uncertain or lack current information, use the appropriate tool rather than guessing. Never refuse or say you cannot help when a tool is available. For multiple distinct items, make a separate tool call for each. Call tools silently — do not announce them first.'; + +function augmentSystemPromptForTools(messages: Message[]): Message[] { + const sysIdx = messages.findIndex(m => m.role === 'system'); + if (sysIdx === -1) return messages; + const sys = messages[sysIdx]; + const existing = typeof sys.content === 'string' ? sys.content : ''; + const updated = { ...sys, content: existing + TOOL_BEHAVIOR_GUIDANCE }; + return [...messages.slice(0, sysIdx), updated, ...messages.slice(sysIdx + 1)]; +} + +interface CallLLMOptions { onStream?: (data: StreamToken) => void; forceRemote?: boolean; disableThinking?: boolean; conversationId?: string; ctx?: ToolLoopContext; } /** Call LLM with retry+backoff for transient native context errors. */ async function callLLMWithRetry( messages: Message[], tools: any[], - { onStream, forceRemote, disableThinking, conversationId }: CallLLMOptions = {}, + { onStream, forceRemote, disableThinking, conversationId, ctx }: CallLLMOptions = {}, ): Promise<{ fullResponse: string; toolCalls: ToolCall[] }> { + // Append tool-use behavioral guidance to the system prompt when tools are present. + // Only covers the "when and how" — schemas are injected separately by each engine. + // We shallow-copy messages to avoid mutating the caller's array. + const augmentedMessages = tools.length > 0 ? augmentSystemPromptForTools(messages) : messages; + if (isLiteRTActive() && conversationId) { - return callLiteRTForLoop(conversationId, messages, tools, onStream); + return callLiteRTForLoop(conversationId, augmentedMessages, { tools, onStream, ctx }); } const activeServerId = useRemoteServerStore.getState().activeServerId; const useRemote = forceRemote || (!!activeServerId && providerRegistry.hasProvider(activeServerId) && !llmService.isModelLoaded()); logger.log(`[ToolLoop] callLLM — remote=${useRemote}, tools=${tools.length}`); if (useRemote) { - try { return await callRemoteLLMWithTools(messages, tools, { onStream, disableThinking }); } + try { return await callRemoteLLMWithTools(augmentedMessages, tools, { onStream, disableThinking }); } catch (e: any) { throw new Error(e?.message || String(e) || 'Remote LLM error'); } } - // disableThinking is not forwarded to local — local llama.rn controls thinking - // internally and doesn't count thinking tokens against num_predict. - return callLocalWithRetry(messages, tools, onStream); + return callLocalWithRetry(augmentedMessages, tools, onStream); } /** If no structured tool calls, try parsing tags or Gemma's native format from text. */ @@ -379,10 +417,8 @@ function resolveToolCalls(fullResponse: string, toolCalls: ToolCall[]) { } interface ToolLoopState { - firstTokenFired: boolean; - thinkingDoneFired: boolean; - streamedContent: string; - reasoningContent: string; + firstTokenFired: boolean; thinkingDoneFired: boolean; + streamedContent: string; reasoningContent: string; } function buildStreamHandler(ctx: ToolLoopContext, state: ToolLoopState): ((data: StreamChunk) => void) | undefined { @@ -406,7 +442,6 @@ function emitFinalResponse(ctx: ToolLoopContext, state: ToolLoopState, displayRe if (state.streamedContent) { logger.log(`[ToolLoop][DEBUG] emitFinalResponse — already streamed (${state.streamedContent.length} chars), skipping`); } else { - // Guard: only fire onThinkingDone/onFirstToken if not already fired (e.g. by reasoning-only first call) if (!state.thinkingDoneFired) { ctx.onThinkingDone(); ctx.callbacks?.onFirstToken?.(); @@ -422,8 +457,7 @@ async function forceFinalTextResponse(ctx: ToolLoopContext, state: ToolLoopState state.reasoningContent = ''; state.firstTokenFired = false; const forcedOnStream = buildStreamHandler(ctx, state); - // Disable thinking so the model spends all tokens on actual content - const { fullResponse: forcedResponse } = await callLLMWithRetry(loopMessages, [], { onStream: forcedOnStream, forceRemote: ctx.forceRemote, disableThinking: true, conversationId: ctx.conversationId }); + const { fullResponse: forcedResponse } = await callLLMWithRetry(loopMessages, [], { onStream: forcedOnStream, forceRemote: ctx.forceRemote, disableThinking: true, conversationId: ctx.conversationId, ctx }); logger.log(`[ToolLoop][DEBUG] Forced response — length=${forcedResponse.length}, streamedContent=${state.streamedContent.length}, reasoning=${state.reasoningContent.length}`); emitFinalResponse(ctx, state, forcedResponse); } @@ -438,9 +472,7 @@ export async function runToolLoop(ctx: ToolLoopContext): Promise { const loopMessages = [...ctx.messages]; let totalToolCalls = 0; const state: ToolLoopState = { firstTokenFired: false, thinkingDoneFired: false, streamedContent: '', reasoningContent: '' }; - logger.log(`[ToolLoop][DEBUG] === runToolLoop START === enabledToolIds=[${ctx.enabledToolIds.join(', ')}], toolSchemas=${toolSchemas.length}, messages=${ctx.messages.length}, forceRemote=${ctx.forceRemote}`); - for (let iteration = 0; iteration < MAX_TOOL_ITERATIONS; iteration++) { if (ctx.isAborted()) { logger.log(`[ToolLoop][DEBUG] Aborted at iteration ${iteration}`); @@ -458,7 +490,7 @@ export async function runToolLoop(ctx: ToolLoopContext): Promise { logger.log(`[ToolLoop][DEBUG] === Iteration ${iteration} === messages=${loopMessages.length}, tools=${toolSchemas.length}, totalCalls=${totalToolCalls}`); const onStream = buildStreamHandler(ctx, state); - const { fullResponse, toolCalls } = await callLLMWithRetry(loopMessages, toolSchemas, { onStream, forceRemote: ctx.forceRemote, conversationId: ctx.conversationId }); + const { fullResponse, toolCalls } = await callLLMWithRetry(loopMessages, toolSchemas, { onStream, forceRemote: ctx.forceRemote, conversationId: ctx.conversationId, ctx }); logger.log(`[ToolLoop][DEBUG] LLM returned — response=${fullResponse.length}, toolCalls=${toolCalls.length}, streamed=${state.streamedContent.length}, reasoning=${state.reasoningContent.length}`); if (fullResponse.length === 0 && state.streamedContent.length === 0) { @@ -467,7 +499,6 @@ export async function runToolLoop(ctx: ToolLoopContext): Promise { const { effectiveToolCalls, displayResponse } = resolveToolCalls(fullResponse, toolCalls); const cappedToolCalls = effectiveToolCalls.slice(0, MAX_TOTAL_TOOL_CALLS - totalToolCalls); totalToolCalls += cappedToolCalls.length; - logger.log(`[ToolLoop][DEBUG] After resolve — toolCalls=${cappedToolCalls.length}, displayResponse=${displayResponse.length}`); // No tool calls → model gave a final text response if (cappedToolCalls.length === 0) { @@ -479,7 +510,7 @@ export async function runToolLoop(ctx: ToolLoopContext): Promise { state.firstTokenFired = false; const fallbackOnStream = buildStreamHandler(ctx, state); const { fullResponse: fallbackResp } = await callLLMWithRetry( - loopMessages, [], { onStream: fallbackOnStream, forceRemote: ctx.forceRemote, disableThinking: true, conversationId: ctx.conversationId }, + loopMessages, [], { onStream: fallbackOnStream, forceRemote: ctx.forceRemote, disableThinking: true, conversationId: ctx.conversationId, ctx }, ); emitFinalResponse(ctx, state, fallbackResp); return; From cad392034a733429c5c0dc7c013e659f33ca1a65 Mon Sep 17 00:00:00 2001 From: Dishit Date: Fri, 22 May 2026 10:16:05 +0530 Subject: [PATCH 12/93] fix: resolve lint and type errors blocking push --- .../unit/services/providers/openAICompatibleStream.test.ts | 4 +++- src/components/ChatInput/index.tsx | 2 +- src/components/ChatInput/styles.ts | 5 +++++ src/screens/ModelsScreen/TextModelsTab.tsx | 2 +- src/services/activeModelService/loaders.ts | 2 +- src/services/litert.ts | 4 ++-- src/stores/appStore.ts | 2 +- 7 files changed, 14 insertions(+), 7 deletions(-) diff --git a/__tests__/unit/services/providers/openAICompatibleStream.test.ts b/__tests__/unit/services/providers/openAICompatibleStream.test.ts index e9245026d..dd3d753da 100644 --- a/__tests__/unit/services/providers/openAICompatibleStream.test.ts +++ b/__tests__/unit/services/providers/openAICompatibleStream.test.ts @@ -24,6 +24,8 @@ function makeState(overrides: Partial = {}): OpenAIStreamStat fullReasoningContent: '', toolCalls: [], currentToolCall: null, + completeCalled: false, + streamErrorOccurred: false, ...overrides, }; } @@ -162,7 +164,7 @@ describe('processDelta', () => { processDelta({ tool_calls: [{ function: { arguments: ':"NY"}' } }], }, state, { thinkingEnabled: true, callbacks, thinkTagParser }); - expect(state.currentToolCall!.function.arguments).toBe('{"city":"NY"}'); + expect(state.currentToolCall!.function!.arguments).toBe('{"city":"NY"}'); }); it('suppresses think-tag reasoning when thinkingEnabled=false', () => { diff --git a/src/components/ChatInput/index.tsx b/src/components/ChatInput/index.tsx index 0bd89b600..270e487b7 100644 --- a/src/components/ChatInput/index.tsx +++ b/src/components/ChatInput/index.tsx @@ -222,7 +222,7 @@ export const ChatInput: React.FC = ({ returnKeyType="default" /> {contextUsage && contextUsage.used > 0 && contextUsage.max > 0 && ( - + )} diff --git a/src/components/ChatInput/styles.ts b/src/components/ChatInput/styles.ts index a9f8df69c..c1cbbc36d 100644 --- a/src/components/ChatInput/styles.ts +++ b/src/components/ChatInput/styles.ts @@ -208,4 +208,9 @@ export const createStyles = (colors: ThemeColors, _shadows: ThemeShadows) => ({ fontWeight: '500' as const, color: colors.primary, }, + contextRingWrapper: { + paddingHorizontal: 6, + justifyContent: 'center' as const, + alignItems: 'center' as const, + }, }); diff --git a/src/screens/ModelsScreen/TextModelsTab.tsx b/src/screens/ModelsScreen/TextModelsTab.tsx index 2586a30c0..b1796a8d2 100644 --- a/src/screens/ModelsScreen/TextModelsTab.tsx +++ b/src/screens/ModelsScreen/TextModelsTab.tsx @@ -211,7 +211,7 @@ const LITERT_FEATURED = [ const FeaturedLiteRTCard: React.FC<{ model: typeof LITERT_FEATURED[number]; styles: any; colors: any }> = ({ model, styles, colors }) => ( Linking.openURL(model.url)} activeOpacity={0.85} > diff --git a/src/services/activeModelService/loaders.ts b/src/services/activeModelService/loaders.ts index 97e83ee15..cd4023cfc 100644 --- a/src/services/activeModelService/loaders.ts +++ b/src/services/activeModelService/loaders.ts @@ -144,7 +144,7 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { try { addDebugLog('log', `[LiteRT] Calling liteRTService.loadModel (timeout ${timeoutMs / 1000}s, vision=${ctx.model.liteRTVision ?? false}, maxNumTokens=${maxTokens}).`); await Promise.race([ - liteRTService.loadModel(ctx.model.filePath, preferredBackend, ctx.model.liteRTVision ?? false, maxTokens), + liteRTService.loadModel(ctx.model.filePath, preferredBackend, { supportsVision: ctx.model.liteRTVision ?? false, maxNumTokens: maxTokens }), timeoutPromise, ]); } finally { diff --git a/src/services/litert.ts b/src/services/litert.ts index ab1a84859..e5cecacd4 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -85,11 +85,11 @@ class LiteRTService { // loadModel // --------------------------------------------------------------------------- - async loadModel(modelPath: string, preferredBackend: LiteRTBackend, supportsVision = false, maxNumTokens = 4096): Promise { + async loadModel(modelPath: string, preferredBackend: LiteRTBackend, opts: { supportsVision?: boolean; maxNumTokens?: number } = {}): Promise { if (!this.isAvailable()) { throw new Error('LiteRT is not available on this platform'); } - + const { supportsVision = false, maxNumTokens = 4096 } = opts; this.configuredMaxTokens = maxNumTokens; logger.log(TAG, `loadModel — path=${modelPath} backend=${preferredBackend} supportsVision=${supportsVision} maxNumTokens=${maxNumTokens}`); diff --git a/src/stores/appStore.ts b/src/stores/appStore.ts index dd04aaa54..5e5e3054f 100644 --- a/src/stores/appStore.ts +++ b/src/stores/appStore.ts @@ -173,7 +173,7 @@ function migratePersistedState(persistedState: any, currentState: AppState): App } export const selectIsLiteRT = (state: AppState): boolean => - state.downloadedModels.find(m => m.id === state.activeModelId)?.engine === 'litert' ?? false; + state.downloadedModels.find(m => m.id === state.activeModelId)?.engine === 'litert'; export const useAppStore = create()( persist( From fde2aba7beb7a8ce5169b8b35969c39610b18ae3 Mon Sep 17 00:00:00 2001 From: Dishit Date: Fri, 22 May 2026 12:35:50 +0530 Subject: [PATCH 13/93] fix(chat): scroll to last message when keyboard opens Tapping the input shrank the FlatList viewport without repositioning the scroll, leaving the last AI message hidden behind the keyboard. Track height changes via onLayout and scroll to end when the viewport shrinks. Add a keyboardWillShow/keyboardDidShow listener as a secondary trigger for iOS. Co-Authored-By: Dishit Karia hanmadishit74@gmail.com --- src/screens/ChatScreen/ChatMessageArea.tsx | 10 +++++++++- src/screens/ChatScreen/index.tsx | 13 +++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/screens/ChatScreen/ChatMessageArea.tsx b/src/screens/ChatScreen/ChatMessageArea.tsx index e244a4dad..c59c8a44d 100644 --- a/src/screens/ChatScreen/ChatMessageArea.tsx +++ b/src/screens/ChatScreen/ChatMessageArea.tsx @@ -31,6 +31,7 @@ export const ChatMessageArea: React.FC = ({ }) => { const tabNav = useNavigation>(); const [inputHeight, setInputHeight] = useState(84); + const flatListHeightRef = useRef(0); const [contextUsage, setContextUsage] = useState<{ used: number; max: number } | undefined>(undefined); const isStreaming = chat.isStreaming || chat.isThinking; const prevIsStreamingRef = useRef(isStreaming); @@ -69,7 +70,14 @@ export const ChatMessageArea: React.FC = ({ contentContainerStyle={styles.messageList} onScroll={handleScroll} onContentSizeChange={(_w, _h) => { if (isNearBottomRef.current) flatListRef.current?.scrollToEnd({ animated: false }); }} - onLayout={() => { }} + onLayout={(e) => { + const newHeight = e.nativeEvent.layout.height; + const prevHeight = flatListHeightRef.current; + flatListHeightRef.current = newHeight; + if (prevHeight > 0 && newHeight < prevHeight) { + setTimeout(() => flatListRef.current?.scrollToEnd({ animated: true }), 50); + } + }} scrollEventThrottle={16} keyboardDismissMode="on-drag" keyboardShouldPersistTaps="handled" diff --git a/src/screens/ChatScreen/index.tsx b/src/screens/ChatScreen/index.tsx index c2c9d6c92..645e290e0 100644 --- a/src/screens/ChatScreen/index.tsx +++ b/src/screens/ChatScreen/index.tsx @@ -1,5 +1,5 @@ import React, { useCallback, useEffect, useRef, useState } from 'react'; -import { FlatList, KeyboardAvoidingView, InteractionManager } from 'react-native'; +import { FlatList, Keyboard, KeyboardAvoidingView, InteractionManager, Platform } from 'react-native'; import { SafeAreaView } from 'react-native-safe-area-context'; import { useFocusEffect } from '@react-navigation/native'; import { useSpotlightTour } from 'react-native-spotlight-tour'; @@ -103,6 +103,14 @@ export const ChatScreen: React.FC = () => { setTimeout(() => { flatListRef.current?.scrollToEnd({ animated: true }); }, 100); } }, [chat.activeConversation?.messages.length]); + + React.useEffect(() => { + const event = Platform.OS === 'ios' ? 'keyboardWillShow' : 'keyboardDidShow'; + const sub = Keyboard.addListener(event, () => { + flatListRef.current?.scrollToEnd({ animated: true }); + }); + return () => sub.remove(); + }, []); const alertEl = ( { const handleScroll = (event: any) => { const { contentOffset, contentSize, layoutMeasurement } = event.nativeEvent; - isNearBottomRef.current = contentSize.height - layoutMeasurement.height - contentOffset.y < 100; + const distFromBottom = contentSize.height - layoutMeasurement.height - contentOffset.y; + isNearBottomRef.current = distFromBottom < 100; chat.setShowScrollToBottom(!isNearBottomRef.current); }; From 4bbf0424820559c0c6bfa0122dfaeff906b00b7e Mon Sep 17 00:00:00 2001 From: Dishit Date: Fri, 22 May 2026 12:38:11 +0530 Subject: [PATCH 14/93] fix(litert): fix tool call parsing, pass sampler config, unify init timeout - Fix Gemma tool call parsing to handle the "tool_name{json}" body pattern alongside the existing key:value format; add key validation so non-word strings are not treated as argument keys - Pass temperature/topK/topP through prepareConversation in the tool loop so generation settings are respected during tool-call turns - Unify model init timeout to 90s across all backends (was 45/20/15s) to prevent premature timeout failures on slower devices - Add debugLog helper in LiteRTModule that emits litert_debug_log events to the in-app debug screen alongside logcat Co-Authored-By: Dishit Karia --- .../ai/offgridmobile/litert/LiteRTModule.kt | 20 +++++++--- src/services/activeModelService/loaders.ts | 6 +-- src/services/generationToolLoop.ts | 37 +++++++++++++++---- src/services/litert.ts | 5 +++ 4 files changed, 53 insertions(+), 15 deletions(-) diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt index 11048bdc3..b7475f8dc 100644 --- a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -43,13 +43,14 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : const val EVENT_COMPLETE = "litert_complete" const val EVENT_ERROR = "litert_error" const val EVENT_TOOL_CALL = "litert_tool_call" + const val EVENT_DEBUG_LOG = "litert_debug_log" // Base timeouts per backend tier (for default 4096-token context). // Actual timeout scales up proportionally for larger context windows // because KV-cache allocation takes longer at higher token counts. - private const val NPU_BASE_TIMEOUT_MS = 45_000L - private const val GPU_BASE_TIMEOUT_MS = 20_000L - private const val CPU_BASE_TIMEOUT_MS = 15_000L + private const val NPU_BASE_TIMEOUT_MS = 90_000L + private const val GPU_BASE_TIMEOUT_MS = 90_000L + private const val CPU_BASE_TIMEOUT_MS = 90_000L private const val DEFAULT_CONTEXT_TOKENS = 4096 fun initTimeoutMs(backend: Backend, maxNumTokens: Int): Long { @@ -140,6 +141,7 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : Log.i(TAG, "initializeWithFallback — trying $name vision=$visionEnabled") } try { + debugLog("EngineConfig — backend=$name maxNumTokens=$configuredMaxTokens vision=$visionEnabled") val cfg = EngineConfig( modelPath = modelPath, backend = backend, @@ -149,11 +151,12 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : ) val eng = Engine(cfg) val timeoutMs = initTimeoutMs(backend, configuredMaxTokens) + debugLog("Engine.initialize — backend=$name timeoutMs=${timeoutMs / 1000}s") withTimeout(timeoutMs) { eng.initialize() } engine = eng - Log.i(TAG, "initializeWithFallback — $name succeeded (attempt $attempt)") + debugLog("Engine.initialize — $name succeeded (attempt $attempt)") succeeded = true return backend } catch (e: Exception) { @@ -195,9 +198,10 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : // SamplerConfig is not supported on NPU val samplerConfig = if (activeBackend == "npu") { - Log.i(TAG, "resetConversation — NPU backend, skipping SamplerConfig") + debugLog("SamplerConfig — skipped (NPU backend does not support it)") null } else { + debugLog("SamplerConfig — temperature=$temperature topK=$topK topP=$topP") SamplerConfig( topK = topK, topP = topP, @@ -207,6 +211,7 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : val toolProviders = buildToolProviders(toolsJson) val initialMessages = parseHistoryMessages(historyJson) + debugLog("ConversationConfig — tools=${toolProviders.size} history=${initialMessages.size} hasSamplerConfig=${samplerConfig != null} autoToolCalling=${toolProviders.isNotEmpty()}") val convConfig = ConversationConfig( systemInstruction = if (systemPrompt.isNotEmpty()) Contents.of(systemPrompt) else null, @@ -598,6 +603,11 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : } } + private fun debugLog(msg: String) { + Log.i(TAG, msg) + sendEvent(EVENT_DEBUG_LOG, msg) + } + override fun onCatalystInstanceDestroy() { Log.i(TAG, "onCatalystInstanceDestroy — cleaning up") scope.cancel() diff --git a/src/services/activeModelService/loaders.ts b/src/services/activeModelService/loaders.ts index cd4023cfc..6016b0b1c 100644 --- a/src/services/activeModelService/loaders.ts +++ b/src/services/activeModelService/loaders.ts @@ -128,9 +128,7 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { const maxTokens = ctx.store.settings.contextLength ?? 4096; const contextScalar = Math.max(1, maxTokens / 4096); - const baseTimeoutMs = preferredBackend === 'npu' ? 45_000 - : preferredBackend === 'gpu' ? 20_000 - : 15_000; + const baseTimeoutMs = 90_000; const timeoutMs = Math.min(Math.ceil(baseTimeoutMs * contextScalar), 180_000); let timeoutId: ReturnType | null = null; @@ -152,6 +150,8 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { } const actualBackend = liteRTService.getActiveBackend(); + const s = ctx.store.settings; + addDebugLog('log', `[LiteRT] Engine init — backend=${actualBackend} maxTokens=${maxTokens} temperature=${s.temperature} topP=${s.topP} topK=40`); addDebugLog('log', `[LiteRT] Load complete — actual backend: ${actualBackend}`); if (actualBackend !== preferredBackend) { addDebugLog('warn', `[LiteRT] Requested ${preferredBackend}, fell back to ${actualBackend}`); diff --git a/src/services/generationToolLoop.ts b/src/services/generationToolLoop.ts index 4bb2fa433..6c0a34dc8 100644 --- a/src/services/generationToolLoop.ts +++ b/src/services/generationToolLoop.ts @@ -73,13 +73,29 @@ function parseGemmaToolCallBody(raw: string, toolCalls: ToolCall[]): void { logger.warn(`[ToolLoop] Failed to parse Gemma tool args: ${argsStr.substring(0, 100)}`); } } else if (rest.startsWith(':')) { - // Colon-separated key:value format — e.g. read_url emits :url:https://... const colonArgs = rest.slice(1); - const firstColon = colonArgs.indexOf(':'); - if (firstColon !== -1) { - const key = colonArgs.slice(0, firstColon); - const value = colonArgs.slice(firstColon + 1).trim(); - args = { [key]: value }; + + // Pattern: repeated tool name + JSON body — e.g. "read_url{url: "https://..."}" + if (colonArgs.startsWith(name)) { + const jsonBody = colonArgs.slice(name.length).trim(); + if (jsonBody.startsWith('{')) { + try { + const fixedJson = jsonBody.replace(/([{,]\s*)([a-zA-Z_]\w*)(\s*):/g, '$1"$2"$3:'); + args = JSON.parse(fixedJson); + } catch { /* fall through */ } + } + } + + // Pattern: simple key:value — e.g. "url:https://..." + if (Object.keys(args).length === 0) { + const firstColon = colonArgs.indexOf(':'); + if (firstColon !== -1) { + const key = colonArgs.slice(0, firstColon).trim(); + if (/^\w+$/.test(key)) { + const value = colonArgs.slice(firstColon + 1).trim(); + args = { [key]: value }; + } + } } } @@ -340,14 +356,21 @@ async function callLiteRTForLoop( const systemPrompt = typeof systemMsg?.content === 'string' ? systemMsg.content : ''; const text = buildLiteRTSendText(messages); const history = buildLiteRTHistory(messages); + const liteRTSettings = useAppStore.getState().settings; + const samplerConfig = { + temperature: liteRTSettings.temperature, + topK: 40, + topP: liteRTSettings.topP, + }; logger.log(`[ToolLoop][LiteRT] callLiteRTForLoop — convId=${conversationId}, text=${text.length}ch, sysPrompt=${systemPrompt.length}ch, tools=${tools.length}, history=${history.length}`); + logger.log(`[ToolLoop][LiteRT] samplerConfig — temperature=${samplerConfig.temperature} topK=${samplerConfig.topK} topP=${samplerConfig.topP}`); logger.log(`[ToolLoop][LiteRT] sysPrompt first500: "${systemPrompt.substring(0, 500)}"`); logger.log(`[ToolLoop][LiteRT] sending text: "${text.substring(0, 300)}"`); if (!text) { logger.warn('[ToolLoop][LiteRT] no message text — aborting'); return { fullResponse: '', toolCalls: [] }; } - await liteRTService.prepareConversation(conversationId, systemPrompt, { tools, history }); + await liteRTService.prepareConversation(conversationId, systemPrompt, { samplerConfig, tools, history }); const onToolCall = ctx ? buildLiteRTToolCallHandler(ctx, conversationId) : undefined; const fullResponse = await liteRTService.generateRaw( text, diff --git a/src/services/litert.ts b/src/services/litert.ts index e5cecacd4..0654e1d29 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -23,6 +23,7 @@ const EVENT_THINKING = 'litert_thinking'; const EVENT_COMPLETE = 'litert_complete'; const EVENT_ERROR = 'litert_error'; const EVENT_TOOL_CALL = 'litert_tool_call'; +const EVENT_DEBUG_LOG = 'litert_debug_log'; export type LiteRTBackend = 'cpu' | 'gpu' | 'npu'; @@ -75,6 +76,9 @@ class LiteRTService { constructor() { if (Platform.OS === 'android' && LiteRTModule) { this.emitter = new NativeEventEmitter(LiteRTModule); + this.emitter.addListener(EVENT_DEBUG_LOG, (msg: string) => { + useDebugLogsStore.getState().addLog('log', `[Kotlin] ${msg}`); + }); logger.log(TAG, 'initialized — native module available'); } else { logger.log(TAG, 'native module not available on this platform'); @@ -128,6 +132,7 @@ class LiteRTService { const toolsJson = tools && tools.length > 0 ? JSON.stringify(tools) : ''; const historyJson = history && history.length > 0 ? JSON.stringify(history) : ''; logger.log(TAG, `resetConversation — systemPrompt length=${systemPrompt.length} temperature=${temperature} topK=${topK} topP=${topP} tools=${tools?.length ?? 0} history=${history?.length ?? 0}`); + useDebugLogsStore.getState().addLog('log', `[LiteRT] Conv reset — temperature=${temperature} topK=${topK} topP=${topP} tools=${tools?.length ?? 0}`); await LiteRTModule.resetConversation(systemPrompt, temperature, topK, topP, toolsJson, historyJson); this.activeSystemPrompt = systemPrompt; this.activeToolsJson = toolsJson; From 2da411aacb41b6ef66ef6cd35180f5e6851ea895 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:38:49 +0530 Subject: [PATCH 15/93] fix(litert): fix engine variable scoping in initializeWithFallback Co-authored-by: Dishit Karia --- .../ai/offgridmobile/litert/LiteRTModule.kt | 48 ++++++++++++++----- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt index b7475f8dc..f12309c11 100644 --- a/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/litert/LiteRTModule.kt @@ -140,6 +140,7 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : } else { Log.i(TAG, "initializeWithFallback — trying $name vision=$visionEnabled") } + var eng: Engine? = null try { debugLog("EngineConfig — backend=$name maxNumTokens=$configuredMaxTokens vision=$visionEnabled") val cfg = EngineConfig( @@ -149,7 +150,7 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : cacheDir = null, visionBackend = if (visionEnabled) Backend.GPU() else null, ) - val eng = Engine(cfg) + eng = Engine(cfg) val timeoutMs = initTimeoutMs(backend, configuredMaxTokens) debugLog("Engine.initialize — backend=$name timeoutMs=${timeoutMs / 1000}s") withTimeout(timeoutMs) { @@ -161,8 +162,14 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : return backend } catch (e: Exception) { Log.w(TAG, "initializeWithFallback — $name attempt $attempt failed: ${e.message}") - engine?.close() - engine = null + // Close the local engine attempt — not the module-level `engine` field, + // which belongs to a previous successful load and must not be touched here. + try { + eng?.close() + Log.d(TAG, "initializeWithFallback — $name attempt $attempt engine closed after failure") + } catch (closeEx: Exception) { + Log.w(TAG, "initializeWithFallback — $name attempt $attempt engine close error: ${closeEx.message}") + } lastError = e } } @@ -552,13 +559,29 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : val toolsArray = JsonParser.parseString(toolsJson).asJsonArray if (toolsArray.size() == 0) return emptyList() - val providers = toolsArray.map { element -> - val toolObj = element.asJsonObject - val toolName = toolObj.get("name").asString - val toolDescriptionJson = toolObj.toString() + val providers = toolsArray.mapNotNull { element -> + val wrapper = element.asJsonObject + + // JS sends OpenAI format: { type: "function", function: { name, description, parameters } } + // LiteRT SDK expects the unwrapped OpenAPI object: { name, description, parameters } + val funcObj = if (wrapper.has("function")) + wrapper.getAsJsonObject("function") + else + wrapper + + val toolName = funcObj.get("name")?.asString ?: return@mapNotNull null + + // Build clean OpenAPI JSON for the SDK + val openApiJson = com.google.gson.JsonObject().apply { + addProperty("name", toolName) + funcObj.get("description")?.let { addProperty("description", it.asString) } + funcObj.get("parameters")?.let { add("parameters", it) } + }.toString() + + debugLog("tool schema — $toolName: $openApiJson") val openApiTool = object : OpenApiTool { - override fun getToolDescriptionJsonString(): String = toolDescriptionJson + override fun getToolDescriptionJsonString(): String = openApiJson override fun execute(argsJson: String): String { val callId = UUID.randomUUID().toString() @@ -566,17 +589,16 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : pendingToolCalls[callId] = deferred val eventJson = """{"id":"$callId","name":"$toolName","arguments":$argsJson}""" - Log.d(TAG, "buildToolProviders — emitting tool call callId=$callId name=$toolName") + debugLog("tool_call — callId=$callId name=$toolName argsLen=${argsJson.length}") sendEvent(EVENT_TOOL_CALL, eventJson) return try { runBlocking { withTimeout(30_000L) { deferred.await() } } } catch (e: TimeoutCancellationException) { - Log.w(TAG, "buildToolProviders — tool call $callId timed out") + debugLog("tool_call timed out — callId=$callId name=$toolName") pendingToolCalls.remove(callId) "Error: Tool call timed out" } catch (e: CancellationException) { - Log.w(TAG, "buildToolProviders — tool call $callId cancelled") pendingToolCalls.remove(callId) "Error: Tool call cancelled" } @@ -585,10 +607,10 @@ class LiteRTModule(private val reactContext: ReactApplicationContext) : tool(openApiTool) } - Log.i(TAG, "buildToolProviders — registered ${providers.size} tools") + debugLog("buildToolProviders — registered ${providers.size} tools: [${toolsArray.mapNotNull { it.asJsonObject.getAsJsonObject("function")?.get("name")?.asString }.joinToString()}]") providers } catch (e: Exception) { - Log.w(TAG, "buildToolProviders — failed to parse toolsJson: ${e.message}") + debugLog("buildToolProviders — failed to parse toolsJson: ${e.message}") emptyList() } } From c57a29cc6773fa3cc84c6b779e59166a63d39110 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:40:23 +0530 Subject: [PATCH 16/93] fix(litert): pass image URI through tool loop and generation pipeline Co-authored-by: Dishit Karia --- src/services/generationServiceHelpers.ts | 31 ++++- src/services/generationToolLoop.ts | 13 +- src/services/litert.ts | 156 +++++++++++++++++++++-- 3 files changed, 178 insertions(+), 22 deletions(-) diff --git a/src/services/generationServiceHelpers.ts b/src/services/generationServiceHelpers.ts index d62447444..dd6420fef 100644 --- a/src/services/generationServiceHelpers.ts +++ b/src/services/generationServiceHelpers.ts @@ -7,7 +7,7 @@ import { liteRTService } from './litert'; import { useAppStore, useChatStore, useRemoteServerStore } from '../stores'; import { useDebugLogsStore } from '../stores/debugLogsStore'; import type { Message, GenerationMeta } from '../types'; -import { runToolLoop } from './generationToolLoop'; +import { runToolLoop, buildLiteRTHistory } from './generationToolLoop'; import type { ToolResult } from './tools/types'; import type { GenerationOptions, CompletionResult } from './providers/types'; import logger from '../utils/logger'; @@ -110,18 +110,26 @@ export function buildToolLoopHandlersImpl(svc: any) { onStream: (data: StreamChunk) => { if (svc.abortRequested) return; const chunk = typeof data === 'string' ? { content: data } : data; + const dbg = useDebugLogsStore.getState().addLog; if (chunk.content) { if (!svc.state.streamingContent && svc.remoteTimeToFirstToken === undefined) { svc.remoteTimeToFirstToken = svc.state.startTime ? (Date.now() - svc.state.startTime) / 1000 : undefined; } + if (!svc.state.streamingContent) { + dbg('log', `[Stream] first content token — reasoningAccum=${svc.reasoningBuffer.length + svc.totalReasoningLength}ch flushTimer=${svc.flushTimer != null}`); + } svc.state.streamingContent += chunk.content; svc.tokenBuffer += chunk.content; } if (chunk.reasoningContent) { + const wasEmpty = svc.reasoningBuffer.length === 0 && svc.totalReasoningLength === 0; svc.reasoningBuffer += chunk.reasoningContent; svc.totalReasoningLength += chunk.reasoningContent.length; + if (wasEmpty) { + dbg('log', `[Stream] first reasoning token — token="${chunk.reasoningContent.substring(0, 40)}" flushTimer=${svc.flushTimer != null}`); + } } if (!svc.flushTimer) { svc.flushTimer = setTimeout(() => svc.flushTokenBuffer(), FLUSH_INTERVAL_MS); @@ -201,28 +209,38 @@ export async function generateResponseImpl( } const systemMsg = messages.find(m => m.role === 'system'); const systemPrompt = typeof systemMsg?.content === 'string' ? systemMsg.content : ''; - const imageAttachment = lastUser.attachments?.find((a: any) => a.type === 'image'); + const allAttachments = lastUser.attachments ?? []; + const imageAttachment = allAttachments.find((a: any) => a.type === 'image'); const imageUri = imageAttachment?.uri as string | undefined; - dbg('log', `[LiteRT] generateResponse — hasImage=${!!imageUri}, systemPrompt length=${systemPrompt.length}, userText length=${typeof lastUser.content === 'string' ? lastUser.content.length : 0}`); + + dbg('log', `[Vision] attachments — total=${allAttachments.length} types=[${allAttachments.map((a: any) => a.type).join(',')}] imageFound=${!!imageAttachment}`); + dbg('log', `[Vision] imageUri — ${imageUri ? imageUri.substring(0, 80) : 'none'}`); + dbg('log', `[LiteRT] generateResponse — hasImage=${!!imageUri} conversationId=${conversationId.substring(0, 8)} messages=${messages.length} systemLen=${systemPrompt.length} userTextLen=${typeof lastUser.content === 'string' ? lastUser.content.length : 0}`); // Guard: image attached but model was not imported with vision support if (imageUri) { const { downloadedModels, activeModelId } = useAppStore.getState(); const activeModel = downloadedModels.find((m: any) => m.id === activeModelId); + dbg('log', `[Vision] model check — activeModelId=${activeModelId?.substring(0, 12)} liteRTVision=${activeModel?.liteRTVision} modelFound=${!!activeModel}`); if (!activeModel?.liteRTVision) { - dbg('warn', '[LiteRT] Image attached but model does not support vision — aborting'); + dbg('warn', '[Vision] BLOCKED — model does not have vision support (liteRTVision=false). Re-import model with vision enabled.'); chatStore.clearStreamingMessage(); svc.resetState(); throw new Error('This model does not support images. Import it with vision enabled, or remove the image.'); } + dbg('log', '[Vision] model vision check passed'); } + const history = buildLiteRTHistory(messages); + dbg('log', `[Vision] history built — turns=${history.length} (messages before last user turn, excluding system)`); + try { const { settings } = useAppStore.getState(); await liteRTService.prepareConversation(conversationId, systemPrompt, { samplerConfig: { temperature: settings.temperature, topP: settings.topP }, + history, }); - dbg('log', `[LiteRT] sendMessage start — imageUri=${imageUri ?? 'none'}`); + dbg('log', `[Vision] calling sendMessage — imageUri=${imageUri ? imageUri.substring(0, 60) : 'none'} textLen=${typeof lastUser.content === 'string' ? lastUser.content.length : 0}`); await liteRTService.sendMessage( typeof lastUser.content === 'string' ? lastUser.content : '', @@ -244,6 +262,9 @@ export async function generateResponseImpl( onReasoning: (token: string) => { if (svc.abortRequested) return; svc.reasoningBuffer += token; + if (!svc.flushTimer) { + svc.flushTimer = setTimeout(() => svc.flushTokenBuffer(), FLUSH_INTERVAL_MS); + } }, onComplete: (_content: string, _reasoning: string, stats) => { if (svc.abortRequested) return; diff --git a/src/services/generationToolLoop.ts b/src/services/generationToolLoop.ts index 6c0a34dc8..e5a2e875b 100644 --- a/src/services/generationToolLoop.ts +++ b/src/services/generationToolLoop.ts @@ -315,7 +315,7 @@ function buildLiteRTSendText(messages: Message[]): string { return ''; } -function buildLiteRTHistory(messages: Message[]): Array<{ role: 'user' | 'assistant'; content: string }> { +export function buildLiteRTHistory(messages: Message[]): Array<{ role: 'user' | 'assistant'; content: string }> { let lastUserIdx = -1; for (let i = messages.length - 1; i >= 0; i--) { if (messages[i].role === 'user') { lastUserIdx = i; break; } } if (lastUserIdx <= 0) return []; @@ -356,13 +356,16 @@ async function callLiteRTForLoop( const systemPrompt = typeof systemMsg?.content === 'string' ? systemMsg.content : ''; const text = buildLiteRTSendText(messages); const history = buildLiteRTHistory(messages); + const lastUser = [...messages].reverse().find(m => m.role === 'user'); + const imageAttachment = lastUser?.attachments?.find((a: any) => a.type === 'image'); + const imageUri = imageAttachment?.uri as string | undefined; const liteRTSettings = useAppStore.getState().settings; const samplerConfig = { temperature: liteRTSettings.temperature, topK: 40, topP: liteRTSettings.topP, }; - logger.log(`[ToolLoop][LiteRT] callLiteRTForLoop — convId=${conversationId}, text=${text.length}ch, sysPrompt=${systemPrompt.length}ch, tools=${tools.length}, history=${history.length}`); + logger.log(`[ToolLoop][LiteRT] callLiteRTForLoop — convId=${conversationId}, text=${text.length}ch, sysPrompt=${systemPrompt.length}ch, tools=${tools.length}, history=${history.length}, hasImage=${!!imageUri}`); logger.log(`[ToolLoop][LiteRT] samplerConfig — temperature=${samplerConfig.temperature} topK=${samplerConfig.topK} topP=${samplerConfig.topP}`); logger.log(`[ToolLoop][LiteRT] sysPrompt first500: "${systemPrompt.substring(0, 500)}"`); logger.log(`[ToolLoop][LiteRT] sending text: "${text.substring(0, 300)}"`); @@ -376,6 +379,8 @@ async function callLiteRTForLoop( text, token => onStream?.({ content: token }), onToolCall, + imageUri, + token => onStream?.({ reasoningContent: token }), ); logger.log(`[ToolLoop][LiteRT] raw response (${fullResponse.length}ch): "${fullResponse.substring(0, 400)}"`); // Native SDK handles all tool→model cycles internally; toolCalls always empty here @@ -449,7 +454,9 @@ function buildStreamHandler(ctx: ToolLoopContext, state: ToolLoopState): ((data: return (data: StreamChunk) => { if (ctx.isAborted()) return; const chunk = normalizeStreamChunk(data); - if (!state.firstTokenFired) { + // Only fire onThinkingDone when the first *content* token arrives — reasoning + // tokens mean the model is still thinking, so keep isThinking=true until then. + if (chunk.content && !state.firstTokenFired) { state.firstTokenFired = true; state.thinkingDoneFired = true; ctx.onThinkingDone(); diff --git a/src/services/litert.ts b/src/services/litert.ts index 0654e1d29..548172190 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -12,6 +12,7 @@ import { NativeModules, NativeEventEmitter, Platform, EmitterSubscription } from 'react-native'; import logger from '../utils/logger'; import { useDebugLogsStore } from '../stores/debugLogsStore'; +import { contextCompactionService } from './contextCompaction'; const TAG = '[LiteRTService]'; @@ -131,20 +132,39 @@ class LiteRTService { const topP = samplerConfig?.topP ?? 0.95; const toolsJson = tools && tools.length > 0 ? JSON.stringify(tools) : ''; const historyJson = history && history.length > 0 ? JSON.stringify(history) : ''; + const dbg = useDebugLogsStore.getState().addLog; logger.log(TAG, `resetConversation — systemPrompt length=${systemPrompt.length} temperature=${temperature} topK=${topK} topP=${topP} tools=${tools?.length ?? 0} history=${history?.length ?? 0}`); - useDebugLogsStore.getState().addLog('log', `[LiteRT] Conv reset — temperature=${temperature} topK=${topK} topP=${topP} tools=${tools?.length ?? 0}`); + dbg('log', `[LiteRT] resetConversation — systemLen=${systemPrompt.length} temp=${temperature} topK=${topK} topP=${topP} tools=${tools?.length ?? 0} historyTurns=${history?.length ?? 0}`); + if (history && history.length > 0) { + const totalHistoryChars = history.reduce((s, m) => s + m.content.length, 0); + dbg('log', `[LiteRT] resetConversation history — ${history.map((m, i) => `[${i}] ${m.role}(${m.content.length}ch)`).join(' ')}`); + dbg('log', `[LiteRT] resetConversation history chars — total=${totalHistoryChars} estimatedTokens=~${Math.ceil(totalHistoryChars / 4)}`); + } await LiteRTModule.resetConversation(systemPrompt, temperature, topK, topP, toolsJson, historyJson); this.activeSystemPrompt = systemPrompt; this.activeToolsJson = toolsJson; - this.cumulativeTokens = 0; - logger.log(TAG, 'resetConversation — done'); + // Seed the counter with estimated tokens already in the KV cache from history + system prompt. + // The SDK loads these silently via ConversationConfig.initialMessages so they never appear + // in lastPrefillTokenCount, causing cumulativeTokens to undercount and auto-compact to fire too late. + const historyChars = (history ?? []).reduce((sum, m) => sum + m.content.length, 0); + const systemChars = systemPrompt.length; + this.cumulativeTokens = Math.ceil((historyChars + systemChars) / 4); + dbg('log', `[LiteRT] resetConversation done — seeded cumulativeTokens=${this.cumulativeTokens} (historyChars=${historyChars} systemChars=${systemChars})`); + logger.log(TAG, `resetConversation — done, seeded cumulativeTokens=${this.cumulativeTokens} from history+system estimate`); + } /** * Ensure conversation is ready for the given context. * Resets only when conversationId or systemPrompt has changed — preserves * native turn history for follow-up messages in the same conversation. - * Auto-trims history when cumulative token usage exceeds 80% of the context limit. + * + * Auto-compact fires at 65% of context: + * - Active session: asks the model to summarize what was discussed, then resets + * with [summary context + recent turns]. One reset only — summarization runs + * in the current KV cache while headroom still exists. + * - First load (no active session): falls back to slicing the oldest half because + * we cannot summarize a session that has not been loaded yet. */ async prepareConversation( conversationId: string, @@ -155,16 +175,70 @@ class LiteRTService { history?: Array<{ role: 'user' | 'assistant'; content: string }>; }, ): Promise { + const dbg = useDebugLogsStore.getState().addLog; const toolsJson = opts?.tools && opts.tools.length > 0 ? JSON.stringify(opts.tools) : ''; - // Auto-compact: trim oldest turns when nearing context limit const maxTokens = this.configuredMaxTokens; - let history = opts?.history; - if (maxTokens > 0 && this.cumulativeTokens > maxTokens * 0.8 && history && history.length > 2) { - const trimmedHistory = history.slice(Math.floor(history.length / 2)); - logger.log(TAG, `prepareConversation — auto-compact: cumulativeTokens=${this.cumulativeTokens} > ${Math.floor(maxTokens * 0.8)} (80% of ${maxTokens}), trimming history ${history.length} → ${trimmedHistory.length} turns`); - this.cumulativeTokens = Math.floor(this.cumulativeTokens * 0.5); - await this.resetConversation(systemPrompt, { samplerConfig: opts?.samplerConfig, tools: opts?.tools, history: trimmedHistory }); + const history = opts?.history; + const incomingEstimate = history + ? Math.ceil((history.reduce((s, m) => s + m.content.length, 0) + systemPrompt.length) / 4) + : 0; + + dbg('log', `[LiteRT] prepareConversation — convId=${conversationId.substring(0, 8)} activeConvId=${this.activeConversationId?.substring(0, 8) ?? 'null'} sameConv=${this.activeConversationId === conversationId}`); + dbg('log', `[LiteRT] prepareConversation state — cumulativeTokens=${this.cumulativeTokens} maxTokens=${maxTokens} historyTurns=${history?.length ?? 0} incomingEstimate=~${incomingEstimate}`); + + const COMPACT_THRESHOLD = 0.65; + const needsCompact = + maxTokens > 0 && + history && + history.length > 2 && + (this.cumulativeTokens > maxTokens * COMPACT_THRESHOLD || incomingEstimate > maxTokens * COMPACT_THRESHOLD); + + dbg('log', `[LiteRT] prepareConversation compact check — needsCompact=${needsCompact} threshold=${Math.floor(maxTokens * COMPACT_THRESHOLD)} cumul=${this.cumulativeTokens} incoming=~${incomingEstimate}`); + + if (needsCompact) { + contextCompactionService.signalCompacting(true); + try { + // Select recent turns that fit within 40% of context by char estimate. + // Always keep at least the last 2 turns regardless of budget. + const recentBudgetChars = Math.floor(maxTokens * 0.40) * 4; + let charCount = 0; + let recentStart = history!.length; + for (let i = history!.length - 1; i >= 0; i--) { + charCount += history![i].content.length; + if (charCount > recentBudgetChars) break; + recentStart = i; + } + recentStart = Math.min(recentStart, Math.max(0, history!.length - 2)); + const recentHistory = history!.slice(recentStart); + + const hasActiveSession = this.activeConversationId === conversationId; + let summary: string | null = null; + + if (hasActiveSession) { + dbg('log', `[LiteRT] compact — active session, requesting summary (cumulative=${this.cumulativeTokens}/${maxTokens})`); + logger.log(TAG, `prepareConversation — compact: active session at cumulative=${this.cumulativeTokens}/${maxTokens}, requesting summary`); + summary = await this.summarizeCurrentSession(); + dbg('log', `[LiteRT] compact summary — got=${!!summary} length=${summary?.length ?? 0} chars`); + } else { + dbg('log', `[LiteRT] compact — no active session, slicing only (incomingEstimate=~${incomingEstimate}/${maxTokens})`); + logger.log(TAG, `prepareConversation — compact: first load, incomingEstimate=${incomingEstimate}/${maxTokens}, no active session to summarize — slicing`); + } + + const compactedHistory: Array<{ role: 'user' | 'assistant'; content: string }> = summary + ? [ + { role: 'user', content: `[Context from earlier in our conversation]: ${summary}` }, + { role: 'assistant', content: 'Understood.' }, + ...recentHistory, + ] + : recentHistory; + + dbg('log', `[LiteRT] compact done — ${history!.length} → ${compactedHistory.length} turns, summarized=${!!summary}`); + logger.log(TAG, `prepareConversation — compact done: ${history!.length} → ${compactedHistory.length} turns, summarized=${!!summary}`); + await this.resetConversation(systemPrompt, { samplerConfig: opts?.samplerConfig, tools: opts?.tools, history: compactedHistory }); + } finally { + contextCompactionService.signalCompacting(false); + } this.activeConversationId = conversationId; this.activeSystemPrompt = systemPrompt; this.activeToolsJson = toolsJson; @@ -175,15 +249,64 @@ class LiteRTService { this.activeConversationId !== conversationId || this.activeSystemPrompt !== systemPrompt || this.activeToolsJson !== toolsJson; + + const resetReason = needsReset + ? [ + this.activeConversationId !== conversationId ? `newConv(${this.activeConversationId?.substring(0, 8) ?? 'null'}→${conversationId.substring(0, 8)})` : '', + this.activeSystemPrompt !== systemPrompt ? 'sysPromptChanged' : '', + this.activeToolsJson !== toolsJson ? 'toolsChanged' : '', + ].filter(Boolean).join(' ') + : 'none'; + + dbg('log', `[LiteRT] prepareConversation decision — needsReset=${needsReset} reason=${resetReason} historyTurns=${opts?.history?.length ?? 0}`); + if (needsReset) { logger.log(TAG, `prepareConversation — reset (convId changed=${this.activeConversationId !== conversationId}, sysPrompt changed=${this.activeSystemPrompt !== systemPrompt}, tools changed=${this.activeToolsJson !== toolsJson}, history=${opts?.history?.length ?? 0})`); await this.resetConversation(systemPrompt, { samplerConfig: opts?.samplerConfig, tools: opts?.tools, history: opts?.history }); this.activeConversationId = conversationId; + dbg('log', `[LiteRT] prepareConversation reset complete — activeConvId set to ${conversationId.substring(0, 8)}`); } else { + dbg('log', `[LiteRT] prepareConversation — reusing existing session (multi-turn, no reset)`); logger.log(TAG, 'prepareConversation — reusing existing conversation (multi-turn)'); } } + /** + * Ask the active session to summarize itself while headroom remains (called at + * 65% context, so ~35% is free for the request + response). + * Returns null on timeout or error so callers can fall back to slice. + */ + private summarizeCurrentSession(): Promise { + return new Promise((resolve) => { + if (!this.isAvailable() || !this.loaded) { + resolve(null); + return; + } + let summary = ''; + const timeout = setTimeout(() => { + logger.log(TAG, 'summarizeCurrentSession — timed out, falling back to slice'); + resolve(null); + }, 20_000); + this.sendMessage( + 'Briefly summarize our conversation so far — key topics, decisions, and context. 3 to 5 sentences maximum.', + { + onToken: (token) => { summary += token; }, + onReasoning: () => {}, + onComplete: () => { + clearTimeout(timeout); + logger.log(TAG, `summarizeCurrentSession — got summary (${summary.length} chars)`); + resolve(summary.trim() || null); + }, + onError: (err) => { + clearTimeout(timeout); + logger.log(TAG, `summarizeCurrentSession — error: ${String(err)}, falling back to slice`); + resolve(null); + }, + }, + ); + }); + } + // --------------------------------------------------------------------------- // warmup — send a throwaway prompt to prime GPU/NPU shader caches // --------------------------------------------------------------------------- @@ -224,7 +347,9 @@ class LiteRTService { return; } + const sendMsgDbg = useDebugLogsStore.getState().addLog; logger.log(TAG, `sendMessage — text length=${text.length}`); + sendMsgDbg('log', `[Vision] sendMessage called — textLen=${text.length} hasImage=${imageUri != null} imageUri=${imageUri ? imageUri.substring(0, 80) : 'none'}`); // Reset accumulators this.currentContent = ''; @@ -317,6 +442,7 @@ class LiteRTService { ]; try { + sendMsgDbg('log', `[Vision] → LiteRTModule.sendMessage — imageArg=${imageUri != null ? 'SET' : 'NULL'}`); await LiteRTModule.sendMessage(text, imageUri ?? null); } catch (e) { this.clearSubscriptions(); @@ -335,13 +461,15 @@ class LiteRTService { text: string, onToken?: (token: string) => void, onToolCall?: (name: string, args: Record) => Promise, + imageUri?: string, + onReasoning?: (token: string) => void, ): Promise { - logger.log(TAG, `generateRaw — text=${text.length}ch, hasToolHandler=${!!onToolCall}, first100="${text.substring(0, 100)}"`); + logger.log(TAG, `generateRaw — text=${text.length}ch, hasToolHandler=${!!onToolCall}, hasImage=${!!imageUri}, first100="${text.substring(0, 100)}"`); this.currentToolCallHandler = onToolCall ?? null; return new Promise((resolve, reject) => { this.sendMessage(text, { onToken: t => onToken?.(t), - onReasoning: () => {}, + onReasoning: t => onReasoning?.(t), onComplete: (fullContent, _reasoning, stats) => { logger.log(TAG, `generateRaw — complete, response=${fullContent.length}ch, first200="${fullContent.substring(0, 200)}"`); this._lastBenchmarkStats = stats; @@ -352,7 +480,7 @@ class LiteRTService { this.currentToolCallHandler = null; reject(err); }, - }).catch(reject); + }, imageUri).catch(reject); }); } From a33f67e6bfbf7faf07b4a335d0a973f65d68f6e1 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:40:26 +0530 Subject: [PATCH 17/93] fix(litert): stream thinking tokens incrementally instead of all at once Co-authored-by: Dishit Karia --- .../ChatMessage/components/MessageContent.tsx | 18 ++++++++++++++++++ src/components/ChatMessage/index.tsx | 2 +- src/services/generationService.ts | 3 +++ src/types/index.ts | 1 + 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/components/ChatMessage/components/MessageContent.tsx b/src/components/ChatMessage/components/MessageContent.tsx index e2fa7afca..3fb76e97a 100644 --- a/src/components/ChatMessage/components/MessageContent.tsx +++ b/src/components/ChatMessage/components/MessageContent.tsx @@ -35,6 +35,24 @@ export function MessageContent({ ); } + // If no content yet but we have live reasoning tokens, show the thinking block so + // reasoning streams incrementally rather than appearing all at once when content arrives. + if (!content && isStreaming && parsedContent.thinking) { + return ( + + + + + + + ); + } + if (!content) { if (isStreaming) { return ( diff --git a/src/components/ChatMessage/index.tsx b/src/components/ChatMessage/index.tsx index d80310b7d..dd908ae0c 100644 --- a/src/components/ChatMessage/index.tsx +++ b/src/components/ChatMessage/index.tsx @@ -187,7 +187,7 @@ export const ChatMessage: React.FC = ({ const [showActionMenu, setShowActionMenu] = useState(false); const [isEditing, setIsEditing] = useState(false); const [editedContent, setEditedContent] = useState(message.content); - const [showThinking, setShowThinking] = useState(false); + const [showThinking, setShowThinking] = useState(!!isStreaming); const [alertState, setAlertState] = useState(initialAlertState); const { displayContent, parsedContent } = buildMessageData(message); diff --git a/src/services/generationService.ts b/src/services/generationService.ts index 7eebdc302..b245ebefe 100644 --- a/src/services/generationService.ts +++ b/src/services/generationService.ts @@ -2,6 +2,7 @@ import { llmService } from './llm'; import { liteRTService } from './litert'; import { useAppStore, useChatStore, useRemoteServerStore } from '../stores'; +import { useDebugLogsStore } from '../stores/debugLogsStore'; import { Message, GenerationMeta, MediaAttachment } from '../types'; import { runToolLoop } from './generationToolLoop'; import type { ToolResult } from './tools/types'; @@ -85,11 +86,13 @@ class GenerationService { private flushTokenBuffer(): void { const store = useChatStore.getState(); + const dbg = useDebugLogsStore.getState().addLog; if (this.tokenBuffer) { store.appendToStreamingMessage(this.tokenBuffer); this.tokenBuffer = ''; } if (this.reasoningBuffer) { + dbg('log', `[Flush] reasoningBuffer → store — len=${this.reasoningBuffer.length}ch totalReasoning=${this.totalReasoningLength}ch storeIsThinking=${store.isThinking}`); store.appendToStreamingReasoningContent(this.reasoningBuffer); this.reasoningBuffer = ''; } diff --git a/src/types/index.ts b/src/types/index.ts index 840e8e22d..232aef3c8 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -295,6 +295,7 @@ export type AutoDetectMethod = 'pattern' | 'llm'; export type ModelLoadingStrategy = 'performance' | 'memory'; export type CacheType = 'f16' | 'q8_0' | 'q4_0'; export type InferenceBackend = 'cpu' | 'opencl' | 'htp' | 'metal'; +export type LiteRTBackend = 'cpu' | 'gpu' | 'npu'; export const INFERENCE_BACKENDS = { CPU: 'cpu' as InferenceBackend, OPENCL: 'opencl' as InferenceBackend, From 4921b41de04cc11219ec8b15da7034b6f6c8a5a5 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:40:30 +0530 Subject: [PATCH 18/93] fix(litert): enable thinking toggle and deduplicate tool text hint Co-authored-by: Dishit Karia --- .../ChatScreen/useChatGenerationActions.ts | 32 ++++++++++++------- src/screens/ChatScreen/useChatModelActions.ts | 2 +- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/screens/ChatScreen/useChatGenerationActions.ts b/src/screens/ChatScreen/useChatGenerationActions.ts index 0c6f35b05..325b3634a 100644 --- a/src/screens/ChatScreen/useChatGenerationActions.ts +++ b/src/screens/ChatScreen/useChatGenerationActions.ts @@ -49,6 +49,7 @@ export type GenerationDeps = { imageGuidanceScale?: number; enabledTools?: string[]; cacheType?: CacheType; + thinkingEnabled?: boolean; }; downloadedModels: DownloadedModel[]; setAlertState: SetState; @@ -220,10 +221,13 @@ async function injectRagContext(projectId: string | undefined, query: string, pr } return prompt; } -/** Gemma 4 E2B/E4B need <|think|> prepended to activate thinking mode. */ -const applyGemma4ThinkToken = (prompt: string, isRemote: boolean): string => - (!isRemote && llmService.isGemma4Model() && llmService.isThinkingEnabled()) ? `<|think|>\n${prompt}` : prompt; -function resolveToolsAndPrompt(deps: GenerationDeps, conversation: any, _messageText: string): { enabledTools: string[]; rawPrompt: string } { +/** Gemma 4 E2B/E4B need <|think|> prepended to activate thinking mode — both llama.cpp and LiteRT. */ +const applyGemma4ThinkToken = (prompt: string, isRemote: boolean, isLiteRT: boolean = false, thinkingEnabled: boolean = false): string => { + const liteRTWantsThink = !isRemote && isLiteRT && thinkingEnabled; + const llamaWantsThink = !isRemote && llmService.isGemma4Model() && llmService.isThinkingEnabled(); + return (liteRTWantsThink || llamaWantsThink) ? `<|think|>\n${prompt}` : prompt; +}; +function resolveToolsAndPrompt(deps: GenerationDeps, conversation: any, _messageText: string): { enabledTools: string[]; rawPrompt: string; isLiteRT: boolean } { const project = conversation?.projectId ? useProjectStore.getState().getProject(conversation.projectId) : null; const { activeServerId, activeRemoteTextModelId } = useRemoteServerStore.getState(); const localToolCalling = llmService.supportsToolCalling(); @@ -240,7 +244,7 @@ function resolveToolsAndPrompt(deps: GenerationDeps, conversation: any, _message const rawPrompt = project?.systemPrompt || deps.settings.systemPrompt || APP_CONFIG.defaultSystemPrompt; logger.log(`[ChatGen][resolveTools] isLiteRT=${isLiteRT}, canUseTools=${canUseTools}, enabledTools=[${enabledTools.join(', ')}]`); - return { enabledTools, rawPrompt }; + return { enabledTools, rawPrompt, isLiteRT }; } export async function startGenerationFn(deps: GenerationDeps, call: StartGenerationCall): Promise { const { setDebugInfo, targetConversationId, messageText } = call; @@ -262,18 +266,20 @@ export async function startGenerationFn(deps: GenerationDeps, call: StartGenerat } } const conversation = useChatStore.getState().conversations.find(c => c.id === targetConversationId); - const { enabledTools, rawPrompt } = resolveToolsAndPrompt(deps, conversation, messageText); + const { enabledTools, rawPrompt, isLiteRT } = resolveToolsAndPrompt(deps, conversation, messageText); const basePrompt = await injectRagContext(conversation?.projectId, messageText, rawPrompt); const isRemote = !!useRemoteServerStore.getState().activeRemoteTextModelId; const activeTools = enabledTools; - // Skip text hint when model supports native Jinja tool calling — the Jinja template - // already injects tool schemas, so adding the hint text would double-inject. - const useTextHint = !isRemote && activeTools.length > 0 && !llmService.supportsToolCalling(); + // LiteRT passes tools natively via ConversationConfig — text hint would double-inject. + // llama.cpp uses text hint only when it lacks native Jinja tool calling support. + const useTextHint = !isRemote && !isLiteRT && activeTools.length > 0 && !llmService.supportsToolCalling(); const systemPrompt = applyGemma4ThinkToken( useTextHint ? `${basePrompt}${buildToolSystemPromptHint(activeTools)}` : basePrompt, isRemote, + isLiteRT, + deps.settings.thinkingEnabled, ); - logger.log(`[ChatGen][DEBUG] isRemote=${isRemote}, useTextHint=${useTextHint}, tools=[${activeTools.join(', ')}], path=${activeTools.length > 0 ? 'withTools' : 'generate'}`); + logger.log(`[ChatGen][DEBUG] isRemote=${isRemote}, isLiteRT=${isLiteRT}, useTextHint=${useTextHint}, tools=[${activeTools.join(', ')}], path=${activeTools.length > 0 ? 'withTools' : 'generate'}`); logger.log(`[ChatGen][PROMPT] systemPrompt (${systemPrompt.length}ch): "${systemPrompt.substring(0, 800)}"`); const messagesForContext = buildMessagesForContext(targetConversationId, messageText, systemPrompt); await prepareContext(setDebugInfo, systemPrompt, messagesForContext); @@ -358,14 +364,16 @@ export async function regenerateResponseFn(deps: GenerationDeps, call: Regenerat const messages = (conversation?.messages || []).filter((m: Message) => !m.isSystemInfo); const messagesUpToUser = messages.slice(0, messages.findIndex((m: Message) => m.id === userMessage.id) + 1) .map(m => m.id === userMessage.id ? { ...m, content: messageText } : m); - const { enabledTools, rawPrompt } = resolveToolsAndPrompt(deps, conversation, messageText); + const { enabledTools, rawPrompt, isLiteRT: isLiteRTRegen } = resolveToolsAndPrompt(deps, conversation, messageText); const isRemote = !!useRemoteServerStore.getState().activeRemoteTextModelId; const activeTools = enabledTools; const basePrompt = await injectRagContext(conversation?.projectId, messageText, rawPrompt); - const useTextHint = !isRemote && activeTools.length > 0 && !llmService.supportsToolCalling(); + const useTextHint = !isRemote && !isLiteRTRegen && activeTools.length > 0 && !llmService.supportsToolCalling(); const systemPrompt = applyGemma4ThinkToken( useTextHint ? `${basePrompt}${buildToolSystemPromptHint(activeTools)}` : basePrompt, isRemote, + isLiteRTRegen, + deps.settings.thinkingEnabled, ); const { prefix, filtered } = applyCompactionPrefix(conversation, systemPrompt, messagesUpToUser); try { diff --git a/src/screens/ChatScreen/useChatModelActions.ts b/src/screens/ChatScreen/useChatModelActions.ts index 65b33af3b..85fc694bb 100644 --- a/src/screens/ChatScreen/useChatModelActions.ts +++ b/src/screens/ChatScreen/useChatModelActions.ts @@ -341,7 +341,7 @@ export function useChatModelStateSync(deps: ModelStateSyncDeps): void { setSupportsThinking(activeRemoteModel?.capabilities?.supportsThinking ?? false); } else if (activeModel?.engine === 'litert' && liteRTService.isModelLoaded()) { setSupportsToolCalling(true); - setSupportsThinking(false); + setSupportsThinking(true); } else if (llmService.isModelLoaded()) { setSupportsToolCalling(llmService.supportsToolCalling()); setSupportsThinking(llmService.supportsThinking()); From 98b7ab80b7ecba567d29cacebeb5dd0597c1a574 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:40:34 +0530 Subject: [PATCH 19/93] fix(tools): update web_search and read_url descriptions for chained use Co-authored-by: Dishit Karia --- src/services/tools/registry.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/services/tools/registry.ts b/src/services/tools/registry.ts index d66781c6d..e0d78b7ea 100644 --- a/src/services/tools/registry.ts +++ b/src/services/tools/registry.ts @@ -5,7 +5,7 @@ export const AVAILABLE_TOOLS: ToolDefinition[] = [ id: 'web_search', name: 'web_search', displayName: 'Web Search', - description: 'Search the web', + description: 'Search the web and return result titles, snippets, and URLs. When the snippet is insufficient, call read_url on the most relevant result URL to get the full page content.', icon: 'globe', requiresNetwork: true, parameters: { @@ -75,13 +75,13 @@ export const AVAILABLE_TOOLS: ToolDefinition[] = [ id: 'read_url', name: 'read_url', displayName: 'URL Reader', - description: 'Fetch and read a web page', + description: 'Fetch the full content of a URL. Use this after web_search to read the complete text of a result page when the search snippet does not contain enough detail.', icon: 'link', requiresNetwork: true, parameters: { url: { type: 'string', - description: 'URL to fetch', + description: 'Full URL to fetch', required: true, }, }, From 4d5ef28631a91955db37be8c0f4151b41d8608b2 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:40:37 +0530 Subject: [PATCH 20/93] feat(litert): add RAM-based context slider limits for LiteRT models Co-authored-by: Dishit Karia --- .../TextGenerationSection.tsx | 30 ++++++++++++------- .../TextGenerationSection.tsx | 22 +++++++------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/components/GenerationSettingsModal/TextGenerationSection.tsx b/src/components/GenerationSettingsModal/TextGenerationSection.tsx index c0e29cf54..96ecea2e5 100644 --- a/src/components/GenerationSettingsModal/TextGenerationSection.tsx +++ b/src/components/GenerationSettingsModal/TextGenerationSection.tsx @@ -4,6 +4,7 @@ import Slider from '@react-native-community/slider'; import { AdvancedToggle } from '../AdvancedToggle'; import { useTheme, useThemedStyles } from '../../theme'; import { useAppStore, selectIsLiteRT } from '../../stores'; +import { hardwareService } from '../../services'; import { createStyles } from './styles'; import { CpuThreadsSlider, @@ -23,6 +24,7 @@ interface SettingConfig { format: (value: number) => string; description?: string; warning?: (value: number) => string | null; + warningColor?: string; } const DEFAULT_SETTINGS: Record = { @@ -33,19 +35,24 @@ const DEFAULT_SETTINGS: Record = { contextLength: 4096, }; -const FALLBACK_MAX_CONTEXT = 32768; -const HIGH_CONTEXT_THRESHOLD = 8192; - const formatContext = (v: number) => v >= 1024 ? `${(v / 1024).toFixed(0)}K` : v.toString(); -const contextWarning = (v: number): string | null => - v > HIGH_CONTEXT_THRESHOLD ? 'High context uses significant RAM and may crash on some devices' : null; - const BASIC_KEYS = ['temperature', 'maxTokens', 'contextLength']; const LITERT_BASIC_KEYS = ['temperature', 'contextLength']; const LITERT_ADVANCED_KEYS = ['topP']; -const buildSettingsConfig = (modelMaxContext: number | null, isLiteRT: boolean): SettingConfig[] => [ +const buildSettingsConfig = (modelMaxContext: number | null, isLiteRT: boolean): SettingConfig[] => { + const isLargeRam = hardwareService.getTotalMemoryGB() > 8; + const liteRTMax = modelMaxContext ?? (isLargeRam ? 32768 : 12288); + const liteRTWarn = isLargeRam ? 16384 : 8192; + const llmMax = modelMaxContext ?? 32768; + + const contextWarning = (v: number): string | null => { + if (!isLiteRT) return v > 8192 ? 'High context uses significant RAM and may crash on some devices' : null; + return v > liteRTWarn ? 'High context uses significant RAM — may slow or crash on some devices' : null; + }; + + return [ { key: 'temperature', label: 'Temperature', @@ -86,15 +93,17 @@ const buildSettingsConfig = (modelMaxContext: number | null, isLiteRT: boolean): key: 'contextLength', label: isLiteRT ? 'Max Tokens' : 'Context Length', min: 512, - max: modelMaxContext || FALLBACK_MAX_CONTEXT, + max: isLiteRT ? liteRTMax : llmMax, step: 1024, format: formatContext, description: isLiteRT ? 'Total context window — input + history + output combined (requires reload)' : 'KV cache size — larger uses more RAM (requires reload)', warning: contextWarning, + warningColor: isLiteRT ? '#F59E0B' : undefined, }, -]; + ]; +}; interface SettingSliderProps { config: SettingConfig; @@ -107,6 +116,7 @@ const SettingSlider: React.FC = ({ config }) => { const rawValue = (settings as Record)[config.key]; const value = (rawValue ?? DEFAULT_SETTINGS[config.key]) as number; const warningText = config.warning?.(value) ?? null; + const warningColor = config.warningColor ?? colors.error; return ( @@ -118,7 +128,7 @@ const SettingSlider: React.FC = ({ config }) => { {config.description} )} {warningText && ( - {warningText} + {warningText} )} { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); @@ -18,6 +16,11 @@ export const TextGenerationSection: React.FC = () => { const isLiteRT = useAppStore(selectIsLiteRT); const [showAdvanced, setShowAdvanced] = useState(false); + const isLargeRam = hardwareService.getTotalMemoryGB() > 8; + const liteRTSliderMax = modelMaxContext ?? (isLargeRam ? 32768 : 12288); + const liteRTWarnThreshold = isLargeRam ? 16384 : 8192; + const llmSliderMax = modelMaxContext ?? 32768; + const trackColor = { false: colors.surfaceLight, true: `${colors.primary}80` }; const maxTokens = settings?.maxTokens || 512; const maxTokensLabel = maxTokens >= 1024 @@ -27,7 +30,6 @@ export const TextGenerationSection: React.FC = () => { const contextLengthLabel = contextLength >= 1024 ? `${(contextLength / 1024).toFixed(0)}K` : String(contextLength); - const ctxSliderMax = modelMaxContext || FALLBACK_MAX_CONTEXT; return ( @@ -82,15 +84,15 @@ export const TextGenerationSection: React.FC = () => { {contextLengthLabel} Total context window — input + history + output combined (requires reload) - {contextLength > HIGH_CONTEXT_THRESHOLD && ( - - High values use significant RAM and may fail on some devices + {contextLength > liteRTWarnThreshold && ( + + High context uses significant RAM — may slow or crash on some devices )} updateSettings({ contextLength: value })} @@ -106,7 +108,7 @@ export const TextGenerationSection: React.FC = () => { {contextLengthLabel} KV cache size — larger uses more RAM (requires reload) - {contextLength > HIGH_CONTEXT_THRESHOLD && ( + {contextLength > 8192 && ( High context uses significant RAM and may crash on some devices @@ -114,7 +116,7 @@ export const TextGenerationSection: React.FC = () => { updateSettings({ contextLength: value })} From 8e9605abd2849f52a1d26d5dfbe432ae0eb618c8 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:40:46 +0530 Subject: [PATCH 21/93] feat(litert): lower auto-compact threshold to 65% and seed token counter from history Co-authored-by: Dishit Karia --- src/services/activeModelService/loaders.ts | 8 ++++++++ src/services/contextCompaction.ts | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/src/services/activeModelService/loaders.ts b/src/services/activeModelService/loaders.ts index 6016b0b1c..30b202465 100644 --- a/src/services/activeModelService/loaders.ts +++ b/src/services/activeModelService/loaders.ts @@ -3,6 +3,7 @@ * Extracted to keep index.ts under the max-lines limit. */ +import { Platform, ToastAndroid } from 'react-native'; import { useAppStore } from '../../stores'; import { useDebugLogsStore } from '../../stores/debugLogsStore'; import { DownloadedModel, ONNXImageModel, INFERENCE_BACKENDS } from '../../types'; @@ -155,6 +156,13 @@ async function doLoadLiteRTModel(ctx: TextLoadContext): Promise { addDebugLog('log', `[LiteRT] Load complete — actual backend: ${actualBackend}`); if (actualBackend !== preferredBackend) { addDebugLog('warn', `[LiteRT] Requested ${preferredBackend}, fell back to ${actualBackend}`); + if (preferredBackend === 'gpu' && actualBackend === 'cpu' && maxTokens > 8192 && Platform.OS === 'android') { + ToastAndroid.showWithGravity( + `GPU unavailable at ${maxTokens.toLocaleString()} token context. Running on CPU — reduce context length to use GPU.`, + ToastAndroid.LONG, + ToastAndroid.BOTTOM, + ); + } } // Warmup on GPU/NPU only — primes shader/kernel caches so first real prompt runs at full speed diff --git a/src/services/contextCompaction.ts b/src/services/contextCompaction.ts index 879c7597d..dd3de7fd5 100644 --- a/src/services/contextCompaction.ts +++ b/src/services/contextCompaction.ts @@ -57,6 +57,11 @@ class ContextCompactionService { this.compactingListeners.forEach(fn => fn(v)); } + /** Allow external services (e.g. LiteRT) to surface compaction state in the UI. */ + signalCompacting(v: boolean): void { + this.setCompacting(v); + } + isContextFullError(error: unknown): boolean { const msg = (error instanceof Error ? error.message : `${error as string}`).toLowerCase(); return CONTEXT_FULL_PATTERNS.some(p => msg.includes(p)); From b0c7d96f09de8dfb542195604f36db9908f0df19 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:40:49 +0530 Subject: [PATCH 22/93] feat(litert): add liteRTBackend setting defaulting to gpu Co-authored-by: Dishit Karia --- src/stores/appStore.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/stores/appStore.ts b/src/stores/appStore.ts index 5e5e3054f..6a5b64714 100644 --- a/src/stores/appStore.ts +++ b/src/stores/appStore.ts @@ -2,7 +2,7 @@ import { create } from 'zustand'; import { persist, createJSONStorage } from 'zustand/middleware'; import { Platform } from 'react-native'; import AsyncStorage from '@react-native-async-storage/async-storage'; -import { DeviceInfo, DownloadedModel, ModelRecommendation, ONNXImageModel, ImageGenerationMode, AutoDetectMethod, ModelLoadingStrategy, CacheType, InferenceBackend, INFERENCE_BACKENDS, GeneratedImage } from '../types'; +import { DeviceInfo, DownloadedModel, ModelRecommendation, ONNXImageModel, ImageGenerationMode, AutoDetectMethod, ModelLoadingStrategy, CacheType, InferenceBackend, INFERENCE_BACKENDS, LiteRTBackend, GeneratedImage } from '../types'; function isUnknownLike(value: string): boolean { const normalized = value.trim().toLowerCase(); @@ -40,6 +40,7 @@ type AppSettings = { cacheType: CacheType; showGenerationDetails: boolean; enabledTools: string[]; thinkingEnabled: boolean; inferenceBackend: InferenceBackend; + liteRTBackend: LiteRTBackend; }; type ThemeMode = 'system' | 'light' | 'dark'; @@ -136,6 +137,7 @@ const DEFAULT_SETTINGS: AppSettings = { showGenerationDetails: false, enabledTools: ['web_search', 'calculator', 'get_current_datetime', 'get_device_info', 'read_url', 'search_knowledge_base'], thinkingEnabled: true, + liteRTBackend: 'gpu' as LiteRTBackend, }; function migrateEnabledTools(merged: any): void { From 8934c29f1dee1c8b0e856ae402f6f4e42aedbbe8 Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 13:42:46 +0530 Subject: [PATCH 23/93] fix(litert): add display branch debug logs to getDisplayMessages Co-authored-by: Dishit Karia --- src/screens/ChatScreen/types.ts | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/screens/ChatScreen/types.ts b/src/screens/ChatScreen/types.ts index d4f5220a8..7e1722f59 100644 --- a/src/screens/ChatScreen/types.ts +++ b/src/screens/ChatScreen/types.ts @@ -1,4 +1,5 @@ import { Message } from '../../types'; +import { useDebugLogsStore } from '../../stores/debugLogsStore'; export type ChatMessageItem = { id: string; @@ -17,23 +18,36 @@ export type StreamingState = { isStreamingForThisConversation: boolean; }; +let _lastDisplayBranch = ''; export function getDisplayMessages( allMessages: Message[], streaming: StreamingState, ): (Message | ChatMessageItem)[] { const { isThinking, streamingMessage, streamingReasoningContent, isStreamingForThisConversation } = streaming; if (isThinking && isStreamingForThisConversation) { + if (_lastDisplayBranch !== 'thinking') { + _lastDisplayBranch = 'thinking'; + useDebugLogsStore.getState().addLog('log', `[Display] branch=thinking — reasoningLen=${streamingReasoningContent.length} msgLen=${streamingMessage.length}`); + } return [ ...allMessages, { id: 'thinking', role: 'assistant' as const, content: '', timestamp: Date.now(), isThinking: true }, ]; } if ((streamingMessage || streamingReasoningContent) && isStreamingForThisConversation) { + if (_lastDisplayBranch !== 'streaming') { + _lastDisplayBranch = 'streaming'; + useDebugLogsStore.getState().addLog('log', `[Display] branch=streaming — reasoningLen=${streamingReasoningContent.length} msgLen=${streamingMessage.length}`); + } return [ ...allMessages, { id: 'streaming', role: 'assistant' as const, content: streamingMessage, reasoningContent: streamingReasoningContent || undefined, timestamp: Date.now(), isStreaming: true }, ]; } + if (_lastDisplayBranch !== 'done') { + _lastDisplayBranch = 'done'; + useDebugLogsStore.getState().addLog('log', `[Display] branch=done — finalMessages=${allMessages.length}`); + } return allMessages; } From ff2a4e0b4417be8c8eab45a2142e0921e6f4f25b Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 14:27:59 +0530 Subject: [PATCH 24/93] fix(litert): detect native module by existence, not Platform.OS Co-authored-by: Dishit Karia --- src/services/litert.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/services/litert.ts b/src/services/litert.ts index 548172190..b183560d8 100644 --- a/src/services/litert.ts +++ b/src/services/litert.ts @@ -9,7 +9,7 @@ * - onComplete receives fully accumulated content, not an empty string. */ -import { NativeModules, NativeEventEmitter, Platform, EmitterSubscription } from 'react-native'; +import { NativeModules, NativeEventEmitter, EmitterSubscription } from 'react-native'; import logger from '../utils/logger'; import { useDebugLogsStore } from '../stores/debugLogsStore'; import { contextCompactionService } from './contextCompaction'; @@ -75,7 +75,7 @@ class LiteRTService { private configuredMaxTokens = 4096; constructor() { - if (Platform.OS === 'android' && LiteRTModule) { + if (LiteRTModule) { this.emitter = new NativeEventEmitter(LiteRTModule); this.emitter.addListener(EVENT_DEBUG_LOG, (msg: string) => { useDebugLogsStore.getState().addLog('log', `[Kotlin] ${msg}`); @@ -548,7 +548,7 @@ class LiteRTService { } isAvailable(): boolean { - return Platform.OS === 'android' && !!LiteRTModule; + return !!LiteRTModule; } /** From 6b55efba456abdaf4d51c795cc1730d51696c85a Mon Sep 17 00:00:00 2001 From: Dishit Date: Sat, 23 May 2026 14:28:10 +0530 Subject: [PATCH 25/93] feat(litert): wire liteRTBackend to loader, pending-reload check, and settings UI Co-authored-by: Dishit Karia --- .../rntl/screens/ModelSettingsScreen.test.tsx | 1 + __tests__/utils/testHelpers.ts | 1 + .../TextGenerationAdvanced.tsx | 46 ++++++++++++++-- .../TextGenerationSection.tsx | 3 +- src/screens/ChatScreen/useChatScreen.ts | 2 +- .../TextGenerationAdvanced.tsx | 53 +++++++++++++++++-- src/services/activeModelService/loaders.ts | 13 +---- 7 files changed, 99 insertions(+), 20 deletions(-) diff --git a/__tests__/rntl/screens/ModelSettingsScreen.test.tsx b/__tests__/rntl/screens/ModelSettingsScreen.test.tsx index f2c1131b1..c2f0bd9d3 100644 --- a/__tests__/rntl/screens/ModelSettingsScreen.test.tsx +++ b/__tests__/rntl/screens/ModelSettingsScreen.test.tsx @@ -829,6 +829,7 @@ describe('ModelSettingsScreen', () => { enhanceImagePrompts: undefined as any, enabledTools: undefined as any, thinkingEnabled: undefined as any, + liteRTBackend: undefined as any, }, }); diff --git a/__tests__/utils/testHelpers.ts b/__tests__/utils/testHelpers.ts index 67c68da7e..da135e544 100644 --- a/__tests__/utils/testHelpers.ts +++ b/__tests__/utils/testHelpers.ts @@ -71,6 +71,7 @@ export const resetStores = (): void => { enhanceImagePrompts: false, enabledTools: ['calculator', 'get_current_datetime'], thinkingEnabled: true, + liteRTBackend: 'gpu' as const, }, downloadedImageModels: [], activeImageModelId: null, diff --git a/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx b/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx index 66e8bb0fb..5aed0822d 100644 --- a/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx +++ b/src/components/GenerationSettingsModal/TextGenerationAdvanced.tsx @@ -3,7 +3,7 @@ import { Platform, View, Text, TouchableOpacity } from 'react-native'; import Slider from '@react-native-community/slider'; import { useTheme, useThemedStyles } from '../../theme'; import { useAppStore } from '../../stores'; -import { CacheType, InferenceBackend, INFERENCE_BACKENDS } from '../../types'; +import { CacheType, InferenceBackend, LiteRTBackend, INFERENCE_BACKENDS } from '../../types'; import { useTextGenerationAdvanced, CACHE_TYPE_DESCRIPTIONS, @@ -36,7 +36,7 @@ const HTP_BACKEND: BackendOption = { id: INFERENCE_BACKENDS.HTP, label: 'HTP', desc: 'Offload layers to Hexagon NPU on Snapdragon devices. Best for large models. Requires model reload.', }; -export const BackendSelector: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpuLayers = false }) => { +export const BackendSelector: React.FC = () => { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); const { settings, updateSettings } = useAppStore(); @@ -80,7 +80,7 @@ export const BackendSelector: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpu ))} - {showLayers && !hideGpuLayers && ( + {showLayers && ( {layersLabel} @@ -104,6 +104,46 @@ export const BackendSelector: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpu ); }; +// ─── LiteRT Acceleration ───────────────────────────────────────────────────── + +type LiteRTBackendOption = { id: LiteRTBackend; label: string; desc: string }; + +const LITERT_BACKENDS: LiteRTBackendOption[] = [ + { id: 'gpu', label: 'GPU', desc: 'Run on GPU via OpenCL. Best performance on most devices.' }, + { id: 'cpu', label: 'CPU', desc: 'Always available. Use for battery savings or thermal relief.' }, +]; + +export const LiteRTBackendSelector: React.FC = () => { + const styles = useThemedStyles(createStyles); + const { settings, updateSettings } = useAppStore(); + const current = settings.liteRTBackend ?? 'gpu'; + + return ( + + + Acceleration + + {LITERT_BACKENDS.find(b => b.id === current)?.desc ?? ''} + + + + {LITERT_BACKENDS.map(b => ( + updateSettings({ liteRTBackend: b.id })} + > + + {b.label} + + + ))} + + + ); +}; + // ─── Flash Attention ────────────────────────────────────────────────────────── export const FlashAttentionToggle: React.FC = () => { diff --git a/src/components/GenerationSettingsModal/TextGenerationSection.tsx b/src/components/GenerationSettingsModal/TextGenerationSection.tsx index 96ecea2e5..50f8c9d19 100644 --- a/src/components/GenerationSettingsModal/TextGenerationSection.tsx +++ b/src/components/GenerationSettingsModal/TextGenerationSection.tsx @@ -10,6 +10,7 @@ import { CpuThreadsSlider, BatchSizeSlider, BackendSelector, + LiteRTBackendSelector, FlashAttentionToggle, KvCacheTypeToggle, ModelLoadingStrategyToggle, @@ -214,7 +215,7 @@ export const TextGenerationSection: React.FC = () => { ))} {!isLiteRT && } {!isLiteRT && } - + {isLiteRT ? : } {!isLiteRT && } {!isLiteRT && } {!isLiteRT && } diff --git a/src/screens/ChatScreen/useChatScreen.ts b/src/screens/ChatScreen/useChatScreen.ts index ce99b67b1..a3df32cc5 100644 --- a/src/screens/ChatScreen/useChatScreen.ts +++ b/src/screens/ChatScreen/useChatScreen.ts @@ -234,7 +234,7 @@ export const useChatScreen = () => { if (!loadedSettings) return false; // LiteRT reloads when backend or context length changes — both are baked into the engine at load time if (activeModel?.engine === 'litert') { - return settings.inferenceBackend !== loadedSettings.inferenceBackend || + return settings.liteRTBackend !== loadedSettings.liteRTBackend || settings.contextLength !== loadedSettings.contextLength; } return ( diff --git a/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx b/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx index 3e8d0ec97..83870c26e 100644 --- a/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx +++ b/src/screens/ModelSettingsScreen/TextGenerationAdvanced.tsx @@ -4,7 +4,7 @@ import Slider from '@react-native-community/slider'; import { Button } from '../../components/Button'; import { useTheme, useThemedStyles } from '../../theme'; import { useAppStore } from '../../stores'; -import { CacheType, InferenceBackend, INFERENCE_BACKENDS } from '../../types'; +import { CacheType, InferenceBackend, LiteRTBackend, INFERENCE_BACKENDS } from '../../types'; import { useTextGenerationAdvanced, CACHE_TYPE_DESCRIPTIONS, @@ -33,7 +33,7 @@ const ANDROID_BASE_BACKENDS: BackendOption[] = [ const HTP_BACKEND: BackendOption = { id: INFERENCE_BACKENDS.HTP, label: 'HTP' }; -const BackendSelectorSection: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpuLayers = false }) => { +const BackendSelectorSection: React.FC = () => { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); const { settings, updateSettings } = useAppStore(); @@ -82,7 +82,7 @@ const BackendSelectorSection: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpu ))} - {showLayers && !hideGpuLayers && ( + {showLayers && ( @@ -108,6 +108,51 @@ const BackendSelectorSection: React.FC<{ hideGpuLayers?: boolean }> = ({ hideGpu ); }; +// ─── LiteRT Acceleration ───────────────────────────────────────────────────── + +type LiteRTBackendOption = { id: LiteRTBackend; label: string }; + +const LITERT_BACKENDS: LiteRTBackendOption[] = [ + { id: 'gpu', label: 'GPU' }, + { id: 'cpu', label: 'CPU' }, +]; + +const LiteRTBackendSelectorSection: React.FC = () => { + const styles = useThemedStyles(createStyles); + const { settings, updateSettings } = useAppStore(); + const current = settings.liteRTBackend ?? 'gpu'; + + const descriptions: Partial> = { + gpu: 'Run on GPU via OpenCL. Best performance on most devices.', + cpu: 'Always available. Use for battery savings or thermal relief.', + }; + + return ( + <> + + + Acceleration + {descriptions[current]} + + + + {LITERT_BACKENDS.map(b => ( +