-
Notifications
You must be signed in to change notification settings - Fork 998
Open
Description
There might be a way to speed up Tensor.tolist() by simplifying its code
below is a side-by-side comparison of my version with the 3.0.2 one we ship in FF
var transformers = ChromeUtils.importESModule("chrome://global/content/ml/transformers-dev.js");
class CustomTensor extends transformers.Tensor {
tolist() {
const shape = this.dims;
const data = this.data;
if (shape.length === 1) return [...data];
const result = new Array(shape[0]);
const rowSize = shape[1];
for (let i = 0; i < shape[0]; i++) {
const row = new Array(rowSize);
for (let j = 0; j < rowSize; j++) {
row[j] = data[i * rowSize + j];
}
result[i] = row;
}
return result;
}
}
async function test_tolist_speed(iterations = 10) {
const numTokens = 16;
const embeddingDim = 768;
const totalElements = numTokens * embeddingDim;
const data = Array.from({ length: totalElements }, () => Math.random() * 2 - 1);
const shape = [numTokens, embeddingDim]; // [16, 768] for a sentence embedding
const tensor = new transformers.Tensor('float32', data, shape);
const customTensor = new CustomTensor('float32', data, shape);
let totalTimeDefault = 0;
let totalTimeOptimized = 0;
let allEqual = true;
function arraysEqual(arr1, arr2) {
if (!Array.isArray(arr1) || !Array.isArray(arr2) || arr1.length !== arr2.length) return false;
return arr1.every((val, index) => Array.isArray(val) ? arraysEqual(val, arr2[index]) : val === arr2[index]);
}
for (let i = 0; i < iterations; i++) {
let start = performance.now();
const defaultList = tensor.tolist();
totalTimeDefault += performance.now() - start;
start = performance.now();
const optimizedList = customTensor.tolist();
totalTimeOptimized += performance.now() - start;
if (!arraysEqual(defaultList, optimizedList)) {
allEqual = false;
}
}
console.debug(`Avg tolist() time (default): ${(totalTimeDefault / iterations).toFixed(4)} ms`);
console.debug(`Avg tolist() time (optimized): ${(totalTimeOptimized / iterations).toFixed(4)} ms`);
console.debug(`Lists are equal: ${allEqual}`);
}
await test_tolist_speed();
And the results:
console.debug: "Avg tolist() time (default): 0.2163 ms"
console.debug: "Avg tolist() time (optimized): 0.0904 ms"
console.debug: "Lists are equal: true"
console.debug: "Avg tolist() time (default): 0.2083 ms"
console.debug: "Avg tolist() time (optimized): 0.0793 ms"
console.debug: "Lists are equal: true"
Let me know if that sounds correct or if I missed something - happy to contribute that patch
Metadata
Metadata
Assignees
Labels
No labels