diff --git a/src/components/ConnectionNoticeDialog.tsx b/src/components/ConnectionNoticeDialog.tsx index 1e51e2d..07c5380 100644 --- a/src/components/ConnectionNoticeDialog.tsx +++ b/src/components/ConnectionNoticeDialog.tsx @@ -15,6 +15,7 @@ import { } from "@tabler/icons-react"; import type { ConnectionMethod } from "./DeviceConnection"; import { saveNoticeAcceptance } from "../lib/connectionNoticeStorage"; +import { isUsbConnectionAvailable } from "../lib/transport/usb"; import { useCallback, useMemo, useState } from "react"; interface ConnectionNoticeDialogProps { @@ -36,9 +37,9 @@ export function ConnectionNoticeDialog({ }: ConnectionNoticeDialogProps) { const isUSB = method === "serial"; const isBLE = method === "ble"; - const isSerialAvailable = useMemo(() => "serial" in navigator, []); + const isUSBAvailable = useMemo(() => isUsbConnectionAvailable(), []); const isBLEAvailable = useMemo(() => "bluetooth" in navigator, []); - const canContinue = (isUSB && isSerialAvailable) || (isBLE && isBLEAvailable); + const canContinue = (isUSB && isUSBAvailable) || (isBLE && isBLEAvailable); const [neverShowAgain, setNeverShowAgain] = useState(false); const handleAgree = useCallback(() => { @@ -118,7 +119,7 @@ export function ConnectionNoticeDialog({

)} - {isUSB && !isSerialAvailable && ( + {isUSB && !isUSBAvailable && (

{ - let connectFn; + let connectFn: () => Promise; if (method === "ble") { connectFn = connectBLE; } else if (method === "demo") { connectFn = connectDemo; } else { - connectFn = connectSerial; + connectFn = connectUSB; } await zmkApp.connect(connectFn); }, diff --git a/src/components/__tests__/ConnectionNoticeDialog.test.tsx b/src/components/__tests__/ConnectionNoticeDialog.test.tsx index 2fee0b7..94b11ac 100644 --- a/src/components/__tests__/ConnectionNoticeDialog.test.tsx +++ b/src/components/__tests__/ConnectionNoticeDialog.test.tsx @@ -27,6 +27,11 @@ Object.defineProperty(window, "localStorage", { describe("ConnectionNoticeDialog", () => { beforeEach(() => { mockLocalStorage.clear(); + Object.defineProperty(navigator, "serial", { + writable: true, + configurable: true, + value: undefined, + }); }); test("renders USB connection dialog", () => { @@ -139,6 +144,39 @@ describe("ConnectionNoticeDialog", () => { expect(screen.queryByText("Connect via USB")).not.toBeInTheDocument(); }); + + test("allows USB connection on Android Chrome when WebUSB is available", () => { + const onAgree = jest.fn(); + const onCancel = jest.fn(); + const userAgentSpy = jest + .spyOn(navigator, "userAgent", "get") + .mockReturnValue( + "Mozilla/5.0 (Linux; Android 14; Pixel 8) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36", + ); + + Object.defineProperty(navigator, "serial", { + configurable: true, + value: undefined, + }); + Object.defineProperty(navigator, "usb", { + configurable: true, + value: { requestDevice: jest.fn() }, + }); + + render( + , + ); + + expect(screen.getByText("Data Collection Notice")).toBeInTheDocument(); + expect(screen.getByText("Agree to start")).toBeInTheDocument(); + + userAgentSpy.mockRestore(); + }); }); describe("hasAcceptedNotice", () => { diff --git a/src/components/__tests__/DeviceConnection.test.tsx b/src/components/__tests__/DeviceConnection.test.tsx index 19a07ec..c0f28eb 100644 --- a/src/components/__tests__/DeviceConnection.test.tsx +++ b/src/components/__tests__/DeviceConnection.test.tsx @@ -25,6 +25,11 @@ jest.mock("@zmkfirmware/zmk-studio-ts-client/transport/serial", () => ({ connect: jest.fn(), })); +// Mock the app-level USB transport selector +jest.mock("../../lib/transport/usb", () => ({ + connect: jest.fn(), +})); + // Mock the BLE transport jest.mock("@zmkfirmware/zmk-studio-ts-client/transport/gatt", () => ({ connect: jest.fn(), diff --git a/src/lib/transport/__tests__/usb.test.ts b/src/lib/transport/__tests__/usb.test.ts new file mode 100644 index 0000000..b2808f3 --- /dev/null +++ b/src/lib/transport/__tests__/usb.test.ts @@ -0,0 +1,78 @@ +import { connect as connectSerial } from "@zmkfirmware/zmk-studio-ts-client/transport/serial"; +import { connect as connectWebUsb } from "../webUsb"; +import { + connect, + isUsbConnectionAvailable, + shouldUseWebUsbForUsbConnection, +} from "../usb"; + +jest.mock("@zmkfirmware/zmk-studio-ts-client/transport/serial", () => ({ + connect: jest.fn(), +})); + +jest.mock("../webUsb", () => ({ + connect: jest.fn(), +})); + +type NavigatorWithOptionalTransport = Navigator & { + serial?: unknown; + usb?: unknown; +}; + +const androidChromeUserAgent = + "Mozilla/5.0 (Linux; Android 14; Pixel 8) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36"; +const desktopChromeUserAgent = + "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"; + +describe("USB transport selection", () => { + beforeEach(() => { + jest.clearAllMocks(); + delete (navigator as NavigatorWithOptionalTransport).serial; + delete (navigator as NavigatorWithOptionalTransport).usb; + }); + + test("uses WebUSB for Android Chrome", async () => { + await connectWithUserAgent(androidChromeUserAgent); + + expect(connectWebUsb).toHaveBeenCalledTimes(1); + expect(connectSerial).not.toHaveBeenCalled(); + }); + + test("uses Web Serial for non-Android Chrome", async () => { + await connectWithUserAgent(desktopChromeUserAgent); + + expect(connectSerial).toHaveBeenCalledTimes(1); + expect(connectWebUsb).not.toHaveBeenCalled(); + }); + + test("detects Android Chrome user agents", () => { + expect(shouldUseWebUsbForUsbConnection(androidChromeUserAgent)).toBe(true); + expect(shouldUseWebUsbForUsbConnection(desktopChromeUserAgent)).toBe(false); + }); + + test("treats Android Chrome WebUSB as USB-capable without Web Serial", () => { + const userAgentSpy = jest + .spyOn(navigator, "userAgent", "get") + .mockReturnValue(androidChromeUserAgent); + Object.defineProperty(navigator, "usb", { + configurable: true, + value: { requestDevice: jest.fn() }, + }); + + expect(isUsbConnectionAvailable()).toBe(true); + + userAgentSpy.mockRestore(); + }); +}); + +async function connectWithUserAgent(userAgent: string) { + const userAgentSpy = jest + .spyOn(navigator, "userAgent", "get") + .mockReturnValue(userAgent); + + try { + await connect(); + } finally { + userAgentSpy.mockRestore(); + } +} diff --git a/src/lib/transport/usb.ts b/src/lib/transport/usb.ts new file mode 100644 index 0000000..2f3fc34 --- /dev/null +++ b/src/lib/transport/usb.ts @@ -0,0 +1,23 @@ +import type { RpcTransport } from "@zmkfirmware/zmk-studio-ts-client/transport/index"; +import { connect as connectSerial } from "@zmkfirmware/zmk-studio-ts-client/transport/serial"; +import { connect as connectWebUsb } from "./webUsb"; + +export function shouldUseWebUsbForUsbConnection( + userAgent = navigator.userAgent, +) { + return /\bAndroid\b/i.test(userAgent) && /\bChrome\//i.test(userAgent); +} + +export function isUsbConnectionAvailable() { + return ( + "serial" in navigator || + (shouldUseWebUsbForUsbConnection() && "usb" in navigator) + ); +} + +export async function connect(): Promise { + if (shouldUseWebUsbForUsbConnection()) { + return connectWebUsb(); + } + return connectSerial(); +} diff --git a/src/lib/transport/webUsb.ts b/src/lib/transport/webUsb.ts new file mode 100644 index 0000000..ff08636 --- /dev/null +++ b/src/lib/transport/webUsb.ts @@ -0,0 +1,344 @@ +import type { RpcTransport } from "@zmkfirmware/zmk-studio-ts-client/transport/index"; + +type WebUsbDeviceFilterLike = { + classCode?: number; + subclassCode?: number; + protocolCode?: number; + vendorId?: number; + productId?: number; +}; + +type WebUsbEndpointLike = { + endpointNumber: number; + direction: "in" | "out"; + packetSize: number; + type: "bulk" | "interrupt" | "isochronous"; +}; + +type WebUsbAlternateInterfaceLike = { + alternateSetting: number; + interfaceClass: number; + endpoints: WebUsbEndpointLike[]; +}; + +type WebUsbInterfaceLike = { + interfaceNumber: number; + alternates: WebUsbAlternateInterfaceLike[]; +}; + +type WebUsbConfigurationLike = { + configurationValue: number; + interfaces: WebUsbInterfaceLike[]; +}; + +type WebUsbTransferResultLike = { + data?: DataView; + status: "ok" | "stall" | "babble"; +}; + +type WebUsbDeviceLike = { + configuration: WebUsbConfigurationLike | null; + configurations: WebUsbConfigurationLike[]; + opened: boolean; + productId?: number; + productName?: string; + vendorId?: number; + open: () => Promise; + close: () => Promise; + selectConfiguration: (configurationValue: number) => Promise; + claimInterface: (interfaceNumber: number) => Promise; + releaseInterface: (interfaceNumber: number) => Promise; + selectAlternateInterface?: ( + interfaceNumber: number, + alternateSetting: number, + ) => Promise; + transferIn: ( + endpointNumber: number, + length: number, + ) => Promise; + transferOut: ( + endpointNumber: number, + data: Uint8Array, + ) => Promise<{ status: "ok" | "stall" }>; + clearHalt?: ( + direction: "in" | "out", + endpointNumber: number, + ) => Promise; +}; + +type WebUsbCdcEndpoints = { + interfaceNumber: number; + alternateSetting: number; + inEndpoint: WebUsbEndpointLike; + outEndpoint: WebUsbEndpointLike; +}; + +type NavigatorWithWebUsb = Navigator & { + usb?: { + requestDevice: (options: { + filters: WebUsbDeviceFilterLike[]; + }) => Promise; + }; +}; + +const webUsbCdcFilters: WebUsbDeviceFilterLike[] = [ + { classCode: 0x02 }, + { classCode: 0x0a }, +]; + +export async function connect(): Promise { + const usb = (navigator as NavigatorWithWebUsb).usb; + if (!usb) throw new Error("WebUSB is not available in this browser"); + + const device = await usb.requestDevice({ filters: webUsbCdcFilters }); + await device.open().catch((caught) => { + if (caught instanceof DOMException && caught.name === "NetworkError") { + throw new Error( + "Failed to open the WebUSB device. Check the permissions of the device and verify it is not in use by another process.", + { cause: caught }, + ); + } + throw caught; + }); + + try { + if (!device.configuration) { + const configurationValue = device.configurations[0]?.configurationValue; + if (configurationValue === undefined) { + throw new Error("No USB configuration is available on this device."); + } + await device.selectConfiguration(configurationValue); + } + + const endpoints = findWebUsbCdcEndpoints(device.configuration); + if (!endpoints) { + throw new Error( + "No CDC bulk data interface was found on the selected WebUSB device.", + ); + } + + await device.claimInterface(endpoints.interfaceNumber).catch((caught) => { + throw new Error( + "Failed to claim the CDC interface on the WebUSB device. Check that it is not in use by another app or browser tab.", + { cause: caught }, + ); + }); + + if (endpoints.alternateSetting !== 0) { + await device.selectAlternateInterface?.( + endpoints.interfaceNumber, + endpoints.alternateSetting, + ); + } + + const abortController = new AbortController(); + const close = createWebUsbCloseHandler(device, endpoints.interfaceNumber); + abortController.signal.addEventListener("abort", () => void close(), { + once: true, + }); + + return { + label: createWebUsbLabel(device), + abortController, + readable: createWebUsbReadableStream( + device, + endpoints.inEndpoint, + abortController, + ), + writable: createWebUsbWritableStream( + device, + endpoints.outEndpoint, + abortController, + ), + }; + } catch (caught) { + await closeWebUsbDevice(device).catch((closeError) => { + console.warn( + "Failed to close ZMK WebUSB device after error.", + closeError, + ); + }); + throw caught; + } +} + +function findWebUsbCdcEndpoints( + configuration: WebUsbConfigurationLike | null, +): WebUsbCdcEndpoints | null { + const candidates: WebUsbCdcEndpoints[] = []; + + for (const usbInterface of configuration?.interfaces ?? []) { + for (const alternate of usbInterface.alternates) { + const inEndpoint = alternate.endpoints.find( + (endpoint) => endpoint.type === "bulk" && endpoint.direction === "in", + ); + const outEndpoint = alternate.endpoints.find( + (endpoint) => endpoint.type === "bulk" && endpoint.direction === "out", + ); + if (!inEndpoint || !outEndpoint) continue; + + const candidate = { + interfaceNumber: usbInterface.interfaceNumber, + alternateSetting: alternate.alternateSetting, + inEndpoint, + outEndpoint, + }; + candidates.push(candidate); + if (alternate.interfaceClass === 0x0a) return candidate; + } + } + + return candidates[0] ?? null; +} + +function createWebUsbReadableStream( + device: WebUsbDeviceLike, + endpoint: WebUsbEndpointLike, + abortController: AbortController, +) { + return new ReadableStream({ + start(controller) { + void pumpWebUsbIn(device, endpoint, abortController.signal, controller); + }, + cancel() { + abortController.abort(); + }, + }); +} + +async function pumpWebUsbIn( + device: WebUsbDeviceLike, + endpoint: WebUsbEndpointLike, + signal: AbortSignal, + controller: ReadableStreamDefaultController, +) { + while (!signal.aborted) { + try { + const result = await device.transferIn( + endpoint.endpointNumber, + endpoint.packetSize, + ); + if (signal.aborted) return; + + if (result.status === "stall") { + await device.clearHalt?.("in", endpoint.endpointNumber); + continue; + } + if (result.status !== "ok") { + throw new Error(`WebUSB transferIn failed with ${result.status}.`); + } + if (result.data?.byteLength) { + controller.enqueue(dataViewToUint8Array(result.data)); + } + } catch (caught) { + if (!signal.aborted) controller.error(caught); + return; + } + } +} + +function createWebUsbWritableStream( + device: WebUsbDeviceLike, + endpoint: WebUsbEndpointLike, + abortController: AbortController, +) { + return new WritableStream({ + async write(chunk) { + if (abortController.signal.aborted) { + throw new DOMException("WebUSB connection is closed.", "AbortError"); + } + await writeWebUsbChunks(device, endpoint, chunk); + }, + abort() { + abortController.abort(); + }, + }); +} + +async function writeWebUsbChunks( + device: WebUsbDeviceLike, + endpoint: WebUsbEndpointLike, + chunk: Uint8Array, +) { + for ( + let offset = 0; + offset < chunk.byteLength; + offset += endpoint.packetSize + ) { + const slice = chunk.subarray(offset, offset + endpoint.packetSize); + const result = await device.transferOut(endpoint.endpointNumber, slice); + if (result.status === "stall") { + await device.clearHalt?.("out", endpoint.endpointNumber); + throw new Error("WebUSB transferOut stalled."); + } + if (result.status !== "ok") { + throw new Error(`WebUSB transferOut failed with ${result.status}.`); + } + } +} + +function createWebUsbCloseHandler( + device: WebUsbDeviceLike, + interfaceNumber: number, +) { + let closePromise: Promise | null = null; + return () => { + closePromise ??= closeWebUsbDevice(device, interfaceNumber).catch( + (caught) => { + console.warn("Failed to close ZMK WebUSB device.", caught); + }, + ); + return closePromise; + }; +} + +async function closeWebUsbDevice( + device: WebUsbDeviceLike, + interfaceNumber?: number, +) { + for (let attempt = 0; attempt < 10; attempt += 1) { + await delay(attempt === 0 ? 0 : 50); + try { + if (interfaceNumber !== undefined) { + await device.releaseInterface(interfaceNumber).catch((caught) => { + if (!isWebUsbDeviceAlreadyClosed(caught)) throw caught; + }); + } + if (device.opened) await device.close(); + return; + } catch (caught) { + if (attempt === 9 || isWebUsbDeviceAlreadyClosed(caught)) { + if (!isWebUsbDeviceAlreadyClosed(caught)) throw caught; + return; + } + } + } +} + +function createWebUsbLabel(device: WebUsbDeviceLike) { + const vendorId = device.vendorId?.toLocaleString() || ""; + const productId = device.productId?.toLocaleString() || ""; + return [device.productName, `${vendorId}:${productId}`] + .filter(Boolean) + .join(" "); +} + +function isWebUsbDeviceAlreadyClosed(caught: unknown) { + return ( + caught instanceof DOMException && + caught.name === "InvalidStateError" && + /not open|already closed/i.test(caught.message) + ); +} + +function dataViewToUint8Array(data: DataView) { + return new Uint8Array( + data.buffer.slice(data.byteOffset, data.byteOffset + data.byteLength), + ); +} + +function delay(ms: number) { + return new Promise((resolve) => { + globalThis.setTimeout(resolve, ms); + }); +}