diff --git a/App.tsx b/App.tsx index ac8cee15..7ee0a5bb 100644 --- a/App.tsx +++ b/App.tsx @@ -139,6 +139,7 @@ function App() { progress: progress.progress, bytesDownloaded: progress.bytesDownloaded, totalBytes: progress.totalBytes, + ownerDownloadId: progress.downloadId, }); }, ); diff --git a/__tests__/App.test.tsx b/__tests__/App.test.tsx index e532f701..47f2984d 100644 --- a/__tests__/App.test.tsx +++ b/__tests__/App.test.tsx @@ -1,13 +1,155 @@ /** - * @format + * App startup tests */ import React from 'react'; import ReactTestRenderer from 'react-test-renderer'; + +const appState = { + setDeviceInfo: jest.fn(), + setModelRecommendation: jest.fn(), + setDownloadedModels: jest.fn(), + setDownloadedImageModels: jest.fn(), + clearImageModelDownloading: jest.fn(), + setBackgroundDownload: jest.fn(), + addDownloadedModel: jest.fn(), + setDownloadProgress: jest.fn(), + activeBackgroundDownloads: { + 42: { + modelId: 'test/model', + fileName: 'model.gguf', + quantization: 'Q4_K_M', + author: 'test', + totalBytes: 1000, + }, + }, +}; + +const authState = { + isEnabled: false, + isLocked: false, + setLocked: jest.fn(), + setLastBackgroundTime: jest.fn(), +}; + +const mockUseAppStore = Object.assign( + (selector?: (state: typeof appState) => unknown) => (selector ? selector(appState) : appState), + { + getState: () => appState, + persist: { hasHydrated: () => true, rehydrate: jest.fn() }, + }, +); + +const mockUseAuthStore = Object.assign( + (selector?: (state: typeof authState) => unknown) => (selector ? selector(authState) : authState), + { + getState: () => authState, + }, +); + +const mockUseRemoteServerStore = Object.assign( + () => ({}), + { + persist: { hasHydrated: () => true, rehydrate: jest.fn() }, + }, +); + +const mockModelManager = { + initialize: jest.fn(() => Promise.resolve()), + cleanupMMProjEntries: jest.fn(() => Promise.resolve()), + setBackgroundDownloadMetadataCallback: jest.fn(), + syncBackgroundDownloads: jest.fn(() => Promise.resolve([])), + syncCompletedImageDownloads: jest.fn(() => Promise.resolve([])), + restoreInProgressDownloads: jest.fn((_persisted, onProgress?: (progress: any) => void) => { + onProgress?.({ + downloadId: 42, + modelId: 'test/model', + fileName: 'model.gguf', + bytesDownloaded: 600, + totalBytes: 1000, + progress: 0.6, + }); + return Promise.resolve([]); + }), + refreshModelLists: jest.fn(() => Promise.resolve({ textModels: [], imageModels: [] })), + watchDownload: jest.fn(), +}; + +jest.mock('../src/navigation', () => ({ + AppNavigator: () => null, +})); + +jest.mock('../src/screens', () => ({ + LockScreen: () => null, +})); + +jest.mock('../src/theme', () => ({ + useTheme: () => ({ + colors: { background: '#fff', primary: '#000' }, + isDark: false, + }), +})); + +jest.mock('../src/hooks/useAppState', () => ({ + useAppState: jest.fn(), +})); + +jest.mock('../src/utils/logger', () => ({ + __esModule: true, + default: { + log: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }, +})); + +jest.mock('../src/stores', () => ({ + useAppStore: mockUseAppStore, + useAuthStore: mockUseAuthStore, + useRemoteServerStore: mockUseRemoteServerStore, +})); + +jest.mock('../src/services', () => ({ + hardwareService: { + getDeviceInfo: jest.fn(() => Promise.resolve({ totalMemory: 8 * 1024 * 1024 * 1024 })), + getModelRecommendation: jest.fn(() => ({ maxParameters: 7, recommendedQuantization: 'Q4_K_M' })), + }, + modelManager: mockModelManager, + authService: { + hasPassphrase: jest.fn(() => Promise.resolve(false)), + }, + ragService: { + ensureReady: jest.fn(() => Promise.resolve()), + }, + remoteServerManager: { + initializeProviders: jest.fn(() => Promise.resolve()), + }, +})); + import App from '../App'; -test('renders correctly', async () => { - await ReactTestRenderer.act(() => { - ReactTestRenderer.create(); +describe('App', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('tags restored download progress with ownerDownloadId during startup restore', async () => { + await ReactTestRenderer.act(async () => { + ReactTestRenderer.create(); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(mockModelManager.restoreInProgressDownloads).toHaveBeenCalledWith( + appState.activeBackgroundDownloads, + expect.any(Function), + ); + expect(appState.setDownloadProgress).toHaveBeenCalledWith('test/model/model.gguf', { + progress: 0.6, + bytesDownloaded: 600, + totalBytes: 1000, + ownerDownloadId: 42, + }); }); }); diff --git a/__tests__/rntl/screens/ChatScreen.test.tsx b/__tests__/rntl/screens/ChatScreen.test.tsx index d71ad6ff..dadec416 100644 --- a/__tests__/rntl/screens/ChatScreen.test.tsx +++ b/__tests__/rntl/screens/ChatScreen.test.tsx @@ -461,6 +461,7 @@ const renderChatScreen = () => { describe('ChatScreen', () => { afterEach(() => { cleanup(); + jest.useRealTimers(); }); beforeEach(() => { @@ -3687,10 +3688,9 @@ describe('ChatScreen', () => { await act(async () => { fireEvent.press(getByTestId('chat-settings-icon')); }); await act(async () => { fireEvent.press(getByTestId('delete-conversation-btn')); }); await act(async () => { fireEvent.press(getByTestId('alert-button-Delete')); }); - await act(async () => { await new Promise(r => setTimeout(() => r(), 200)); }); - - // llmService.stopGeneration should have been called (was streaming) - expect(llmService.stopGeneration).toHaveBeenCalled(); + await waitFor(() => { + expect(llmService.stopGeneration).toHaveBeenCalled(); + }); }); }); @@ -3733,10 +3733,9 @@ describe('ChatScreen', () => { await act(async () => { fireEvent.press(getByTestId('send-with-image')); }); - await act(async () => { await new Promise(r => setTimeout(() => r(), 300)); }); - - // The test exercises handleImageGeneration failure path - no crash - expect(getByTestId('chat-screen')).toBeTruthy(); + await waitFor(() => { + expect(getByTestId('chat-screen')).toBeTruthy(); + }); }); }); @@ -3785,6 +3784,7 @@ describe('ChatScreen', () => { }); // Queue count should appear and clear queue button + await waitFor(() => expect(getByTestId('clear-queue-button')).toBeTruthy()); const clearQueueBtn = getByTestId('clear-queue-button'); await act(async () => { fireEvent.press(clearQueueBtn); diff --git a/__tests__/rntl/screens/DownloadManagerScreen.test.tsx b/__tests__/rntl/screens/DownloadManagerScreen.test.tsx index af1d76ae..b5b4fa9a 100644 --- a/__tests__/rntl/screens/DownloadManagerScreen.test.tsx +++ b/__tests__/rntl/screens/DownloadManagerScreen.test.tsx @@ -825,27 +825,6 @@ describe('DownloadManagerScreen', () => { expect(mockModelManager.getActiveBackgroundDownloads).toHaveBeenCalled(); }); - it('handleRefresh reloads models and image models', async () => { - const setDownloadedModels = jest.fn(); - const setDownloadedImageModels = jest.fn(); - const state = createDefaultState({ setDownloadedModels, setDownloadedImageModels }); - mockStoreState(state); - - const { UNSAFE_root } = render(); - - // Find the FlatList and trigger its RefreshControl onRefresh - const flatList = UNSAFE_root.findAll((node: any) => node.type && node.type.displayName === 'FlatList')[0] - || UNSAFE_root.findAll((node: any) => node.props?.refreshControl)[0]; - - if (flatList && flatList.props.refreshControl) { - await act(async () => { - flatList.props.refreshControl.props.onRefresh(); - }); - } - - expect(mockModelManager.getDownloadedModels).toHaveBeenCalled(); - expect(mockModelManager.getDownloadedImageModels).toHaveBeenCalled(); - }); it('confirming delete model calls deleteModel and removeDownloadedModel', async () => { const removeDownloadedModel = jest.fn(); diff --git a/__tests__/rntl/screens/ModelsScreen.test.tsx b/__tests__/rntl/screens/ModelsScreen.test.tsx index 08431ac8..182a9178 100644 --- a/__tests__/rntl/screens/ModelsScreen.test.tsx +++ b/__tests__/rntl/screens/ModelsScreen.test.tsx @@ -1664,12 +1664,12 @@ describe('ModelsScreen', () => { it('filters search results by size', async () => { mockSearchModels.mockResolvedValue([ createModelInfo({ - id: 'test/small-1B', + id: 'test/model-1B', name: 'Small 1B', files: [createModelFile({ size: 1000000000 })], }), createModelInfo({ - id: 'test/large-70B', + id: 'test/model-70B', name: 'Large 70B', files: [createModelFile({ size: 4000000000 })], }), @@ -1691,14 +1691,11 @@ describe('ModelsScreen', () => { fireEvent.press(getByText('1-3B')); }); - // Search - await act(async () => { - fireEvent.changeText(getByTestId('search-input'), 'test'); - }); - + // Size filters auto-trigger search even with an empty query. await waitFor(() => { expect(getByText('Small 1B')).toBeTruthy(); }); + expect(mockSearchModels).toHaveBeenCalled(); // Large 70B doesn't match 1-3B size filter expect(queryByText('Large 70B')).toBeNull(); }); diff --git a/__tests__/unit/services/modelManager.test.ts b/__tests__/unit/services/modelManager.test.ts index dc323bf3..4edb52d7 100644 --- a/__tests__/unit/services/modelManager.test.ts +++ b/__tests__/unit/services/modelManager.test.ts @@ -246,6 +246,52 @@ describe('ModelManager', () => { await expect(modelManager.deleteModel('nonexistent')).rejects.toThrow('Model not found'); }); + + it('preserves mmproj file when another model still references it', async () => { + const sharedMmproj = '/mock/documents/models/shared-mmproj.gguf'; + const storedModels = [ + { + id: 'qwen-vl-q4', + name: 'Qwen VL Q4', + filePath: '/mock/documents/models/qwen-vl-q4.gguf', + fileSize: 100, + mmProjPath: sharedMmproj, + }, + { + id: 'qwen-vl-q8', + name: 'Qwen VL Q8', + filePath: '/mock/documents/models/qwen-vl-q8.gguf', + fileSize: 200, + mmProjPath: sharedMmproj, + }, + ]; + mockedAsyncStorage.getItem.mockResolvedValue(JSON.stringify(storedModels)); + mockedRNFS.exists.mockResolvedValue(true); + + await modelManager.deleteModel('qwen-vl-q4'); + + expect(RNFS.unlink).toHaveBeenCalledWith('/mock/documents/models/qwen-vl-q4.gguf'); + expect(RNFS.unlink).not.toHaveBeenCalledWith(sharedMmproj); + }); + + it('deletes mmproj file when no other model references it', async () => { + const mmprojPath = '/mock/documents/models/solo-mmproj.gguf'; + const storedModels = [ + { + id: 'qwen-vl-q4', + name: 'Qwen VL Q4', + filePath: '/mock/documents/models/qwen-vl-q4.gguf', + fileSize: 100, + mmProjPath: mmprojPath, + }, + ]; + mockedAsyncStorage.getItem.mockResolvedValue(JSON.stringify(storedModels)); + mockedRNFS.exists.mockResolvedValue(true); + + await modelManager.deleteModel('qwen-vl-q4'); + + expect(RNFS.unlink).toHaveBeenCalledWith(mmprojPath); + }); }); // ======================================================================== diff --git a/__tests__/unit/services/parallelMmproj.test.ts b/__tests__/unit/services/parallelMmproj.test.ts index 353193b7..6b6c8054 100644 --- a/__tests__/unit/services/parallelMmproj.test.ts +++ b/__tests__/unit/services/parallelMmproj.test.ts @@ -454,6 +454,104 @@ describe('Parallel mmproj download', () => { expect(metadataCallback).toHaveBeenCalledWith(42, null); }); + + it('keeps vision when mmproj move fails but target exists and is a valid GGUF', async () => { + const completeCbs = await setupVisionDownload(); + const onComplete = jest.fn(); + + // Main move succeeds; mmproj move rejects (target already exists because + // a sibling quant already downloaded it). + mockService.moveCompletedDownload + .mockImplementationOnce((id: number) => { + if (id === 43) return Promise.reject(new Error('target exists')); + return Promise.resolve(`${MODELS_DIR}/vision.gguf`); + }) + .mockResolvedValue(`${MODELS_DIR}/vision.gguf`); + + // checkMmProjExists: file exists, size OK, GGUF magic OK + mockedRNFS.exists.mockResolvedValue(true); + mockedRNFS.stat.mockResolvedValue({ size: 500_000_000 } as any); + (mockedRNFS.read as jest.Mock).mockResolvedValue('GGUF'); + + watchBackgroundDownload({ + downloadId: 42, + modelsDir: MODELS_DIR, + backgroundDownloadContext: bgContext, + backgroundDownloadMetadataCallback: metadataCallback, + onComplete, + }); + + await completeCbs[43]?.({ downloadId: 43, fileName: 'mmproj.gguf' }); + await completeCbs[42]?.({ downloadId: 42, fileName: 'vision.gguf' }); + + expect(onComplete).toHaveBeenCalledTimes(1); + const savedModel = onComplete.mock.calls[0][0]; + expect(savedModel.mmProjPath).toBeDefined(); + }); + + it('downgrades to text-only when mmproj move fails and target has bad magic bytes', async () => { + const completeCbs = await setupVisionDownload(); + const onComplete = jest.fn(); + + mockService.moveCompletedDownload + .mockImplementationOnce((id: number) => { + if (id === 43) return Promise.reject(new Error('io error')); + return Promise.resolve(`${MODELS_DIR}/vision.gguf`); + }) + .mockResolvedValue(`${MODELS_DIR}/vision.gguf`); + + // File exists + correct size, but magic bytes wrong → invalid, unlink + downgrade + mockedRNFS.exists.mockResolvedValue(true); + mockedRNFS.stat.mockResolvedValue({ size: 500_000_000 } as any); + (mockedRNFS.read as jest.Mock).mockResolvedValue('XXXX'); + + watchBackgroundDownload({ + downloadId: 42, + modelsDir: MODELS_DIR, + backgroundDownloadContext: bgContext, + backgroundDownloadMetadataCallback: metadataCallback, + onComplete, + }); + + await completeCbs[43]?.({ downloadId: 43, fileName: 'mmproj.gguf' }); + await completeCbs[42]?.({ downloadId: 42, fileName: 'vision.gguf' }); + + expect(onComplete).toHaveBeenCalledTimes(1); + const savedModel = onComplete.mock.calls[0][0]; + expect(savedModel.mmProjPath).toBeUndefined(); + }); + + it('downgrades to text-only when mmproj move fails and target does not exist', async () => { + const completeCbs = await setupVisionDownload(); + const onComplete = jest.fn(); + + mockService.moveCompletedDownload + .mockImplementationOnce((id: number) => { + if (id === 43) return Promise.reject(new Error('io error')); + return Promise.resolve(`${MODELS_DIR}/vision.gguf`); + }) + .mockResolvedValue(`${MODELS_DIR}/vision.gguf`); + + // Target does not exist anywhere + mockedRNFS.exists.mockImplementation((p: any) => + Promise.resolve(!String(p).includes('mmproj')), + ); + + watchBackgroundDownload({ + downloadId: 42, + modelsDir: MODELS_DIR, + backgroundDownloadContext: bgContext, + backgroundDownloadMetadataCallback: metadataCallback, + onComplete, + }); + + await completeCbs[43]?.({ downloadId: 43, fileName: 'mmproj.gguf' }); + await completeCbs[42]?.({ downloadId: 42, fileName: 'vision.gguf' }); + + expect(onComplete).toHaveBeenCalledTimes(1); + const savedModel = onComplete.mock.calls[0][0]; + expect(savedModel.mmProjPath).toBeUndefined(); + }); }); // ======================================================================== diff --git a/__tests__/unit/services/restore.test.ts b/__tests__/unit/services/restore.test.ts index 00bb814e..e600a9f5 100644 --- a/__tests__/unit/services/restore.test.ts +++ b/__tests__/unit/services/restore.test.ts @@ -244,6 +244,7 @@ describe('restoreInProgressDownloads', () => { }); expect(onProgress).toHaveBeenCalledWith({ + downloadId: 42, modelId: 'test/model', fileName: 'model.gguf', bytesDownloaded: 2_000_000_000, diff --git a/android/app/src/main/java/ai/offgridmobile/download/DownloadManagerModule.kt b/android/app/src/main/java/ai/offgridmobile/download/DownloadManagerModule.kt index 3264961a..aa025171 100644 --- a/android/app/src/main/java/ai/offgridmobile/download/DownloadManagerModule.kt +++ b/android/app/src/main/java/ai/offgridmobile/download/DownloadManagerModule.kt @@ -76,7 +76,7 @@ class DownloadManagerModule(reactContext: ReactApplicationContext) : val downloadId = System.currentTimeMillis() val destination = File( reactApplicationContext.getExternalFilesDir(Environment.DIRECTORY_DOWNLOADS), - fileName, + "${downloadId}_${fileName}", ).absolutePath val entity = DownloadEntity( @@ -265,6 +265,12 @@ class DownloadManagerModule(reactContext: ReactApplicationContext) : } if (!sourceFile.exists()) { + val targetFile = File(targetPath) + if (targetFile.exists()) { + withContext(Dispatchers.IO) { downloadDao.deleteDownload(d) } + SafePromise(promise, NAME).resolve(targetPath) + return@launch + } SafePromise(promise, NAME).reject("MOVE_ERROR", "Downloaded file not found: ${sourceFile.absolutePath}") return@launch } diff --git a/src/screens/DownloadManagerScreen/index.tsx b/src/screens/DownloadManagerScreen/index.tsx index aff12ee9..d565d720 100644 --- a/src/screens/DownloadManagerScreen/index.tsx +++ b/src/screens/DownloadManagerScreen/index.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { View, Text, FlatList, RefreshControl, TouchableOpacity } from 'react-native'; +import { View, Text, FlatList, TouchableOpacity } from 'react-native'; import { SafeAreaView } from 'react-native-safe-area-context'; import { useNavigation } from '@react-navigation/native'; import Icon from 'react-native-vector-icons/Feather'; @@ -15,12 +15,10 @@ export const DownloadManagerScreen: React.FC = () => { const { colors } = useTheme(); const styles = useThemedStyles(createStyles); const { - isRefreshing, activeItems, completedItems, alertState, setAlertState, - handleRefresh, handleRemoveDownload, handleRetryDownload, handleDeleteItem, @@ -104,13 +102,6 @@ export const DownloadManagerScreen: React.FC = () => { )} keyExtractor={item => item.key} - refreshControl={ - - } contentContainerStyle={styles.listContent} /> diff --git a/src/screens/DownloadManagerScreen/useDownloadManager.ts b/src/screens/DownloadManagerScreen/useDownloadManager.ts index a6d9a877..1de928b8 100644 --- a/src/screens/DownloadManagerScreen/useDownloadManager.ts +++ b/src/screens/DownloadManagerScreen/useDownloadManager.ts @@ -17,12 +17,10 @@ import logger from '../../utils/logger'; import { getUserFacingDownloadMessage } from '../../utils/downloadErrors'; export interface UseDownloadManagerResult { - isRefreshing: boolean; activeItems: DownloadItem[]; completedItems: DownloadItem[]; alertState: AlertState; setAlertState: (state: AlertState) => void; - handleRefresh: () => Promise; handleRemoveDownload: (item: DownloadItem) => void; handleRetryDownload: (item: DownloadItem) => void; handleDeleteItem: (item: DownloadItem) => void; @@ -196,7 +194,6 @@ function syncDownloadSnapshot( } export function useDownloadManager(): UseDownloadManagerResult { - const [isRefreshing, setIsRefreshing] = useState(false); const [activeDownloads, setActiveDownloads] = useState([]); const [alertState, setAlertState] = useState(initialAlertState); const cancelledKeysRef = useRef>(new Set()); @@ -209,7 +206,6 @@ export function useDownloadManager(): UseDownloadManagerResult { activeBackgroundDownloads, setBackgroundDownload, downloadedImageModels, - setDownloadedImageModels, removeDownloadedImageModel, removeImageModelDownloading, } = useAppStore(); @@ -309,16 +305,6 @@ export function useDownloadManager(): UseDownloadManagerResult { }, [loadActiveDownloads, setDownloadProgress]); - const handleRefresh = useCallback(async () => { - setIsRefreshing(true); - await loadActiveDownloads(); - const models = await modelManager.getDownloadedModels(); - setDownloadedModels(models); - const imageModels = await modelManager.getDownloadedImageModels(); - setDownloadedImageModels(imageModels); - setIsRefreshing(false); - - }, [loadActiveDownloads, setDownloadedModels, setDownloadedImageModels]); const executeRemoveDownload = async (item: DownloadItem) => { setAlertState(hideAlert()); @@ -554,12 +540,10 @@ export function useDownloadManager(): UseDownloadManagerResult { const totalStorageUsed = completedItems.reduce((sum, item) => sum + item.fileSize, 0); return { - isRefreshing, activeItems, completedItems, alertState, setAlertState, - handleRefresh, handleRemoveDownload, handleRetryDownload, handleDeleteItem, diff --git a/src/screens/ModelsScreen/useModelsScreen.ts b/src/screens/ModelsScreen/useModelsScreen.ts index 319f9439..0bab4d5c 100644 --- a/src/screens/ModelsScreen/useModelsScreen.ts +++ b/src/screens/ModelsScreen/useModelsScreen.ts @@ -172,10 +172,14 @@ export function useModelsScreen() { } }; - const activeDownloadCount = Object.keys(text.downloadProgress).filter(key => { - if (!key.startsWith('image:')) return true; - const imageId = key.split('/').slice(0, -1).join('/').replace('image:', ''); - return !image.downloadedImageModels.some(m => m.id === imageId); + const activeDownloadCount = Object.entries(text.downloadProgress).filter(([key, progress]) => { + const status = progress?.status; + if (status === 'failed' || status === 'completed') return false; + if (key.startsWith('image:')) { + const imageId = key.split('/').slice(0, -1).join('/').replace('image:', ''); + return !image.downloadedImageModels.some(m => m.id === imageId); + } + return true; }).length; const totalModelCount = text.downloadedModels.length + @@ -184,16 +188,38 @@ export function useModelsScreen() { const handleDownload = useCallback( (...args: Parameters) => { + if (activeDownloadCount >= 2) { + setAlertState(showAlert( + 'Downloads Already Active', + '2 downloads are already running. Starting more can affect performance.', + [ + { text: 'Cancel', style: 'cancel' }, + { text: 'Start Anyway', style: 'default', onPress: () => { text.handleDownload(...args); } }, + ], + )); + return; + } text.handleDownload(...args); }, - [text], + [text, activeDownloadCount, setAlertState], ); const handleDownloadImageModel = useCallback( (...args: Parameters) => { + if (activeDownloadCount >= 2) { + setAlertState(showAlert( + 'Downloads Already Active', + '2 downloads are already running. Starting more can affect performance.', + [ + { text: 'Cancel', style: 'cancel' }, + { text: 'Start Anyway', style: 'default', onPress: () => { image.handleDownloadImageModel(...args); } }, + ], + )); + return; + } image.handleDownloadImageModel(...args); }, - [image], + [image, activeDownloadCount, setAlertState], ); return { diff --git a/src/screens/ModelsScreen/useTextModels.ts b/src/screens/ModelsScreen/useTextModels.ts index 0e0788b6..00d4b97a 100644 --- a/src/screens/ModelsScreen/useTextModels.ts +++ b/src/screens/ModelsScreen/useTextModels.ts @@ -228,10 +228,25 @@ export function useTextModels(setAlertState: (s: AlertState) => void) { return rest; }); addDownloadedModel(dm); - setAlertState(showAlert('Success', `${model.name} downloaded successfully!`)); + if (file.mmProjFile && !dm.isVisionModel) { + setAlertState(showAlert( + 'Model Downloaded', + `${model.name} downloaded but the vision projection file could not be saved. Go to Download Manager and use "Repair Vision" to fix it.`, + )); + } else { + setAlertState(showAlert('Success', `${model.name} downloaded successfully!`)); + } }; const onError = (err: Error) => { - setDownloadProgress(downloadKey, null); + const existing = useAppStore.getState().downloadProgress[downloadKey]; + setDownloadProgress(downloadKey, { + progress: existing?.progress ?? 0, + bytesDownloaded: existing?.bytesDownloaded ?? 0, + totalBytes: existing?.totalBytes ?? totalBytes, + ownerDownloadId: existing?.ownerDownloadId, + status: 'failed', + reason: err.message, + }); setDownloadIds(prev => { const { [downloadKey]: _r, ...rest } = prev; downloadIdsRef.current = rest; diff --git a/src/services/huggingface.ts b/src/services/huggingface.ts index e0837d77..30ab0c35 100644 --- a/src/services/huggingface.ts +++ b/src/services/huggingface.ts @@ -150,35 +150,26 @@ class HuggingFaceService { return undefined; } - // modelQuant intentionally unused; matching is done via modelLower below - const modelLower = modelFileName.toLowerCase(); - - // Try to match by quantization level - for (const mmProj of mmProjFiles) { - const mmProjQuant = this.extractQuantization(mmProj.path); - // Match exact quantization or if model uses the mmproj's quantization variant - if (mmProjQuant !== 'Unknown' && modelLower.includes(mmProjQuant.toLowerCase())) { - return { - name: mmProj.path, - size: mmProj.lfs?.size || mmProj.size || 0, - downloadUrl: this.getDownloadUrl(modelId, mmProj.path), - }; - } + const toResult = (f: { path: string; size?: number; lfs?: { size: number } }) => ({ + name: f.path, + size: f.lfs?.size || f.size || 0, + downloadUrl: this.getDownloadUrl(modelId, f.path), + }); + + // Exact symmetric match: model quant === mmproj quant + const modelQuant = this.extractQuantization(modelFileName); + if (modelQuant !== 'Unknown') { + const exactMatch = mmProjFiles.find(f => this.extractQuantization(f.path) === modelQuant); + if (exactMatch) return toResult(exactMatch); } - // Fallback: prefer f16 mmproj if available, otherwise use the first one - // Match F16/FP16 but exclude BF16 — BF16 mmproj can be incompatible with some runtimes - const f16MMProj = mmProjFiles.find(f => { + // Fallback: prefer F16/FP16, exclude BF16 (can be incompatible with some runtimes) + const f16 = mmProjFiles.find(f => { const lower = f.path.toLowerCase(); return (lower.includes('f16') || lower.includes('fp16')) && !lower.includes('bf16'); }); - const selectedMMProj = f16MMProj || mmProjFiles[0]; - return { - name: selectedMMProj.path, - size: selectedMMProj.lfs?.size || selectedMMProj.size || 0, - downloadUrl: this.getDownloadUrl(modelId, selectedMMProj.path), - }; + return toResult(f16 ?? mmProjFiles[0]); } private detectModelType(name: string, tags: string[]): string { diff --git a/src/services/modelManager/download.ts b/src/services/modelManager/download.ts index cf8c8fd3..80c71620 100644 --- a/src/services/modelManager/download.ts +++ b/src/services/modelManager/download.ts @@ -49,18 +49,41 @@ export async function performBackgroundDownload(opts: PerformBackgroundDownloadO }); } +const GGUF_MAGIC = 'GGUF'; + +async function hasGgufMagic(path: string): Promise { + // Returns true if GGUF magic confirmed, false if confirmed invalid, null if + // the read itself failed (iOS RNFS.read has a known NSInteger bridging bug; + // treat as inconclusive rather than invalid in that case). + try { + const header = await RNFS.read(path, 4, 0, 'ascii'); + return header.startsWith(GGUF_MAGIC); + } catch { + return null; + } +} + async function checkMmProjExists(path: string | null, expectedSize?: number): Promise { if (!path) return true; const exists = await RNFS.exists(path); - if (!exists || !expectedSize) return exists; + if (!exists) return false; try { - const stat = await RNFS.stat(path); - const actualSize = typeof stat.size === 'string' ? Number.parseInt(stat.size, 10) : stat.size; - if (actualSize < expectedSize) { - logger.warn(`[ModelManager] mmproj partial (${actualSize}/${expectedSize}), re-downloading`); + if (expectedSize) { + const stat = await RNFS.stat(path); + const actualSize = typeof stat.size === 'string' ? Number.parseInt(stat.size, 10) : stat.size; + if (actualSize < expectedSize) { + logger.warn(`[ModelManager] mmproj partial (${actualSize}/${expectedSize}), re-downloading`); + await RNFS.unlink(path).catch(() => {}); + return false; + } + } + const magicOk = await hasGgufMagic(path); + if (magicOk === false) { + logger.warn(`[ModelManager] mmproj failed GGUF magic check, re-downloading: ${path}`); await RNFS.unlink(path).catch(() => {}); return false; } + // magicOk === true or null (inconclusive): accept — llama.rn validates on load. return true; } catch { await RNFS.unlink(path).catch(() => {}); @@ -166,6 +189,7 @@ async function startBgDownload(opts: StartBgDownloadOpts): Promise 0 ? combinedDownloaded / combinedTotalBytes : 0, @@ -281,9 +305,21 @@ export function watchBackgroundDownload(opts: WatchDownloadOpts): void { ctx.mmProjCompleteHandled = true; try { await backgroundDownloadService.moveCompletedDownload(event.downloadId, ctx.mmProjLocalPath!); - ctx.mmProjCompleted = true; - await tryFinalize(); - } catch (error) { handleError(error as Error, downloadId); } + } catch (moveErr) { + // Move can fail legitimately if another quant of the same family already + // placed a valid mmproj at this shared path. Validate the existing file + // (size + GGUF magic); trust it if valid, discard + downgrade if not. + const expectedSize = ctx.file.mmProjFile?.size; + const valid = await checkMmProjExists(ctx.mmProjLocalPath, expectedSize); + if (!valid) { + logger.warn('[ModelManager] mmproj move failed and target invalid, continuing without vision:', moveErr); + ctx.mmProjLocalPath = null; + } else { + logger.log('[ModelManager] mmproj move failed but target valid (likely shared with sibling quant):', moveErr); + } + } + ctx.mmProjCompleted = true; + await tryFinalize(); }); removeMmProjError = backgroundDownloadService.onError(ctx.mmProjDownloadId, (event) => { handleError(new Error(`Vision projection download failed: ${event.reason || 'Unknown error'}`), downloadId); diff --git a/src/services/modelManager/index.ts b/src/services/modelManager/index.ts index 4b5c2508..6b94946c 100644 --- a/src/services/modelManager/index.ts +++ b/src/services/modelManager/index.ts @@ -124,8 +124,13 @@ class ModelManager { throw new Error('Invalid mmproj path: outside app directory'); } await RNFS.unlink(model.filePath); - if (model.mmProjPath) await RNFS.unlink(model.mmProjPath).catch(() => {}); - await saveModelsList(models.filter(m => m.id !== modelId)); + const remaining = models.filter(m => m.id !== modelId); + // mmproj files are shared across quantizations of the same family. Only + // delete when no other model still references this path. + if (model.mmProjPath && !remaining.some(m => m.mmProjPath === model.mmProjPath)) { + await RNFS.unlink(model.mmProjPath).catch(() => {}); + } + await saveModelsList(remaining); } async getModelPath(modelId: string): Promise { @@ -309,7 +314,14 @@ class ModelManager { await new Promise((resolve, reject) => { const removeProgress = backgroundDownloadService.onProgress(info.downloadId, (event) => { if (event.status === 'retrying' || event.status === 'waiting_for_network') return; - opts?.onProgress?.({ modelId, fileName: file.mmProjFile!.name, bytesDownloaded: event.bytesDownloaded, totalBytes, progress: totalBytes > 0 ? event.bytesDownloaded / totalBytes : 0 }); + opts?.onProgress?.({ + downloadId: info.downloadId, + modelId, + fileName: file.mmProjFile!.name, + bytesDownloaded: event.bytesDownloaded, + totalBytes, + progress: totalBytes > 0 ? event.bytesDownloaded / totalBytes : 0, + }); }); const removeComplete = backgroundDownloadService.onComplete(info.downloadId, async (event) => { removeProgress(); removeComplete(); removeError(); diff --git a/src/services/modelManager/restore.ts b/src/services/modelManager/restore.ts index 858fd13b..1e38f823 100644 --- a/src/services/modelManager/restore.ts +++ b/src/services/modelManager/restore.ts @@ -117,6 +117,7 @@ async function restoreDownloadEntry(opts: RestoreEntryOpts): Promise { const reportProgress = () => { const combinedDownloaded = mainBytesDownloaded + mmProjBytesDownloaded; onProgress?.({ + downloadId: download.downloadId, modelId: metadata.modelId, fileName: metadata.fileName, bytesDownloaded: combinedDownloaded, totalBytes: combinedTotalBytes, progress: combinedTotalBytes > 0 ? combinedDownloaded / combinedTotalBytes : 0, diff --git a/src/types/index.ts b/src/types/index.ts index dbf99063..a0711500 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -80,6 +80,7 @@ export interface PersistedDownloadInfo { } export interface DownloadProgress { + downloadId?: number; modelId: string; fileName: string; bytesDownloaded: number;