Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 277 additions & 0 deletions examples/nn_layers.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
' # Neural Networks

' ## NN Prelude

' This ReLU implementation is for one neuron.
Use `map relu` to apply it to a 1D table, `map (map relu)` for 2D, etc.

def relu (input : Float) : Float =
select (input > 0.0) input 0.0

' A pair of vector spaces is also a vector space.

instance [Add a, Add b] Add (a & b)
add = \(a, b) (c, d). ( (a + c), (b + d))
sub = \(a, b) (c, d). ( (a - c), (b - d))
zero = (zero, zero)

instance [VSpace a, VSpace b] VSpace (a & b)
scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b)

' Layer type, describing a function with trainable side-parameters.
This may represent a primitive layer (e.g. dense), a composition of layers
(e.g. resnet_block), or an entire network.
`forward` is the function implementing the layer computation.
`init` provides initial values of the parameters, given a random key.

data Layer inp:Type out:Type params:Type =
AsLayer {forward:(params -> inp -> out) & init:(Key -> params)}

' Convenience functions to extract `forward` and `init` functions.

def forward (l:Layer i o p) (p : p) (x : i): o =
(AsLayer l') = l
(getAt #forward l') p x

def init (l:Layer i o p) (k:Key) : p =
(AsLayer l') = l
(getAt #init l') k

' Adapt a pure function into a (parameterless) `Layer`.
This is unused and for illustration only: we will see below how to apply
pure functions directly with `trace_map`.

def as_layer (f:u->v) : Layer u v Unit =
AsLayer {
forward = \ () x . f x,
init = \_ . ()
}

' ## Layers

' Dense layer.

def DenseParams (a:Type) (b:Type) : Type =
((a=>b=>Float) & (b=>Float))

def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) =
AsLayer {
forward = (\ ((weight, bias)) x .
for j. (bias.j + sum for i. weight.i.j * x.i)),
init = arb
}


' CNN layer.

def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type =
((outc=>inc=>Fin kh=>Fin kw=>Float) &
(outc=>Float))

def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float)
(kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) :
outc=>(Fin h)=>(Fin w)=>Float =
for o i j.
(i', j') = (ordinal i, ordinal j)
case (i' + kh) <= h && (j' + kw) <= w of
True ->
sum for (ki, kj, inp).
(di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)),
fromOrdinal (Fin w) (j' + (ordinal kj)))
x.inp.di.dj * kernel.o.inp.ki.kj
False -> zero

def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) :
Layer (inc=>(Fin h)=>(Fin w)=>Float)
(outc=>(Fin h)=>(Fin w)=>Float)
(CNNParams inc outc kw kh) =
AsLayer {
forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o),
init = arb
}

' ## Tracing

' A tracer is an object that we pass through a network function. The tracer
is used as if it were an actual input to the network, and invoking constituent
layers results in new tracers denoting intermediate values of the network.
A tracer is implemented as a partial neural network, from the original inputs
to the current intermediate values.
Starting with an identity layer as the input tracer, tracing thus accumulates
an output tracer that implements the full network.

def trace (f : Layer inp inp Unit -> Layer inp out params) : Layer inp out params =
tracer = AsLayer {
forward = \() x. x,
init = \key. ()
}
f tracer

' Converts a pure function (e.g. `map relu`) to a callable acting on tracers.

def trace_map (f : inp -> out) (tracer : Layer ph inp ps) : Layer ph out ps =
AsLayer {
forward = \w x. f (forward tracer w x),
init = \key. init tracer key
}

' Converts a two-arg pure function (e.g. `add`) to a callable acting on tracers.

def trace_map2 (f : inp0 -> inp1 -> out)
(tracer0 : Layer ph inp0 ps0) (tracer1 : Layer ph inp1 ps1) : Layer ph out (ps0&ps1) =
AsLayer {
forward = \(w0, w1) x. f (forward tracer0 w0 x) (forward tracer1 w1 x),
init = (\key.
[k0, k1] = splitKey key
(init tracer0 k0, init tracer1 k1))
}

' Adapts an existing `Layer` into a callable that can be invoked within a
larger layer being traced.

def callable (layer : Layer inp out params) (tracer : Layer ph inp ps) : Layer ph out (params & ps) =
AsLayer {
forward = \w x. forward layer (fst w) $ forward tracer (snd w) x,
init = (\key.
[k0, k1] = splitKey key
(init layer k0, init tracer k1))
}

