alltom.com

Working with tensors symbolically in Wolfram Language

When I want to do work with tensors, I use JAX in Python. But when I want to understand and visualize them, I use Wolfram Language. This article is an extended demonstration of why that is. :)

For a while now, I’ve spent a little time each week working through Pattern Recognition and Machine Learning (PRML), a textbook that teaches the fundamentals of ML. Wolfram Language has rich library functions for most of the concepts in the book, like matrices, probability distributions, and function minimization. And it plots!

But symbolic manipulation is the best part. I can write an expression and it stays an expression! When I calculate a derivative, I get to see the expression for the derivative.

Getting to this point wasn’t all sunshine and rainbows, though. My goal here is to succinctly summarize the lessons I’ve learned in how to use Wolfram Language to bridge the gap between the tensor work I already knew how to do in JAX/numpy, and the symbolic math I can do in Wolfram Language.

I’m going to assume that you have a similar background to mine: you know some linear algebra and have implemented neural networks in a framework like TensorFlow or JAX, but you’re not sure how best to represent those concepts in Wolfram Language for inspectability.

What’s inside…

Defining tensors in Wolfram Language

In a math text, I can define a tensor A and write article_2.gif to refer to its entries.

However, in Wolfram Language, if I were to define a tensor a containing article_3.gif, article_4.gif, and so on, the definition would be recursive and not evaluate very well:

article_5.gif

article_6.gif

So my naming convention is to append “t” (for “tensor”) to the name that refers to the entire tensor. For example, I could define a tensor of shape {2,5} like this:

article_7.gif

I leave a undefined so that expressions like article_8.gif remain unevaluated.

Instead of writing out all the terms like that, though, I’d actually use Array, so that I can get the same result just by specifying the shape:

article_11.gif

article_12.gif

Poking at neural networks

Neural networks are a nice first example because they’re genuinely useful, but not too complicated.

We can use the built-in NetGraph library to quickly create a visual representation of the network we want to build:

article_13.gif

article_14.png

It takes 2-D inputs, has one 5-unit hidden layer with Tanh activation, and a 2-D output layer with linear activation.

Building the network

While NetGraph is a great library, our goal is to understand all the math, and it’s tedious to disassemble a NetGraph as completely as we can if we define the network ourselves, like this:

article_15.gif
article_16.gif
article_17.gif
article_18.gif

The expressions for calculating output of the network are easy to write:

article_19.gif
article_20.gif

As is squared error loss:

article_21.gif
article_22.gif

Pretty-printing the equations

I suppressed the printing of the expressions for the network outputs and loss above because, by default, Wolfram Language fully expands them until every defined identifier has been replaced by its definition. That’s handy sometimes, but not typically elucidating for expressions this large:

article_23.gif

article_24.gif

But MatrixForm and HoldForm can help. MatrixForm displays a matrix like a matrix, and HoldForm prevents identifiers from being expanded during evaluation.

They’re great for showing matrices as matrices:

article_26.gif

article_27.gif

And with a helper…

article_28.gif

… we can render the network output expression in a way that communicates shape as well:

article_29.gif

article_30.gif

And the loss, of course:

article_31.gif

article_32.gif

Poking the network

Since everything’s defined symbolically, we can substitute values in arbitrarily.

There are dozens of variables across inputt, w1t, b1t, etc, so a function to assign them random values is helpful:

article_33.gif

It works by returning a list of replacement Rules:

article_34.gif

article_35.gif

To actually substitute them in, we can use ReplaceAll. If we substitute in values for the inputs and all the network weights, we simply get the network’s output:

article_36.gif

article_37.gif

But we can get the output in terms of arbitrary subsets of the variables! We could plot the network’s response to changing just the second input:

article_41.gif

article_42.gif

article_43.png

Or the network’s response to editing a single value in the hidden layer’s weight matrix:

article_44.gif

article_45.gif

article_46.png

Since the input is rank-2 and loss is rank-1 (what a coincidence!), we can plot the loss over the input domain, assuming a target value of {0, 0}:

article_47.gif

article_48.png

Or, since the input and output domains are both article_49.gif, we can view it as a transform and plot the vector field:

article_50.gif

article_51.png

The usefulness of the above plots all depends on what you’re trying to learn when you decide to render them. The point is the flexibility and relative ease with which you can!

Of more conventionally practical importance, we can calculate the partial derivative of the loss respect to any of the variables, which we will most likely want if we ever decide to train this network…!

article_52.gif

article_53.gif

Training the neural network

A spiral dataset

We need a dataset, so to start, I’m going to define one where the inputs are 2-D coordinates, and the outputs are those same coordinates rotated 45º counter-clockwise around the origin.

First, a function for performing the rotation:

article_54.gif

And a function that generates an example for the dataset by picking a random point in the unit circle and rotating it:

article_55.gif

We can use that to generate a whole dataset:

article_56.gif

We can visualize it by mapping every example to an Arrow:

article_57.gif

article_58.png

Training with a built-in optimizer

One way to find good network weights for our network using this dataset is with one of Wolfram Language’s built-in optimizers. For small datasets, NMinimize is enough to demonstrate that our little neural network can actually learn something.

Our goal is to minimize the total error of all examples in the dataset. That translates almost directly into code:

article_59.gif

article_60.gif

We can plot the transformation that the network learned, which if we eyeball the tangents in a stream plot, looks like the intended rotation:

article_61.gif

article_62.png

If we plot the loss, we see that error is small everywhere in the unit circle that we trained on:

article_63.gif

article_64.png

Does it generalize?

Unfortunately, you don’t have to zoom out very far to see that it doesn’t generalize, though you have to check the legend for the scale of the legend to notice:

article_65.gif

article_66.png

We can see the errors compounding more clearly if we feed the output of the network back into itself to generate trails, and overlay the circles that the trails would ideally trace:

article_67.gif

article_68.png

Anyway, we did it!

Does it scale?

Nah, this not a great way to do it if you care about training a realistic neural network. Like I said at the start, if I were trying to do real work, I’d be using a different tool—or at least, I’d be using completely different language constructs, like NetTrain.

Why isn’t it a good way, though? Because the function we’re asking the library to minimize is big. It contains a copy of the entire loss function (which contains the entire inference function) for every point in the dataset. It looks like this just for 3 data points:

article_69.gif

article_70.gif

I get tired of waiting for it even after only ~200 data points.

Meanwhile, NetTrain manages 10,000+ unique examples per second on my laptop:

article_71.gif

article_72.png

The final weights actually aren’t as good, though, even on the unit circle, though that’s a story for another day…

article_73.gif

article_74.png

To be continued…

I’m publishing new sections as I go, so subscribe to know when there’s more!