type ad_number = Dual_number of ad_number*ad_number*ad_number | Tape of ad_number* ad_number* (ad_number list)* (ad_number list)* (ad_number ref)* (ad_number ref) | Base of float let epsilon = ref (Base 0.0) let dual_number e x x' ( <= ) = if x'<=(Base 0.0) && (Base 0.0)<=x' then x else Dual_number (e, x, x') let tape e x factors tapes = Tape (e, x, factors, tapes, ref (Base 0.0), ref (Base 0.0)) let lift_real_to_real f dfdx ( * ) ( <= ) = let rec self p = match p with (Dual_number (e, x, x')) -> dual_number e (self x) ((dfdx x)*x') ( <= ) | (Tape (e, x, _, _, _, _)) -> tape e (self x) [dfdx x] [p] | Base x -> Base (f x) in self let lift_real_cross_real_to_real f dfdx1 dfdx2 ( +. ) ( *. ) ( < ) ( <= ) = let rec self p1 p2 = match p1 with (Dual_number (e1, x1, x1')) -> (match p2 with (Dual_number (e2, x2, x2')) -> if e1 if e1 dual_number e1 (self x1 p2) ((dfdx1 x1 p2)*.x1') ( <= )) | (Tape (e1, x1, _, _, _, _)) -> (match p2 with (Dual_number (e2, x2, x2')) -> if e1 if e1 tape e1 (self x1 p2) [dfdx1 x1 p2] [p1]) | (Base x1) -> (match p2 with (Dual_number (e2, x2, x2')) -> dual_number e2 (self p1 x2) ((dfdx2 p1 x2)*.x2') ( <= ) | (Tape (e2, x2, _, _, _, _)) -> tape e2 (self p1 x2) [dfdx2 p1 x2] [p2] | (Base x2) -> Base (f x1 x2)) in self let lift_real_cross_real_to_bool f = let rec self p1 p2 = match p1 with (Dual_number (_, x1, _)) -> (match p2 with (Dual_number (_, x2, _)) -> self x1 x2 | (Tape (_, x2, _, _, _, _)) -> self x1 x2 | (Base _) -> self x1 p2) | (Tape (_, x1, _, _, _, _)) -> (match p2 with (Dual_number (_, x2, _)) -> self x1 x2 | (Tape (_, x2, _, _, _, _)) -> self x1 x2 | (Base _) -> self x1 p2) | (Base x1) -> (match p2 with (Dual_number (_, x2, _)) -> self p1 x2 | (Tape (_, x2, _, _, _, _)) -> self p1 x2 | (Base x2) -> f x1 x2) in self let rec write_real p = match p with (Dual_number (_, x, _)) -> ((write_real x); p) | (Tape (_, x, _, _, _, _)) -> ((write_real x); p) | (Base x) -> ((Printf.printf "%.18g\n" x); p) let (( +. ), ( -. ), ( *. ), ( /. ), sqrt, exp, ( < ), ( <= )) = let (plus, minus, times, divide, original_sqrt, original_exp, lt, ge) = (( +. ), ( -. ), ( *. ), ( /. ), sqrt, exp, ( < ), ( <= )) in let rec ( +. ) x1 x2 = (lift_real_cross_real_to_real plus (fun x1 x2 -> Base 1.0) (fun x1 x2 -> Base 1.0) ( +. ) ( *. ) ( < ) ( <= ) x1 x2) and ( -. ) x1 x2 = (lift_real_cross_real_to_real minus (fun x1 x2 -> Base 1.0) (fun x1 x2 -> Base (-1.0)) ( +. ) ( *. ) ( < ) ( <= ) x1 x2) and ( *. ) x1 x2 = (lift_real_cross_real_to_real times (fun x1 x2 -> x2) (fun x1 x2 -> x1) ( +. ) ( *. ) ( < ) ( <= ) x1 x2) and ( /. ) x1 x2 = (lift_real_cross_real_to_real divide (fun x1 x2 -> (Base 1.0)/.x2) (fun x1 x2 -> (Base 0.0)-.x1/.(x2*.x2)) ( +. ) ( *. ) ( < ) ( <= ) x1 x2) and sqrt x = (lift_real_to_real original_sqrt (fun x -> (Base 1.0)/.((sqrt x)+.(sqrt x))) ( *. ) ( <= ) x) and exp x = (lift_real_to_real original_exp exp ( *. ) ( <= ) x) and ( < ) x1 x2 = lift_real_cross_real_to_bool lt x1 x2 and ( <= ) x1 x2 = lift_real_cross_real_to_bool ge x1 x2 in (( +. ), ( -. ), ( *. ), ( /. ), sqrt, exp, ( < ), ( <= )) let derivative_F f x = (epsilon := !epsilon +. (Base 1.0); let y' = match (f (dual_number (!epsilon) x (Base 1.0) ( <= ) )) with Dual_number (e1, _, y') -> if e1<(!epsilon) then Base 0.0 else y' | (Tape _) -> Base 0.0 | (Base _) -> Base 0.0 in epsilon := !epsilon -. (Base 1.0); y') open List let sqr x = x*.x let map_n f n = let rec loop i = if i=n then [] else (f i)::(loop (i+1)) in loop 0 let vplus u v = map2 ( +. ) u v let vminus u v = map2 ( -. ) u v let ktimesv k = map (fun x -> k*.x) let magnitude_squared x = fold_left ( +. ) (Base 0.0) (map sqr x) let magnitude x = sqrt (magnitude_squared x) let distance_squared u v = magnitude_squared (vminus v u) let distance u v = sqrt (distance_squared u v) let rec replace_ith (xh::xt) i xi = if i<=(Base 0.0) && (Base 0.0)<=i then xi::xt else xh::(replace_ith xt (i-.(Base 1.0)) xi) let gradient_F f x = map_n (fun i -> derivative_F (fun xi -> f (replace_ith x (Base (float i)) xi)) (nth x i)) (length x) let rec determine_fanout (Tape (_, _, _, tapes, fanout, _)) = (fanout := !fanout+.(Base 1.0); if !fanout<=(Base 1.0) && (Base 1.0)<=(!fanout) (* for-each *) then (map determine_fanout tapes; ()) else ()) let rec reverse_phase sensitivity1 (Tape (_, _, factors, tapes, fanout, sensitivity)) = (sensitivity := !sensitivity+.sensitivity1; fanout := !fanout-.(Base 1.0); if !fanout<=(Base 0.0) && (Base 0.0)<=(!fanout) (* for-each *) then ((map2 (fun factor tape -> reverse_phase (!sensitivity*.factor) tape) factors tapes); ()) else ()) let gradient_R f x = (epsilon := !epsilon+.(Base 1.0); let x = map (fun xi -> (tape (!epsilon) xi [] [])) x in let y = f x in (match f x with (Dual_number _) -> () | Tape (e1, _, _, _, _, _) -> if e1<(!epsilon) then () else (determine_fanout y; reverse_phase (Base 1.0) y) | Base _ -> ()); epsilon := !epsilon-.(Base 1.0); map (fun (Tape (_, _, _, _, _, sensitivity)) -> !sensitivity) x) let rec gradient_ascent_F f x0 n eta = if n<=(Base 0.0) && (Base 0.0)<=n then (x0, (f x0), (gradient_F f x0)) else gradient_ascent_F f (vplus x0 (ktimesv eta (gradient_F f x0))) (n-.(Base 1.0)) eta let rec gradient_ascent_R f x0 n eta = if n<=(Base 0.0) && (Base 0.0)<=n then (x0, (f x0), (gradient_R f x0)) else gradient_ascent_R f (vplus x0 (ktimesv eta (gradient_R f x0))) (n-.(Base 1.0)) eta let multivariate_argmin_F f x = let g = gradient_F f in let rec loop x fx gx eta i = if (magnitude gx)<=(Base 1e-5) then x else if i<=(Base 10.0) && (Base 10.0)<=i then loop x fx gx ((Base 2.0)*.eta) (Base 0.0) else let x' = vminus x (ktimesv eta gx) in if (distance x x')<=(Base 1e-5) then x else let fx' = (f x') in if fx' (Base 0.0)-.(f x)) x let rec multivariate_max_F f x = f (multivariate_argmax_F f x) let multivariate_argmin_R f x = let g = gradient_R f in let rec loop x fx gx eta i = if (magnitude gx)<=(Base 1e-5) then x else if i<=(Base 10.0) && (Base 10.0)<=i then loop x fx gx ((Base 2.0)*.eta) (Base 0.0) else let x' = vminus x (ktimesv eta gx) in if (distance x x')<=(Base 1e-5) then x else let fx' = (f x') in if fx' (Base 0.0)-.(f x)) x let multivariate_max_R f x = f (multivariate_argmax_R f x)