Training LoRA Models with Axon

A close-up photograph of a computer chip.
Sean Moriarity

Machine Learning Advisor

Sean Moriarity

Whether you need machine learning know-how, design to take your digital product to the next level, or strategy to set your roadmap, we can help. Book a free consult to learn more.

Mix.install([
  {:bumblebee, github: "elixir-nx/bumblebee"},
  {:axon, github: "elixir-nx/axon", override: true},
  {:nx, "~> 0.7"},
  {:exla, ">= 0.0.0"},
  {:explorer, "~> 0.9.0"},
  {:table_rex, "~> 3.1.1", override: true}
])

Nx.default_backend(EXLA.Backend)

Introduction

Low-Rank Adaptation (LoRA) is an extremely popular topic in the world of model fine tuning. LoRA addresses one of the critical challenges of local-first model development on commercial GPUs: memory limitations. LoRA enables researchers and developers with limited computational resources to fine-tune state-of-the-art models for specific tasks.

LoRA works by adding small, trainable “rank decomposition” matrices to the existing weights of the model. This approach allows for task-specific adaptation while keeping most of the original pre-trained weights frozen. The result? Dramatically smaller file sizes, reduced memory requirements, and faster training times - all without significant loss in model performance.

In the Elixir ecosystem, the Lorax library by Ted Wong, was the first implementation of LoRA in Axon and Elixir. As with any good “frontier” library, Lorax had to rely on internal APIs and some hacks to get LoRA to run properly. After looking over the library, I knew we needed to add some APIs to Axon that made LoRA much simpler. However, I wanted something that was not necessarily tightly coupled to LoRA in particular. This lead to the introduction of graph rewriters in Axon, which have proved useful even beyond simple LoRA implementations.

In this article, we’ll discuss how to fine-tune a LoRA model in Elixir, and we’ll dive deep into Axon graph rewriters and some of the internals that make Axon tick.

What is a model?

In order to fully understand the approach Axon takes to implementing LoRA, you need to understand a bit about Axon’s internal representation of a model. Axon builds on top of Nx to provide conveniences for creating neural networks. Every Axon model is represented using the %Axon{} data structure, which is a graph-like data structure:

  defstruct [
    :nodes,
    :output
  ]

nodes is a map of integer ID to an %Axon.Node{} type:

defmodule Axon.Node do
  @moduledoc false

  defstruct [
    :id,
    :name,
    :mode,
    :parent,
    :parameters,
    :args,
    :op,
    :policy,
    :hooks,
    :opts,
    :global_options,
    :op_name,
    :meta,
    :stacktrace
  ]
end

output is a list of output node IDs, that points to the particular nodes which represent the outputs for this model. The %Axon.Node{} data structure references it’s parent layers (e.g. it’s inputs) using the same integer IDs that map integers to nodes in the nodes map Originally, the %Axon{} data structure was recursive; however, this lead to some interesting challenges when compiling large models - particulary because the entire %Axon{} struct is serialized during the Nx compilation process.

This data structure contains all of the information and metadata necessary for “lowering” Axon models into Nx expressions that can then be JIT-compiled using XLA or another Nx compiler. It works well when you’re building models from scratch, but what if you need to rewrite parts of an existing model?

Understanding LoRA and Graph Rewriting

LoRA works by replacing updates to full-weight matrices with low-rank version. Implementing LoRA in an existing model requires us to modify an existing model’s architecture, replacing certain operations such as dense layers with a LoRA version. Something that makes this a bit more difficult, is that a new LoRA layer relies on the original layer’s input and output. It requires a more complex graph manipulation than simply replacing the operation that a given node uses. This kind of manipulation requires a more flexible graph manipulation API than what Axon had. Enter Axon.rewrite_nodes.

Axon.rewrite_nodes allows us to manipulate the graph structure of existing models. The rewrite_nodes function traverses the nodes in an Axon model’s graph and applies a user-defined rewrite function to each node. This rewrite function can either modify the node, replace it entirely, or leave it unchanged.

