Providing Hand-Written Gradients

Providing and writing hand-written gradients for operations can be useful if you are peforming low-level optimizations or equipping your library for backprop.

Ideally, as an end user, you should never have to do this. The whole point of the backprop library is to allow you to use backpropagatable functions as normal functions, and to let you build complicated functions by simply composing normal Haskell functions, where the backprop library automatically infers your gradients.

However, if you are writing a library, you probably need to provide "primitive" backpropagatable functions (like matrix-vector multiplication for a linear algebra library) for your users, so your users can then use those primitive functions to write their own code, without ever having to be aware of any gradients.

If you are writing code and recognize some bottlenecks related to library overhead as described in this post, then you might also want to provide manual gradients as a last resort. However, this should always be a last resort, as figuring out manual gradients is a tedious and error-prone process that can introduce subtle bugs in ways that don't always appear in testing. It also makes your code much more fragile and difficult to refactor and shuffle around (since you aren't using normal function composition and application anymore) and much harder to read. Only proceed if you decide that the huge cognitive costs are worth it.

The Lifted Function

A lifted function of type

myFunc :: Reifies s W => BVar s a -> BVar s b

represents a backpropagatble function taking an a and returning a b. It is represented as a function taking a BVar containing an a and returning a BVar containing a b; the BVar s with the Reifies s W is what allows for tracking of backpropagation.

A BVar s a -> BVar s b is really, actually, under the hood:

type BVar s a -> BVar s b
    = a -> (b, b -> a)

That is, given an input a, you get:

  1. A b, the result (the "forward pass")
  2. A b -> a, the "scaled gradient" function.

A full technical description is given in the documentation for Numeric.Backprop.Op.

The b result is simple enough; it's the result of your function. The "scaled gradient" function requires some elaboration. Let's say you are writing a lifted version of your function y=f(x)y = f(x) (whose derivative is dydx\frac{dy}{dx}), and that your final result at the end of your computation is z=g(f(x))z = g(f(x)) (whose derivative is dzdx\frac{dz}{dx}). In that case, because of the chain rule, dzdx=dzdydydx\frac{dz}{dx} = \frac{dz}{dy} \frac{dy}{dx}.

The scaled gradient b -> a is the function which, given dydz\frac{dy}{dz} :: b, returns dzdx\frac{dz}{dx} :: a. (that is, returns dzdydydx\frac{dz}{dy} \frac{dy}{dx} :: a).

For example, for the mathematical operation y=f(x)=x2y = f(x) = x^2, then, considering z=g(f(x))z = g(f(x)), dzdx=dzdy2x\frac{dz}{dx} = \frac{dz}{dy} 2x. In fact, for all functions taking and returning scalars (just normal single numbers), dzdx=dzdyf(x)\frac{dz}{dx} = \frac{dz}{dy} f'(x).

Simple Example

With that in mind, let's a lifted "squared" operation, that takes x and returns x^2:

square
    :: (Num a, Backprop a, Reifies s W)
    => BVar s a
    -> BVar s a
square = liftOp1 . op1 $ \x ->
    ( x^2              , \dzdy -> dzdy * 2 * x)
--    ^- actual result   ^- scaled gradient function

We can write one for sin, as well. For y=f(x)=sin(x)y = f(x) = \sin(x), we consider z=g(f(x))z = g(f(x)) to see dzdx=dzdycos(x)\frac{dz}{dx} = \frac{dz}{dy} \cos(x). So, we have:

liftedSin
    :: (Floating a, Backprop a, Reifies s W)
    => BVar s a
    -> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
    ( sin x, \dzdy -> dzdy * cos x )

In general, for functions that take and return scalars:

liftedF
    :: (Num a, Backprop a, Reifies s W)
    => BVar s a
    -> BVar s a
liftedF = liftOp1 . op1 $ \x ->
    ( f x, \dzdy -> dzdy * dfdx x )

For an example of every single numeric function in base Haskell, see the source of Op.hs for the Op definitions for every method in Num, Fractional, and Floating.

