# common-mzc.ss

```(module common-mzc mzscheme

(provide *e* set-e! <_e dual-number? dual-number dual-number-epsilon
dual-number-primal dual-number-perturbation tape? tape tape-epsilon
tape-primal tape-factors tape-tapes tape-fanout set-tape-fanout!
tape-sensitivity set-tape-sensitivity! lift-real->real
lift-real*real->real primal* lift-real^n->boolean dpair? d+ d- d* d/
dsqrt dexp dlog dsin dcos datan d= d< d> d<= d>= dzero? dpositive?
dnegative? dreal? derivative-F replace-ith gradient-F
determine-fanout! reverse-phase! gradient-R derivative-R write-real
first second third fourth rest sqr map-n reduce map-reduce remove-if
v+ v- k*v magnitude-squared my-magnitude distance-squared distance
multivariate-argmax-F multivariate-max-F multivariate-argmin-R
multivariate-argmax-R multivariate-max-R)

(define *e* 0)

(define (set-e! e) (set! *e* e))

(define <_e <)

(define dual-number?
(let ((pair? pair?)) (lambda (p) (and (pair? p) (eq? (car p) 'dual-number)))))

(define (dual-number e x x-prime)
(if (dzero? x-prime) x (list 'dual-number e x x-prime)))

(define tape?
(let ((pair? pair?)) (lambda (p) (and (pair? p) (eq? (car p) 'tape)))))

(define (tape e primal factors tapes) (list 'tape e primal factors tapes 0 0))

(define (tape-tapes tape) (cadddr (cdr tape)))

(define (tape-fanout tape) (cadddr (cddr tape)))

(define (set-tape-fanout! tape fanout) (set-car! (cdddr (cddr tape)) fanout))

(define (tape-sensitivity tape) (cadddr (cdddr tape)))

(define (set-tape-sensitivity! tape sensitivity)
(set-car! (cdddr (cdddr tape)) sensitivity))

(define (lift-real->real f df/dx)
(letrec ((self (lambda (p)
(cond ((dual-number? p)
(dual-number (dual-number-epsilon p)
(self (dual-number-primal p))
(d* (df/dx (dual-number-primal p))
(dual-number-perturbation p))))
((tape? p)
(tape (tape-epsilon p)
(self (tape-primal p))
(list (df/dx (tape-primal p)))
(list p)))
(else (f p))))))
self))

(define (lift-real*real->real f df/dx1 df/dx2)
(letrec ((self
(lambda (p1 p2)
(cond
((dual-number? p1)
(cond
((dual-number? p2)
(cond ((<_e (dual-number-epsilon p1)
(dual-number-epsilon p2))
(dual-number (dual-number-epsilon p2)
(self p1 (dual-number-primal p2))
(d* (df/dx2 p1 (dual-number-primal p2))
(dual-number-perturbation p2))))
((<_e (dual-number-epsilon p2)
(dual-number-epsilon p1))
(dual-number (dual-number-epsilon p1)
(self (dual-number-primal p1) p2)
(d* (df/dx1 (dual-number-primal p1) p2)
(dual-number-perturbation p1))))
(else
(dual-number (dual-number-epsilon p1)
(self (dual-number-primal p1)
(dual-number-primal p2))
(d+ (d* (df/dx1 (dual-number-primal p1)
(dual-number-primal p2))
(dual-number-perturbation p1))
(d* (df/dx2 (dual-number-primal p1)
(dual-number-primal p2))
(dual-number-perturbation p2)))))))
((tape? p2)
(if (<_e (dual-number-epsilon p1) (tape-epsilon p2))
(tape (tape-epsilon p2)
(self p1 (tape-primal p2))
(list (df/dx2 p1 (tape-primal p2)))
(list p2))
(dual-number (dual-number-epsilon p1)
(self (dual-number-primal p1) p2)
(d* (df/dx1 (dual-number-primal p1) p2)
(dual-number-perturbation p1)))))
(else (dual-number (dual-number-epsilon p1)
(self (dual-number-primal p1) p2)
(d* (df/dx1 (dual-number-primal p1) p2)
(dual-number-perturbation p1))))))
((tape? p1)
(cond
((dual-number? p2)
(if (<_e (tape-epsilon p1) (dual-number-epsilon p2))
(dual-number (dual-number-epsilon p2)
(self p1 (dual-number-primal p2))
(d* (df/dx2 p1 (dual-number-primal p2))
(dual-number-perturbation p2)))
(tape (tape-epsilon p1)
(self (tape-primal p1) p2)
(list (df/dx1 (tape-primal p1) p2))
(list p1))))
((tape? p2)
(cond
((<_e (tape-epsilon p1) (tape-epsilon p2))
(tape (tape-epsilon p2)
(self p1 (tape-primal p2))
(list (df/dx2 p1 (tape-primal p2)))
(list p2)))
((<_e (tape-epsilon p2) (tape-epsilon p1))
(tape (tape-epsilon p1)
(self (tape-primal p1) p2)
(list (df/dx1 (tape-primal p1) p2))
(list p1)))
(else (tape (tape-epsilon p1)
(self (tape-primal p1) (tape-primal p2))
(list (df/dx1 (tape-primal p1) (tape-primal p2))
(df/dx2 (tape-primal p1) (tape-primal p2)))
(list p1 p2)))))
(else (tape (tape-epsilon p1)
(self (tape-primal p1) p2)
(list (df/dx1 (tape-primal p1) p2))
(list p1)))))
(else (cond ((dual-number? p2)
(dual-number (dual-number-epsilon p2)
(self p1 (dual-number-primal p2))
(d* (df/dx2 p1 (dual-number-primal p2))
(dual-number-perturbation p2))))
((tape? p2)
(tape (tape-epsilon p2)
(self p1 (tape-primal p2))
(list (df/dx2 p1 (tape-primal p2)))
(list p2)))
(else (f p1 p2))))))))
self))

(define (primal* p)
(cond ((dual-number? p) (primal* (dual-number-primal p)))
((tape? p) (primal* (tape-primal p)))
(else p)))

(define (lift-real^n->boolean f) (lambda ps (apply f (map primal* ps))))

(define dpair?
(let ((pair? pair?))
(lambda (x) (and (pair? x) (not (dual-number? x)) (not (tape? x))))))

(define d+ (lift-real*real->real + (lambda (x1 x2) 1) (lambda (x1 x2) 1)))

(define d- (lift-real*real->real - (lambda (x1 x2) 1) (lambda (x1 x2) -1)))

(define d*
(lift-real*real->real * (lambda (x1 x2) x2) (lambda (x1 x2) x1)))

(define d/
(lift-real*real->real
/ (lambda (x1 x2) (d/ 1 x2)) (lambda (x1 x2) (d- 0 (d/ x1 (d* x2 x2))))))

(define dsqrt (lift-real->real sqrt (lambda (x) (d/ 1 (d* 2 (dsqrt x))))))

(define dexp (lift-real->real exp (lambda (x) (dexp x))))

(define dlog (lift-real->real log (lambda (x) (d/ 1 x))))

(define dsin (lift-real->real sin (lambda (x) (dcos x))))

(define dcos (lift-real->real cos (lambda (x) (d- 0 (dsin x)))))

(define datan (lift-real*real->real
atan
(lambda (x1 x2) (d/ (d- 0 x2) (d+ (d* x1 x1) (d* x2 x2))))
(lambda (x1 x2) (d/ x1 (d+ (d* x1 x1) (d* x2 x2))))))

(define d= (lift-real^n->boolean =))

(define d< (lift-real^n->boolean <))

(define d> (lift-real^n->boolean >))

(define d<= (lift-real^n->boolean <=))

(define d>= (lift-real^n->boolean >=))

(define dzero? (lift-real^n->boolean zero?))

(define dpositive? (lift-real^n->boolean positive?))

(define dnegative? (lift-real^n->boolean negative?))

(define dreal? (lift-real^n->boolean real?))

(define (derivative-F f)
(lambda (x)
(set! *e* (d+ *e* 1))
(let* ((y (f (dual-number *e* x 1)))
(y-prime (if (or (not (dual-number? y))
(<_e (dual-number-epsilon y) *e*))
0
(dual-number-perturbation y))))
(set! *e* (d- *e* 1))
y-prime)))

(define (replace-ith x i xi)
(if (dzero? i)
(cons xi (cdr x))
(cons (car x) (replace-ith (cdr x) (d- i 1) xi))))

(lambda (x)
((map-n
(lambda (i)
((derivative-F (lambda (xi) (f (replace-ith x i xi)))) (list-ref x i))))
(length x))))

(define (determine-fanout! tape)
(set-tape-fanout! tape (d+ (tape-fanout tape) 1))
(cond ((d= (tape-fanout tape) 1)
(for-each determine-fanout! (tape-tapes tape)))))

(define (reverse-phase! sensitivity tape)
(set-tape-sensitivity! tape (d+ (tape-sensitivity tape) sensitivity))
(set-tape-fanout! tape (d- (tape-fanout tape) 1))
(cond ((dzero? (tape-fanout tape))
(let ((sensitivity (tape-sensitivity tape)))
(for-each (lambda (factor tape)
(reverse-phase! (d* sensitivity factor) tape))
(tape-factors tape)
(tape-tapes tape))))))

(lambda (x)
(set! *e* (d+ *e* 1))
(let* ((x (map (lambda (xi) (tape *e* xi '() '())) x)) (y (f x)))
(cond ((and (tape? y) (not (<_e (tape-epsilon y) *e*)))
(determine-fanout! y)
(reverse-phase! 1 y)))
(set! *e* (d- *e* 1))
(map tape-sensitivity x))))

(define (derivative-R f)
(lambda (x) (car ((gradient-R (lambda (x) (f (car x)))) (list x)))))

(define (write-real x)
(cond ((dual-number? x) (write-real (dual-number-primal x)) x)
((tape? x) (write-real (tape-primal x)) x)
(else (write x) (newline) x)))

(define (first x) (car x))

(define (second x) (car (cdr x)))

(define (third x) (car (cdr (cdr x))))

(define (fourth x) (car (cdr (cdr (cdr x)))))

(define (rest x) (cdr x))

(define (sqr x) (d* x x))

(define (map-n f)
(lambda (n)
(letrec ((loop (lambda (i) (if (d= i n) '() (cons (f i) (loop (d+ i 1)))))))
(loop 0))))

(define (reduce f i)
(lambda (l) (if (null? l) i (f (car l) ((reduce f i) (cdr l))))))

(define (map-reduce g i f l)
(if (null? l) i (g (f (first l)) (map-reduce g i f (rest l)))))

(define (remove-if p l)
(cond ((null? l) '())
((p (first l)) (remove-if p (rest l)))
(else (cons (first l) (remove-if p (rest l))))))

(define (v+ u v) (map d+ u v))

(define (v- u v) (map d- u v))

(define (k*v k v) (map (lambda (x) (d* k x)) v))

(define (magnitude-squared x) ((reduce d+ 0.0) (map sqr x)))

(define (my-magnitude x) (dsqrt (magnitude-squared x)))

(define (distance-squared u v) (magnitude-squared (v- v u)))

(define (distance u v) (dsqrt (distance-squared u v)))

(define (gradient-ascent-F f x0 n eta)
(if (dzero? n)
(list x0 (f x0) ((gradient-F f) x0))
f
(map (lambda (xi gi) (d+ xi (d* eta gi))) x0 ((gradient-F f) x0))
(d- n 1)
eta)))

(define (gradient-ascent-R f x0 n eta)
(if (dzero? n)
(list x0 (f x0) ((gradient-R f) x0))
f
(map (lambda (xi gi) (d+ xi (d* eta gi))) x0 ((gradient-R f) x0))
(d- n 1)
eta)))

(define (multivariate-argmin-F f x)
(let ((g (gradient-F f)))
(letrec ((loop
(lambda (x fx gx eta i)
(cond ((d<= (my-magnitude gx) 1e-5) x)
((d= i 10) (loop x fx gx (d* 2.0 eta) 0))
(else
(let ((x-prime (v- x (k*v eta gx))))
(if (d<= (distance x x-prime) 1e-5)
x
(let ((fx-prime (f x-prime)))
(if (d< fx-prime fx)
(loop x-prime fx-prime (g x-prime) eta (d+ i 1))
(loop x fx gx (d/ eta 2.0) 0))))))))))
(loop x (f x) (g x) 1e-5 0))))

(define (multivariate-argmax-F f x)
(multivariate-argmin-F (lambda (x) (d- 0.0 (f x))) x))

(define (multivariate-max-F f x) (f (multivariate-argmax-F f x)))

(define (multivariate-argmin-R f x)
(let ((g (gradient-R f)))
(letrec ((loop
(lambda (x fx gx eta i)
(cond ((d<= (my-magnitude gx) 1e-5) x)
((d= i 10) (loop x fx gx (d* 2.0 eta) 0))
(else
(let ((x-prime (v- x (k*v eta gx))))
(if (d<= (distance x x-prime) 1e-5)
x
(let ((fx-prime (f x-prime)))
(if (d< fx-prime fx)
(loop x-prime fx-prime (g x-prime) eta (d+ i 1))
(loop x fx gx (d/ eta 2.0) 0))))))))))
(loop x (f x) (g x) 1e-5 0))))

(define (multivariate-argmax-R f x)
(multivariate-argmin-R (lambda (x) (d- 0.0 (f x))) x))

(define (multivariate-max-R f x) (f (multivariate-argmax-R f x))))
```

Generated by GNU enscript 1.6.4.