Catching Up: Where are Nx, Axon, and Bumblebee Headed?

Robot's hand typing on keyboard
Sean Moriarity

Machine Learning Advisor

Sean Moriarity

Whether you need machine learning know-how, design to take your digital product to the next level, or strategy to set your roadmap, we can help. Book a free consult to learn more.

Introduction

In the 20 months since ChatGPT’s release in November of 2022, there have been a large number of developments for both inference/serving and training of Large Language Models (LLMs). Just to highlight some of these developments:

  1. Companies like Meta, Microsoft, Mistral, and Google have released open-source large language models that rival closed-source counterparts on certain evaluations.

  2. New and old quantization techniques applied to LLMs have pushed the limits of what models can run on consumer hardware.

  3. Efficient training algorithms like LoRA and Galore have enabled fine-tuning LLMs on consumer hardware.

  4. A large number of efficient attention implementations have popped up. The most popular, flash attention, enables significant inference and training speedups.

  5. LLM inference companies and model providers have driven the cost of inference down significantly. Along with this trend, we’ve seen an increasing number of hardware startups looking to challenge Nvidia’s supremacy.

  6. Transformer alternatives such as Mamba have shown the potential to challenge the ubiquity of the traditional Transformer architecture. Even still, most open-source LLM architectures seem to have converged a bit.

  7. Models still seem to be trending bigger. While we’ve seen better and better performing 7/8/9b parameter models, the best-performing models are still the largest and require multiple GPUs or machines to run.

  8. Projects like llamacpp have made it easy to get up and running with LLMs quickly. Llamacpp and similar projects have also encouraged a trend of local-first and portable models.

  9. Specialized function calling models and constrained text generation algorithms have helped bridge the gap between legacy software and LLMs. Projects like instructor have made it easy to turn LLM outputs into Ecto schemas compatible with your application.

This list is by no means comprehensive. It would be impossible to highlight everything that’s happened in this space over the last 20 months. That said, it is starting to feel like the dust has settled a bit (famous last words). It also feels like Nx, Axon, and Bumblebee have some catching up to do. In this post, I’ll highlight some of the key next steps in the works for Nx, Axon, and Bumblebee.

Quantization: Running Large Models on Small Hardware

Quantization is a technique used to reduce the memory requirements of LLMs without significantly compromising their performance. Quantization involves converting the model’s parameters, typically stored as 32-bit or 16-bit floating-point numbers, into lower-precision formats. Usually, model parameters are converted into 8-bit (or less) integers. This process effectively compresses the model, allowing it to run on devices with limited resources. While some minor accuracy loss may occur, quantization can maintain most of the model’s capabilities while dramatically reducing its size.

We have recently started work on adding quantization support to Axon, with the ultimate goal of having quantized types natively supported in Nx. The new Axon.Quantization module supports a single method for rewriting an Axon model and its parameters to use weight-only, 8-bit integer quantization. The current implementation only targets the :dense (linear) layers in an existing model. It also uses a somewhat naive inference implementation, which means the reduced memory footprint may trade off some inference speed. The current implementation is also inference-only, and we do not support any quantization aware training techniques. Despite these limitations, you can still immediately benefit from using Axon’s quantization implementation. Here’s a simple example that loads and quantizes Phi-3 mini LLM with Bumblebee (credits to Yurko Hoshko in the EEF ML Slack for trying this first):

get_quantized_phi = fn ->
  {:ok, %{params: model_state, model: model} = model_info} =
    Bumblebee.load_model({:hf, "microsoft/Phi-3-mini-4k-instruct"})

  IO.inspect(model_state, label: "Unquantized")
  {quantized_model, quantized_model_state} = Axon.Quantization.quantize(model, model_state)
  IO.inspect(quantized_model_state, label: "Quantized")
  %{model_info | model: quantized_model, params: quantized_model_state}
end

quantized_model = get_quantized_phi.()

:ok

After running this code, you should see a significant memory reduction between the quantized and unquantized parameters. The unquantized parameters are 15.28GB while the quantized parameters are 4.42GB. This is around a 71% memory reduction! One thing to note is that we wrapped the quantization process in a function. Because Elixir is a functional language, all of our operations are immutable. Axon’s quantization process returns a brand new model state struct. To ensure the old parameters do not continue to consume memory after quantization, we wrap the quantization process in a function. This means the old parameters will go out of scope and then will be garbage collected and freed. We can force this garbage collection with:

:erlang.garbage_collect()

Once you’ve quantized your model, there’s no difference in running it versus running a normal model. For example, I can take the quantized model and parameters and create a normal text generation serving:

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "microsoft/Phi-3-mini-4k-instruct"})

{:ok, generation_config} =
  Bumblebee.load_generation_config({:hf, "microsoft/Phi-3-mini-4k-instruct"})

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 32)

serving =
  Bumblebee.Text.generation(quantized_model, tokenizer, generation_config,
    compile: [batch_size: 1, sequence_length: 64],
    stream: true,
    defn_options: [compiler: EXLA]
  )

And then you can run the serving like normal:

Nx.Serving.run(serving, "What is your name?")
|> Enum.each(&IO.write/1)

What’s Next?

This small progress in quantization is a win; however, there’s still lots of work that needs to be done. First, we have some in-progress work for supporting FP8 natively in Nx. This should eventually enable us to run models in FP8. Next, we plan to add lower-precision integer types such as uint4 and int4 to Nx. This should enable Axon to support lower-precision quantization types. We will also continue adding more sophisticated quantization strategies to Axon. The plan is to loosely follow what pytorch/ao supports for quantization. Additionally, we will likely eventually support native quantized types in Nx. This should bring the benefits of quantization to any library that uses Nx.

