Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { hardwareService, modelManager, authService, ragService, remoteServerMan
import logger from './src/utils/logger';
import { useAppStore, useAuthStore, useRemoteServerStore } from './src/stores';
import { hydrateDownloadStore } from './src/services/downloadHydration';
import { useDownloads } from './src/hooks/useDownloads';
import { useDownloadListeners } from './src/hooks/useDownloads';
import { LockScreen } from './src/screens';
import { useAppState } from './src/hooks/useAppState';
import { useDownloadStore } from './src/stores/downloadStore';
Expand All @@ -31,7 +31,7 @@ const ensureRemoteServerStoreHydrated = async () => {
};

function App() {
useDownloads();
useDownloadListeners();
const [isInitializing, setIsInitializing] = useState(true);
const setDeviceInfo = useAppStore((s) => s.setDeviceInfo);
const setModelRecommendation = useAppStore((s) => s.setModelRecommendation);
Expand Down
87 changes: 87 additions & 0 deletions __tests__/rntl/components/ModelCard.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -877,4 +877,91 @@ describe('ModelCard', () => {
expect(queryByText('Hardware-accelerated inference with vision support')).toBeNull();
});
});

// ============================================================================
// Failed download state (FailedSection)
// ============================================================================
describe('failedState', () => {
const baseFailedState = {
errorMessage: 'Network connection lost.',
bytesDownloaded: 192_000_000,
totalBytes: 386_000_000,
onRetry: jest.fn(),
onRemove: jest.fn(),
};

it('renders error message when failedState is provided', () => {
const { getByText } = render(
<ModelCard model={baseModel} failedState={baseFailedState} />,
);
expect(getByText('Network connection lost.')).toBeTruthy();
});

it('renders Retry and Remove buttons when failedState is provided', () => {
const { getByText } = render(
<ModelCard model={baseModel} failedState={baseFailedState} />,
);
expect(getByText('Retry')).toBeTruthy();
expect(getByText('Remove')).toBeTruthy();
});

it('calls onRetry when Retry is pressed', () => {
const onRetry = jest.fn();
const { getByText } = render(
<ModelCard model={baseModel} failedState={{ ...baseFailedState, onRetry }} />,
);
fireEvent.press(getByText('Retry'));
expect(onRetry).toHaveBeenCalled();
});

it('calls onRemove when Remove is pressed', () => {
const onRemove = jest.fn();
const { getByText } = render(
<ModelCard model={baseModel} failedState={{ ...baseFailedState, onRemove }} />,
);
fireEvent.press(getByText('Remove'));
expect(onRemove).toHaveBeenCalled();
});

it('shows progress percentage from bytesDownloaded / totalBytes', () => {
const { getByText } = render(
<ModelCard
model={baseModel}
failedState={{ ...baseFailedState, bytesDownloaded: 193_000_000, totalBytes: 386_000_000 }}
/>,
);
expect(getByText('50%')).toBeTruthy();
});

it('shows 0% when totalBytes is 0 (unknown size)', () => {
const { getByText } = render(
<ModelCard
model={baseModel}
failedState={{ ...baseFailedState, bytesDownloaded: 0, totalBytes: 0 }}
/>,
);
expect(getByText('0%')).toBeTruthy();
});

it('hides ModelCardActions when failedState is set', () => {
const onDownload = jest.fn();
const { queryByTestId } = render(
<ModelCard
model={baseModel}
failedState={baseFailedState}
onDownload={onDownload}
testID="card"
/>,
);
expect(queryByTestId('card-download')).toBeNull();
});

it('does not render FailedSection when failedState is absent', () => {
const { queryByText } = render(
<ModelCard model={baseModel} />,
);
expect(queryByText('Retry')).toBeNull();
expect(queryByText('Remove')).toBeNull();
});
});
});
38 changes: 19 additions & 19 deletions __tests__/unit/hooks/useDownloads.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jest.mock('../../../src/utils/downloadErrors', () => ({
toUserMessage: jest.fn((reason: string) => reason),
}));

