# Providing Hand-Written Gradients

Providing and writing hand-written gradients for operations can be useful if you are peforming low-level optimizations or equipping your library for backprop.

Ideally, as an *end user*, you should never have to do this. The whole point
of the *backprop* library is to allow you to use backpropagatable functions as
normal functions, and to let you build complicated functions by simply
composing normal Haskell functions, where the *backprop* library automatically
infers your gradients.

However, if you are writing a library, you probably need to provide "primitive" backpropagatable functions (like matrix-vector multiplication for a linear algebra library) for your users, so your users can then use those primitive functions to write their own code, without ever having to be aware of any gradients.

If you are writing code and recognize some bottlenecks related to library
overhead as described in this post, then you might also want to
provide manual gradients as a last resort. However, this should always be a
last resort, as *figuring out* manual gradients is a tedious and error-prone
process that can introduce subtle bugs in ways that don't always appear in
testing. It also makes your code much more fragile and difficult to refactor
and shuffle around (since you aren't using normal function composition and
application anymore) and much harder to read. Only proceed if you decide that
the huge cognitive costs are worth it.

## The Lifted Function

A lifted function of type

```
myFunc :: Reifies s W => BVar s a -> BVar s b
```

represents a backpropagatble function taking an `a`

and returning a `b`

. It is
represented as a function taking a `BVar`

containing an `a`

and returning a
`BVar`

containing a `b`

; the `BVar s`

with the `Reifies s W`

is what allows for
tracking of backpropagation.

A `BVar s a -> BVar s b`

is really, actually, under the hood:

```
type BVar s a -> BVar s b
= a -> (b, b -> a)
```

That is, given an input `a`

, you get:

- A
`b`

, the result (the "forward pass") - A
`b -> a`

, the "scaled gradient" function.

A full technical description is given in the documentation for Numeric.Backprop.Op.

The `b`

result is simple enough; it's the result of your function. The "scaled
gradient" function requires some elaboration. Let's say you are writing a
lifted version of your function \(y = f(x)\) (whose derivative is
\(\frac{dy}{dx}\)), and that your *final result* at the end of your
computation is \(z = g(f(x))\) (whose derivative is \(\frac{dz}{dx}\)). In
that case, because of the chain rule, \(\frac{dz}{dx} = \frac{dz}{dy}
\frac{dy}{dx}\).

The scaled gradient `b -> a`

is the function which, *given*
\(\frac{dy}{dz}\) `:: b`

, *returns* \(\frac{dz}{dx}\) `:: a`

. (that is,
returns \(\frac{dz}{dy} \frac{dy}{dx}\) `:: a`

).

For example, for the mathematical operation \(y = f(x) = x^2\), then, considering \(z = g(f(x))\), \(\frac{dz}{dx} = \frac{dz}{dy} 2x\). In fact, for all functions taking and returning scalars (just normal single numbers), \(\frac{dz}{dx} = \frac{dz}{dy} f'(x)\).

## Simple Example

With that in mind, let's a lifted "squared" operation, that takes `x`

and
returns `x^2`

:

```
square
:: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
square = liftOp1 . op1 $ \x ->
( x^2 , \dzdy -> dzdy * 2 * x)
-- ^- actual result ^- scaled gradient function
```

We can write one for `sin`

, as well. For \(y = f(x) = \sin(x)\), we consider
\(z = g(f(x))\) to see \(\frac{dz}{dx} = \frac{dz}{dy} \cos(x)\). So, we
have:

```
liftedSin
:: (Floating a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
( sin x, \dzdy -> dzdy * cos x )
```

In general, for functions that take and return scalars:

```
liftedF
:: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
liftedF = liftOp1 . op1 $ \x ->
( f x, \dzdy -> dzdy * dfdx x )
```

For an example of every single numeric function in base Haskell, see the
source of Op.hs for the `Op`

definitions for every method in `Num`

,
`Fractional`

, and `Floating`

.

## Non-trivial example

A simple non-trivial example is `sumElements`

, which we can define to take the
*hmatrix* library's `R n`

type (an n-vector of `Double`

). In this case, we
have to think about \(g(\mathrm{sum}(\mathbf{x}))\). In this case, the types
guide our thinking:

```
sumElements :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n
```

The simplest way for me to do this personally is to just take it element by element.

*Write out the functions in question, in a simple example*In our case:

- \(y = f(\langle a, b, c \rangle) = a + b + c\)
- \(z = g(y) = g(a + b + c)\)

*Identify the components in your gradient*In our case, we have to return a gradient \(\langle \frac{\partial z}{\partial a}, \frac{\partial z}{\partial b}, \frac{\partial z}{\partial c} \rangle\).

*Work out each component of the gradient until you start to notice a pattern*Let's start with \(\frac{\partial z}{\partial a}\). We need to find \(\frac{\partial z}{\partial a}\) in terms of \(\frac{dz}{dy}\):

- Through the chain rule, \(\frac{\partial z}{\partial a} = \frac{dz}{dy} \frac{\partial y}{\partial a}\).
- Because \(y = a + b + c\), we know that \(\frac{\partial y}{\partial a} = 1\).
- Because \(\frac{\partial y}{\partial a} = 1\), we know that \(\frac{\partial z}{\partial a} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}\).

So, our expression of \(\frac{\partial z}{\partial a}\) in terms of \(\frac{dz}{dy}\) is simple -- it's simply \(\frac{\partial z}{\partial a} = \frac{dz}{dy}\).

