Next-Gen Machine Learning with FLAME and Nx: Beyond Serverless Solutions

Flames and sparks against a dark background
Sean Moriarity

Machine Learning Advisor

Sean Moriarity

Introduction

Recently, Chris McCord and the Phoenix team released FLAME, a new library meant to replace the serverless paradigm. FLAME is a new library and paradigm for addressing elastic workloads. While the original FLAME library is implemented in Elixir targeting Fly infrastructure, the idea extends to other languages and cloud service providers. But, how does FLAME compare to serverless?

The promise of serverless is scalability, flexibility, and agility with potentially reduced costs compared to server-centric applications. Serverless really shines when your application has particularly elastic workloads. An elastic workload is one that has sudden spikes in demand for resources. In a server-centric world, you would potentially need to keep the resources required to meet the demands of your elastic workloads around to account for these spikes in demand. Serverless promises elasticity—your serverless functions can scale up to meet the peak demands of your workloads, and scale to zero during dormant times. This also implies some cost savings. Serverless, like Liberty Mutual Insurance and Verizon Wireless, has the motto: “Only pay for what you need”.

Unfortunately, the benefits of serverless can quickly become overshadowed by drawbacks. Serverless can very quickly become more expensive than if you were simply managing your own server. Additionally, as your business needs evolve, serverless applications can quickly become extremely complex. Because serverless functions are isolated from other serverless functions, you are often required to introduce additional pieces of infrastructure for state management, communication, and orchestration. In these cases, serverless itself isn’t necessarily the problem. The problem is that serverless makes it really easy to design unnecessarily complex and expensive applications.

FLAME is a different paradigm for addressing elastic workloads. In Chris’ own words from Rethinking Serverless with FLAME:

“With FLAME, you treat your entire application as a lambda, where modular parts can be executed on short-lived infrastructure. No rewrites. No bespoke runtimes. No outrageous layers of complexity. Need to insert the results of an expensive operation to the database? PubSub broadcast the result of some expensive work? No problem! It’s your whole app so of course you can do it.”

FLAME is enabled by the decades-old power of the BEAM, and has the promise to replace traditional serverless workflows. In this post, we’ll focus on the specific applications of FLAME on machine learning workflows. We’ll look at some traditional applications of serverless and machine learning, and how we might be able to use FLAME and Nx as a replacement.

Serverless Inference

The most obvious application of Serverless in machine learning applications is for inference. GPU time is expensive, and if you don’t always have demand for running models on GPUs, it doesn’t make sense to pay for time you don’t use. Within the serverless paradigm, you can deploy a model behind a function that encapsulates an end-to-end machine-learning task and deploy it as a serverless function. As demand for this function increases, your serverless function is able to scale up to meet the demand.

With FLAME and Nx, we can build inference pipelines for elastic workloads that are embedded directly into our applications. There are two types of inference workloads to consider: online and offline. Online inference workloads, or real-time inference workloads, happen on demand. As inputs come in, you pass them to a model to get predictions, and then use them for some application. Real-time inference workloads are what abstractions such as Nx.Serving attempt to address. Real-time workloads are short-lived (relatively speaking) requests.

Offline inference workloads, or batch inference workloads, happen periodically offline. Offline workloads are run periodically to compute predictions, and then those predictions are retrieved from some cache or store.

A simple example contrasting offline vs. online workloads today is the classic vector-retrieval use case. Consider you have a set of documents you want to embed and search over using vector-search. First, you would compute vector embeddings of all of your documents up front with an offline inference workload and store them in a vector database for retrieval later on. Then, as users need to retrieve documents, you would compute embeddings of queries online or in real-time and query the vector database for similar documents. As you update your models or add documents to the corpus, you periodically need to kick off more embedding jobs to update your vector database.

For this application, we can also see contrasting use-cases in using FLAME vs. Nx.Serving. First, let’s consider we want our online use case to NOT run with FLAME. Later on, we’ll discuss the tradeoffs between using FLAME vs. a plain Nx.Serving on a GPU-enabled machine. For now, let’s pretend we only want our offline embedding jobs to run in FLAME. We can start our Nx.Serving in the normal way:

  children = [
    {Nx.Serving,
      name: MyApp.OnlineEmbedding,
      serving: MyApp.OnlineEmbedding.serving()}
  ]

Now, for our offline embedding use case, we don’t necessarily want to launch a job to embed our documents in the same server where we are serving web requests. Instead, we can use FLAME to launch a job on a separate machine. For example, we can start a FLAME pool in our application:

  children = [
    ...
    {Nx.Serving,
      name: MyApp.OnlineEmbedding,
      serving: MyApp.OnlineEmbedding.serving()},
    {FLAME.Pool,
      name: MyApp.OfflineEmbedding,
      min: 0,
      max: 1,
      backend: {FLAME.FlyBackend, gpu_kind: "a100-pcie-40gb", cpu_kind: "performance", cpus: 8, memory_mb: 20480}
    }
  ]

Then we can use FLAME.call/2 to start our embedding jobs:

FLAME.call(MyApp.OfflineEmbedding, &MyApp.OfflineEmbedding.run/0)