import { useDownloads } from '../../../src/hooks/useDownloads';
import { useDownloads, useDownloadListeners } from '../../../src/hooks/useDownloads';

function fireProgress(event: Parameters<ProgressCb>[0]) {
if (!onAnyProgressCb) throw new Error('onAnyProgressCb not set');
Expand Down Expand Up @@ -106,14 +106,14 @@ describe('useDownloads', () => {

it('subscribes to all three event channels on mount', () => {
const { backgroundDownloadService: svc } = jest.requireMock('../../../src/services/backgroundDownloadService');
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
expect(svc.onAnyProgress).toHaveBeenCalled();
expect(svc.onAnyComplete).toHaveBeenCalled();
expect(svc.onAnyError).toHaveBeenCalled();
});

it('unsubscribes all listeners on unmount', () => {
const { unmount } = renderHook(() => useDownloads());
const { unmount } = renderHook(() => useDownloadListeners());
unmount();
expect(mockUnsubProgress).toHaveBeenCalled();
expect(mockUnsubComplete).toHaveBeenCalled();
Expand All @@ -123,34 +123,34 @@ describe('useDownloads', () => {
it('skips subscription when service is unavailable', () => {
const { backgroundDownloadService: svc } = jest.requireMock('../../../src/services/backgroundDownloadService');
(svc.isAvailable as jest.Mock).mockReturnValueOnce(false);
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
expect(svc.onAnyProgress).not.toHaveBeenCalled();
});

it('ignores progress event when downloadId not in index', () => {
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireProgress({ downloadId: 'unknown', bytesDownloaded: 100, totalBytes: 1000 }); });
expect(mockUpdateProgress).not.toHaveBeenCalled();
});

it('routes retrying status through setStatus instead of updateProgress', () => {
withSingleTextEntry();
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireProgress({ downloadId: 'dl-1', bytesDownloaded: 0, totalBytes: 0, status: 'retrying' }); });
expect(mockSetStatus).toHaveBeenCalledWith('dl-1', 'retrying');
expect(mockUpdateProgress).not.toHaveBeenCalled();
});

it('routes waiting_for_network status through setStatus', () => {
withSingleTextEntry();
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireProgress({ downloadId: 'dl-1', bytesDownloaded: 0, totalBytes: 0, status: 'waiting_for_network' }); });
expect(mockSetStatus).toHaveBeenCalledWith('dl-1', 'waiting_for_network');
});

it('calls updateProgress for main download progress event', () => {
withSingleTextEntry();
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireProgress({ downloadId: 'dl-1', bytesDownloaded: 500, totalBytes: 1000 }); });
expect(mockUpdateProgress).toHaveBeenCalledWith('dl-1', 500, 1000);
});
Expand All @@ -160,7 +160,7 @@ describe('useDownloads', () => {
downloadIdIndex: { 'mmproj-1': 'llm:model' },
downloads: { 'llm:model': { downloadId: 'dl-1', mmProjDownloadId: 'mmproj-1', modelType: 'text' } },
}));
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireProgress({ downloadId: 'mmproj-1', bytesDownloaded: 200, totalBytes: 400 }); });
expect(mockUpdateMmProjProgress).toHaveBeenCalledWith('mmproj-1', 200);
});
Expand All @@ -171,15 +171,15 @@ describe('useDownloads', () => {
downloadIdIndex: { 'other': 'llm:model' },
downloads: { 'llm:model': { downloadId: 'dl-1', mmProjDownloadId: 'mmproj-1', modelType: 'text' } },
}));
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireProgress({ downloadId: 'other', bytesDownloaded: 100, totalBytes: 200 }); });
expect(mockUpdateProgress).not.toHaveBeenCalled();
expect(mockUpdateMmProjProgress).not.toHaveBeenCalled();
warnSpy.mockRestore();
});

