Elixir Machine Learning: Training Models in Axon is Getting Better

Red, glowing lines against a black background that resemble neurons and axons
Sean Moriarity

Machine Learning Advisor

Sean Moriarity

From machine learning to native development, the Elixir ecosystem is growing the meet the needs of modern-day digital products. Book a free consult today to learn how we can put it to work for you.

Introduction

I’ve done a couple of posts in the past on fine-tuning LLMs with Elixir; however, I’ve been working recently to improve the training performance and experience with Elixir. Since it’s been a while and a lot has changed, I thought it would be interesting to document the process, some of the improvements we’ve made, and some of the improvements we’re still making.

In this post, I’ll walk you through what the fine-tuning process looks like in Elixir, and how we’re working to make it even better moving forward.

Before starting, you’ll need to pull a few dependencies in:

Mix.install(
[
{:bumblebee, github: "elixir-nx/bumblebee", branch: "main", override: true},
{:axon, github: "elixir-nx/axon", branch: "main", override: true},
{:polaris, github: "elixir-nx/polaris", branch: "main", override: true},
{:nx, "~> 0.7", override: true},
{:exla, ">= 0.0.0"},
{:explorer, "~> 0.7.0"},
{:table_rex, "~> 3.1.1"}
],
force: true,
system_env: %{
"XLA_TARGET" => "cuda120"
}
)

Nx.default_backend(EXLA.Backend)

Now, let’s dig in.

Classifying Yelp Reviews

One of the first write-ups I did on fine-tuning came shortly after Bumblebee was released. The guide walks a user through fine-tuning a Bert-based model to categorize Yelp reviews. It’s a simple example, but it effectively demonstrates how to go about training and fine-tuning models in Axon. We’ll start by working through the original example and highlight where things have changed.

To start, let’s prepare our dataset for fine-tuning in the same way that is presented in the linked guide:

defmodule Yelp do
def load(path, tokenizer, opts \\ []) do
path
|> Explorer.DataFrame.from_parquet!()
|> Explorer.DataFrame.rename(["label", "text"])
|> stream()
|> tokenize_and_batch(tokenizer, opts[:batch_size], opts[:sequence_length])
end

def stream(df) do
xs = df["text"]
ys = df["label"]

xs
|> Explorer.Series.to_enum()
|> Stream.zip(Explorer.Series.to_enum(ys))
end

def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length) do
tokenizer = Bumblebee.configure(tokenizer, length: sequence_length)

stream
|> Stream.chunk_every(batch_size)
|> Stream.map(fn batch ->
{text, labels} = Enum.unzip(batch)
tokenized = Bumblebee.apply_tokenizer(tokenizer, text)
{tokenized, Nx.stack(labels)}
end)
end
end

This defines a module that we can use to load the Yelp reviews dataset. You’ll need to download the dataset somewhere locally first. You can get a copy of it here. I’ve made a slight modification to the original code to use Parquet rather than CSVs. Now we can use Yelp.load/3 to pull in training and testing data. Before we do that, however, we’ll need to pull in a Tokenizer:

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"})

And now we can build the train and test sets:

batch_size = 128
sequence_length = 128

train_data =
Yelp.load("yelp/yelp_review_full/yelp_review_full/train-00000-of-00001.parquet", tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)

test_data =
Yelp.load("yelp/yelp_review_full/yelp_review_full/test-00000-of-00001.parquet", tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)

Improvement #1: The new Axon Model State

Now, let’s pull in a pre-trained Bert model to highlight the first major improvement we’ve made to the fine-tuning process:

{:ok, spec} =
  Bumblebee.load_spec({:hf, "google-bert/bert-base-cased"},
    architecture: :for_sequence_classification
  )

spec = Bumblebee.configure(spec, num_labels: 5)

{:ok, model} = Bumblebee.load_model({:hf, "google-bert/bert-base-cased"}, spec: spec)

Take a second to inspect the model.params value:

IO.inspect(model.params)
#Axon.ModelState<
  Parameters: 108314117 (433.26 MB)
  Trainable Parameters: 108314117 (433.26 MB)
  Trainable State: 0 (0 B)
>

You might notice a big difference here. In older versions of Axon, model parameters were always represented as regular Elixir maps. I kept it this way for a long time because I believed there was no reason to wrap a perfectly good data structure with a custom API. The map approach worked well enough, but it had some quirks.

One example is that Axon did not actually differentiate internally between model state and model parameters. For example, layers such as batch normalization keep track of a mean and variance throughout training.

Model states such as the mean and variance in this case are not meant to be included in the optimization process that happens during training. Because Axon did not differentiate between them, I had to implement a hack that “patched” the value of the model state after each optimization step.

Even though this state should never have been a part of the optimization process, older versions of Axon would still include them in both the gradient computation and the update computation during training.

