Up and Running Nx

Tags

Photograph titled "Dancing Roof" of a wavy grid-like roof pattern

This is the first blog post in a series by guest writer Sean Moriarity, co-author of the Elixir Nx library and author of the book “Genetic Algorithms in Elixir”.

Nx is a new library for tensor manipulation and numerical computing on the BEAM. Nx hopes to open doors for Elixir, Erlang, and other BEAM languages to new, exciting domains by allowing users to accelerate code through JIT compilation and providing interfaces to highly-specialized tensor manipulation routines. In this post, you will learn some of the basics needed to get started with Nx, and you’ll see a basic example of how Nx can be used for machine learning applications.

Getting Comfortable with Tensors

The Nx definition of “tensor” is similar to the PyTorch or TensorFlow tensor, or the NumPy multidimensional array. If you’re coming from one of those frameworks, manipulating Nx tensors should feel familiar to you. One thing to note - the Nx definition of a tensor is not necessarily the same as the pure math definition of a tensor. Nx follows most of the conventions and precedents put forth by the Python ecosystem, so transitioning from any of those frameworks should be relatively easy. For Elixir programmers, it’s easy to think of tensors as nested lists, with some additional metadata:

iex> Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>

Nx.tensor/2 is one method you can use to create a tensor. It works with both nested lists of numbers and scalars:

iex> Nx.tensor(1.0)
#Nx.Tensor<
  f32
  1.0
>

Notice the additional metadata that comes out when tensors are inspected, namely s64[2][3] and f32 in the examples above. Tensors have both shapes and types associated with them. A tensor’s shape is a tuple representing the size of each dimension in the tensor. In the examples above, the first tensor has a shape of {2, 3} as represented by [2][3] in the inspected tensor:

