diff --git a/src/autoencoder.test.ts b/src/autoencoder.test.ts index 166903b6..b649add4 100644 --- a/src/autoencoder.test.ts +++ b/src/autoencoder.test.ts @@ -66,3 +66,25 @@ test('test a data sample for anomalies', async () => { includesAnomalies(1, 0, 1); includesAnomalies(1, 1, 0); }); + +test('restores a net fromJSON', () => { + expect(result.error).toBeLessThanOrEqual(errorThresh); + function xor(net: AE, ...args: number[]) { + return Math.round(net.denoise(args)[2]); + } + + const json = xornet.toJSON(); + const net = new AE({ + json, + }); + + const run1 = xor(net, 0, 0, 0); + const run2 = xor(net, 0, 1, 1); + const run3 = xor(net, 1, 0, 1); + const run4 = xor(net, 1, 1, 0); + + expect(run1).toBe(0); + expect(run2).toBe(1); + expect(run3).toBe(1); + expect(run4).toBe(0); +}); diff --git a/src/autoencoder.ts b/src/autoencoder.ts index 357d7fcd..aae06053 100644 --- a/src/autoencoder.ts +++ b/src/autoencoder.ts @@ -3,6 +3,7 @@ import { IJSONLayer, INeuralNetworkData, INeuralNetworkDatum, + INeuralNetworkJSON, INeuralNetworkTrainOptions, } from './neural-network'; import { @@ -16,6 +17,7 @@ export interface IAEOptions { binaryThresh: number; decodedSize: number; hiddenLayers: number[]; + json?: INeuralNetworkJSON; } /** @@ -26,7 +28,7 @@ export class AE< EncodedData extends INeuralNetworkData > { private decoder?: NeuralNetworkGPU; - private readonly denoiser: NeuralNetworkGPU; + private denoiser: NeuralNetworkGPU; constructor(options?: Partial) { // Create default options for the autoencoder. @@ -47,6 +49,10 @@ export class AE< // Create the denoiser subnet of the autoencoder. this.denoiser = new NeuralNetworkGPU(options); + + if (options.json) { + this.denoiser = this.denoiser.fromJSON(options.json); + } } /** @@ -191,6 +197,15 @@ export class AE< return (decoder as unknown) as NeuralNetworkGPU; } + toJSON(): INeuralNetworkJSON { + return this.denoiser.toJSON(); + } + + fromJSON(json: INeuralNetworkJSON): this { + this.denoiser = this.denoiser.fromJSON(json); + return this; + } + /** * Get the layer containing the encoded representation. */