common-mlton.sml

datatype ad_number = Dual_number of ad_number*ad_number*ad_number
		   | Tape of ad_number*
			     ad_number*
			     (ad_number list)*
			     (ad_number list)*
			     (ad_number ref)*
			     (ad_number ref)
		   | Base of real

val epsilon = ref (Base 0.0)

fun dual_number e x x' op <= =
    if x'<=(Base 0.0) andalso (Base 0.0)<=x'
    then x
    else Dual_number (e, x, x')

fun tape e x factors tapes =
    Tape (e, x, factors, tapes, ref (Base 0.0), ref (Base 0.0))

fun lift_real_to_real f dfdx op * op <= =
    let fun self (Dual_number (e, x, x')) =
	    dual_number e (self x) ((dfdx x)*x') op <=
          | self (p as (Tape (e, x, _, _, _, _))) =
	    tape e (self x) [dfdx x] [p]
	  | self (Base x) = Base (f x)
    in self end

fun lift_real_cross_real_to_real f dfdx1 dfdx2 op + op * op < op <= =
    let fun self ((p1 as (Dual_number (e1, x1, x1'))),
		  (p2 as (Dual_number (e2, x2, x2')))) =
            if e1<e2
            then dual_number e2 (self (p1, x2)) ((dfdx2 p1 x2)*x2') op <=
	    else if e2<e1
	    then dual_number e1 (self (x1, p2)) ((dfdx1 x1 p2)*x1') op <=
	    else dual_number
		     e1
		     (self (x1, x2))
		     ((dfdx1 x1 x2)*x1'+(dfdx2 x1 x2)*x2')
		     op <=
	  | self ((p1 as (Dual_number (e1, x1, x1'))),
		  (p2 as (Tape (e2, x2, _, _, _, _)))) =
	    if e1<e2
	    then tape e2 (self (p1, x2)) [dfdx2 p1 x2] [p2]
	    else dual_number e1 (self (x1, p2)) ((dfdx1 x1 p2)*x1') op <=
	  | self ((Dual_number (e1, x1, x1')), (p2 as (Base x2))) =
	    dual_number e1 (self (x1, p2)) ((dfdx1 x1 p2)*x1') op <=
	  | self ((p1 as (Tape (e1, x1, _, _, _, _))),
		  (p2 as (Dual_number (e2, x2, x2')))) =
	    if e1<e2
	    then dual_number e2 (self (p1, x2)) ((dfdx2 p1 x2)*x2') op <=
	    else tape e1 (self (x1, p2)) [dfdx1 x1 p2] [p1]
	  | self ((p1 as (Tape (e1, x1, _, _, _, _))),
		  (p2 as (Tape (e2, x2, _, _, _, _)))) =
	    if e1<e2
	    then tape e2 (self (p1, x2)) [dfdx2 p1 x2] [p2]
	    else if e2<e1
	    then tape e1 (self (x1, p2)) [dfdx1 x1 p2] [p1]
	    else
		tape e1 (self (x1, x2)) [(dfdx1 x1 x2), (dfdx2 x1 x2)] [p1, p2]
	  | self ((p1 as (Tape (e1, x1, _, _, _, _))), (p2 as (Base x2))) =
	    tape e1 (self (x1, p2)) [dfdx1 x1 p2] [p1]
	  | self ((p1 as (Base x1)), (Dual_number (e2, x2, x2'))) =
	    dual_number e2 (self (p1, x2)) ((dfdx2 p1 x2)*x2') op <=
	  | self ((p1 as (Base x1)), (p2 as (Tape (e2, x2, _, _, _, _)))) =
	    tape e2 (self (p1, x2)) [dfdx2 p1 x2] [p2]
	  | self ((Base x1), (Base x2)) = Base (f (x1, x2))
    in self end

fun lift_real_cross_real_to_bool f =
    let fun self ((Dual_number (_, x1, _)), (Dual_number (_, x2, _))) =
	    self (x1, x2)
	  | self ((Dual_number (_, x1, _)), (Tape (_, x2, _, _, _, _))) =
	    self (x1, x2)
	  | self ((Dual_number (_, x1, _)), (p2 as (Base _))) = self (x1, p2)
	  | self ((Tape (_, x1, _, _, _, _)), (Dual_number (_, x2, _))) =
	    self (x1, x2)
	  | self ((Tape (_, x1, _, _, _, _)), (Tape (_, x2, _, _, _, _))) =
	    self (x1, x2)
	  | self ((Tape (_, x1, _, _, _, _)), (p2 as (Base _))) = self (x1, p2)
	  | self ((p1 as (Base _)), (Dual_number (_, x2, _))) = self (p1, x2)
	  | self ((p1 as (Base _)), (Tape (_, x2, _, _, _, _))) = self (p1, x2)
	  | self ((Base x1), (Base x2)) = f (x1, x2)
    in self end

open Real.Math

val (op +, op -, op *, op /, sqrt, exp, op <, op <=) =
    let val plus = op +
	val minus = op -
	val times = op *
	val divide = op /
	val original_sqrt = sqrt
	val original_exp = exp
        val lt = op <
        val le = op <=
    in let fun op + (x1, x2) =
	       lift_real_cross_real_to_real
		   plus
		   (fn _ => fn _ => Base 1.0)
		   (fn _ => fn _ => Base 1.0)
		   op +
		   op *
                   op <
		   op <=
		   (x1, x2)
	   and op - (x1, x2) =
	       lift_real_cross_real_to_real
		   minus
		   (fn _ => fn _ => Base 1.0)
		   (fn _ => fn _ => Base ~1.0)
		   op +
		   op *
                   op <
		   op <=
		   (x1, x2)
	   and op * (x1, x2) =
	       lift_real_cross_real_to_real
		   times
		   (fn _ => fn x2 => x2)
		   (fn x1 => fn _ => x1)
		   op +
		   op *
                   op <
		   op <=
		   (x1, x2)
	   and op / (x1, x2) =
	       lift_real_cross_real_to_real
		   divide
		   (fn _ => fn x2 => (Base 1.0)/x2)
		   (fn x1 => fn x2 => (Base 0.0)-x1/(x2*x2))
		   op +
		   op *
                   op <
		   op <=
		   (x1, x2)
	   and sqrt x =
	       lift_real_to_real
		   original_sqrt
		   (fn x => (Base 1.0)/((Base 2.0)*(sqrt x)))
		   op *
		   op <=
		   x
	   and exp x =
	       lift_real_to_real
		   original_exp
		   exp
		   op *
		   op <=
		   x
	   and op < (x1, x2) = lift_real_cross_real_to_bool lt (x1, x2)
	   and op <= (x1, x2) = lift_real_cross_real_to_bool le (x1, x2)
       in (op +, op -, op *, op /, sqrt, exp, op <, op <=) end end

fun derivative_F f x =
    (epsilon := !epsilon+(Base 1.0);
     let val y' =
	     case (f (dual_number (!epsilon) x (Base 1.0) op <=)) of
		 Dual_number (e1, _, y') =>
		 if e1<(!epsilon) then Base 0.0 else y'
	       | Tape t => Base 0.0
	       | Base x => Base 0.0
     in epsilon := !epsilon-(Base 1.0); y' end)

fun replace_ith (xh::xt) i xi =
    if i<=(Base 0.0) andalso (Base 0.0)<=i
    then xi::xt
    else xh::(replace_ith xt (i-(Base 1.0)) xi)

fun gradient_F f x =
    List.tabulate
	((length x),
	 (fn i => derivative_F (fn xi => f (replace_ith x (Base (real i)) xi))
			       (List.nth (x, i))))

fun determine_fanout (Tape (_, _, _, tapes, fanout, _)) =
    (fanout := !fanout+(Base 1.0);
     if !fanout<=(Base 1.0) andalso (Base 1.0)<=(!fanout)
     (* for-each *)
     then (map determine_fanout tapes; ())
     else ())

fun reverse_phase sensitivity1
		  (Tape (_, _, factors, tapes, fanout, sensitivity)) =
    (sensitivity := !sensitivity+sensitivity1;
     fanout := !fanout-(Base 1.0);
     if !fanout<=(Base 0.0) andalso (Base 0.0)<=(!fanout)
     (* for-each *)
     then (ListPair.map
               (fn (factor, tape) => reverse_phase (!sensitivity*factor) tape)
               (factors, tapes);
	   ())
     else ())

fun gradient_R f x =
    (epsilon := !epsilon+(Base 1.0);
     let val x = map (fn xi => (tape (!epsilon) xi [] [])) x
     in (case f x of
             Dual_number _ => ()
	   | y as Tape (e1, _, _, _, _, _) =>
             if e1<(!epsilon)
             then ()
	     else (determine_fanout y; reverse_phase (Base 1.0) y)
	   | Base _ => ());
        epsilon := !epsilon-(Base 1.0);
        map (fn (Tape (_, _, _, _, _, sensitivity)) => !sensitivity) x end)

fun derivative_R f x = let val [y'] = gradient_R (fn [x] => f x) [x] in y' end

fun write_real (p as (Dual_number (_, x, _))) = ((write_real x); p)
  | write_real (p as (Tape (_, x, _, _, _, _))) = ((write_real x); p)
  | write_real (p as (Base x)) = ((print (Real.toString x)); (print "\n"); p)

fun sqr x = x*x

fun vplus u v = ListPair.map op + (u, v)

fun vminus u v = ListPair.map op - (u, v)

fun ktimesv k = map (fn x => k*x)

fun magnitude_squared x = foldl op + (Base 0.0) (map sqr x)

fun magnitude x = sqrt (magnitude_squared x)

fun distance_squared u v = magnitude_squared (vminus v u)

fun distance u v = sqrt (distance_squared u v)

fun gradient_ascent_F f x0 n eta =
    if n<=(Base 0.0) andalso (Base 0.0)<=n
    then (x0, (f x0), (gradient_F f x0))
    else gradient_ascent_F
	     f (vplus x0 (ktimesv eta (gradient_F f x0))) (n-(Base 1.0)) eta

fun gradient_ascent_R f x0 n eta =
    if n<=(Base 0.0) andalso (Base 0.0)<=n
    then (x0, (f x0), (gradient_R f x0))
    else gradient_ascent_R
	     f (vplus x0 (ktimesv eta (gradient_R f x0))) (n-(Base 1.0)) eta

fun multivariate_argmin_F f x =
    let val g = gradient_F f
    in let fun loop x fx gx eta i =
	       if (magnitude gx)<=(Base 1e~5)
	       then x
	       else if i<=(Base 10.0) andalso (Base 10.0)<=i
	       then loop x fx gx ((Base 2.0)*eta) (Base 0.0)
	       else let val x' = vminus x (ktimesv eta gx)
		    in if (distance x x')<=(Base 1e~5)
		       then x
		       else let val fx' = (f x')
			    in if fx'<fx
			       then loop x' fx' (g x') eta (i+(Base 1.0))
			       else loop x fx gx (eta/(Base 2.0)) (Base 0.0)
			    end end
       in loop x (f x) (g x) (Base 1e~5) (Base 0.0) end end

fun multivariate_argmax_F f x =
    multivariate_argmin_F (fn x => (Base 0.0)-(f x)) x

fun multivariate_max_F f x = f (multivariate_argmax_F f x)

fun multivariate_argmin_R f x =
    let val g = gradient_R f
    in let fun loop x fx gx eta i =
	       if (magnitude gx)<=(Base 1e~5)
	       then x
	       else if i<=(Base 10.0) andalso (Base 10.0)<=i
	       then loop x fx gx ((Base 2.0)*eta) (Base 0.0)
	       else let val x' = vminus x (ktimesv eta gx)
		    in if (distance x x')<=(Base 1e~5)
		       then x
		       else let val fx' = (f x')
			    in if fx'<fx
			       then loop x' fx' (g x') eta (i+(Base 1.0))
			       else loop x fx gx (eta/(Base 2.0)) (Base 0.0)
			    end end
       in loop x (f x) (g x) (Base 1e~5) (Base 0.0) end end

fun multivariate_argmax_R f x =
    multivariate_argmin_R (fn x => (Base 0.0)-(f x)) x

fun multivariate_max_R f x = f (multivariate_argmax_R f x)

Generated by GNU enscript 1.6.4.