(include "common-stalingrad") ;;; Representation for weights: ;;; list with one element for each layer following the input; ;;; each such list has one element for each unit in that layer; ;;; which consists of a bias, followed by the weights for each ;;; unit in the previous layer. ;;; Basic MLP (define ((sum-activities activities) bias ws) ((reduce + bias) ((map2 *) ws activities))) (define (sum-layer activities ws-layer) ((map (sum-activities activities)) ws-layer)) (define (sigmoid x) (/ 1 (+ (exp (- 0 x)) 1))) (define ((forward-pass ws-layers) in) (if (null? ws-layers) in ((forward-pass (cdr ws-layers)) ((map sigmoid) (sum-layer in (first ws-layers)))))) (define ((error-on-dataset dataset) ws-layers) ((reduce + 0) ((map (lambda ((list in target)) (* 0.5 (magnitude-squared (v- ((forward-pass ws-layers) in) target))))) dataset))) ;;; Optimization of the sort used with MLPs and backpropagation, ;;; often called "vanilla backprop" ;;; Scaled structure subtraction (define (s-k* x k y) (cond ((real? x) (- x (* k y))) ((pair? x) (cons (s-k* (car x) k (car y)) (s-k* (cdr x) k (cdr y)))) (else x))) ;;; Vanilla gradient optimization. ;;; Gradient minimize f starting at w0 for n iterations via ;;; w(t+1) = w(t) - eta * grad_w f. ;;; returns the last f(w) (define (weight-e ws l u w) ((map-n (lambda (li) (let ((ll (list-ref ws li))) ((map-n (lambda (ui) ((map-n (lambda (wi) (if (and (= li l) (= ui u) (= wi w)) 1 0))) (length (list-ref ll ui))))) (length ll))))) (length ws))) (define ((weight-gradient f) ws) ((map-n (lambda (li) (let ((ll (list-ref ws li))) ((map-n (lambda (ui) ((map-n (lambda (wi) (unperturb (tangent ((j* f) (bundle ws (perturb (weight-e ws li ui wi)))))))) (length (list-ref ll ui))))) (length ll))))) (length ws))) (define (vanilla f w0 n eta) (if (zero? n) (f w0) (vanilla f (s-k* w0 eta ((weight-gradient f) w0)) (- n 1) eta))) ;;; Allow compiler to grok structure of sexpr but not the numbers at ;;; the leaves (define (map-real x) (cond ((real? x) (real x)) ((pair? x) (cons (map-real (car x)) (map-real (cdr x)))) (else x))) ;;; XOR network (define (xor-ws0) (map-real '(((0 -0.284227 1.16054) (0 0.617194 1.30467)) ((0 -0.084395 0.648461))))) (define (xor-data) '(((0 0) (0)) ((0 1) (1)) ((1 0) (1)) ((1 1) (0)))) (write-real (vanilla (error-on-dataset (xor-data)) (xor-ws0) (real 1000000) 0.3))