Axon.rewrite_nodes takes an Axon model and then a function which matches on %Axon.Node{}:

tanh_rewriter = fn [%Axon{} = x], _output ->
  Axon.tanh(x)
end

new_model = Axon.rewrite_nodes(model, fn
  %Axon.Node{op: :relu} -> tanh_rewriter
  _ -> :skip
end)

The function returns either a rewriter or :skip - meaning leave the original node alone. Rewriters take the original list of node inputs, as well as the original node output as regular %Axon{} structs. This means you can work on them using regular Axon APIs. The rewrite function will essentially replace the node output y with the function f(y, x) where y is the originally output, x is the original input, and f is your function.

The simple example above just rewrites all relu layers as tanh layers; however, we can use this to do even more complex manipulations. For example, rewriters are used to implement quantization:

quantized_dense_rewriter = fn [%Axon{} = x], _output, name_fn, units, use_bias ->
  weight_only_quantized_dense(x, units,
    use_bias: use_bias,
    name: name_fn
  )
end

Axon.rewrite_nodes(model, fn
  %Axon.Node{op: :dense, meta: meta, name: name_fn} ->
    &quantized_dense_rewriter.(&1, &2, name_fn, meta[:units], meta[:use_bias])

  _ ->
    :skip
end)

The purpose of Axon.rewrite_nodes is to make it simpler to implement graph manipulations without disconnecting the original graph. We can use this tool to implement LoRA.

Fine-tuning a LoRA Llama 3

In order to demonstrate how to implement LoRA in Elixir, we’ll fine-tune a LoRA version of Llama 3 8b. To start, install the following dependencies:

Mix.install([
  {:bumblebee, github: "elixir-nx/bumblebee"},
  {:axon, github: "elixir-nx/axon", override: true},
  {:nx, "~> 0.7"},
  {:exla, ">= 0.0.0"},
  {:explorer, "~> 0.9.0"}
])

Nx.default_backend(EXLA.Backend)

We’ll train a model to predict function names from a decompiled representation of the function using this dataset. The dataset consists of a decompiled function such as:

void fcn.140030b80(ulong param_1, ulong param_2, ulong param_3) {
  ulong uVar1; uVar1 = fcn.140030ae0(param_3);
  fcn.14002efc0(param_1, param_2, uVar1);
  return;
}

And a corresponding name: process_with_params. We’ll train a LoRA version that predicts the name as a JSON object. Download the dataset’s parquet file and place it in a local directory. Next, use Explorer to load the original dataset:

df = Explorer.DataFrame.from_parquet!("function_names.parquet")

Now, we’ll go ahead and do some pre-processing on the original dataset to attach a prompt to input and wrap the output in a valid JSON object:

output = fn name ->
  Jason.encode!(%{name: name})
end

prompt = fn code ->
  "<|start_header_id|>system<|end_header_id|>" <>
    "Given the following disassembled code, provide a descriptive" <>
    " function name for the code. Your function name should" <>
    " accurately describe the purpose of the code. It should" <>
    " be formatted in C style with lowercase and snakecase." <>
    " Only output the name as valid JSON, e.g. #{output.("function_name")}" <>
    "<|eot_id|><|start_header_id|>user<|end_header_id|>" <>
    "Code: #{code}" <>
    "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
end

Now we can apply this to both the input and output series’ in the dataframe:

df =
  Explorer.DataFrame.transform(df, [names: ["input", "output"]], fn row ->
    %{"input" => inp, "output" => out} = row
    prompt = prompt.(inp)
    output = output.(out)

    %{"sequence" => prompt <> output}
  end)

Now, let’s load the Llama-3 tokenizer and create a dataset from each example:

token = System.get_env("LB_HF_AUTH_TOKEN")
repo = {:hf, "meta-llama/Meta-Llama-3-8B-Instruct", auth_token: token}
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
tokenizer = Bumblebee.configure(tokenizer, length: 1024)

