(include "common-stalingrad")

;;; Representation for weights:
;;;  list with one element for each layer following the input;
;;;  each such list has one element for each unit in that layer;
;;;  which consists of a bias, followed by the weights for each
;;;  unit in the previous layer.

;;; Basic MLP

(define ((sum-activities activities) bias ws)
 ((reduce + bias) ((map2 *) ws activities)))

(define (sum-layer activities ws-layer)
 ((map (sum-activities activities)) ws-layer))

(define (sigmoid x) (/ 1 (+ (exp (- 0 x)) 1)))

(define ((forward-pass ws-layers) in)
 (if (null? ws-layers)
     ((forward-pass (cdr ws-layers))
      ((map sigmoid) (sum-layer in (first ws-layers))))))

(define ((error-on-dataset dataset) ws-layers)
 ((reduce + 0)
  ((map (lambda ((list in target))
	 (* 0.5
	    (magnitude-squared (v- ((forward-pass ws-layers) in) target)))))

;;; Optimization of the sort used with MLPs and backpropagation,
;;; often called "vanilla backprop"

;;; Scaled structure subtraction

(define (s-k* x k y)
 (cond ((real? x) (- x (* k y)))
       ((pair? x) (cons (s-k* (car x) k (car y))
			(s-k* (cdr x) k (cdr y))))
       (else x)))

;;; Vanilla gradient optimization.
;;; Gradient minimize f starting at w0 for n iterations via
;;; w(t+1) = w(t) - eta * grad_w f.
;;; returns the last f(w)

(define (weight-e ws l u w)
   (lambda (li)
    (let ((ll (list-ref ws li)))
     ((map-n (lambda (ui)
	      ((map-n (lambda (wi) (if (and (= li l) (= ui u) (= wi w)) 1 0)))
	       (length (list-ref ll ui)))))
      (length ll)))))
  (length ws)))

(define ((weight-gradient f) ws)
   (lambda (li)
    (let ((ll (list-ref ws li)))
       (lambda (ui)
	((map-n (lambda (wi)
		   ((j* f) (bundle ws (perturb (weight-e ws li ui wi))))))))
	 (length (list-ref ll ui)))))
      (length ll)))))
  (length ws)))

(define (vanilla f w0 n eta)
 (if (zero? n)
     (f w0)
     (vanilla f (s-k* w0 eta ((weight-gradient f) w0)) (- n 1) eta)))

;;; Allow compiler to grok structure of sexpr but not the numbers at
;;; the leaves

(define (map-real x)
 (cond ((real? x) (real x))
       ((pair? x) (cons (map-real (car x)) (map-real (cdr x))))
       (else x)))

;;; XOR network

(define (xor-ws0)
 (map-real '(((0 -0.284227 1.16054) (0 0.617194 1.30467))
	     ((0 -0.084395 0.648461)))))

(define (xor-data)
 '(((0 0) (0))
   ((0 1) (1))
   ((1 0) (1))
   ((1 1) (0))))

 (vanilla (error-on-dataset (xor-data)) (xor-ws0) (real 1000000) 0.3))

Generated by GNU enscript 1.6.4.