- 
                Notifications
    You must be signed in to change notification settings 
- Fork 114
Neural network layers example #581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            RobertStanforth
  wants to merge
  5
  commits into
  google-research:main
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
RobertStanforth:main
  
      
      
   
  
    
  
  
  
 
  
      
    base: main
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Open
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            5 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      9e2266a
              
                Corrected off-by-one error in bounds check on applicability of conv f…
              
              
                 4e68e2f
              
                Merge branch 'google-research:main' into main
              
              
                RobertStanforth ec11ecf
              
                Prototype layers library using type system to manage layers' params a…
              
              
                 55be172
              
                Alternative prototype layers library, using existential types to hide…
              
              
                 59bb740
              
                Corrections to trainModel in existentially-typed layers example
              
              
                 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,277 @@ | ||
| ' # Neural Networks | ||
|  | ||
| ' ## NN Prelude | ||
|  | ||
| ' This ReLU implementation is for one neuron. | ||
| Use `map relu` to apply it to a 1D table, `map (map relu)` for 2D, etc. | ||
|  | ||
| def relu (input : Float) : Float = | ||
| select (input > 0.0) input 0.0 | ||
|  | ||
| ' A pair of vector spaces is also a vector space. | ||
|  | ||
| instance [Add a, Add b] Add (a & b) | ||
| add = \(a, b) (c, d). ( (a + c), (b + d)) | ||
| sub = \(a, b) (c, d). ( (a - c), (b - d)) | ||
| zero = (zero, zero) | ||
|  | ||
| instance [VSpace a, VSpace b] VSpace (a & b) | ||
| scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b) | ||
|  | ||
| ' Layer type, describing a function with trainable side-parameters. | ||
| This may represent a primitive layer (e.g. dense), a composition of layers | ||
| (e.g. resnet_block), or an entire network. | ||
| `forward` is the function implementing the layer computation. | ||
| `init` provides initial values of the parameters, given a random key. | ||
|  | ||
| data Layer inp:Type out:Type params:Type = | ||
| AsLayer {forward:(params -> inp -> out) & init:(Key -> params)} | ||
|  | ||
| ' Convenience functions to extract `forward` and `init` functions. | ||
|  | ||
| def forward (l:Layer i o p) (p : p) (x : i): o = | ||
| (AsLayer l') = l | ||
| (getAt #forward l') p x | ||
|  | ||
| def init (l:Layer i o p) (k:Key) : p = | ||
| (AsLayer l') = l | ||
| (getAt #init l') k | ||
|  | ||
| ' Adapt a pure function into a (parameterless) `Layer`. | ||
| This is unused and for illustration only: we will see below how to apply | ||
| pure functions directly with `trace_map`. | ||
|  | ||
| def as_layer (f:u->v) : Layer u v Unit = | ||
| AsLayer { | ||
| forward = \ () x . f x, | ||
| init = \_ . () | ||
| } | ||
|  | ||
| ' ## Layers | ||
|  | ||
| ' Dense layer. | ||
|  | ||
| def DenseParams (a:Type) (b:Type) : Type = | ||
| ((a=>b=>Float) & (b=>Float)) | ||
|  | ||
| def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) = | ||
| AsLayer { | ||
| forward = (\ ((weight, bias)) x . | ||
| for j. (bias.j + sum for i. weight.i.j * x.i)), | ||
| init = arb | ||
| } | ||
|  | ||
|  | ||
| ' CNN layer. | ||
|  | ||
| def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type = | ||
| ((outc=>inc=>Fin kh=>Fin kw=>Float) & | ||
| (outc=>Float)) | ||
|  | ||
| def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float) | ||
| (kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) : | ||
| outc=>(Fin h)=>(Fin w)=>Float = | ||
| for o i j. | ||
| (i', j') = (ordinal i, ordinal j) | ||
| case (i' + kh) <= h && (j' + kw) <= w of | ||
| True -> | ||
| sum for (ki, kj, inp). | ||
| (di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)), | ||
| fromOrdinal (Fin w) (j' + (ordinal kj))) | ||
| x.inp.di.dj * kernel.o.inp.ki.kj | ||
| False -> zero | ||
|  | ||
| def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) : | ||
| Layer (inc=>(Fin h)=>(Fin w)=>Float) | ||
| (outc=>(Fin h)=>(Fin w)=>Float) | ||
| (CNNParams inc outc kw kh) = | ||
| AsLayer { | ||
| forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o), | ||
| init = arb | ||
| } | ||
|  | ||
| ' ## Tracing | ||
|  | ||
| ' A tracer is an object that we pass through a network function. The tracer | ||
| is used as if it were an actual input to the network, and invoking constituent | ||
| layers results in new tracers denoting intermediate values of the network. | ||
| A tracer is implemented as a partial neural network, from the original inputs | ||
| to the current intermediate values. | ||
| Starting with an identity layer as the input tracer, tracing thus accumulates | ||
| an output tracer that implements the full network. | ||
|  | ||
| def trace (f : Layer inp inp Unit -> Layer inp out params) : Layer inp out params = | ||
| tracer = AsLayer { | ||
| forward = \() x. x, | ||
| init = \key. () | ||
| } | ||
| f tracer | ||
|  | ||
| ' Converts a pure function (e.g. `map relu`) to a callable acting on tracers. | ||
|  | ||
| def trace_map (f : inp -> out) (tracer : Layer ph inp ps) : Layer ph out ps = | ||
| AsLayer { | ||
| forward = \w x. f (forward tracer w x), | ||
| init = \key. init tracer key | ||
| } | ||
|  | ||
| ' Converts a two-arg pure function (e.g. `add`) to a callable acting on tracers. | ||
|  | ||
| def trace_map2 (f : inp0 -> inp1 -> out) | ||
| (tracer0 : Layer ph inp0 ps0) (tracer1 : Layer ph inp1 ps1) : Layer ph out (ps0&ps1) = | ||
| AsLayer { | ||
| forward = \(w0, w1) x. f (forward tracer0 w0 x) (forward tracer1 w1 x), | ||
| init = (\key. | ||
| [k0, k1] = splitKey key | ||
| (init tracer0 k0, init tracer1 k1)) | ||
| } | ||
|  | ||
| ' Adapts an existing `Layer` into a callable that can be invoked within a | ||
| larger layer being traced. | ||
|  | ||
| def callable (layer : Layer inp out params) (tracer : Layer ph inp ps) : Layer ph out (params & ps) = | ||
| AsLayer { | ||
| forward = \w x. forward layer (fst w) $ forward tracer (snd w) x, | ||
| init = (\key. | ||
| [k0, k1] = splitKey key | ||
| (init layer k0, init tracer k1)) | ||
| } | ||
|  | ||
| ' ## Networks | ||
|  | ||
| ' MLP | ||
|  | ||
| mlp = trace \x. | ||
| dense1 = callable $ dense (Fin 2) (Fin 25) | ||
| x = dense1 x | ||
| x = trace_map (map relu) x | ||
| dense3 = callable $ dense _ (Fin 2) | ||
| x = dense3 x | ||
| x | ||
|  | ||
| w_mlp = init mlp (newKey 1) | ||
| :t w_mlp | ||
| :p forward mlp w_mlp (for _. 0.) | ||
|  | ||
| ' ResNet - incorrect first attempt | ||
|  | ||
| ' The following does not work correctly because the (last assigned) value of `x` | ||
| ' is consumed twice, directly or indirectly by both inputs to `trace_map2`. | ||
| ' Because it's a tracer, it leads to two independent copies of the `dense1` | ||
| ' being created, one for each branch. | ||
|  | ||
| resnet_incorrect = trace \x. | ||
| dense1 = callable $ dense (Fin 2) (Fin 10) | ||
| x = dense1 x | ||
| x = trace_map (map relu) x | ||
| -- This value of `x` will be consumed twice. | ||
|  | ||
| dense3 = callable $ dense _ (Fin 25) | ||
| y = dense3 x | ||
| y = trace_map (map relu) y | ||
| dense5 = callable $ dense _ (Fin 10) | ||
| y = dense5 y | ||
| y = trace_map2 add y x | ||
|  | ||
| y = trace_map (map relu) y | ||
| dense7 = callable $ dense _ (Fin 2) | ||
| y = dense7 y | ||
| y | ||
|  | ||
| w_resnet_bad = init resnet_incorrect (newKey 1) | ||
| :t w_resnet_bad | ||
| :p forward resnet_incorrect w_resnet_bad (for _. 0.) | ||
|  | ||
|  | ||
| ' ResNet - nested sequential | ||
|  | ||
| ' In `resnet_block` below, `x` is used twice, but that is safe because it's the | ||
| ' input tracer with no network parameters yet. | ||
|  | ||
| ' We'd like to just use `_` for the return type to be `_`. | ||
| ' Unfortunately that currently leads to a "leaked type variable `a`" error. | ||
|  | ||
| def resnet_block (a:Type) | ||
| : Layer (a=>Float) (a=>Float) (((DenseParams _ a) & (DenseParams a _) & Unit) & Unit) | ||
| = trace \x. | ||
| dense1 = callable $ dense a (Fin 25) | ||
| y = dense1 x | ||
| y = trace_map (map relu) y | ||
| dense3 = callable $ dense _ a | ||
| y = dense3 y | ||
| y = trace_map2 add y x | ||
| y | ||
|  | ||
| resnet = trace \x. | ||
| dense1 = callable $ dense (Fin 2) (Fin 10) | ||
| x = dense1 x | ||
| x = trace_map (map relu) x | ||
| block3 = callable $ resnet_block _ | ||
| x = block3 x | ||
| x = trace_map (map relu) x | ||
| dense7 = callable $ dense _ (Fin 2) | ||
| x = dense7 x | ||
| x | ||
|  | ||
| w_resnet = init resnet (newKey 1) | ||
| :t w_resnet | ||
| :p forward resnet w_resnet (for _. 0.) | ||
|  | ||
|  | ||
| ' ## Training | ||
|  | ||
| ' Train a multiclass classifier with minibatch SGD | ||
| ' `minibatch * minibatches = batch` | ||
|  | ||
| def split (x: batch=>v) : minibatches=>minibatch=>v = | ||
| for b j. x.((ordinal (b,j))@batch) | ||
|  | ||
| def trainClass [VSpace p] (model: Layer a (b=>Float) p) | ||
| (x: batch=>a) | ||
| (y: batch=>b) | ||
| (epochs : Type) | ||
| (minibatch : Type) | ||
| (minibatches : Type) : | ||
| (epochs => p & epochs => Float) = | ||
| xs : minibatches => minibatch => a = split x | ||
| ys : minibatches => minibatch => b = split y | ||
| unzip $ withState (init model $ newKey 1) $ \params . | ||
| for _ : epochs. | ||
| loss = sum $ for b : minibatches. | ||
| (loss, gradfn) = vjp (\ params. | ||
| -sum for j. | ||
| logits = forward model params xs.b.j | ||
| (logsoftmax logits).(ys.b.j)) (get params) | ||
| gparams = gradfn 1.0 | ||
| params := (get params) - scaleVec (0.05 / (IToF 100)) gparams | ||
| loss | ||
| (get params, loss) | ||
|  | ||
| ' Sextant classification dataset. | ||
|  | ||
| [k1, k2] = splitKey $ newKey 1 | ||
| x1 : Fin 400 => Float = arb k1 | ||
| x2 : Fin 400 => Float = arb k2 | ||
| y = for i. case ((x1.i * x1.i * x1.i - 3. * x1.i * x2.i * x2.i) > 0.) of | ||
| True -> 1 | ||
| False -> 0 | ||
| xs = for i. [x1.i, x2.i] | ||
|  | ||
| import plot | ||
| :html showPlot $ xycPlot x1 x2 $ for i. IToF y.i | ||
|  | ||
| ' Train classifier on this dataset. | ||
|  | ||
| -- model = mlp | ||
| model = resnet | ||
|  | ||
| (all_params, losses) = trainClass model xs (for i. (y.i @ (Fin 2))) (Fin 3000) (Fin 50) (Fin 8) | ||
|  | ||
| ' Classification landscape, evolving over training time. Colour denotes softmax. | ||
|  | ||
| span = linspace (Fin 15) (-1.0) (1.0) | ||
| tests = for h : (Fin 100). for i . for j. | ||
| r = softmax $ forward model all_params.((ordinal h * 30)@_) [span.j, -span.i] | ||
| [r.(1@_), 0.5*r.(1@_), r.(0@_)] | ||
|  | ||
| :html imseqshow tests | ||
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might make sense to assert that
minibatches * minibatch = batch, because this function will happily throw away data.