' ## Networks

' MLP

mlp = trace \x.
dense1 = callable $ dense (Fin 2) (Fin 25)
x = dense1 x
x = trace_map (map relu) x
dense3 = callable $ dense _ (Fin 2)
x = dense3 x
x

w_mlp = init mlp (newKey 1)
:t w_mlp
:p forward mlp w_mlp (for _. 0.)

' ResNet - incorrect first attempt

' The following does not work correctly because the (last assigned) value of `x`
' is consumed twice, directly or indirectly by both inputs to `trace_map2`.
' Because it's a tracer, it leads to two independent copies of the `dense1`
' being created, one for each branch.

resnet_incorrect = trace \x.
dense1 = callable $ dense (Fin 2) (Fin 10)
x = dense1 x
x = trace_map (map relu) x
-- This value of `x` will be consumed twice.

dense3 = callable $ dense _ (Fin 25)
y = dense3 x
y = trace_map (map relu) y
dense5 = callable $ dense _ (Fin 10)
y = dense5 y
y = trace_map2 add y x

y = trace_map (map relu) y
dense7 = callable $ dense _ (Fin 2)
y = dense7 y
y

w_resnet_bad = init resnet_incorrect (newKey 1)
:t w_resnet_bad
:p forward resnet_incorrect w_resnet_bad (for _. 0.)


' ResNet - nested sequential

' In `resnet_block` below, `x` is used twice, but that is safe because it's the
' input tracer with no network parameters yet.

' We'd like to just use `_` for the return type to be `_`.
' Unfortunately that currently leads to a "leaked type variable `a`" error.

def resnet_block (a:Type)
: Layer (a=>Float) (a=>Float) (((DenseParams _ a) & (DenseParams a _) & Unit) & Unit)
= trace \x.
dense1 = callable $ dense a (Fin 25)
y = dense1 x
y = trace_map (map relu) y
dense3 = callable $ dense _ a
y = dense3 y
y = trace_map2 add y x
y

resnet = trace \x.
dense1 = callable $ dense (Fin 2) (Fin 10)
x = dense1 x
x = trace_map (map relu) x
block3 = callable $ resnet_block _
x = block3 x
x = trace_map (map relu) x
dense7 = callable $ dense _ (Fin 2)
x = dense7 x
x

w_resnet = init resnet (newKey 1)
:t w_resnet
:p forward resnet w_resnet (for _. 0.)


' ## Training

' Train a multiclass classifier with minibatch SGD
' `minibatch * minibatches = batch`

def split (x: batch=>v) : minibatches=>minibatch=>v =
for b j. x.((ordinal (b,j))@batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might make sense to assert that minibatches * minibatch = batch, because this function will happily throw away data.


def trainClass [VSpace p] (model: Layer a (b=>Float) p)
(x: batch=>a)
(y: batch=>b)
(epochs : Type)
(minibatch : Type)
(minibatches : Type) :
(epochs => p & epochs => Float) =
xs : minibatches => minibatch => a = split x
ys : minibatches => minibatch => b = split y
unzip $ withState (init model $ newKey 1) $ \params .
for _ : epochs.
loss = sum $ for b : minibatches.
(loss, gradfn) = vjp (\ params.
-sum for j.
logits = forward model params xs.b.j
(logsoftmax logits).(ys.b.j)) (get params)
gparams = gradfn 1.0
params := (get params) - scaleVec (0.05 / (IToF 100)) gparams
loss
(get params, loss)

' Sextant classification dataset.

[k1, k2] = splitKey $ newKey 1
x1 : Fin 400 => Float = arb k1
x2 : Fin 400 => Float = arb k2
y = for i. case ((x1.i * x1.i * x1.i - 3. * x1.i * x2.i * x2.i) > 0.) of
True -> 1
False -> 0
xs = for i. [x1.i, x2.i]

import plot
:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i

' Train classifier on this dataset.

-- model = mlp
model = resnet

(all_params, losses) = trainClass model xs (for i. (y.i @ (Fin 2))) (Fin 3000) (Fin 50) (Fin 8)

' Classification landscape, evolving over training time. Colour denotes softmax.

span = linspace (Fin 15) (-1.0) (1.0)
tests = for h : (Fin 100). for i . for j.
r = softmax $ forward model all_params.((ordinal h * 30)@_) [span.j, -span.i]
[r.(1@_), 0.5*r.(1@_), r.(0@_)]

:html imseqshow tests
Loading