it('ignores complete event when downloadId not in index', () => {
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireComplete({ downloadId: 'unknown', bytesDownloaded: 100, totalBytes: 100 }); });
expect(mockSetCompleted).not.toHaveBeenCalled();
});
Expand All @@ -194,7 +194,7 @@ describe('useDownloads', () => {
storeState.downloads['llm:model'] = updatedEntry;
});
mockGetState.mockReturnValue(storeState);
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireComplete({ downloadId: 'mmproj-1', bytesDownloaded: 400, totalBytes: 400 }); });
expect(storeState.setMmProjCompleted).toHaveBeenCalledWith('mmproj-1', 400);
expect(mockSetCompleted).toHaveBeenCalledWith('dl-1');
Expand All @@ -206,15 +206,15 @@ describe('useDownloads', () => {
downloads: { 'llm:model': { downloadId: 'dl-1', mmProjDownloadId: 'mmproj-1', status: 'running', modelType: 'text' } },
});
mockGetState.mockReturnValue(storeState);
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireComplete({ downloadId: 'mmproj-1', bytesDownloaded: 400, totalBytes: 400 }); });
expect(mockSetMmProjCompleted).toHaveBeenCalled();
expect(mockSetCompleted).not.toHaveBeenCalled();
});

it('calls updateProgress when main gguf finishes but mmproj not yet done', () => {
withSingleTextEntry('dl-1', { mmProjDownloadId: 'mmproj-1', mmProjStatus: 'running' });
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireComplete({ downloadId: 'dl-1', bytesDownloaded: 1000, totalBytes: 1000 }); });
expect(mockUpdateProgress).toHaveBeenCalled();
expect(mockSetCompleted).not.toHaveBeenCalled();
Expand All @@ -225,15 +225,15 @@ describe('useDownloads', () => {
downloadIdIndex: { 'dl-1': 'image:model' },
downloads: { 'image:model': { downloadId: 'dl-1', modelType: 'image' } },
}));
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireComplete({ downloadId: 'dl-1', bytesDownloaded: 500, totalBytes: 500 }); });
expect(mockSetProcessing).toHaveBeenCalledWith('dl-1');
expect(mockSetCompleted).not.toHaveBeenCalled();
});

it('calls updateProgress for text model on complete (finalization handled elsewhere)', () => {
withSingleTextEntry();
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireComplete({ downloadId: 'dl-1', bytesDownloaded: 1000, totalBytes: 1000 }); });
expect(mockUpdateProgress).toHaveBeenCalled();
expect(mockSetCompleted).not.toHaveBeenCalled();
Expand All @@ -244,20 +244,20 @@ describe('useDownloads', () => {
downloadIdIndex: { 'dl-1': 'other:model' },
downloads: { 'other:model': { downloadId: 'dl-1', modelType: 'other' } },
}));
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireComplete({ downloadId: 'dl-1', bytesDownloaded: 500, totalBytes: 500 }); });
expect(mockSetCompleted).toHaveBeenCalledWith('dl-1');
});

it('ignores error event when downloadId not in index', () => {
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireError({ downloadId: 'unknown', reason: 'oops' }); });
expect(mockSetStatus).not.toHaveBeenCalled();
});

it('calls setStatus with failed on error event', () => {
withSingleTextEntry();
renderHook(() => useDownloads());
renderHook(() => useDownloadListeners());
act(() => { fireError({ downloadId: 'dl-1', reason: 'timeout', reasonCode: 'E_TIMEOUT' }); });
expect(mockSetStatus).toHaveBeenCalledWith('dl-1', 'failed', expect.objectContaining({ message: 'timeout' }));
});
Expand Down
34 changes: 32 additions & 2 deletions __tests__/unit/services/activeModelService.loaders.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ describe('resolveMmProjPath', () => {
expect(result).toBeUndefined();
});

