Fine-Tuning Code Models in Elixir: How to Create a Code Copilot

View of a pastel sky from the copilot seat of an airplane
Sean Moriarity

Machine Learning Advisor

Sean Moriarity

Are you ready to get the most out of Elixir’s stability, concurrency, and scalability benefits? Get in touch to learn how our industry-leading Elixir engineers can help you reach your goals, faster.

Introduction

In a recent blog post from Hugging Face, they demonstrated the ability to fine-tune your own personal copilot on your own codebases using Python and the transformers library. In particular, they demonstrated fine-tuning Starcoder on Hugging Face’s own repositories.

As stated in the post, code completion models have been a significant innovation in the world of software development. Programmers can gain serious productivity boosts by using code completion models within their codebases. While these code completion models are very powerful, they are limited to the context provided within short snippets of your codebase. They have only a snapshot with which to work. Additionally, it’s not always possible to use these assistants when working in proprietary codebases. With a personally trained model, you can potentially address both of these issues at once.

First, your model theoretically can gain knowledge of your entire codebase during the fine-tuning process (though this doesn’t always work out how you’d expect). Second, because the model is owned by you and your organization, you can utilize it within proprietary codebases.

These fine-tuned code models offer additional potential benefits for Elixir programmers as well. While the best code completion models are trained on massive code datasets consisting of examples from many different programming languages, they are usually not that great at producing Elixir compared to programming languages with larger ecosystems such as Python and Java. This is a consequence of these models being trained on less Elixir code than other languages. By fine-tuning on more Elixir examples, we can potentially achieve performance gains on Elixir-specific problems.

Of course, all of these benefits I’ve listed come with a caveat. There’s no guarantee that our fine-tuned code model will perform better than the base model we’re fine-tuning—even on our specific problems. The reason for this is that fine-tuning may result in a loss of generality that causes the model to perform worse on all problems—even specific ones from your dataset. If you’re not able to produce a large enough dataset of training content, then it may not be worth it to pursue a personal copilot like this one.

Nonetheless, it is a fun and interesting exercise to demonstrate what’s possible in the Elixir machine-learning ecosystem.

Setting Up

To get started, we’ll fire up a Livebook. For this tutorial, I’ll be fine-tuning the smaller deepseek-coder-1.3b-base model. Despite being “smaller” this model still requires a lot of memory to train. To account for that, I’ll be using Fly GPUs-specifically an 80GB A100 instance running Livebook (shoutout to Chris McCord and Jason Stiebs for helping me get set up). Getting Livebook up and running on the Fly GPU node was surprisingly simple. I followed the official instructions from Fly’s guide on deploying Livebook as well as some GPU-specific instructions provided by Chris located here. Despite this being the first time I had deployed Livebook on Fly, the entire process took me less than 30 minutes to get up and running.

To verify we’re using the GPU and everything is expected, we first need to pull in our dependencies:

Mix.install(
  [
  # Data Collection
  {:tentacat, "~> 2.0"},
  # Machine Learning
  {:bumblebee, github: "elixir-nx/bumblebee", branch: "main"},
  {:axon, github: "elixir-nx/axon", branch: "main", override: true},
  {:polaris, "~> 0.1"},
  {:exla, "~> 0.6", override: true},
  {:nx, "~> 0.6", override: true}
  ],
  system_env: %{
    "XLA_TARGET" => "cuda120"
  }
)

If you’ve worked with any of the Elixir machine-learning libraries before, you should be familiar with most of these libraries. We’ll use Bumblebee to pull in a pre-trained code completion model, Axon and Polaris to do our fine-tuning, and EXLA for acceleration. We’ll also use Nx to do some pre-processing of our input data. The additional library, tentacat, is an Elixir client for the GitHub API. For this tutorial, we’ll be training our model on Elixir code from the Elixir Nx GitHub organization. In order to pull that code in, we’ll be using tentacat.

After your libraries are installed, you can check that you have access to a GPU on your machine by running the following:

platforms = EXLA.Client.get_supported_platforms()

unless :cuda in Map.keys(platforms) do
  raise "no gpu detected"
end

If all goes well, nothing will blow up and you should be ready to fine-tune your model!

Dataset Collection

Just for demonstration purposes, we’ll fine-tune our model on the Elixir Nx repositories. This process can be pretty easily extended to a much larger number of repositories and projects. The more additional data you can fine-tune on the better.

In order to build our dataset, we’ll first list all of the repositories available in the organization using Tentacat:

org = "elixir-nx"
repos_dir = "repos"

client = Tentacat.Client.new()
{200, repos, _} = Tentacat.Repositories.list_orgs(client, org)

