common-chicken.sc

(declare (block)
         (standard-bindings)
         (extended-bindings)
         (not safe)
         (not interrupts-enabled))

(define *e* 0)

(define <_e <)

(define dual-number?
 (let ((pair? pair?)) (lambda (p) (and (pair? p) (eq? (car p) 'dual-number)))))

(define (dual-number e x x-prime)
 (if (dzero? x-prime) x (list 'dual-number e x x-prime)))

(define dual-number-epsilon cadr)

(define dual-number-primal caddr)

(define dual-number-perturbation cadddr)

(define tape?
 (let ((pair? pair?)) (lambda (p) (and (pair? p) (eq? (car p) 'tape)))))

(define (tape e primal factors tapes) (list 'tape e primal factors tapes 0 0))

(define tape-epsilon cadr)

(define tape-primal caddr)

(define tape-factors cadddr)

(define (tape-tapes tape) (cadddr (cdr tape)))

(define (tape-fanout tape) (cadddr (cddr tape)))

(define (set-tape-fanout! tape fanout) (set-car! (cdddr (cddr tape)) fanout))

(define (tape-sensitivity tape) (cadddr (cdddr tape)))

(define (set-tape-sensitivity! tape sensitivity)
 (set-car! (cdddr (cdddr tape)) sensitivity))

(define (lift-real->real f df/dx)
 (letrec ((self (lambda (p)
		 (cond ((dual-number? p)
			(dual-number (dual-number-epsilon p)
				     (self (dual-number-primal p))
				     (d* (df/dx (dual-number-primal p))
					 (dual-number-perturbation p))))
		       ((tape? p)
			(tape (tape-epsilon p)
			      (self (tape-primal p))
			      (list (df/dx (tape-primal p)))
			      (list p)))
		       (else (f p))))))
  self))

(define (lift-real*real->real f df/dx1 df/dx2)
 (letrec ((self
	   (lambda (p1 p2)
	    (cond
	     ((dual-number? p1)
	      (cond
	       ((dual-number? p2)
		(cond ((<_e (dual-number-epsilon p1)
			    (dual-number-epsilon p2))
		       (dual-number (dual-number-epsilon p2)
				    (self p1 (dual-number-primal p2))
				    (d* (df/dx2 p1 (dual-number-primal p2))
					(dual-number-perturbation p2))))
		      ((<_e (dual-number-epsilon p2)
			    (dual-number-epsilon p1))
		       (dual-number (dual-number-epsilon p1)
				    (self (dual-number-primal p1) p2)
				    (d* (df/dx1 (dual-number-primal p1) p2)
					(dual-number-perturbation p1))))
		      (else
		       (dual-number (dual-number-epsilon p1)
				    (self (dual-number-primal p1)
					  (dual-number-primal p2))
				    (d+ (d* (df/dx1 (dual-number-primal p1)
						    (dual-number-primal p2))
					    (dual-number-perturbation p1))
					(d* (df/dx2 (dual-number-primal p1)
						    (dual-number-primal p2))
					    (dual-number-perturbation p2)))))))
	       ((tape? p2)
		(if (<_e (dual-number-epsilon p1) (tape-epsilon p2))
		    (tape (tape-epsilon p2)
			  (self p1 (tape-primal p2))
			  (list (df/dx2 p1 (tape-primal p2)))
			  (list p2))
		    (dual-number (dual-number-epsilon p1)
				 (self (dual-number-primal p1) p2)
				 (d* (df/dx1 (dual-number-primal p1) p2)
				     (dual-number-perturbation p1)))))
	       (else (dual-number (dual-number-epsilon p1)
				  (self (dual-number-primal p1) p2)
				  (d* (df/dx1 (dual-number-primal p1) p2)
				      (dual-number-perturbation p1))))))
	     ((tape? p1)
	      (cond
	       ((dual-number? p2)
		(if (<_e (tape-epsilon p1) (dual-number-epsilon p2))
		    (dual-number (dual-number-epsilon p2)
				 (self p1 (dual-number-primal p2))
				 (d* (df/dx2 p1 (dual-number-primal p2))
				     (dual-number-perturbation p2)))
		    (tape (tape-epsilon p1)
			  (self (tape-primal p1) p2)
			  (list (df/dx1 (tape-primal p1) p2))
			  (list p1))))
	       ((tape? p2)
		(cond
		 ((<_e (tape-epsilon p1) (tape-epsilon p2))
		  (tape (tape-epsilon p2)
			(self p1 (tape-primal p2))
			(list (df/dx2 p1 (tape-primal p2)))
			(list p2)))
		 ((<_e (tape-epsilon p2) (tape-epsilon p1))
		  (tape (tape-epsilon p1)
			(self (tape-primal p1) p2)
			(list (df/dx1 (tape-primal p1) p2))
			(list p1)))
		 (else (tape (tape-epsilon p1)
			     (self (tape-primal p1) (tape-primal p2))
			     (list (df/dx1 (tape-primal p1) (tape-primal p2))
				   (df/dx2 (tape-primal p1) (tape-primal p2)))
			     (list p1 p2)))))
	       (else (tape (tape-epsilon p1)
			   (self (tape-primal p1) p2)
			   (list (df/dx1 (tape-primal p1) p2))
			   (list p1)))))
	     (else (cond ((dual-number? p2)
			  (dual-number (dual-number-epsilon p2)
				       (self p1 (dual-number-primal p2))
				       (d* (df/dx2 p1 (dual-number-primal p2))
					   (dual-number-perturbation p2))))
			 ((tape? p2)
			  (tape (tape-epsilon p2)
				(self p1 (tape-primal p2))
				(list (df/dx2 p1 (tape-primal p2)))
				(list p2)))
			 (else (f p1 p2))))))))
  self))

(define (primal* p)
 (cond ((dual-number? p) (primal* (dual-number-primal p)))
       ((tape? p) (primal* (tape-primal p)))
       (else p)))

(define (lift-real^n->boolean f) (lambda ps (apply f (map primal* ps))))

(define dpair?
 (let ((pair? pair?))
  (lambda (x) (and (pair? x) (not (dual-number? x)) (not (tape? x))))))

(define d+ (lift-real*real->real + (lambda (x1 x2) 1) (lambda (x1 x2) 1)))

(define d- (lift-real*real->real - (lambda (x1 x2) 1) (lambda (x1 x2) -1)))

(define d*
 (lift-real*real->real * (lambda (x1 x2) x2) (lambda (x1 x2) x1)))

(define d/
 (lift-real*real->real
  / (lambda (x1 x2) (d/ 1 x2)) (lambda (x1 x2) (d- 0 (d/ x1 (d* x2 x2))))))

(define dsqrt (lift-real->real sqrt (lambda (x) (d/ 1 (d* 2 (dsqrt x))))))

(define dexp (lift-real->real exp (lambda (x) (dexp x))))

(define dlog (lift-real->real log (lambda (x) (d/ 1 x))))

(define dsin (lift-real->real sin (lambda (x) (dcos x))))

(define dcos (lift-real->real cos (lambda (x) (d- 0 (dsin x)))))

(define datan (lift-real*real->real
	       atan
	       (lambda (x1 x2) (d/ (d- 0 x2) (d+ (d* x1 x1) (d* x2 x2))))
	       (lambda (x1 x2) (d/ x1 (d+ (d* x1 x1) (d* x2 x2))))))

(define d= (lift-real^n->boolean =))

(define d< (lift-real^n->boolean <))

(define d> (lift-real^n->boolean >))

(define d<= (lift-real^n->boolean <=))

(define d>= (lift-real^n->boolean >=))

(define dzero? (lift-real^n->boolean zero?))

(define dpositive? (lift-real^n->boolean positive?))

(define dnegative? (lift-real^n->boolean negative?))

(define dreal? (lift-real^n->boolean real?))

(define (derivative-F f)
 (lambda (x)
  (set! *e* (d+ *e* 1))
  (let* ((y (f (dual-number *e* x 1)))
	 (y-prime (if (or (not (dual-number? y))
			  (<_e (dual-number-epsilon y) *e*))
		      0
		      (dual-number-perturbation y))))
   (set! *e* (d- *e* 1))
   y-prime)))

(define (replace-ith x i xi)
 (if (dzero? i)
     (cons xi (cdr x))
     (cons (car x) (replace-ith (cdr x) (d- i 1) xi))))

(define (gradient-F f)
 (lambda (x)
  ((map-n
    (lambda (i)
     ((derivative-F (lambda (xi) (f (replace-ith x i xi)))) (list-ref x i))))
   (length x))))

(define (determine-fanout! tape)
 (set-tape-fanout! tape (d+ (tape-fanout tape) 1))
 (cond ((d= (tape-fanout tape) 1)
	(for-each determine-fanout! (tape-tapes tape)))))

(define (reverse-phase! sensitivity tape)
 (set-tape-sensitivity! tape (d+ (tape-sensitivity tape) sensitivity))
 (set-tape-fanout! tape (d- (tape-fanout tape) 1))
 (cond ((dzero? (tape-fanout tape))
	(let ((sensitivity (tape-sensitivity tape)))
	 (for-each (lambda (factor tape)
		    (reverse-phase! (d* sensitivity factor) tape))
		   (tape-factors tape)
		   (tape-tapes tape))))))

(define (gradient-R f)
 (lambda (x)
  (set! *e* (d+ *e* 1))
  (let* ((x (map (lambda (xi) (tape *e* xi '() '())) x)) (y (f x)))
   (cond ((and (tape? y) (not (<_e (tape-epsilon y) *e*)))
	  (determine-fanout! y)
	  (reverse-phase! 1 y)))
   (set! *e* (d- *e* 1))
   (map tape-sensitivity x))))

(define (derivative-R f)
 (lambda (x) (car ((gradient-R (lambda (x) (f (car x)))) (list x)))))

(define (write-real x)
 (cond ((dual-number? x) (write-real (dual-number-primal x)) x)
       ((tape? x) (write-real (tape-primal x)) x)
       (else (write x) (newline) x)))

(define (rest x) (cdr x))

(define (sqr x) (d* x x))

(define (map-n f)
 (lambda (n)
  (letrec ((loop (lambda (i) (if (d= i n) '() (cons (f i) (loop (d+ i 1)))))))
   (loop 0))))

(define (reduce f i)
 (lambda (l) (if (null? l) i (f (car l) ((reduce f i) (cdr l))))))

(define (map-reduce g i f l)
 (if (null? l) i (g (f (first l)) (map-reduce g i f (rest l)))))

(define (remove-if p l)
 (cond ((null? l) '())
       ((p (first l)) (remove-if p (rest l)))
       (else (cons (first l) (remove-if p (rest l))))))

(define (v+ u v) (map d+ u v))

(define (v- u v) (map d- u v))

(define (k*v k v) (map (lambda (x) (d* k x)) v))

(define (magnitude-squared x) ((reduce d+ 0.0) (map sqr x)))

(define (magnitude x) (dsqrt (magnitude-squared x)))

(define (distance-squared u v) (magnitude-squared (v- v u)))

(define (distance u v) (dsqrt (distance-squared u v)))

(define (gradient-ascent-F f x0 n eta)
 (if (dzero? n)
     (list x0 (f x0) ((gradient-F f) x0))
     (gradient-ascent-F
      f
      (map (lambda (xi gi) (d+ xi (d* eta gi))) x0 ((gradient-F f) x0))
      (d- n 1)
      eta)))

(define (gradient-ascent-R f x0 n eta)
 (if (dzero? n)
     (list x0 (f x0) ((gradient-R f) x0))
     (gradient-ascent-R
      f
      (map (lambda (xi gi) (d+ xi (d* eta gi))) x0 ((gradient-R f) x0))
      (d- n 1)
      eta)))

(define (multivariate-argmin-F f x)
 (let ((g (gradient-F f)))
  (letrec ((loop
	    (lambda (x fx gx eta i)
	     (cond ((d<= (magnitude gx) 1e-5) x)
		   ((d= i 10) (loop x fx gx (d* 2.0 eta) 0))
		   (else
		    (let ((x-prime (v- x (k*v eta gx))))
		     (if (d<= (distance x x-prime) 1e-5)
			 x
			 (let ((fx-prime (f x-prime)))
			  (if (d< fx-prime fx)
			      (loop x-prime fx-prime (g x-prime) eta (d+ i 1))
			      (loop x fx gx (d/ eta 2.0) 0))))))))))
   (loop x (f x) (g x) 1e-5 0))))

(define (multivariate-argmax-F f x)
 (multivariate-argmin-F (lambda (x) (d- 0.0 (f x))) x))

(define (multivariate-max-F f x) (f (multivariate-argmax-F f x)))

(define (multivariate-argmin-R f x)
 (let ((g (gradient-R f)))
  (letrec ((loop
	    (lambda (x fx gx eta i)
	     (cond ((d<= (magnitude gx) 1e-5) x)
		   ((d= i 10) (loop x fx gx (d* 2.0 eta) 0))
		   (else
		    (let ((x-prime (v- x (k*v eta gx))))
		     (if (d<= (distance x x-prime) 1e-5)
			 x
			 (let ((fx-prime (f x-prime)))
			  (if (d< fx-prime fx)
			      (loop x-prime fx-prime (g x-prime) eta (d+ i 1))
			      (loop x fx gx (d/ eta 2.0) 0))))))))))
   (loop x (f x) (g x) 1e-5 0))))

(define (multivariate-argmax-R f x)
 (multivariate-argmin-R (lambda (x) (d- 0.0 (f x))) x))

(define (multivariate-max-R f x) (f (multivariate-argmax-R f x)))

Generated by GNU enscript 1.6.4.