it('finds mmproj file via directory scan when stored path is stale', async () => {
it('finds mmproj file via directory scan when stored path is stale (vision model)', async () => {
mockedRNFS.exists.mockResolvedValue(false);
mockedRNFS.readDir.mockResolvedValue([
{ name: 'mmproj-model-f16.gguf', path: '/models/mmproj-model-f16.gguf', isFile: () => true, size: 500 } as any,
Expand All @@ -96,10 +96,40 @@ describe('resolveMmProjPath', () => {
const { modelManager } = require('../../../src/services/modelManager');
modelManager.saveModelWithMmproj.mockResolvedValue(undefined);

const model = { filePath: '/models/m.gguf', mmProjPath: '/stale/path.gguf' } as any;
// isVisionModel: true so the guard allows the scan
const model = { filePath: '/models/m.gguf', mmProjPath: '/stale/path.gguf', isVisionModel: true } as any;
const result = await resolveMmProjPath(model, 'model-1');
expect(result).toBe('/models/mmproj-model-f16.gguf');
});

it('returns undefined for text-only model when no mmproj file exists in the directory', async () => {
// Text-only model: neither isVisionModel nor mmProjFileName is set,
// and the models directory contains no mmproj file.
mockedRNFS.exists.mockResolvedValue(false);
mockedRNFS.readDir.mockResolvedValue([]);

const model = { filePath: '/models/SmolLM2-360M-Instruct-Q8_0.gguf' } as any;
const result = await resolveMmProjPath(model, 'bartowski/SmolLM2-360M-Instruct-GGUF/SmolLM2-360M-Instruct-Q8_0.gguf');

expect(result).toBeUndefined();
});

it('allows scan for model with mmProjFileName sentinel even when isVisionModel is false (repair case)', async () => {
// After a failed mmproj download buildDownloadedModel sets mmProjFileName as a sentinel
// so needsVisionRepair can detect the gap. resolveMmProjPath must still scan for
// this model so that if the user repairs vision the path can be recovered.
mockedRNFS.exists.mockResolvedValue(false);
mockedRNFS.readDir.mockResolvedValue([]); // mmproj not on disk yet
const model = {
filePath: '/models/SmolVLM2-256M-Video-Instruct-Q8_0.gguf',
isVisionModel: false,
mmProjFileName: 'SmolVLM2-256M-Video-Instruct-Q8_0-mmproj.gguf',
} as any;
const result = await resolveMmProjPath(model, 'ggml-org/SmolVLM2');

expect(result).toBeUndefined(); // mmproj not on disk β†’ scan found nothing
expect(mockedRNFS.readDir).toHaveBeenCalled(); // guard did NOT block the scan
});
});

describe('doLoadTextModel β€” llama.cpp path', () => {
Expand Down
20 changes: 5 additions & 15 deletions __tests__/unit/services/generationServiceHelpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ describe('buildGenerationMetaImpl β€” LiteRT path', () => {
// ---------------------------------------------------------------------------

function makeSvc(overrides: any = {}) {
const state = Object.assign({ isGenerating: false, startTime: Date.now(), streamingContent: '' }, overrides.state);
const state = { isGenerating: false, startTime: Date.now(), streamingContent: '', ...overrides.state };
const svc = {
state,
updateState: jest.fn((patch: any) => { Object.assign(state, patch); }),
Expand Down Expand Up @@ -348,7 +348,7 @@ function makeLiteRTState() {
};
}

function makeLiteRTSvc() {
function makeServiceSvc() {
return {
...makeSvc(),
flushTimer: null,
Expand All @@ -361,6 +361,9 @@ function makeLiteRTSvc() {
};
}

const makeLiteRTSvc = makeServiceSvc;
const makeLlmSvc = makeServiceSvc;

describe('generateResponseImpl β€” LiteRT path', () => {
beforeEach(() => {
mockedLiteRT.isModelLoaded.mockReturnValue(true);
Expand Down Expand Up @@ -430,19 +433,6 @@ describe('generateResponseImpl β€” llama.cpp path', () => {
mockedGetState.mockReturnValue(makeLlmAppState());
});

function makeLlmSvc() {
return {
...makeSvc(),
flushTimer: null,
liteRTBenchmarkStats: null,
forceFlushTokens: jest.fn(),
flushTokenBuffer: jest.fn(),
checkSharePrompt: jest.fn(),
isUsingRemoteProvider: () => false,
getCurrentProvider: () => null,
};
}

it('calls finalizeStreamingMessage on successful completion', async () => {
const { llmService: llm } = require('../../../src/services/llm');
llm.isModelLoaded.mockReturnValue(true);
Expand Down
Loading
Loading