Next, we’ll loop through these repositories and clone each available repository using git:

File.mkdir_p!(repos_dir)

Enum.each(repos, fn %{"full_name" => repo} ->
  [_, project] = String.split(repo, "/")
  path = Path.join([repos_dir, project])

  if not File.exists?(path) do
    System.cmd("git", ["clone", "https://github.com/#{repo}", "#{repos_dir}/#{project}"])
  end
end)

We can verify that we’ve pulled all of the repositories we need by running:

File.ls!("repos")

Now we want to create a dataset that consists of project, file_path, contents. This matches the approach taken in the personal copilot article. For now, we’ll do only *.ex and *.exs files, but you could extend this to include other types of files in your projects:

data =
  Path.wildcard(Path.join([repos_dir, "**", "*.{ex,exs}"]))
  |> Enum.map(fn file ->
    ["repos", project, path] = String.split(file, "/", parts: 3)
    contents = File.read!(file)
    %{project: project, path: path, contents: contents}
  end)

After running this, you’ll have a list of maps with file paths, projects, and full file contents for all the Elixir files in the Elixir Nx organization.

Now, we need to create a training dataset. These code completion models are trained with a Fill-In-Middle (FIM) task. With this task, we specify a prefix, then some code, then a “hole” token to be filled, and then a suffix.

At first, you might think these permutations don’t make a ton of sense. Why do we list the prefix, then the suffix, and then the middle? But, the task is a “fill-in-the-middle” task modeled using autoregressive language models. These models generate sequences from left to right, but we can “trick” them into filling in the model by training them to predict the middle given the prefix and suffix first.

Now, in order to build our dataset, we need to implement a preprocessing function that will convert our raw file contents into a proper FIM dataset. To do that, you can use the following code:

defmodule FIM do
  def permute(sample, opts \\ []) do
    opts =
      Keyword.validate!(opts, [
        :prefix_token_id,
        :middle_token_id,
        :suffix_token_id,
        :pad_token_id,
        fim_rate: 0.5,
        fim_spm_rate: 0.5
      ])

    if :rand.uniform() > opts[:fim_rate] do
      [low, high] = Enum.take_random(1..(Nx.axis_size(sample, 0) - 1), 2) |> Enum.sort()

      prefix = sample[0..(low - 1)//1]
      middle = sample[low..(high - 1)//1]
      suffix = sample[high..-1//1]

      Nx.concatenate([
        Nx.tensor([opts[:prefix_token_id], opts[:suffix_token_id]]),
        suffix,
        Nx.tensor([opts[:middle_token_id]]),
        prefix,
        middle
      ])
    else
      pad_token = Nx.tensor([opts[:pad_token_id]])
      Nx.concatenate([sample, pad_token, pad_token, pad_token])
    end
  end
end

This module implements a simple permutation algorithm that takes a tokenized sequence and transforms it into a sequence with an FIM task. We only convert around 50% of the original samples to FIM tasks and leave the rest as is. This means that our fine-tuned model will perform both FIM tasks and simple next-token prediction.

Next, we want to apply this pre-processing to our entire dataset. First, we need to grab the deepseek tokenizer from Huggingface:

{:ok, tokenizer} =
  Bumblebee.load_tokenizer(
    {:hf, "deepseek-ai/deepseek-coder-1.3b-base",
    revision: "e94f2b11bc28abbd67ecadfaad058c30b24a589f"}
  )

Next, we need our “special tokens” to pass to our pre-processing function:

fim_prefix = "<|fim▁begin|>"
fim_middle = "<|fim▁hole|>"
fim_suffix = "<|fim▁end|>"
pad_token = "<|end▁of▁sentence|>"

prefix_token_id = Tokenizers.Tokenizer.token_to_id(tokenizer.tokenizer, fim_prefix)
middle_token_id = Tokenizers.Tokenizer.token_to_id(tokenizer.tokenizer, fim_middle)
suffix_token_id = Tokenizers.Tokenizer.token_to_id(tokenizer.tokenizer, fim_suffix)
pad_token_id = Tokenizers.Tokenizer.token_to_id(tokenizer.tokenizer, pad_token)

Now, we can actually apply our pipeline. One thing to be very particular about is to ensure your resultingdataset is created on the binary back end or the EXLA CPU back end. If you set the EXLA backend as your defaultback end up front, you will likely silently be consuming GPU memory by loading the entire pre-processed dataset onto the GPU. We only want to load this data on the GPU when necessary. You can achieve that either by using Stream and letting Axon/Nx lazily transfer data to the GPU when necessary or you can explicitly set Nx to use the binary back end or the CPU:

max_seq_len = 128
batch_size = 8

train_data =
  data
  |> Stream.flat_map(fn %{contents: contents} ->
    tokenized = Bumblebee.apply_tokenizer(tokenizer, contents)
    tokenized["input_ids"]
    |> case do
      # this will discard some examples, but that's okay
      %{shape: {1, seq}} when seq < max_seq_len ->
        []

      tensor ->
        tensor
        |> Nx.transpose()
        |> Nx.to_batched(max_seq_len - 3, leftover: :discard)
        |> Enum.map(&Nx.squeeze/1)
      end
    end)
  |> Stream.map(
    &FIM.permute(&1,
    prefix_token_id: prefix_token_id,
    middle_token_id: middle_token_id,
    suffix_token_id: suffix_token_id,
    pad_token_id: pad_token_id
    )
  )
  |> Stream.chunk_every(batch_size, batch_size, leftover: :discard)
  |> Stream.map(fn input_ids ->
      batch = Nx.stack(input_ids)
      {%{"input_ids" => batch}, batch}
    end)
  |> Stream.take(200)

Notice that your dataset consists of a tuple of your input sequence and your input sequence (again!). With autoregressive language models, training occurs on the data itself. Our training objective is to predict the next token correctly, so our targets and our inputs are the same (just shifted!).

With your dataset ready, it’s time to train!

Training

First, we need to download our model:

repo = {:hf, "deepseek-ai/deepseek-coder-1.3b-base"}
{:ok, model_info} = Bumblebee.load_model(repo, backend: {EXLA.Backend, client: :cuda})

Next, we need to implement a simple “causal loss” that shifts our inputs and targets “left” and “right” for the next token prediction task. This loss also accounts for padding tokens in the input (which we don’t care about):

defmodule Trainer do
  import Nx.Defn

  defn causal_loss(labels, logits, opts \\ []) do
    opts = keyword!(opts, [:pad_token_id])
    # shift logits left and labels right
    labels = labels[[.., 1..-1//1]]
    logits = logits[[.., 0..-2//1, ..]]

    padding_mask = Nx.equal(labels, opts[:pad_token_id])

    Nx.select(padding_mask, 0.0,
    Axon.Losses.categorical_cross_entropy(Nx.new_axis(labels, -1), logits,
      from_logits: true,
      sparse: true
      )
    )
    |> Nx.mean()
  end
end

Next, we can define our training loop. For this example, we’re going to make use of mixed precision training to save some memory. Mixed precision training conducts training in FP16 or BF16 rather than F32. We can use mixed precision in Axon pretty simply with the Axon.MixedPrecision module:

%{model: model, params: params} = model_info

model = Axon.nx(model, & &1.logits)

bf16 = {:bf, 16}
mp_policy = Axon.MixedPrecision.create_policy(params: bf16, compute: bf16, output: bf16)
mp_model = Axon.MixedPrecision.apply_policy(model, mp_policy)

This example takes our original model, extracts the logits output from it, and applies a mixed precision policy to do everything in {:bf, 16}. Now, we can define our training loop and execute it:

loss = &Trainer.causal_loss(&1, &2, pad_token_id: pad_token_id)

optimizer = :sgd

trained_model_state =
  mp_model
  |> Axon.Loop.trainer(loss, optimizer, gradient_accumulation_steps: 1, log: 1)
  |> Axon.Loop.run(train_data, params, epochs: 1, compiler: EXLA)

Once training completes—which may take some time depending on the size of your dataset—you can wrap your new trained model parameters in a generation serving and try it out:

model_info = %{model_info | params: trained_model_state}
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config, compiler: EXLA)

prompt = "<|fim▁begin|>defn softmax(x) do\n <|fim▁hole|> \n <|fim▁end|> "
Nx.Serving.run(serving, prompt)

Conclusion

This was a pretty quick and dirty example, and there are a lot of improvements to make. For example, we can train even larger models using LoRA and the Lorax library by Ted Wang. Additionally, we probably want to train a model with a larger sequence length and on more data.

However, the purpose of this tutorial was just to get our feet wet with fine-tuning in Axon and Bumblebee. Finally, there are a lot of efforts happening in the code completion space specifically for use in Elixir. I highly encourage you to join the Elixir ML working group to get involved in these efforts. Until next time!

Newsletter

Stay in the Know

Get the latest news and insights on Elixir, Phoenix, machine learning, product strategy, and more—delivered straight to your inbox.

Narwin holding a press release sheet while opening the DockYard brand kit box