Non-trivial example

A simple non-trivial example is sumElements, which we can define to take the hmatrix library's R n type (an n-vector of Double). In this case, we have to think about g(sum(x))g(\mathrm{sum}(\mathbf{x})). In this case, the types guide our thinking:

sumElements           :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n

The simplest way for me to do this personally is to just take it element by element.

  1. Write out the functions in question, in a simple example

    In our case:

    • y=f(a,b,c)=a+b+cy = f(\langle a, b, c \rangle) = a + b + c
    • z=g(y)=g(a+b+c)z = g(y) = g(a + b + c)
  2. Identify the components in your gradient

    In our case, we have to return a gradient za,zb,zc\langle \frac{\partial z}{\partial a}, \frac{\partial z}{\partial b}, \frac{\partial z}{\partial c} \rangle.

  3. Work out each component of the gradient until you start to notice a pattern

    Let's start with za\frac{\partial z}{\partial a}. We need to find za\frac{\partial z}{\partial a} in terms of dzdy\frac{dz}{dy}:

    • Through the chain rule, za=dzdyya\frac{\partial z}{\partial a} = \frac{dz}{dy} \frac{\partial y}{\partial a}.
    • Because y=a+b+cy = a + b + c, we know that ya=1\frac{\partial y}{\partial a} = 1.
    • Because ya=1\frac{\partial y}{\partial a} = 1, we know that za=dzdy×1=dzdy\frac{\partial z}{\partial a} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}.

    So, our expression of za\frac{\partial z}{\partial a} in terms of dzdy\frac{dz}{dy} is simple -- it's simply za=dzdy\frac{\partial z}{\partial a} = \frac{dz}{dy}.

    Now, let's look at zb\frac{\partial z}{\partial b}. We need to find zb\frac{\partial z}{\partial b} in terms of dzdy\frac{dz}{dy}.

    • Through the chain rule, zb=dzdyyb\frac{\partial z}{\partial b} = \frac{dz}{dy} \frac{\partial y}{\partial b}.
    • Because y=a+b+cy = a + b + c, we know that yb=1\frac{\partial y}{\partial b} = 1.
    • Because yb=1\frac{\partial y}{\partial b} = 1, we know that zb=dzdy×1=dzdy\frac{\partial z}{\partial b} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}.

    It looks like zb=zy\frac{\partial z}{\partial b} = \frac{\partial z}{\partial y}, as well.

    At this point, we start to notice a pattern. We can apply the same logic to see that zc=dzdy\frac{\partial z}{\partial c} = \frac{dz}{dy}.

  4. Write out the pattern

    Extrapolating the pattern, zq\frac{\partial z}{\partial q}, where qq is any component, is always going to be a constant -- dzdy\frac{dz}{dy}.

So in the end:

liftedSumElements
    :: (KnownNat n, Reifies s W)
    => BVar s (R n)
    -> BVar s Double
liftedSumElements = liftOp1 . op1 $ \xs ->
    ( sumElements xs, \dzdy -> konst dzdy )  -- a constant vector

Multiple-argument functions

Lifting multiple-argument functions is the same thing, except using liftOp2 and op2, or liftOpN and opN.

A BVar s a -> BVar s b -> BVar s c is, really, under the hood:

type BVar s a -> BVar s b -> BVar s c =
    a -> b -> (c, c -> (a, b))

That is, given an input a and b, you get:

  1. A c, the result (the "forward pass")
  2. A c -> (a, b), the "scaled gradient" function returning the gradient of both inputs.

The c parameter of the scaled gradient is again dzdy\frac{dz}{dy}, and the final (a,b) is a tuple of zx1\frac{\partial z}{\partial x_1} and zx2\frac{\partial z}{\partial x_2}: how dzdy\frac{dz}{dy} affects both of the inputs.