The new model state includes metadata that explicitly differentiates between the model parameters and model state–meaning we no longer have to worry about state leaking into the optimization process.

Additionally, Ted Wong pointed out to me that freezing model layers still resulted in large amounts of memory usage during training.

Previously, in order to “freeze” layers in Axon, you had to use an API that would mark a model’s nodes as frozen or not-trainable. The Axon compiler would recognize these frozen layers and wrap them in a stop_grad node to tell Nx’s automatic differentiation system to return zero gradient for nodes in that path.

This served the purpose effectively, but it’s rather wasteful from a resource perspective. Typically, when you train a neural network, a lot of memory usage comes from the optimizer state. For each parameter that you’ve marked as trainable, you need to initialize some state.

With our old approach, we would always initialize the optimizer state for every parameter, even if it was frozen or represented state. Additionally, we would always return a value for the gradient of a frozen layer. You can see how wasteful this entire process was!

The new Axon.ModelState data structure and API wraps a model’s parameters into a struct that keeps around additional metadata about parameters, state, and what should be considered “trainable” within that state. This ends up being a significant memory optimization and also simplifies things such as quantization, mixed precision, and weight tying.

An additional benefit is that using the new model state API doesn’t really require any code changes. We can train our model like normal. For example, we could fine-tune the model like we do in the example guide with:

%{model: model, params: params} = model
logits_model = Axon.nx(model, & &1.logits)

loss =
  &Axon.Losses.categorical_cross_entropy(&1, &2,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )
accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)

optimizer = Polaris.Optimizers.adam(learning_rate: 5.0e-5)

model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(loop, event: :epoch_completed)
  |> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA)

Or, better yet, we can take advantage of Axon’s new optimizations and freeze the entire model except the final one. To do that, let’s first grab the model and model state:

%{model: model, params: model_state} = model
logits_model = Axon.nx(model, & &1.logits)

Now, we can freeze every parameter in the model except for the classification head. Axon.ModelState.freeze/2 takes a predicate that accepts a path to a parameter and should return a boolean value where true indicates we should freeze the parameter and false indicates we should not. In the future, we will make some efforts to simplify this API. To freeze the parameters we want, we can run:

model_state =
  Axon.ModelState.freeze(model_state, fn
    ["sequence_classification_head.output", _] -> false
    _ -> true
  end)

If you inspect the state, you will see:

#Axon.ModelState<
  Parameters: 108314117 (433.26 MB)
  Trainable Parameters: 3845 (15.38 KB)
  Trainable State: 0 (0 B)
>

Notice how we’ve suddenly gone from 108,314,117 trainable parameters to 3,845! Now, we can train our model in the same way as before:

loss =
  &Axon.Losses.categorical_cross_entropy(&1, &2,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )

accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)

optimizer = Polaris.Optimizers.adam(learning_rate: 5.0e-5)

model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(event: :epoch_completed)
  |> Axon.Loop.run(train_data, model_state,
    epochs: 3,
    compiler: EXLA,
    force_garbage_collection?: true
  )

If you haven’t tried doing this before, you might not appreciate how much faster and better the new Axon.ModelState data structure has made the training process. Previously, I would consistently run into issues with OOMs and really slow training–even on a machine with a 4090! Honestly, it was a bit impractical prior to this change to do any fine-tuning at all. Now, you shouldn’t run into many issues with OOMs, and if you do, please let me know!

Improvement #2: Mixed Precision Training

One of the biggest speed improvements you can get during training is to use mixed precision.

A lot of LLMs recently will train entirely in half precision. While Axon has had the ability to work with mixed precision for a while, it didn’t always work as expected. The implementation was buggy, and at times Axon would randomly cast values that should never have been cast in the first place.

Since the original fine-tuning guide for Bumblebee was written, we’ve done some work to make working with mixed precision models less buggy than before.

First, you can now just directly load models with Bumblebee in the precision you want:

{:ok, model} =
  Bumblebee.load_model({:hf, "google-bert/bert-base-cased"},
    spec: spec,
    policy: Axon.MixedPrecision.create_policy(compute: {:bf, 16})
  )

This will load a model that will do all computations in :bf16 while keeping the model outputs and parameters as :f32. Now, you can train this model as you would have before. You should notice a speed-up in training.

Another thing you might notice is that you can potentially double your training batch size thanks to the memory savings of computing everything in half rather than full precision. For example, we can first try to train a model without mixed precision on a dataset with a batch size of 256:

loss =
  &Axon.Losses.categorical_cross_entropy(&1, &2,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )

accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)

optimizer = Polaris.Optimizers.adam(learning_rate: 5.0e-5)

model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(event: :epoch_completed)
  |> Axon.Loop.run(train_data, model_state, epochs: 3, compiler: EXLA, force_garbage_collection?: true)

