Skip to content

Commit 05ce66f

Browse files
christinakopitharvik
authored andcommitted
discojs/src/models/gpt/layers.spec.ts check if there are NaN in MLPs' outputs
1 parent 0ba1f8a commit 05ce66f

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

discojs/src/models/gpt/layers.spec.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ describe('GPT Layers', function () {
127127
// MLP Layer tests
128128
describe('MLP Layer', function () {
129129

130-
it('should produce deterministic outputs with the same random seed', async function () {
130+
it('should produce deterministic/non-NaN outputs with the same random seed', async function () {
131131
// an MLP config with a fixed seed
132132
const config: MLPConfig = {
133133
name: 'testMLP',
@@ -142,8 +142,6 @@ describe('GPT Layers', function () {
142142
// two separate MLP model instances using the same config
143143
const model1 = MLP(config);
144144
const model2 = MLP(config);
145-
146-
//TODO: check if there are NANs
147145

148146
const input = tf.ones([1, config.contextLength, config.nEmbd]);
149147

@@ -159,7 +157,15 @@ describe('GPT Layers', function () {
159157

160158
// check that the models produce the same output
161159
expect(arr1).to.deep.equal(arr2);
162-
160+
161+
// Check that there are no NaN values in the outputs.
162+
for (let i = 0; i < arr1.length; i++) {
163+
expect(isNaN(arr1[i])).to.be.false;
164+
}
165+
for (let i = 0; i < arr2.length; i++) {
166+
expect(isNaN(arr2[i])).to.be.false;
167+
}
168+
163169
});
164170
});
165171

0 commit comments

Comments
 (0)