For a simple example, let's look at x+yx + y. Working it out:

  • y=f(x1,x2)=x1+x2y = f(x_1, x_2) = x_1 + x_2
  • z=g(f(x1,x2))=g(x1+x2)z = g(f(x_1, x_2)) = g(x_1 + x_2)
  • Looking first for zx1\frac{\partial z}{\partial x_1} in terms of dzdy\frac{dz}{dy}:
    • zx1=dzdyyx1\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial y}{\partial x_1} (chain rule)
    • From y=x1+x2y = x_1 + x_2, we see that yx1=1\frac{\partial y}{\partial x_1} = 1
    • Therefore, zx1=dzdy×1=dzdy\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}.
  • Looking second for zx2\frac{\partial z}{\partial x_2} in terms of dzdy\frac{dz}{dy}:
    • zx2=dzdyyx2\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \frac{\partial y}{\partial x_2} (chain rule)
    • From y=x1+x2y = x_1 + x_2, we see that yx2=1\frac{\partial y}{\partial x_2} = 1
    • Therefore, zx2=dzdy×1=dzdy\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}.
  • Therefore, zx1=dzdy\frac{\partial z}{\partial x_1} = \frac{dz}{dy}, and also zx2=dzdy\frac{\partial z}{\partial x_2} = \frac{dz}{dy}.

Putting it into code:

add :: (Num a, Backprop a, Reifies s W)
    => BVar s a
    -> BVar s a
    -> BVar s a
add = liftOp2 . op2 $ \x1 x2 ->
    ( x1 + x2, \dzdy -> (dzdy, dzdy) )

Let's try our hand at multiplication, or xyx * y:

  • y=f(x1,x2)=x1x2y = f(x_1, x_2) = x_1 x_2
  • z=g(f(x1,x2))=g(x1x2)z = g(f(x_1, x_2)) = g(x_1 x_2)
  • Looking first for dd1\frac{d\partial }{d\partial _1} in terms of dzdy\frac{dz}{dy}:
    • zx1=dzdyyx1\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial y}{\partial x_1} (chain rule)
    • From y=x1x2y = x_1 x_2, we see that yx1=x2\frac{\partial y}{\partial x_1} = x_2
    • Therefore, zx1=dzdyx2\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2.
  • Looking second for zx2\frac{\partial z}{\partial x_2} in terms of dzdy\frac{dz}{dy}:
    • zx1=dzdyyx1\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial y}{\partial x_1} (chain rule)
    • From y=x1x2y = x_1 x_2, we see that yx2=x1\frac{\partial y}{\partial x_2} = x_1
    • Therefore, zx2=dzdyx1\frac{\partial z}{\partial x_2} = \frac{dz}{dy} x_1.
  • Therefore, zx1=dzdyx2\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2, and zx2=x1dzdy\frac{\partial z}{\partial x_2} = x_1 \frac{dz}{dy}.

In code:

mul :: (Num a, Backprop a, Reifies s W)
    => BVar s a
    -> BVar s a
    -> BVar s a
mul = liftOp2 . op2 $ \x1 x2 ->
    ( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) )

For non-trivial examples involving linear algebra, see the source for the hmatrix-backprop library.

Some examples, for the dot product between two vectors and for matrix-vector multiplication:

-- import qualified Numeric.LinearAlgebra.Static as H

-- | dot product between two vectors
dot
    :: (KnownNat n, Reifies s W)
    => BVar s (R n)
    -> BVar s (R n)
    -> BVar s Double
dot = liftOp2 . op2 $ \u v ->
    ( u `H.dot` v
    , \dzdy -> (H.konst dzdy * v, u * H.konst dzdy)
    )


-- | matrix-vector multiplication
(#>)
    :: (KnownNat m, KnownNat n, Reifies s W)
    => BVar s (L m n)
    -> BVar s (R n)
    -> BVar s (R m)
(#>) = liftOp2 . op2 $ \mat vec ->
    ( mat H.#> vec
    , \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy)
    )

Possibilities

That's it for this introductory tutorial on lifting single operations. More information on the ways to apply these techniques to fully equip your library for backpropagation (including arguments with multiple results, taking advantage of isomorphisms, providing non-gradient functions) can be found here!