On a 4090 I get an OOM at this batch size:

** (RuntimeError) Out of memory while trying to allocate 19832509512 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.21GiB
              constant allocation:     2.1KiB
        maybe_live_out allocation:    1.21GiB
     preallocated temp allocation:   18.47GiB
  preallocated temp fragmentation:   96.00MiB (0.51%)
                 total allocation:   20.89GiB
              total fragmentation:    1.05GiB (5.05%)

But, if we try again after applying mixed precision to our model, it will run just fine:

loss =
  &Axon.Losses.categorical_cross_entropy(&1, &2,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )

accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)

optimizer = Polaris.Optimizers.adam(learning_rate: 5.0e-5)

model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1, loss_scale: :dynamic)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(event: :epoch_completed)
  |> Axon.Loop.run(train_data, model_state, epochs: 3, compiler: EXLA, force_garbage_collection?: true)

Note that you need to specify a loss scale in order to work with mixed precision properly. In this case, you can just specify loss_scale: :dynamic and everything will work fine.

Improvement #3: Gradient Accumulation

Another minor improvement we’ve made is around the ergonomics of gradient accumulation in Axon.

Gradient accumulation is a way to increase the effective batch size during training without consuming much additional memory by accumulating the gradients for multiple steps before performing an update. The previous implementation was tied directly to Axon’s training loop as an option :gradient_accumulation_steps. I always felt these ergonomics were a bit strange.

Another issue with the previous implementation is that it would conditionally perform a large optimization step, or just return 0s for the updates, which seemed to create issues for the XLA optimizer when compiling the large conditional expression.

Now, we’ve moved gradient accumulation from Axon to Polaris and implemented it as a Polaris.Updates function:

Polaris.Optimizers.adam(learning_rate: 5.0e-5)
|> Polaris.Updates.accumulate_gradients(5)

This implementation will accumulate gradients for five steps before performing an update. Interestingly, this approach seems to have eliminated some of the issues XLA had with optimizing large conditionals previously.

Additionally, the new usage of Axon.ModelState comes in handy here too. Previously, we would store the entire gradient state, which for any model was the size of all of the parameters in the map. Now, we store a gradient state only for trainable parameters, which in some cases (like this Yelp example), is much smaller than before. If you re-run training with gradient accumulation:

loss =
  &Axon.Losses.categorical_cross_entropy(&1, &2,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )

accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)

optimizer =
  Polaris.Optimizers.adam(learning_rate: 5.0e-5)
  |> Polaris.Updates.accumulate_gradients(5)

model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1, loss_scale: :dynamic)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(event: :epoch_completed)
  |> Axon.Loop.run(train_data, model_state, epochs: 3, compiler: EXLA, force_garbage_collection?: true)

You may notice a slight boost in training speed and results.

Improvement #4: General Nx Improvements

One final improvement that Axon benefits from for free are the many improvements Paulo has been making upstream to Nx and EXLA. We’ve just completed a migration to MLIR, and Paulo has been working on several other improvements that should prove immensely useful for folks looking to adopt Nx.

The best part about our ecosystem is that any improvements to the performance of Nx, Axon gets downstream without any changes to the library at all.

Improvement #5: LoRA and Galore

Low-rank adaptation (LoRA) is a method for fine-tuning large language models that significantly reduces memory consumption by reducing the number of trainable parameters. LoRA models replace the trainable parameters of certain large layers in a model with trainable low-rank adapters. Thanks to the work of Ted Wong and SpawnFest, there is a framework for performing LoRA in Elixir called Lorax.

Additionally, thanks to Ted’s feedback, I have a number of changes I am working on to improve the ergonomics of training LoRA models with Axon and Bumblebee.

Another implementation along the same lines as LoRA called Galore, was released earlier this year and promised the ability to train full-rank LLMs on consumer GPUs. Galore is a special kind of optimizer that improves the memory efficiency of training LLMs. We are currently working on an implementation of Galore as an optimizer in Polaris.

Still Coming…

I have been motivated by the promise of some of the performance improvements we’ve made in the last month, and plan to continue pushing to improve the model training story in Elixir. Over the next year, you should hopefully see improvements to support some or all of the following:

  • Model quantization (for inference and training, e.g. QLoRA)
  • Ergonomic improvements to LoRA
  • Fault-tolerant training loops (e.g. supervised training loops)
  • Training loops as streams
  • Fused training loops
  • Distributed training loops

And anything else you might ask for along the way :)

If you had previously tried and failed training or fine-tuning models with Axon, please let me know so I can help you fix the problems you had. Until next time!

Newsletter

Stay in the Know

Get the latest news and insights on Elixir, Phoenix, machine learning, product strategy, and more—delivered straight to your inbox.

Narwin holding a press release sheet while opening the DockYard brand kit box