Overloading Haskell numbers, part 2, Forward Automatic Differentiation. I will continue my overloading by some examples that have been nicely illustrated by an article by Jerzy Karczmarczuk. And blogged about by sigfpe. But at least I'll end this entry with a small twist I've not seen before. When computing the derivative of a function you normally do by either symbolic derivation, or by a numerical approximation. Say that you have a function f(x) = x2 + 1 and you want to know the derivative at x=5. Doing it symbolically you first get f' f'(x) = 2x using high school calculus (maybe they don't teach it in high school anymore?), and then you plug in 5 f'(5) = 2*5 = 10 Computing it by numeric differentiation you compute f'(x) = (f(x+h) - f(x)) / h for some small h. Let's pick h=1e-5, and we get f'(5) = 10.000009999444615. Close, but not that good. So why don't we always use the symbolic method? Well, some functions are not that easy to differentiate. Take this one g x = if abs (x - 0.7) < 0.4 then x else g (cos x) What's the derivative? Well, it's tricky because this is not really a proper definition of g. It's an equation that if solved will yield a definition of g. And like equations in general, it could have zero, one, or many solutions. (If we happen to use CPOs there is always a unique smallest solution which is what programs compute, as if by magic.) If you think g is contrived, lets pick a different example: computing the square root with Newton-Raphson.
sqr x = convAbs $ iterate improve 1 where improve r = (r + x/r) / 2 convAbs (x1:x2:_) | abs (x1-x2) < 1e-10 = x2 convAbs (_:xs) = convAbs xsSo symbolic is not so easy here, and numeric differentiation is not very accurate. But there is a third way! Automatic differentiation. The idea behind AD is that instead of computing with with just numbers, we instead compute with pairs of numbers. The first component is the normal number, and the second component is the derivative. What are the rules for these numbers? Let's look at addition (x, x') + (y, y') = (x+y, x'+y') To add two numbers you just add the regular part and the derivatives. For multiplication you have to remember how to compute the derivative of a product: (f(x)*g(x))' = f(x)*g'(x) + f'(x)*g(x) So for our pairs we get (x, x') * (y, y') = (x*y, x*y' + x'*y) i.e., first the regular product, then the derivative according to the recipe above. Let's see how it works on f(x) = x2 + 1 We want the derivative at x=5. So what is the pair we use for x? It is (5, 1). Why? Well it has to be 5 for the regular part, and since this represents x and the derivative of x is 1, the pair is (5, 1). In the right hand side for f we need to replace 1 by (1,0), since the derivative of a constant is 0. So then we get f (5,1) = (5,1)*(5,1) + (1,0) = (26,10) using the rules above. And look! There is the normal result, 26, as well as the derivative, 10. Let's turn this into Haskell, using the type PD to hold a pair of Doubles
data PD = P Double Double deriving (Eq, Ord, Show) instance Num PD where P x x' + P y y' = P (x+y) (x'+y') P x x' - P y y' = P (x-y) (x'-y') P x x' * P y y' = P (x*y) (x*y' + y'*x) fromInteger i = P (fromInteger i) 0A first observation is that there is nothing Double specific in this definitions; it would work for any Num. So we can change it to
data PD a = P a a deriving (Eq, Ord, Show) instance Num a => Num (PD a) where ...Let's also add abs&signum and the Fractional instance
... abs (P x x') = P (abs x) (signum x * x') signum (P x x') = P (signum x) 0 instance Fractional a => Fractional (PD a) where P x x' / P y y' = P (x / y) ( (x'*y - x*y') / (y * y)) fromRational r = P (fromRational r) 0We can now try the sqr example
Main> sqr (P 9 1) P 3.0 0.16666666666666666The derivative of x**0.5 is 0.5*x**(-0.5), i.e., 0.5*9**(-0.5) = 0.5/3 = 0.16666666666666666. So we got the right answer. BTW, if you want to be picky the derivative of signum is not 0. The signum function makes a jump from -1 to 1 at 0. So the "proper" value would be 2*dirac, if dirac is a Dirac pulse. But since we don't have numbers with Dirac pulses (yet), I'll just pretend the derivative is 0 everywhere. The very clever insight that Jerzy had was that when doing these numbers in Haskell there is no need to limit yourself to just the first derivative. Since Haskell is lazy we can easily keep an infinite list of all derivatives instead of just the first one. Let's look at how that definition looks. It's very similar to what we just did. But instead of the derivative being just a number, it's now one of our new numbers with a value, and all derivatives... Since we are now dealing with an infinite data structure we need to define our own show, (==), etc.
data Dif a = D a (Dif a) val (D x _) = x df (D _ x') = x' dVar x = D x 1 instance (Show a) => Show (Dif a) where show x = show (val x) instance (Eq a) => Eq (Dif a) where x == y = val x == val y instance (Ord a) => Ord (Dif a) where x `compare` y = val x `compare` val y instance (Num a) => Num (Dif a) where D x x' + D y y' = D (x + y) (x' + y') D x x' - D y y' = D (x - y) (x' - y') p@(D x x') * q@(D y y') = D (x * y) (x' * q + p * y') fromInteger i = D (fromInteger i) 0 abs p@(D x x') = D (abs x) (signum p * x') signum (D x _) = D (signum x) 0 instance (Fractional a) => Fractional (Dif a) where recip (D x x') = ip where ip = D (recip x) (-x' * ip * ip) fromRational r = D (fromRational r) 0This looks simple, but it's rather subtle. For instance, take the 0 in the definition of fromInteger. It's actually of Dif type, so it's a recursive call to fromInteger. So let's try with our sqr function again, this time computing up to the third derivative. The dVar is used to create a value for "variable" where we want to differentiate.
Main> sqr $ dVar 9 3.0 Main> df $ sqr $ dVar 9 0.16666666666666669 Main> df $ df $ sqr $ dVar 9 -9.259259259259259e-3 Main> df $ df $ df $ sqr $ dVar 9 1.5432098765432098e-3And the transcendentals in a similar way:
lift (f : f') p@(D x x') = D (f x) (x' * lift f' p) instance (Floating a) => Floating (Dif a) where pi = D pi 0 exp (D x x') = r where r = D (exp x) (x' * r) log p@(D x x') = D (log x) (x' / p) sqrt (D x x') = r where r = D (sqrt x) (x' / (2 * r)) sin = lift (cycle [sin, cos, negate . sin, negate . cos]) cos = lift (cycle [cos, negate . sin, negate . cos, sin]) acos p@(D x x') = D (acos x) (-x' / sqrt(1 - p*p)) asin p@(D x x') = D (asin x) ( x' / sqrt(1 - p*p)) atan p@(D x x') = D (atan x) ( x' / (p*p - 1)) sinh x = (exp x - exp (-x)) / 2 cosh x = (exp x + exp (-x)) / 2 asinh x = log (x + sqrt (x*x + 1)) acosh x = log (x + sqrt (x*x - 1)) atanh x = (log (1 + x) - log (1 - x)) / 2And why not try the function g we defined above?
Main> g 10 0.6681539175313869 Main> g (dVar 10) 0.6681539175313869 Main> df $ g (dVar 10) 0.4047642621121782 Main> df $ df $ g (dVar 10) 0.4265424381635987 Main> df $ df $ df $ g (dVar 10) -1.4395397945007182It all works very nicely. So now when we can compute the derivative of a function, let's define something somewhat more interesting with it. Let's revisit the sqr function again. It uses Newton-Raphson to find the square root. How does Newton-Raphson actually work? Given a differentiable function, f(x), it finds a zero by starting with some x1 and then iterating xn+1 = xn - f(xn)/f'(xn) until we meet some convergence criterion. Using this, let's define a function that finds a zero of another function:
findZero f = convRel $ cut $ iterate step start where step x = x - val fx / val (df fx) where fx = f (dVar x) start = 1 -- just some value epsilon = 1e-10 cut = (++ error "No convergence in 1000 steps") . take 1000 convRel (x1:x2:_) | x1 == x2 || abs (x1+x2) / abs (x1-x2) > 1/epsilon = x2 convRel (_:xs) = convRel xsThe only interesting part is the step function that does one iteration with Newton-Raphson. It computes f x and then divides the normal value with the derivative. We then produce the infinite list of approximations using step, then cut it of at some point (in case it doesn't converge), and then we look down the list for two values that are within some relative epsilon. And it even seems to work.
Main> findZero (\x -> x*x - 9) 3.0 Main> findZero (\x -> sin x - 0.5) 0.5235987755982989 Main> sin it 0.5 Main> findZero (\x -> x*x + 9) *** Exception: No convergence in 1000 steps Main> findZero (\x -> sqr x - 3) 9.0Note how it finds a zero of the sqr function which is actually using recursion internally to compute the square root. So now we can compute numerical derivatives. But wait! We also have symbolic numbers. Can we combine them? Of course, that is the power of polymorphism. Let's load up both modules:
Data.Number.Symbolic Dif3> let x :: Num a => Dif (Sym a); x = dVar (var "x") Data.Number.Symbolic Dif3> df $ x*x x+x Data.Number.Symbolic Dif3> df $ sin x cos x Data.Number.Symbolic Dif3> df $ sin (exp (x - 4) * x) (exp (-4.0+x)*x+exp (-4.0+x))*cos (exp (-4.0+x)*x) Data.Number.Symbolic Dif3> df $ df $ sin (exp (x - 4) * x) (exp (-4.0+x)*x+exp (-4.0+x)+exp (-4.0+x))*cos (exp (-4.0+x)*x)+(exp (-4.0+x)*x+exp (-4.0+x))*(exp (-4.0+x)*x+exp (-4.0+x))*(-sin (exp (-4.0+x)*x))We define x to be a differentiable number, "the variable", over symbolic numbers, over some numbers. And then we just happily use df to get the differentiated versions. So we set out to compute numeric derivatives, and we got these for free. Not too bad. One final note, the Dif type is defined above can be made more efficient by not keeping all the infinite tails with 0 derivatives around. In a real module for this, you'd want to make this optimization. [Edit: fixed typo.]