common-ocaml.ml

type ad_number = Dual_number of ad_number*ad_number*ad_number
  | Tape of ad_number*
	ad_number*
	(ad_number list)*
	(ad_number list)*
	(ad_number ref)*
	(ad_number ref)
  | Base of float

let epsilon = ref (Base 0.0)

let dual_number e x x' ( <= ) =
    if x'<=(Base 0.0) && (Base 0.0)<=x'
    then x
    else Dual_number (e, x, x')

let tape e x factors tapes =
    Tape (e, x, factors, tapes, ref (Base 0.0), ref (Base 0.0))

let lift_real_to_real f dfdx ( * ) ( <= ) =
    let rec self p =
      match p
      with
	(Dual_number (e, x, x')) -> dual_number e (self x) ((dfdx x)*x') ( <= )
      | (Tape (e, x, _, _, _, _)) -> tape e (self x) [dfdx x] [p]
      | Base x -> Base (f x)
    in self

let lift_real_cross_real_to_real f dfdx1 dfdx2 ( +. ) ( *. ) ( < ) ( <= ) =
    let rec self p1 p2 =
      match p1
      with (Dual_number (e1, x1, x1')) ->
	(match p2
	with (Dual_number (e2, x2, x2')) ->
	  if e1<e2
	  then dual_number e2 (self p1 x2) ((dfdx2 p1 x2)*.x2') ( <= )
	  else if e2<e1
	  then dual_number e1 (self x1 p2) ((dfdx1 x1 p2)*.x1') ( <= )
	  else dual_number
	      e1
	      (self x1 x2)
	      ((dfdx1 x1 x2)*.x1'+.(dfdx2 x1 x2)*.x2')
	      ( <= )
	| (Tape (e2, x2, _, _, _, _)) ->
	    if e1<e2
	    then tape e2 (self p1 x2) [dfdx2 p1 x2] [p2]
	    else dual_number e1 (self x1 p2) ((dfdx1 x1 p2)*.x1') ( <= )
	| (Base x2) ->
	    dual_number e1 (self x1 p2) ((dfdx1 x1 p2)*.x1') ( <= ))
      | (Tape (e1, x1, _, _, _, _)) ->
	  (match p2
	  with (Dual_number (e2, x2, x2')) ->
	    if e1<e2
	    then dual_number e2 (self p1 x2) ((dfdx2 p1 x2)*.x2') ( <= )
	    else tape e1 (self x1 p2) [dfdx1 x1 p2] [p1]
	  | (Tape (e2, x2, _, _, _, _)) ->
	      if e1<e2
	      then tape e2 (self p1 x2) [dfdx2 p1 x2] [p2]
	      else if e2<e1
	      then tape e1 (self x1 p2) [dfdx1 x1 p2] [p1]
	      else
		tape e1 (self x1 x2) [(dfdx1 x1 x2); (dfdx2 x1 x2)] [p1; p2]
	  | (Base x2) ->
	      tape e1 (self x1 p2) [dfdx1 x1 p2] [p1])
      | (Base x1) ->
	  (match p2
	  with (Dual_number (e2, x2, x2')) ->
	    dual_number e2 (self p1 x2) ((dfdx2 p1 x2)*.x2') ( <= )
	  | (Tape (e2, x2, _, _, _, _)) ->
	      tape e2 (self p1 x2) [dfdx2 p1 x2] [p2]
	  | (Base x2) -> Base (f x1 x2))
    in self

let lift_real_cross_real_to_bool f =
    let rec self p1 p2 =
      match p1
      with (Dual_number (_, x1, _)) ->
	(match p2
	with (Dual_number (_, x2, _)) -> self x1 x2
	| (Tape (_, x2, _, _, _, _)) ->  self x1 x2
	| (Base _) -> self x1 p2)
      | (Tape (_, x1, _, _, _, _)) ->
	  (match p2
	  with (Dual_number (_, x2, _)) -> self x1 x2
	  | (Tape (_, x2, _, _, _, _)) ->  self x1 x2
	  | (Base _) -> self x1 p2)
      | (Base x1) ->
	  (match p2
	  with (Dual_number (_, x2, _)) -> self p1 x2
	  | (Tape (_, x2, _, _, _, _)) -> self p1 x2
	  | (Base x2) -> f x1 x2)
    in self

let rec write_real p =
  match p with (Dual_number (_, x, _)) -> ((write_real x); p)
  | (Tape (_, x, _, _, _, _)) -> ((write_real x); p)
  | (Base x) -> ((Printf.printf "%.18g\n" x); p)

let (( +. ), ( -. ), ( *. ), ( /. ), sqrt, exp, ( < ), ( <= )) =
  let (plus, minus, times, divide, original_sqrt, original_exp, lt, ge) =
    (( +. ), ( -. ), ( *. ), ( /. ), sqrt, exp, ( < ), ( <= ))
  in let rec ( +. ) x1 x2 = (lift_real_cross_real_to_real
			       plus
			       (fun x1 x2 -> Base 1.0)
			       (fun x1 x2 -> Base 1.0)
			       ( +. )
			       ( *. )
			       ( < )
			       ( <= )
			       x1
			       x2)
  and ( -. ) x1 x2 = (lift_real_cross_real_to_real
			minus
			(fun x1 x2 -> Base 1.0)
			(fun x1 x2 -> Base (-1.0))
			( +. )
			( *. )
			( < )
			( <= )
			x1
			x2)
  and ( *. ) x1 x2 = (lift_real_cross_real_to_real
			times
			(fun x1 x2 -> x2)
			(fun x1 x2 -> x1)
			( +. )
			( *. )
			( < )
			( <= )
			x1
			x2)
  and ( /. ) x1 x2 = (lift_real_cross_real_to_real
			divide
			(fun x1 x2 -> (Base 1.0)/.x2)
			(fun x1 x2 -> (Base 0.0)-.x1/.(x2*.x2))
			( +. )
			( *. )
			( < )
			( <= )
			x1
			x2)
  and sqrt x = (lift_real_to_real
		  original_sqrt
		  (fun x -> (Base 1.0)/.((sqrt x)+.(sqrt x)))
		  ( *. )
		  ( <= )
		  x)
  and exp x = (lift_real_to_real
		 original_exp
		 exp
		 ( *. )
		 ( <= )
		 x)
  and ( < ) x1 x2 = lift_real_cross_real_to_bool lt x1 x2
  and ( <= ) x1 x2 = lift_real_cross_real_to_bool ge x1 x2
  in (( +. ), ( -. ), ( *. ), ( /. ), sqrt, exp, ( < ), ( <= ))

let derivative_F f x =
  (epsilon := !epsilon +. (Base 1.0);
   let y' =
     match (f (dual_number (!epsilon) x (Base 1.0) ( <= ) )) with
       Dual_number (e1, _, y') ->
	 if e1<(!epsilon) then Base 0.0 else y'
     | (Tape _) -> Base 0.0
     | (Base _) -> Base 0.0
   in epsilon := !epsilon -. (Base 1.0); y')

open List

let sqr x = x*.x

let map_n f n =
  let rec loop i = if i=n then [] else (f i)::(loop (i+1)) in loop 0

let vplus u v = map2 ( +. ) u v

let vminus u v = map2 ( -. ) u v

let ktimesv k = map (fun x -> k*.x)

let magnitude_squared x = fold_left ( +. ) (Base 0.0) (map sqr x)

let magnitude x = sqrt (magnitude_squared x)

let distance_squared u v = magnitude_squared (vminus v u)

let distance u v = sqrt (distance_squared u v)

let rec replace_ith (xh::xt) i xi =
    if i<=(Base 0.0) && (Base 0.0)<=i
    then xi::xt
    else xh::(replace_ith xt (i-.(Base 1.0)) xi)

let gradient_F f x =
  map_n
    (fun i -> derivative_F (fun xi -> f (replace_ith x (Base (float i)) xi)) (nth x i))
    (length x)

let rec determine_fanout (Tape (_, _, _, tapes, fanout, _)) =
    (fanout := !fanout+.(Base 1.0);
     if !fanout<=(Base 1.0) && (Base 1.0)<=(!fanout)
     (* for-each *)
     then (map determine_fanout tapes; ())
     else ())

let rec reverse_phase sensitivity1 (Tape (_, _, factors, tapes, fanout, sensitivity)) =
  (sensitivity := !sensitivity+.sensitivity1;
   fanout := !fanout-.(Base 1.0);
   if !fanout<=(Base 0.0) && (Base 0.0)<=(!fanout)
       (* for-each *)
   then ((map2
	    (fun factor tape -> reverse_phase (!sensitivity*.factor) tape)
	    factors tapes);
	 ())
   else ())

let gradient_R f x =
    (epsilon := !epsilon+.(Base 1.0);
     let x = map (fun xi -> (tape (!epsilon) xi [] [])) x in
     let y = f x in
     (match f x with (Dual_number _) -> ()
     | Tape (e1, _, _, _, _, _) ->
	 if e1<(!epsilon)
	 then ()
	 else (determine_fanout y; reverse_phase (Base 1.0) y)
     | Base _ -> ());
     epsilon := !epsilon-.(Base 1.0);
     map (fun (Tape (_, _, _, _, _, sensitivity)) -> !sensitivity) x)

let rec gradient_ascent_F f x0 n eta =
    if n<=(Base 0.0) && (Base 0.0)<=n
    then (x0, (f x0), (gradient_F f x0))
    else gradient_ascent_F
	     f (vplus x0 (ktimesv eta (gradient_F f x0))) (n-.(Base 1.0)) eta

let rec gradient_ascent_R f x0 n eta =
    if n<=(Base 0.0) && (Base 0.0)<=n
    then (x0, (f x0), (gradient_R f x0))
    else gradient_ascent_R
	     f (vplus x0 (ktimesv eta (gradient_R f x0))) (n-.(Base 1.0)) eta

let multivariate_argmin_F f x =
    let g = gradient_F f in
    let rec loop x fx gx eta i =
	       if (magnitude gx)<=(Base 1e-5)
	       then x
	       else if i<=(Base 10.0) && (Base 10.0)<=i
	       then loop x fx gx ((Base 2.0)*.eta) (Base 0.0)
	       else let x' = vminus x (ktimesv eta gx)
		    in if (distance x x')<=(Base 1e-5)
		       then x
		       else let fx' = (f x')
			    in if fx'<fx
			       then loop x' fx' (g x') eta (i+.(Base 1.0))
			       else loop x fx gx (eta/.(Base 2.0)) (Base 0.0)
       in loop x (f x) (g x) (Base 1e-5) (Base 0.0)

let rec multivariate_argmax_F f x =
    multivariate_argmin_F (fun x -> (Base 0.0)-.(f x)) x

let rec multivariate_max_F f x = f (multivariate_argmax_F f x)

let multivariate_argmin_R f x =
    let g = gradient_R f
    in let rec loop x fx gx eta i =
	       if (magnitude gx)<=(Base 1e-5)
	       then x
	       else if i<=(Base 10.0) && (Base 10.0)<=i
	       then loop x fx gx ((Base 2.0)*.eta) (Base 0.0)
	       else let x' = vminus x (ktimesv eta gx)
		    in if (distance x x')<=(Base 1e-5)
		       then x
		       else let fx' = (f x')
			    in if fx'<fx
			       then loop x' fx' (g x') eta (i+.(Base 1.0))
			       else loop x fx gx (eta/.(Base 2.0)) (Base 0.0)
       in loop x (f x) (g x) (Base 1e-5) (Base 0.0)

let rec multivariate_argmax_R f x =
  multivariate_argmin_R (fun x -> (Base 0.0)-.(f x)) x

let multivariate_max_R f x = f (multivariate_argmax_R f x)

Generated by GNU enscript 1.6.4.