common-ghc.hs

```module Common where

data Num a => Bundle a = B a a

instance (Num a, Show a) => Show (Bundle a) where
show (B x x') = "(B " ++ (show x) ++ " " ++ (show x') ++ ")"

lift x = B x 0

instance (Num a) => Num (Bundle a) where
fromInteger z       = lift (fromInteger z)
(B x x') + (B y y') = B (x + y) (x' + y')
(B x x') - (B y y') = B (x - y) (x' - y')
(B x x') * (B y y') = B (x * y) (x * y' + x' * y)
negate (B x x')     = B (- x)   (- x')
abs    (B x x')     = let s = signum x in B (s * x) (s * x')
signum (B x _)      = lift (signum x)

instance Fractional a => Fractional (Bundle a) where
recip (B x x') = let r = recip x in B r (- x' * r * r)
fromRational z = lift (fromRational z)

instance (Num a, Eq a) => Eq (Bundle a) where
(B x _) == (B y _) = (x == y)

instance (Num a, Ord a) => Ord (Bundle a) where
(B x _) `compare` (B y _) = x `compare` y

instance Floating a => Floating (Bundle a) where
pi             = lift pi
exp   (B x x') = B (exp x)   (x' * exp x)
log   (B x x') = B (log x)   (x' / x)
sqrt  (B x x') = let y = sqrt x in B y (x' / (2 * y))
sin   (B x x') = B (sin x)   (x' * (cos x))
cos   (B x x') = B (cos x)   (x' * (- sin x))
asin  (B x x') = B (asin x)  (x' * (error "unimplemented"))
atan  (B x x') = B (atan x)  (x' * (error "unimplemented"))
acos  (B x x') = B (acos x)  (x' * (error "unimplemented"))
sinh  (B x x') = B (sinh x)  (x' * (error "unimplemented"))
cosh  (B x x') = B (cosh x)  (x' * (error "unimplemented"))
asinh (B x x') = B (asinh x) (x' * (error "unimplemented"))
atanh (B x x') = B (atanh x) (x' * (error "unimplemented"))
acosh (B x x') = B (acosh x) (x' * (error "unimplemented"))

instance (Num a, Enum a) => Enum (Bundle a) where
toEnum i         = lift (toEnum i)
fromEnum (B i _) = fromEnum i
succ             = (+ 1)
pred             = (subtract 1)

instance (Num a, Ord a, Real a) => Real (Bundle a) where
toRational (B x _) = toRational x

derivative :: Num a => (Bundle a -> Bundle a) -> a -> a
derivative f x = let (B _ y') = f (B x 1) in y'

sqr x = x * x

vplus :: Num a => [a] -> [a] -> [a]
vplus = zipWith (+)

vminus :: Num a => [a] -> [a] -> [a]
vminus = zipWith (-)

ktimesv k = map (k *)

magnitude_squared x = foldl (+) 0 (map sqr x)

magnitude :: Floating a => [a] -> a
magnitude = sqrt . magnitude_squared

distance_squared u v = magnitude_squared (vminus u v)

distance u v = sqrt (distance_squared u v)

replace_ith (x : xs) 0 xi = (xi : xs)
replace_ith (x : xs) (i + 1) xi = (x : (replace_ith xs i xi))

gradient f x =
map (\ i -> derivative
(\ xi -> f (replace_ith (map lift x) i xi)) (x !! i))
[0 .. (length x) - 1]

lower_fs :: Num a => ([Bundle a] -> Bundle a) -> [a] -> a
lower_fs f xs = let (B y _) = f (map lift xs) in y

multivariate_argmin f x =
let g = gradient f
ff = lower_fs f
loop x fx gx eta i =
if (magnitude gx) <= 1e-5
then x
else if i == 10
then loop x fx gx (2 * eta) 0
else let x_prime = vminus x (ktimesv eta gx)
in if (distance x x_prime) <= 1e-5
then x
else let fx_prime = ff x_prime
in if fx_prime < fx
then
loop
x_prime fx_prime (g x_prime) eta       (i + 1)
else
loop
x       fx       gx          (eta / 2) 0
in loop x (ff x) (g x) 1e-5 0

multivariate_argmax :: (Floating a, Ord a) =>
([Bundle a] -> Bundle a) -> [a] -> [a]
multivariate_argmax f x = multivariate_argmin (\ x -> - (f x)) x

multivariate_max f x = (lower_fs f) (multivariate_argmax f x)
```

Generated by GNU enscript 1.6.4.