iex> Nx.shape(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
{2, 3}

If you’re comfortable thinking of tensors as nested lists, this should make some intuitive sense - the first example contains 2 lists of 3 elements each. If you were to wrap the first example in more lists, the shape would change accordingly:

iex> Nx.shape(Nx.tensor([[[[1, 2, 3], [4, 5, 6]]]]))
{1, 1, 2, 3}

1 list of 1 list of 2 lists of 3 elements

This line of thinking can be a bit confusing when working with scalars. The shape of a scalar tensor is represented by an empty tuple:

iex> Nx.shape(Nx.tensor(1.0))
{}

This is because scalars are actually 0-dimensional tensors. They don’t have any dimensions and therefore, they have an “empty” shape.

A tensor’s type is the numeric type associated with the tensor. Types in Nx are represented as 2 element tuples with a type-class and size or bitwidth:

iex> Nx.type(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
{:s, 64}
iex> Nx.type(Nx.tensor(1.0))
{:f, 32}

Types are important because they tell Nx how to store tensors internally. Nx tensors are internally represented as binaries:

iex> Nx.to_binary(Nx.tensor(1))
<<1, 0, 0, 0, 0, 0, 0, 0>>
iex> Nx.to_binary(Nx.tensor(1.0))
<<0, 0, 128, 63>>

A Note on Endianness: Nx uses the native endianness specification, so the endianness of the binary is resolved at load-time to match the endianness of your machine. If, for some reason, your project requires big or little endian regardless of the machine it’s on, please open an issue describing your use case.

Notice the internal binary representation changes with a floating-point versus a signed integer type. You should also notice that Nx will attempt to infer the input type; however, you can also specify the input type when creating tensors using the type option:

iex> Nx.to_binary(Nx.tensor(1, type: {:f, 32}))
<<0, 0, 128, 63>>
iex> Nx.to_binary(Nx.tensor(1.0))
<<0, 0, 128, 63>>

Because Nx tensors are represented as binaries, you should almost never use Nx.tensor/2 in practice because it’s expensive for very large tensors. Nx exposes a useful method, Nx.from_binary/2 which does not require traversing a nested list:

iex> Nx.from_binary(<<0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64>>, {:f, 32})
#Nx.Tensor<
  f32[3]
  [1.0, 2.0, 3.0]
>

Nx.from_binary/2 takes a binary and a type and always returns a 1-dimensional tensor. If you want your data to have a different shape, you can use Nx.reshape/2:

iex> Nx.reshape(Nx.from_binary(<<0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64>>, {:f, 32}), {3, 1})
#Nx.Tensor<
  f32[3][1]
  [
    [1.0],
    [2.0],
    [3.0]
  ]
>

Nx.reshape/2 only ever changes the shape attribute of the tensor, so it’s a relatively inexpensive operation. When your data comes in as a binary, using Nx.from_binary/2 with Nx.reshape/2 is the most efficient way to create tensors.

Working with Tensor Operations

If you’re an experienced Elixir programmer, you’re probably intimately familiar with the Enum module for manipulating collections that implement the Enumerable protocol. Because of this, you’ll probably search for and prefer to use the functional constructs map and reduce. Nx does expose both map and reduce as methods for manipulating tensors, and they work in almost exactly the same way you’d expect; however, you should almost never use these methods.

All of the operations in the Nx library are tensor-aware, which means they work on tensors of any shape and type. For example, in Elixir you might be used to doing something like:

iex> Enum.map([1, 2, 3], fn x -> :math.cos(x) end)
[0.5403023058681398, -0.4161468365471424, -0.9899924966004454]

But, you can achieve the same thing in Nx using just Nx.cos/1:

iex> Nx.cos(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  f32[3] 
  [0.5403022766113281, -0.416146844625473, -0.9899924993515015]
>

All of the unary operators in Nx work this way - they apply a function element-wise to a tensor of any type and any shape:

iex> Nx.exp(Nx.tensor([[[1], [2], [3]]]))
#Nx.Tensor<
  f32[1][3][1]
  [
    [
      [2.7182817459106445],
      [7.389056205749512],
      [20.08553695678711]
    ]
  ]
>
iex> Nx.sin(Nx.tensor([[1, 2, 3]]))
#Nx.Tensor<
  f32[1][3]
  [
    [0.8414709568023682, 0.9092974066734314, 0.14112000167369843]
  ]
>
iex> Nx.acosh(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  f32[3] 
  [0.0, 1.316957950592041, 1.7627471685409546]
>

There’s almost never a need to use something like Nx.map, because the element-wise unary operators can almost always be used to achieve the same effect. Nx.map will almost always be less efficient, and you will be unable to use Nx transforms like grad with Nx.map. Additionally, Nx.map cannot be supported by some Nx backends or compilers - so portability is a concern. The same applies for working with aggregate methods like Nx.reduce. You should prefer the Nx provided aggregate methods like Nx.sum, Nx.mean, and Nx.product, over implementing your own using Nx.reduce:

iex> Nx.sum(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64
  6
>
iex> Nx.product(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64
  6
>
iex> Nx.mean(Nx.tensor([1, 2, 3]))
#Nx.tensor<
  f32
  2.0
>

Nx aggregate methods also have the added benefit of being capable of reducing along a single axis. For example, if you have a collection of examples in a batch, you might only want to compute the mean for single examples:

iex> Nx.mean(Nx.tensor([[1, 2, 3], [4, 5, 6]]), axes: [1])
#Nx.Tensor<
  f32[2] 
  [2.0, 5.0]
>

You can even provide multiple axes:

iex> Nx.mean(Nx.tensor([[[1, 2, 3], [4, 5, 6]]]), axes: [0, 1])
#Nx.Tensor<
  f32[3] 
  [2.5, 3.5, 4.5]
>

Nx also has binary operators that are tensor aware. Things like addition, subtraction, multiplication, and division work element-wise:

iex> Nx.add(Nx.tensor([1, 2, 3]), Nx.tensor([4, 5, 6]))
#Nx.Tensor<
  s64[3]
  [5, 7, 9]
>
iex> Nx.subtract(Nx.tensor([[1, 2, 3]]), Nx.tensor([[4, 5, 6]]))
#Nx.Tensor<
  s63[1][3]
  [-3, -3, -3]
>
iex> Nx.multiply(Nx.tensor([[1], [2], [3]]), Nx.tensor([[4], [5], [6]]))
#Nx.Tensor<
  s64[3][1]
  [
    [4],
    [10],
    [18]
  ]
>
iex> Nx.divide(Nx.tensor([1, 2, 3]), Nx.tensor([4, 5, 6]))
#Nx.Tensor<
  f32[3] 
  [0.25, 0.4000000059604645, 0.5]
>

With binary operators, however, there is an additional caveat: the tensor shapes must be compatible or capable of being broadcasted to the same shape. Broadcasting occurs when the input tensors have different shapes:

iex> Nx.add(Nx.tensor(1), Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  [2, 3, 4]
>

In the previous example, the scalar tensor 1 is broadcasted over the larger tensor. Broadcasting can be used to implement more memory-efficient routines by relaxing the need to work with tensors of the same shape. For example, if you need to multiply a 50x50x50 tensor by 2, you can use broadcasting to turn the operation into a loop which iterates over the 50x50x50 tensor, multiplying each element by 2, rather than creating another 50x50x50 tensor of all 2s.

In order for two tensors to be capable of broadcasting, each of their dimensions must be compatible. Dimensions are compatible if one of the following requirements is met:

1) They are equal 2) One of the dimensions is size 1

When you attempt to broadcast incompatible tensors, you’ll be met with the following runtime error:

iex> Nx.add(Nx.tensor([[1, 2, 3], [4, 5, 6]]), Nx.tensor([[1, 2], [3, 4]]))
** (ArgumentError) cannot broadcast tensor of dimensions {2, 3} to {2, 2}
    (nx 0.1.0-dev) lib/nx/shape.ex:241: Nx.Shape.binary_broadcast/4
    (nx 0.1.0-dev) lib/nx.ex:2430: Nx.element_wise_bin_op/4

If necessary, you can get around broadcasting issues by expanding, padding, or slicing one of the input tensors; however, you should carefully consider how this might affect the outcome of your algorithm.

Basic Linear Regression

So far, we’ve spent all of our time in iex with trivial examples and demonstrations of tensor operations. All of our work could have been done with some clever use of Enum and lists. In this section, we’ll start to unlock some of the real power of Nx by solving a basic linear regression problem using gradient descent.

You’ll want to start by creating a new Mix project that imports Nx, as well as an Nx compiler or backend. In this example, I’ll be using EXLA; however, you can use Torchx with some minor adjustments to this example. There are some fundamental differences between EXLA and Torchx that are outside the scope of this post; however, both of them will work fine for this example.

At the time of this writing, Nx is still not available on Hex, so you’ll need to use a Git dependency in your mix.exs:

def deps do
  [
    {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
    {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}
  ]
end

Now you can run:

$ mix deps.get && mix deps.compile

If this is your first time compiling EXLA, it will take quite a bit of time on the first compilation. You’ll also want to take a look at the installation section of the EXLA README for prerequisites and troubleshooting steps.

Once both Nx and EXLA are compiled, create a new file, regression.exs somewhere inside your Mix project. Inside regression.exs, create a module and import Nx.Defn:

defmodule LinReg do
  import Nx.Defn
end

Nx.Defn is a module that contains the Nx defn definition. defn is a macro for declaring numerical definitions. Numerical definitions work just the same as regular Elixir functions; however, they support a limited subset of the Elixir programming language, in favor of supporting JIT compilation to accelerators such as GPUs. defn also replaces much of the Elixir kernel with Nx specific implementations. As an example:

defn add_two(a, b) do
  a + b
end

will work on both tensors and scalars, because + internally resolves to Nx.add/2. defn also has support for a special transformation: grad. grad is a macro that returns the gradient of a function with respect to some provided parameters. The gradient of a function provides information about the rate of change of a function with respect to some parameters. A complete discussion of gradients falls outside the scope of this post - for now, you’ll just need to know how to use grad, and what it means at a high-level.

As I mentioned before, we’ll be implementing a basic linear regression model using gradient descent. Linear regression is an approach to modeling the relationship between some number of input variables and an output variable. The input variables are called the explanatory variables because they are assumed to have a causal relationship which explains the behavior of an output variable. As a practical example, imagine you want to predict the number of daily average users to your website based on the month, time of day, and whether or not there is an ongoing promotion on the website. You can collect data over the course of several months, and then use this data to fit a basic regression model that predicts daily average users for you.

In our example, we’ll create a regression model that predicts an output variable with respect to 1 input variable. We’ll start by defining our training set outside of the LinReg module:

target_m = :rand.normal(0.0, 10.0)
target_b = :rand.normal(0.0, 5.0)
target_fn = fn x -> target_m * x + target_b end
data =
  Stream.repeatedly(fn -> for _ <- 1..32, do: :rand.uniform() * 10 end)
  |> Stream.map(fn x -> Enum.zip(x, Enum.map(x, target_fn)) end)
IO.puts("Target m: #{target_m}\tTarget b: #{target_b}\n")

First, we define target_m, target_b and target_fn. Our linear function has the form: y = m*x + b, so we create a Stream that repeatedly generates batches of input and output pairs by applying target_fn on random inputs. Our goal is to learn target_m and target_b using gradient descent.

The next thing we need to define is our model. A model is really just a parameterized function that maps inputs to outputs. We know our function should have the form y = m * x + b, so we can easily define our model in the same way:

defmodule LinReg do
  import Nx.Defn
  defn predict({m, b}, x) do
    m * x + b
  end
end

Next, we need to define a loss function. Loss functions evaluate predictions with respect to true data, often to measure the divergence between a model’s representation of the data-generating distribution and the true representation of the data-generating distribution. This essentially means that loss functions tell you how poor your model performs. The goal is to minimize your loss function by fitting a function to a target function.

With linear regression problems, it’s most common to use mean-squared error (MSE) as the loss function:

defn loss(params, x, y) do
  y_pred = predict(params, x)
  Nx.mean(Nx.power(y - y_pred, 2))
end

MSE measures the average squared difference between our targets and predictions. As our predictions get closer to our targets, MSE tends towards 0. Given our loss function, we need a way to update our model such that it minimizes loss/3. We can achieve this using gradient descent. Gradient descent calculates the gradient of a loss function with respect to the input parameters. The gradient then provides information on how to update model parameters.

It can be difficult to understand exactly what gradient descent is doing at first. Imagine you want to find the deepest point in a lake. You have a depth finder on your boat, but no other information. You could search over the entire lake; however, this would take an impossible amount of time. Instead, you can use your depth finder to iteratively find deeper and deeper points in smaller areas of the lake. For example, if you know traveling left increases depth from 5 to 7 meters and traveling right decreases depth from 5 to 3 meters, you would choose to move your boat left. This is, in essence, what gradient descent is doing - it gives you depth-finding information you can use to navigate a parameter space.

You can implement your update state by calculating the gradient with respect to your loss function, and using the gradient to update each parameter, like this:

defn update({m, b} = params, inp, tar) do
  {grad_m, grad_b} = grad(params, &loss(&1, inp, tar))
  {
    m - grad_m * 0.01,
    b - grad_b * 0.01
  }
end

grad takes the parameters you want to evaluate the gradient at, as well as a parameterized function - in this case the loss function. grad_m and grad_b are the gradients of m and b respectively. You use the gradients to update m by scaling the gradients by a factor of 0.01 and then subtracting this value from m. 0.01 is also called the learning rate. You want to take small steps; large jumps cause you to move too erratically within the parameter space and inhibit learning.

update returns the updated parameters m and b. At this point, however, you need some initial starting point for both m and b. Revisiting the depth-finding example, imagine you have a friend who has some intuition about the general location of the deepest point in the lake. He tells you where to start your search, and thus you have a better shot at finding the deepest point. This is essentially the same as parameter initialization. You need to have a good starting point in order to effectively learn a good parameterization of your model:

defn init_random_params do
  m = Nx.random_normal({}, 0.0, 0.1)
  b = Nx.random_normal({}, 0.0, 0.1)
  {m, b}
end

init_random_params uses Nx.random_normal/3 to initialize m and b using a normal distribution with mean 0.0 and standard deviation 0.1. Now, you need to write a training loop. A training loop takes batches of examples, and applies update after each batch, halting only after some condition is met. In this example, we’ll train on 200 batches for a total of 10 epochs or full training iterations:

def train(epochs, data) do
  init_params = init_random_params()
  for _ <- 1..epochs, reduce: init_params do
    acc ->
      data
      |> Enum.take(200)
      |> Enum.reduce(
        acc,
        fn batch, cur_params ->
          {inp, tar} = Enum.unzip(batch)
          x = Nx.tensor(inp)
          y = Nx.tensor(tar)
          update(cur_params, x, y)
        end
      )
  end
end

In the training loop, we take 200 batches from the stream and iteratively update the model parameters after each batch. We repeat this process epochs number of times, returning the updated params after every epoch. Now, we just need to call LinReg.train/2 to return the learned m and b:

{m, b} = LinReg.train(100, data)
IO.puts("Learned m: #{Nx.to_scalar(m)}\tLearned b: #{Nx.to_scalar(b)}")

Overall, regression.exs should now look like:

defmodule LinReg do
  import Nx.Defn
  defn predict({m, b}, x) do
    m * x + b
  end
  defn loss(params, x, y) do
    y_pred = predict(params, x)
    Nx.mean(Nx.power(y - y_pred, 2))
  end
  defn update({m, b} = params, inp, tar) do
    {grad_m, grad_b} = grad(params, &loss(&1, inp, tar))
    {
      m - grad_m * 0.01,
      b - grad_b * 0.01
    }
  end
  defn init_random_params do
    m = Nx.random_normal({}, 0.0, 0.1)
    b = Nx.random_normal({}, 0.0, 0.1)
    {m, b}
  end
  def train(epochs, data) do
    init_params = init_random_params()
    for _ <- 1..epochs, reduce: init_params do
      acc ->
        data
        |> Enum.take(200)
        |> Enum.reduce(
          acc,
          fn batch, cur_params ->
            {inp, tar} = Enum.unzip(batch)
            x = Nx.tensor(inp)
            y = Nx.tensor(tar)
            update(cur_params, x, y)
          end
        )
    end
  end
end
target_m = :rand.normal(0.0, 10.0)
target_b = :rand.normal(0.0, 5.0)
target_fn = fn x -> target_m * x + target_b end
data =
  Stream.repeatedly(fn -> for _ <- 1..32, do: :rand.uniform() * 10 end)
  |> Stream.map(fn x -> Enum.zip(x, Enum.map(x, target_fn)) end)
IO.puts("Target m: #{target_m}\tTarget b: #{target_b}\n")
{m, b} = LinReg.train(100, data)
IO.puts("Learned m: #{Nx.to_scalar(m)}\tLearned b: #{Nx.to_scalar(b)}")

Now, you can run this example:

$ mix run regression.exs
Target m: -0.057762353079829236 Target b: 0.681480460783122
Learned m: -0.05776193365454674 Learned b: 0.6814777255058289

Notice how our learned m and b are almost identical to the target m and b! We’ve successfully implemented linear regression using gradient descent; however, there’s one thing we can do to take this implementation to the next level.

You should have noticed that training for 100 epochs took a noticeable amount of time. That’s because we’re not taking advantage of defn JIT compilation with EXLA. Because this is a relatively small example, we don’t really need JIT compilation; however, you will want to accelerate your models as your implementations get more complex. First, so we can really see the difference between EXLA and pure Elixir, let’s time how long model training takes:

{time, {m, b}} = :timer.tc(LinReg, :train, [100, data])
IO.puts("Learned m: #{Nx.to_scalar(m)}\tLearned b: #{Nx.to_scalar(b)}\n")
IO.puts("Training time: #{time / 1_000_000}s")

and then run again without any acceleration:

$ mix run regression.exs
Target m: -1.4185910271067492 Target b: -2.9781437461823965
Learned m: -1.4185925722122192  Learned b: -2.978132724761963
Training time: 4.460695s

Once again, we successfully learned m and b. This time, we can see that training took about 4.5 seconds. Now, in order to take advantage of JIT compilation using EXLA, add the following attribute to your module:

defmodule LinReg do
  import Nx.Defn
  @default_defn_compiler EXLA
end

This tells Nx to use the EXLA compiler to compile all of the numerical definitions in the module. Now, run the example again:

Target m: -3.1572039775886167 Target b: -1.9610560589959405
Learned m: -3.1572046279907227  Learned b: -1.961051106452942
Training time: 2.564152s

The results are exactly the same, but we we’re able to train in 2.6 seconds versus 4.5 seconds - an almost 60% speedup! Admittedly, this is a relatively trivial example, and the speedup you’re seeing here is only a fraction of what you would see with more complex implementations. As an example, you can attempt to run a pure Elixir implementation of the MNIST example in the EXLA repository and a single epoch will take hours to complete whereas the EXLA-compiled version will complete in anywhere from 0.5s to 4s per epoch - depending on the accelerator and machine you’re using.

Conclusion

This post covered a lot of the Nx core functionality. You learned:

1) How to create tensors using Nx.tensor and Nx.from_binary 2) How to use unary, binary, and aggregate operations to manipulate tensors 3) How to implement gradient descent using defn and Nx automatic differentiation with grad 4) How to accelerate numerical definitions using the EXLA compiler

While this post covered the basics of what’s needed to get started with Nx, there’s still much more to learn. I hope this post motivates you to continue learning about the Nx project and inspires you to seek out unique applications of Nx in practice. Nx is still in its infancy, and there are many more exciting things ahead!

DockYard is a digital product agency offering custom software, mobile, and web application development consulting. We provide exceptional professional services in strategy, user experience, design, and full stack engineering using Ember.js, React.js, Elixir, Ruby, and other technologies. With staff nationwide, we’ve got consultants in key markets across the U.S., including Portland, San Francisco, Los Angeles, Salt Lake City, Minneapolis, Dallas, Miami, Washington D.C., and Boston.