train_dataset =
  df
  |> Explorer.DataFrame.pull("sequence")
  |> Explorer.Series.to_list()
  |> Stream.map(fn sequence ->
    inputs =
      tokenizer
      |> Bumblebee.apply_tokenizer(sequence)
      |> Map.take(["input_ids"])

    {inputs, Map.fetch!(inputs, "input_ids")}
  end)

Next, we need to load the original model:

{:ok, %{model: model, params: params}} = Bumblebee.load_model(repo)

And now we need to rewrite the model to use LoRA for specific layers:

lora_config = %{
  rank: 8,
  alpha: 16,
  dropout: 0.2
}

lora_rewriter = fn [%Axon{} = x], %Axon{} = wx, name, units ->
  lora_a_name = "#{name}.lora_a"
  lora_b_name = "#{name}.lora_b"

  scaling = lora_config[:alpha] / lora_config[:rank]

  x
  |> Axon.dropout(rate: lora_config[:dropout])
  |> Axon.dense(lora_config[:rank], name: lora_a_name, use_bias: false)
  |> Axon.dense(units, name: lora_b_name, use_bias: false)
  |> Axon.multiply(Axon.constant(scaling))
  |> Axon.add(wx)
end

lora_model =
  Axon.rewrite_nodes(model, fn
    %Axon.Node{op_name: op, name: name, meta: meta} ->
      node_name = name.(op, %{})

      if String.contains?(node_name, ["query", "value"]) do
        &lora_rewriter.(&1, &2, node_name, meta[:units])
      else
        :skip
      end
  end)

This code replaces the query and value projections of the original Llama model with LoRA versions of the same. Now, we need to initialize a new model state such that it contains parameters for our LoRA layers:

{init_fn, _predict_fn} = Axon.build(lora_model)

input_template = %{"input_ids" => Nx.template({1, 1024}, :s64)}
lora_model_state = init_fn.(input_template, params)

Now we just need to freeze all of the non-LoRA layers in our model state:

lora_model_state =
  Axon.ModelState.freeze(lora_model_state, fn
    [layer | _] -> not String.contains?(layer, "lora")
  end)

Now, we just need to implement a causal loss function and we can start fine-tuning our model on our input dataset. You can do this with the following code:

defmodule Trainer do
  import Nx.Defn

  defn causal_loss(labels, logits, opts \\ []) do
    opts = keyword!(opts, [:pad_token_id])
    # shift logits left and labels right
    labels = labels[[.., 1..-1//1]]
    logits = logits[[.., 0..-2//1, ..]]

    padding_mask = Nx.equal(labels, opts[:pad_token_id])

    Nx.select(
      padding_mask,
      0.0,
      Axon.Losses.categorical_cross_entropy(Nx.new_axis(labels, -1), logits,
        from_logits: true,
        sparse: true
      )
    )
    |> Nx.mean()
  end
end
pad_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, "<|eot_id|>")
lora_logits = Axon.nx(lora_model, & &1.logits)

loss = &Trainer.causal_loss(&1, &2, pad_token_id: pad_token_id)

optimizer = :sgd

trained_model_state =
  lora_logits
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.run(train_dataset, lora_model_state, epochs: 1, compiler: EXLA)

And that’s all it takes. With just a few lines of code you can fine-tune a LoRA model using Axon and Bumblebee!

Conclusion

The training and serving story in the Elixir ecosystem is getting progressively better. If you’re working on interesting machine learning problem sets, and you’re tempted to make the leap from Python to Elixir, give it a shot. We’ve come a long way in the last year. Until next time!

Newsletter

Stay in the Know

Get the latest news and insights on Elixir, Phoenix, machine learning, product strategy, and more—delivered straight to your inbox.

Narwin holding a press release sheet while opening the DockYard brand kit box