Finally, I would like to support directly loading quantized models. Ideally, we do not have to initialize a model in full-precision in order to convert it to a quantized form. Along with this work, I plan on exploring the possibility of loading GGUF parameters directly into Elixir and Nx.

LoRA and Graph Rewriters

In one of my recent posts, I highlighted some of the improvements we’ve made to Axon’s training capabilities. In that post, I briefly touched on LoRA and the Lorax library. Over the last month or so, I’ve made some additions to Axon in an attempt to make LoRA implementations a bit easier. The main addition was the introduction of graph rewriters.

LoRA, or Low-Rank Adaptation, is an efficient fine-tuning technique for LLMs that significantly reduces the number of trainable parameters while maintaining model performance. LoRA works by adding small, trainable “rank decomposition” matrices to the existing weights of the model. This approach allows for task-specific adaptation without modifying the original pre-trained weights, resulting in dramatically smaller file sizes and reduced memory requirements. At a low-level LoRA works by replacing certain layers in a model with adapter layers. Previously, performing graph rewrites on Axon models was a challenge. Now, Axon supports a simple rewriter API that allows you to implement graph transformations to rewrite models on the fly. For example, you can write a transformation that replaces :dense layers in an LLM with a LoRA-based adapter:

lora_config = %{
  units_a: 128,
  units_b: 128,
  scaling: 0.9,
  dropout: 0.1
}

lora_dense_rewriter = fn [%Axon{} = x], %Axon{} = wx ->
  x
  |> Axon.dropout(rate: lora_config.dropout)
  |> Axon.dense(lora_config.units_a, use_bias: false, name: "lora_a")
  |> Axon.dense(lora_config.units_b, use_bias: false, name: "lora_b")
  |> Axon.multiply(Axon.constant(lora_config.scaling))
  |> Axon.add(wx)
end

lora_model = Axon.rewrite_nodes(model, fn
  %Axon.Node{op: :dense} -> lora_dense_rewriter
  _ -> :skip
end)

With these new improvements, implementing LoRA with Axon is much simpler than before. However, there’s still a bit of work to do. In the future, you can expect to see more tightly integrated LoRA support in Axon, as well as the ability to load and run pre-trained adapters. Finally, with some additional work on the quantization front, QLoRA should be possible in Axon very soon.

Model Sharding: Running LLMs across Multiple Machines

Model sharding is a technique used to distribute LLMs across multiple computational devices. Some LLMs are too large to fit on a single device and thus need to be split across several. Nx and Axon currently do not have any ability to shard or partition inference across multiple machines; however, this is one of my biggest priorities moving forward. You can expect to see developments on this front soon. You should also expect to see overall improvements in big model inference.

MLIR, IREE, and Inference Portability

One of the big behind-the-scenes developments of the last several months was EXLA’s transition from using the legacy XLA builder APIs to using MLIR.

MLIR, which stands for Multi-Level Intermediate Representation, is a compiler framework and representation format. It acts like a bridge between different programming languages and hardware platforms in the world of machine learning. MLIR allows developers to represent complex machine-learning models in a way that’s both flexible and efficient, making it easier to optimize these models for various types of hardware. Most users will not interact directly with MLIR; however, this change will still benefit them. First, we will get all of the portability benefits of MLIR out of the box. Second, MLIR opens a door to potentially introducing lower-level programs and kernels into Nx (think Triton-lang). Third, it opens the door for us to take advantage of projects like IREE.

In addition to leading the MLIR transition, Paulo Valente has been working on integrating the IREE runtime as a companion library to EXLA. One of the biggest benefits of IREE is that it will enable us to accelerate Nx programs on Mac Metal hardware (and potentially other Apple devices). IREE supports diverse deployment platforms as well. For example, it’s possible we could run Nx programs on WebGPU thanks to IREE.

Constrained Sampling

Constrained sampling in LLMs is a technique that guides the model’s output to follow specific patterns or rules. This approach allows developers to control the format and content of the generated text more precisely. It’s particularly useful when you need the LLM to produce structured data or adhere to certain constraints. While it’s possible to fine-tune models to reliably produce JSON or data in another format, constrained sampling can guarantee a model’s outputs will match a specific grammar.

The most popular application of constrained sampling is in generating structured outputs. You can convert JSON-schemas into grammars, and then use constrained sampling to generate valid outputs that follow the designated schema. To support Bumblebee as an adapter for a library like Instructor, we’re working on support for constrained sampling as a text generation technique in Bumblebee. This work has stalled a bit due to other priorities; however, I expect it will pick back up in the next month or so.

Opportunities to Contribute

There are always opportunities to contribute and help us further the Nx ecosystem. One area in particular where we could use some help is in contributing new Bumblebee architectures. I did a small write-up a year ago on the process of adding models. This process is relatively unchanged. The few developers we have working in our ecosystem have limited capacity to add new architectures, so if there’s one you want to see I recommend giving it a shot! It’s a great exercise in understanding how Nx, Axon, and Bumblebee work.

We are also always looking for individuals to contribute guides and tutorials. The more educational content we can produce, the better! If you are looking for some inspiration, I recommend taking a look at one of the Keras Examples and trying to replicate one in Axon. If you’re able to do so, I’d be happy to accept it as an official guide in the Axon documentation.

Conclusion

And that sums up some of the priorities and opportunities in the ecosystem moving into the end of 2024! I don’t think there is a better time to bet on Elixir and Nx than right now. Until next time.

Newsletter

Stay in the Know

Get the latest news and insights on Elixir, Phoenix, machine learning, product strategy, and more—delivered straight to your inbox.

Narwin holding a press release sheet while opening the DockYard brand kit box