Now, let's look at \(\frac{\partial z}{\partial b}\). We need to find \(\frac{\partial z}{\partial b}\) in terms of \(\frac{dz}{dy}\).

- Through the chain rule, \(\frac{\partial z}{\partial b} = \frac{dz}{dy} \frac{\partial y}{\partial b}\).
- Because \(y = a + b + c\), we know that \(\frac{\partial y}{\partial b} = 1\).
- Because \(\frac{\partial y}{\partial b} = 1\), we know that \(\frac{\partial z}{\partial b} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}\).

It looks like \(\frac{\partial z}{\partial b} = \frac{\partial z}{\partial y}\), as well.

At this point, we start to notice a pattern. We can apply the same logic to see that \(\frac{\partial z}{\partial c} = \frac{dz}{dy}\).

*Write out the pattern*Extrapolating the pattern, \(\frac{\partial z}{\partial q}\), where \(q\) is

*any*component, is always going to be a constant -- \(\frac{dz}{dy}\).

So in the end:

```
liftedSumElements
:: (KnownNat n, Reifies s W)
=> BVar s (R n)
-> BVar s Double
liftedSumElements = liftOp1 . op1 $ \xs ->
( sumElements xs, \dzdy -> konst dzdy ) -- a constant vector
```

### Multiple-argument functions

Lifting multiple-argument functions is the same thing, except using `liftOp2`

and `op2`

, or `liftOpN`

and `opN`

.

A `BVar s a -> BVar s b -> BVar s c`

is, really, under the hood:

```
type BVar s a -> BVar s b -> BVar s c =
a -> b -> (c, c -> (a, b))
```

That is, given an input `a`

and `b`

, you get:

- A
`c`

, the result (the "forward pass") - A
`c -> (a, b)`

, the "scaled gradient" function returning the gradient of both inputs.

The `c`

parameter of the scaled gradient is again \(\frac{dz}{dy}\), and the
final `(a,b)`

is a tuple of \(\frac{\partial z}{\partial x_1}\) and
\(\frac{\partial z}{\partial x_2}\): how \(\frac{dz}{dy}\) affects both of
the inputs.

For a simple example, let's look at \(x + y\). Working it out:

- \(y = f(x_1, x_2) = x_1 + x_2\)
- \(z = g(f(x_1, x_2)) = g(x_1 + x_2)\)
- Looking first for \(\frac{\partial z}{\partial x_1}\) in terms of
\(\frac{dz}{dy}\):
- \(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial y}{\partial x_1}\) (chain rule)
- From \(y = x_1 + x_2\), we see that \(\frac{\partial y}{\partial x_1} = 1\)
- Therefore, \(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}\).

- Looking second for \(\frac{\partial z}{\partial x_2}\) in terms of
\(\frac{dz}{dy}\):
- \(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \frac{\partial y}{\partial x_2}\) (chain rule)
- From \(y = x_1 + x_2\), we see that \(\frac{\partial y}{\partial x_2} = 1\)
- Therefore, \(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \times 1 = \frac{dz}{dy}\).

- Therefore, \(\frac{\partial z}{\partial x_1} = \frac{dz}{dy}\), and also \(\frac{\partial z}{\partial x_2} = \frac{dz}{dy}\).

Putting it into code:

```
add :: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
-> BVar s a
add = liftOp2 . op2 $ \x1 x2 ->
( x1 + x2, \dzdy -> (dzdy, dzdy) )
```

Let's try our hand at multiplication, or \(x * y\):

- \(y = f(x_1, x_2) = x_1 x_2\)
- \(z = g(f(x_1, x_2)) = g(x_1 x_2)\)
- Looking first for \(\frac{d\partial }{d\partial _1}\) in terms of
\(\frac{dz}{dy}\):
- \(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial y}{\partial x_1}\) (chain rule)
- From \(y = x_1 x_2\), we see that \(\frac{\partial y}{\partial x_1} = x_2\)
- Therefore, \(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2\).

- Looking second for \(\frac{\partial z}{\partial x_2}\) in terms of
\(\frac{dz}{dy}\):
- \(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial y}{\partial x_1}\) (chain rule)
- From \(y = x_1 x_2\), we see that \(\frac{\partial y}{\partial x_2} = x_1\)
- Therefore, \(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} x_1\).

- Therefore, \(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2\), and \(\frac{\partial z}{\partial x_2} = x_1 \frac{dz}{dy}\).

In code:

```
mul :: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
-> BVar s a
mul = liftOp2 . op2 $ \x1 x2 ->
( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) )
```

For non-trivial examples involving linear algebra, see the source for the *hmatrix-backprop* library.

Some examples, for the dot product between two vectors and for matrix-vector multiplication:

```
-- import qualified Numeric.LinearAlgebra.Static as H
-- | dot product between two vectors
dot
:: (KnownNat n, Reifies s W)
=> BVar s (R n)
-> BVar s (R n)
-> BVar s Double
dot = liftOp2 . op2 $ \u v ->
( u `H.dot` v
, \dzdy -> (H.konst dzdy * v, u * H.konst dzdy)
)
-- | matrix-vector multiplication
(#>)
:: (KnownNat m, KnownNat n, Reifies s W)
=> BVar s (L m n)
-> BVar s (R n)
-> BVar s (R m)
(#>) = liftOp2 . op2 $ \mat vec ->
( mat H.#> vec
, \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy)
)
```

## Possibilities

That's it for this introductory tutorial on lifting single operations. More information on the ways to apply these techniques to fully equip your library for backpropagation (including arguments with multiple results, taking advantage of isomorphisms, providing non-gradient functions) can be found here!