Our embedding job can access parts of our application transparently because the job is running in a copy of our application. In another world, we might have a separate Python microservice/script that kicks off these jobs and interacts with our application via a REST API or some other means. With FLAME, we can interact with our vector database through the same Elixir API we use in production. Since our application already starts an Nx.Serving, we can use the Nx.Serving inference APIs as well!

defmodule MyApp.OfflineEmbedding do

  def run() do
    batches_to_embed = MyApp.Documents.get()

    Enum.each(batches_to_embed, fn batch ->
      embedded = Nx.Serving.batched_run({:local, MyApp.OnlineEmbedding}, to_embed)

      MyApp.VectorDatabase.put(batch, embedded)
    end)
  end
end

The only stipulation when using a serving inside a FLAME call is that we need to specify {:local, ServingName} to force the serving to run in a non-distributed manner. Otherwise, the API is exactly the same! Now our offline embedding jobs will launch on a separate GPU-enabled machine and not hog resources we need for online inference. You can imagine how we can easily combine this with Oban to periodically run embedding jobs after a certain period of time, or when the number of documents available to embed reaches some threshold. Using Flame.Parent.get/0, we can also configure our Nx.Serving to start with different batch sizes:

  flame_parent = Flame.Parent.get()
  batch_size = if flame_parent, do: 64, else: 8

  children = [
    ...
    {Nx.Serving,
      name: MyApp.OnlineEmbedding,
      serving: MyApp.OnlineEmbedding.serving(),
      batch_size: batch_size}
    ...
  ]

FLAME also has applications for online machine learning workloads; however, there are some tradeoffs to consider vs. a monolithic approach where the model runs on the same machine as your server. If you want your online workloads to run with FLAME, you first need to consider the cold-start time for your application. With Nx.Serving, there is an upfront model compilation time that your application takes on every start up. That means if you have a lot of cold starts, those requests will take a long time to run. FLAME does have some configuration options to keep the application hot for a period of time; however, it can be difficult to predict how long that interval should be.

If you do decide to use FLAME for online machine learning workloads, you’ll at least want to make sure your model artifacts are baked into your build container. This won’t allow you to overcome model compilation times; however, it would save model download time. At the moment we’re working on a solution that allows you to serialize compiled executables and load them from disk, which should seriously improve cold start times for Nx applications.

Serverless Training

Another promising use case of FLAME in machine learning is executing training runs. For continuously training and updating models, you might imagine a world where you have additional infrastructure to manage some sort of ETL pipeline that aggregates data for training and then kicks off a training run using the aggregated data. In the Python ecosystem, this might involve using Airflow or some other DAG runner. With FLAME, we can build something that lives directly in our application. For example, I might have some sort of model behavior that defines both the inference and training pipelines for some model:

defmodule MyApp.EmbeddingModel do
  
  def predict(data) do
    Nx.Serving.batched_run(MyApp.EmbeddingModel, data)
  end

  def train(config, data) do
    model = MyApp.EmbeddingModel.model()
    
    trained =
      model
      |> Axon.Loop.trainer(config.optimizer, config.loss)
      |> Axon.Loop.run(%{}, epochs: config.epochs, compiler: EXLA)

    save_to_s3(trained)
  end

  defp model do
    ...
  end
end

And then I can launch my training run with FLAME.call/0:

FLAME.call(MyApp.TrainingPool, fn ->
  config = %{
    optimizer: Axon.Optimizers.sgd(1.0e-3),
    loss: :mean_squared_error,
    epochs: 3
  }

  data = MyApp.Data.get()
  MyApp.EmbeddingModel.train(config, data)
)

If you have the resources, you can imagine a world where we use FLAME to kick off multiple training runs at once to grid search across a set of hyperparameters:

configs = [config1, config2, config3]
data = MyApp.Data.get()

Task.async_stream(configs, &FLAME.call(MyApp.TrainingPool, fn -> train(&1, data) end))

You can even imagine a “Phoenix Live Dashboard-like” environment for managing models and training in your application. Additionally, because FLAME is running a copy of your app, you can broadcast events and such using PubSub and monitor training from your actual application. For example, I might monitor loss using Axon event handlers and pubsub:

    model
    |> Axon.Loop.trainer(config.optimizer, config.loss)
    |> Axon.Loop.handle_event(:iteration_completed, fn state ->
      MyApp.PubSub.broadcast({:loss, state.metrics.loss})
    end)
    |> Axon.Loop.run(%{}, epochs: config.epochs, compiler: EXLA)

And then I can handle those events in my application to render a plot of training loss over time. We can also extend this to evaluation pipelines, and use hot-code reloading or some other means to dynamically promote newly trained models to production.

Conclusion

The potential applications of FLAME and Nx for machine learning workloads are exciting. FLAME is still a new development, but I believe there is potential to provide tooling for machine learning practitioners that changes the way people build machine learning applications. We are just getting started! Until next time!

Newsletter

Stay in the Know

Get the latest news and insights on Elixir, Phoenix, machine learning, product strategy, and more—delivered straight to your inbox.

Narwin holding a press release sheet while opening the DockYard brand kit box