Load Testing a Machine Learning Model API

Deploying a Machine Learning (ML) model as a live service to be consumed by a business-critical system or directly by end-users can be a scary prospect. This post looks at how you can perform load testing on your model APIs to ensure they can stand up to even the highest-demand situations.

Load Testing a Machine Learning Model API

Deploying with confidence

Deploying a Machine Learning (ML) model as a live service to be consumed by a business-critical system or directly by end-users can be a scary prospect. The more in-demand the system, the higher the risk that issues with your model may disrupt it. However, you also can't keep your models stuck in a notebook either. They need to be out there generating value for your customers and your business.

There are a number of ways to allay these anxieties. One approach is (on paper) straightforward: have a robust testing strategy. At the end of the day, an effective testing strategy can help you establish a higher degree of confidence in your ML system. Indeed, much like any other software system, you should regard well thought-out testing strategy as an essential prerequisite for the deployment of any ML system.

However, the testing strategies and techniques you employ will vary depending on the context in which your system is being developed and deployed. One technique particularly important for software systems deployed as some form of web service (e.g. a REST API) is the idea of load testing. And, if you hadn't guessed by the title of this page, that's the subject of this post: load testing ML model APIs.

The aim here is to get you familiar with the motivations for and concepts behind load testing, to review some important design considerations that load testing can be used to explore, and to provide a short example of how you could go about load testing a Scikit-Learn model deployed as a basic web app (complete with a template repository).

What is Load Testing?

Let's start from a general software engineering viewpoint. Most modern websites have at least a handful of dependencies on external web services. These services can range from notification managers, to analytics tools, to core infrastructure like authentication and payment services.

When changes are made to these services, it's important to ensure that the system is not disrupted. This is often particularly important for high-traffic platforms where service downtime or disruption can lead to material financial loss. When you deploy a new service (or update an existing one), you want to be confident your service can withstand the demands of 'the real world'.

This is where load testing can come in handy. Load testing is a form of non-functional testing: it's aimed at testing system behaviours, and not the specifics of the functionality of the system. In other words, the aim of load testing is to explore how your service responds — as the name suggests — under load. Specifically, it can be used to characterise your system at varying loads and under different operating conditions — including different types of users and user behaviours. Depending on the framework or tool you use, you can typically exert fine-grained control over these conditions and behaviours, and tailor your own testing to reflect the dynamics of interactions specific to your services.

The aim here is to demonstrate that under the conditions you specify, your service can sustain desired performance characteristics. But what performance characteristics, you ask? In general, load testing focusses on two key metrics: latency and throughput. Let's define these:

  • Latency - How long it takes for the service to respond. It's commonly referred to as response time, too. It's typically given in milliseconds (ms). A service with very low latency can be described as being highly responsive.
  • Throughput - The total number of requests a service can handle in a given time frame. For example, requests-per-second (RPS) is a commonly used metric for quantifying throughput characteristics.

Furthermore, as with any engineering problem, characterising a system's limitations is also important. Aside from being useful for establishing that a service can support expected demand and performance constraints, load testing can (and should) be used to understand the upper limits of your service. This can be helpful in establishing under what conditions your service begins to degrade (and eventually, fail altogether). With this sort of information in hand, you can more confidently deploy your service 'into production', and also understand how it is likely to perform as your service scales — both in the present and over time as demand for your service (hopefully) grows.

How does this relate to ML?

So far, so good. Let's map this back to a concrete ML problem. Say you're a Data Scientist at a business that runs an online marketplace. You've been tasked with creating a pricing model to help users of the marketplace price the products they're selling more competitively. You've built your model, performed various evaluation tasks, and now you're ready to deploy. After thinking through your options, you arrive at the decision to deploy your model as a simple HTTP web service for your engineering teams (and internal services!) to use in other features. Conveniently for you, this is exactly the situation this post is written to address.

As with any similar service, one of the first questions you need to ask yourself is: 'is this service responsive enough to ensure good user experience?'. This can be broken down in multiple ways, but some natural questions flowing from this include:

  • Does your service have sufficient 'physical' resources (CPU, RAM etc.) to support your expected (and 'surge') usage?
  • Are there any bottlenecks in your system design? Think carefully about other services you depend on, particularly with respect to any rate limiting or concurrency limits that may kick in under load.
  • Are there any behaviours or request-types that are less responsive? Think about how your model performs inference, and whether it's worth adding additional constraints to limit the types of requests users can make in the interests of guaranteeing performance.

At this point, it's also worth recognising that it's important not just to consider your 'happy path' (i.e. expected 'normal' behaviour). For example, you should make sure to ask the above questions in the context of errors: what happens to your system if you add a bunch of invalid queries into the mix? Does your service suddenly hang and block further requests when it receives an invalid query? Or do you quickly and gracefully handle errors?

Ensuring your load testing strategy can answer questions such as this is important. It can pay off to think adversarially about how you can go about testing the design and implementation of your system (including ML aspects) — not just in a load testing context either!

Load testing in practice

Hopefully you're now sold on the utility of load testing. Perhaps you're thinking through how this relates to your own work. Perhaps you already were thinking this way about your ML projects, in which case: great job!

As you might be aware, there are a wide variety of tools out there for software testing, and the sub-field of load testing is no different. However, the user-friendliness of these tools is... variable. One of the more user-friendly tools out there — particularly for most Data Scientists — is Locust. Locust is a load testing tool and framework written in and controlled using Python.

Practically, this means you can define your tests (including defining simulated customer behaviour) in pure Python. You write your user behaviours in Python, specify the number of users you'd like to simulate, and you're off! Plus, the chances are that as a Data Scientist, you're at least familiar with Python, even if you're not a 'power user'. My own experience suggests that Locust has one of the lowest barriers to entry for this class of tools in practice, with the associated adoption and compliance benefits this can bring to Data Science/ML teams.

It's not just easy to use, either. It can scale to massive applications too. For example, it is reportedly used by the folks over at DICE (of the Battlefield video game series fame) to load test their servers, as well as by various teams at Google and Microsoft.

Beyond the ease-of-use afforded by it's Pythonic design, Locust also provides several useful additional features. For one, it allows you to export the statistics of your test for further analysis or archival purposes in CSV format. This can be a handy low-tech way of tracking load testing experiments.

For another, it has a simple but informative user interface (UI) too. Out of the box, the UI allows you to configure the number of users you'd like to simulate, point to the host you'd like to test, and set the rate at which you'd like to add users to the simulation and fire up your test suite in no time at all.

When running, the UI gives you a few tabs that provide you with a range of information, from tabulated statistics, 'real-time' charts of latency, throughput and user count, to detailed information on failures and errors the testing uncovers. For the minimal case of a service with a single route (/), the top-level tabulated view of statistics looks like this:

A minimal example of the Locust dashboard for an API with a single endpoint.

As you can see, you get information on the number of requests per second (RPS) the service is currently handling (i.e. a top-level throughput statistic). You also get statistics on the average, median, minimum and maximum latency too (all quoted in milliseconds). When considered across each of the routes under test this information can give you a surprisingly comprehensive view of the behaviours of your system under load – exactly what you're after. There's also some nice 'real-time' charts showing this information too:

An example of the real-time monitoring dashboard provided by Locust.

And that's a 30,000ft view of Locust. Time to use it with a concrete example!

Getting set up

Time for the code. The rest of this post is going to show you how to configure a basic load testing setup that demonstrates some of the ideas outlined above for a minimal ML API on your local machine. The example project I've put together includes some utilities for deploying to your favourite cloud provider, and the general process of load testing an API is the same for a deployed service too, though the actual process of deploying the setup is beyond the scope of this article. That said, the process is very similar to that laid out in my earlier post, over here:

Flask in Production: Minimal Web APIs
Flask is a popular ‘micro-framework’ for building web APIs in Python. However, getting a Flask API ‘into production’ can be a little tricky for newcomers. This post provides a minimal template project for a Flask API, and gives some tips on how to build out basic production Flask APIs.

First off, the basics. I've decided to go with Flask for this tutorial, as it has (at least in my experience) a great community, is feature-rich and easy to use. Plus, it's used in a huge variety of projects out there, so having a good grasp of Flask may be more generally useful than other web frameworks in Python.

That said, if you're looking for a Flask alternative, you would do well to take a look at FastAPI, a newer 'micro web framework' that fills a similar niche to Flask, but has been designed with a focus on speed. By extension, if the responsiveness of an API is paramount, that may be the better place to start. Otherwise, Flask is likely a strong bet!

FastAPI
FastAPI framework, high performance, easy to learn, fast to code, ready for production

For this walkthrough, you'll need to familiarise yourself with the example project repository:

markdouthwaite/sklearn-flask-api-demo
This repository provides example code for deploying a Scikit-Learn model as a containerised Flask app with Gunicorn. It includes example pytest & locust testing files. - markdouthwaite/sklearn-...

To use this repository, you should fork the repository (you could also create your own version from the template) and clone it into your working environment. You'll also need to ensure you are running Python 3.7 or greater to run the code. If you do create your project from the template, make sure to check the Include all branches box before doing so!

All set? Awesome. Time to install some dependencies. To do this, navigate into your newly downloaded project and run the following command:

make develop

This will install both the requirements for your model API and development dependencies including test packages and styling tools too (such as the excellent black and isort packages).

What is in the box?

The obvious next question is: what have you just downloaded? Thanks for asking. As previously mentioned, the structure of this project is largely the same as that defined in the vanilla Flask project it is derived from. I'd advise you check out the blog post running through that, if you haven't already:

Flask in Production: Minimal Web APIs
Flask is a popular ‘micro-framework’ for building web APIs in Python. However, getting a Flask API ‘into production’ can be a little tricky for newcomers. This post provides a minimal template project for a Flask API, and gives some tips on how to build out basic production Flask APIs.

There's a couple of key distinctions between the 'vanilla' Flask project above and this scikit-learn flavoured one we're looking at here. These distinctions are primarily the addition of following directories:

  • data - For this example, this is where any data (including intermediate artefacts like models and model metrics) will be stored. By default, it stores a data set (heart-disease.csv, from Kaggle) and an example payload.json providing an illustrative request payload for the API you'll be deploying. In practice, you should consider storing all data and artefacts using a dedicated cloud storage service (like Google Cloud Storage or AWS S3).
  • steps - This is where any (automated) steps required to train, evaluate and monitor models should sit. In practice, this directory may therefore contain several Python scripts, each training, evaluating and validating your models and outputs. In this example, however, it simply contains a single train.py script that trains a simple scikit-learn classifier pipeline.

You may also notice that the Flask application code is slightly different too, and is quite similar to a related tutorial I've put together previously:

markdouthwaite/serverless-scikit-learn-demo
A repository providing demo code for deploying a lightweight Scikit-Learn based ML pipeline modelling heart disease data as a Google Cloud Function. - markdouthwaite/serverless-scikit-learn-demo

Next, you'll need to create a pickled model artefact that your API can load and expose. To do this, simply run:

python steps/train.py --path=data/heart-disease.csv 

This will train a Scikit-Learn Pipeline on the UCI Heart Disease dataset, pickle it and and export it to the data directory (along with some important metadata). Feel free to take a deeper dive into the steps/train.py file for a better look at what's going on here.

Starting the model server

Before continuing, it's worth checking your model server is properly configured and runnable. To run the model server (to expose your API) locally, you can simply run:

make start

In the top-level directory of the project. This will create a gunicorn server running your Flask app. You should then be able to send the query:

curl http://localhost:8080/health

And see the response:

OK

With status code 200. To query the model directly, you can use the data in the data/payload.json example and send:

curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' -d @data/payload.json

You should see a response similar to:

{"diagnosis":"heart-disease"}

This means your model is live. You can leave the model server running for now as you'll need it active to run Locust against.

Now to write a load test or two...

Simulating user behaviour

As you saw earlier, the main objective of load testing isn't to evaluate the correctness (or in an ML context: accuracy/other metric of your choice) of the API per se. It's to explore how your API behaves under different loads, to understand if, when, how and why an API begins to degrade under said load, and ultimately to deploy your API 'into production' armed with this information.

Practically, the first step is to define how you expect users to interact with your service. For an API with a single endpoint, this may be only a single type of interaction. In fact, it should be a single type of interaction (e.g. querying a single model). This will be the case in this example. However, if you develop a more complete API, say a full-blown REST API with many endpoints, the number of possible user interactions may be much larger.

Defining user behaviours programmatically is very simple in the world of Locust. The framework provides a User class which you can use to build out your behaviours with. Furthermore, as this is built on pure Python it ensures you can define users with very complex behaviours, if you so wish. Here's an example of a minimal Locust User:

from locust import HttpUser, task

class SimpleUser(HttpUser):
	@task
    def get_health(self):
    	self.client.get("/health")

What do we have here? In this example, you have a single user that inherits from the provided HttpUser class and defines a single task behaviour, in this case get_health. Here's a breakdown of what these things mean:

  • HttpUser - The HttpUser class is a special subclass of Locust User that provides access to the client attribute. This client attribute in turn provides a HTTP client you can use to query your API (provided it exposes a HTTP interface, of course!). Notice that nowhere in this file is there a reference to the base URL for your API. This is provided to Locust either at runtime via the Locust CLI or from within the Locust web UI. We'll take a look at that later.
  • get_health - This method defines a user behaviour you'd like to simulate for your API. In this case, simply a GET request to the /health route. In other words, check the GET /health endpoint responds with a non-error status code. The client attribute exposes a fully-featured HTTP client, so you have the ability of sending arbitrary HTTP requests to your chosen service. You can make multiple calls with the client in each method, meaning you can construct arbitrarily complicated user flows if you need to.
  • task - This decorator marks the method it is associated with as a task. This tells Locust to execute this method when a User is initialised during the test. You can have as many methods marked with this decorator as you like, so a single User can define multiple tasks. Practically, this allows you to build out still more complex user behaviours, as you might be able to see.

If you look in the tests/load/locustfiles/api.py file you'll see this example is similar to what is provided there. However, you might notice a couple of distinctions between the minimal case presented here and the one in this file. For one thing, there's a wait_time attribute set to between(1, 5) set on this class. This ensures that Locust will wait a random duration between one second and five seconds before a User makes a follow-up action (or in Locust terms: executes another task). Additionally, given it is a class variable, this is configured for all DefaultUser classes initialised by Locust.

A brief warning

Before continuing to a slightly more advanced case, I'd like to raise an important conceptual point that (in my experience) is particularly relevant to how you should go about setting up your Locust tests for ML APIs. Locust initialises a single instance of each User you define for each user being simulated. In other words, if you have a single User defined, and you load test simulates 100 concurrent users, Locust will create 100 instances of this class.

This has implications for how you create your custom User classes. For example, when defining user behaviours for testing a ML model API, it might be tempting to load historical data as sample payloads to test your API with. This may in turn involve loading CSVs as pandas Data Frames for each user, potentially resulting in a dataset being initialised for each of those hundred users. Clearly, you'll want to think carefully about whether this is something you feel is appropriate for your problem. There are ways of effectively testing your APIs without loading large amounts of data into memory all at once and (potentially) crashing your machine/s!

Let's look at one option now.

Querying your model

Okay. You know the basics of what makes up a Locust User. Let's dig a bit deeper. The simplest case is to send a fixed payload to query your model API. If you recall, there's a sample payload in the data/payload.json file. Here's what that might look like if we create a custom user based on that data structure:

class FixedPayloadUser(HttpUser):
    payload = {
        "sex": 0,
        "cp": 1,
        "restecg": 0,
        ...
    }

    @task
    def get_diagnosis_with_valid_payload(self):
        self.client.post("/predict", json=self.payload)

This User would then be sending a fixed payload (defined by the class variable payload) to the POST /predict endpoint. Simple enough, right? This is of course better than nothing, and in many cases may well be adequate. However, the utility of a test such as this will clearly be defeated by any caching layers you may have in front of your API, and the payload chosen may not quite stretch your API as much as you'd like.

That's where randomised payloads can sometimes come in handy. Again, we're not particularly interested in notions of 'absolute model accuracy', but rather the service's throughput and resilience, so it doesn't typically matter too much if the randomised payloads are not 'physically plausible' (i.e. realistically-valued). Additionally, if your API requires queries to be within certain ranges, the same approach can be used with small modifications to ensure your server-side validation works as you expect. Here's an example of how you might go about setting up randomised queries for your API:

import random
from dataclasses import dataclass, asdict

random.seed(42)

@dataclass
class Payload:
    sex: int = random.randint(0, 1)
    cp: int = random.randint(0, 3)
    restecg: int = random.randint(0, 1)
	... # other attributes here.
    
    
def iter_random_payloads() -> dict:
    while True:
        yield asdict(Payload())

We can then define our user behaviour with:

from locust import HttpUser, between, task


class RandomizedPayloadUser(HttpUser):

    # wait between requests from one user for between 1 and 5 seconds.
    wait_time = between(1, 5)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.payloads = iter_random_payloads()

    @task
    def get_diagnosis_with_valid_payload(self):
        payload = next(self.payloads)
        self.client.post("/predict", json=asdict(payload))

This approach has a couple of benefits. As you may see, the iter_random_payloads function is a Python generator that will produce infinitely many randomised Payload objects. I'm not going to dig into the advantages and disadvantages of generators in this post, but suffice it to say that generators can help you to create memory-efficient iterators. In this case specifically, each instance of a RandomizedPayloadUser will have a single instance of a Payload in memory at once.

As alluded to earlier, this is an important consideration when you have hundreds (or even many thousands) of these users initialised as it can eliminate the need for initialising large numbers of potentially very large payloads at once (e.g. as you would if you used a common Python list rather than a generator). This can be particularly helpful in applications where you're working with images or similar large numerical arrays, for example.

At this point, I recommend you copy the above snippets (for the randomised case) into a new file on the path tests/load/locustfiles/randomized.py. When you've done this, you can boot up Locust with:

locust -f tests/load/locustfiles/randomized.py 

Open a browser to http://localhost:8089 (or follow the link in the terminal). You'll be prompted to provide a total number of users to simulate (i.e. the maximum number of concurrent users you'd like to run), the spawn rate (i.e. how quickly you'd like to ramp up the addition of new users from launch), and finally the host for your API. As this example is running locally, you can set this to http://localhost:8080 (assuming you used the make start command from earlier). Hit Start swarming and you'll be redirected to the Locust dashboard. After a few moments, you should see something that looks a little like this:

An example of a Locust dashboard for a live test against a locally-running model API.

Getting adversarial

Great! You've got a simple test running against your model API! For the rest of this post, you'll want to make sure you've committed any changes you've made so far and then you should checkout the locust-tests branch of the repository. All done? Let's continue.

Part of the objective of load testing is to characterise the behaviour of your service/s under expected load. If you're of an engineering mindset, you'll recognise that this includes what happens when some percentage of traffic is sending erroneous requests, or otherwise behaving in ways you didn't expect. Practically, it can be very helpful to think adversarially about how users may interact with your service and test accordingly.

Practically, it can be very helpful to think adversarially about how users may interact with your service and test accordingly.

With this in mind, you could define a random payload generator in a manner similar to that shown previously, but this time you could modify the resulting payload to omit required fields, for example. This might look a bit like this:

def iter_random_invalid_payloads() -> dict:
    fields = list(asdict(Payload()).keys())
    while True:
        field = random.choice(fields)
        payload = asdict(Payload())
        # missing field
        del payload[field]
        yield payload

Clearly, you'd then expect malformed payloads to produce some type of error in your API. In this case specifically, you'd likely expect this to return a 4xx error of some kind (i.e. a client error). For the API in this project (on the locust-tests branch, at least), errors of this type are configured to return status code 400 (bad request). Locust lets you inspect the responses from your API to ensure they're handled as you expect. To incorporate this behaviour into the previously defined RandomizedPayloadUser user, you could have:

class RandomizedPayloadUser(HttpUser):

    # wait between requests from one user for between 1 and 5 seconds.
    wait_time = between(1, 5)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.valid_payloads = iter_random_valid_payloads()
        self.invalid_payloads = iter_random_invalid_payloads()

    @task(4)
    def get_diagnosis_with_valid_payload(self):
        self.client.post("/predict", json=next(self.valid_payloads))

    @task(1)
    def get_diagnosis_with_invalid_payload(self):
        payload = next(self.invalid_payloads)
        with self.client.post(
            "/predict", json=payload, catch_response=True
        ) as response:
            if response.status_code != 400:
                response.failure(
                    f"Wrong response. Expected status code 400, "
                    f"got {response.status_code}"
                )
            else:
                response.success()

Where we've added the get_diagnosis_with_invalid_payload method as a task to the User. You might also notice that both of the task decorators now have an additional argument. In this instance, this tells Locust the ratio of calls to each method it should be making, so in this case it'll make four times as many calls to get_diagnosis_with_valid_payload as it does to get_diagnosis_with_invalid_payload. This allows you to 'tune' your User behaviours to moderate the number of erroneous calls made to your API.

You'll also see that the new method inspects the response to the call, and ensures that invalid payloads return the correct response code. This can be useful in ensuring your API continues to respond to errors as you expect it to. For example, failure to consistently respond with expected status codes can be a sign your API isn't handling errors as gracefully as you expect, and could be a sign of serious degradation of your service. This is all important information to monitor.

I've added the above code examples (plus some small tweaks to error handling on the app-side) to the locust-tests branch. Feel free to checkout that branch and have a play!

With that, you've got a basic example of using Locust for load testing a model API. As you may be able to see, the opportunity for extending Locust tests is extensive, and you'll certainly need to think carefully about how to use Locust in your own projects. If you're interested in discovering more about how you can configure and extend Locust, I'd recommend checking out the Locust docs:

Writing a locustfile — Locust 1.4.3 documentation

Tracking results

Before wrapping up, I'd like to draw your attention to one more important aspect of load testing, and that is how to store and use results. With load tests, it can be useful to track results over time. For example, this might help you detect how your service limits (say: the maximum number of concurrent users your service can support before degrading) varies with time. Locust lets you write testings outputs in CSV format to a path of your choice using the following:

locust -f tests/load/locustfiles/api.py --csv data/example 

In this case the results of the load test will be written into the data directory. You'll see at least four files appear:

  • example_stats.csv - In this file, you'll find top-level aggregated statistics for each endpoint your User calls. This file is useful for getting information about the responsiveness of your service overall
  • example_stats_history.csv - This is a full log of every timestamp sampled by Locust. In other words, it's a time series of the statistics of your test aggregated over all Users over the duration of the test.
  • example_failures.csv - This file contains a summary of data associated with any failures that occurred during the test and the number of times they occurred. Failures are responses that returned error codes unless you intercept error codes and mark them as success (for use cases similar to that outlined above).
  • example_exceptions.csv - This file contains summary information of uncaught exceptions your test ran into. This can be useful for debugging your testing setup.

It is worth considering using some form of unique identifier when naming your test outputs to ensure it's clear when the results were generated and under what conditions. I'll leave this up to you to decide what's appropriate for your use-case, however.

Finally, while simply recording results is a good start, it can be useful to ensure that the latest results are displayed prominently in a project so team members or other users/developers know what capabilities to expect when deploying your API or making changes. One option might be to write a small script to update your README.md (or other documentation) to reflect the latest results of your testing. Another more advanced option might be to export results to be tracked in a database, and to write some diagnostic notebooks to analyse performance over time. The choice is yours, but whatever the case, be aware that test results buried in your own head or in some dark and dusty repository somewhere are of very limited use to others and that is a vital consideration in production software. Engineering is a team sport, after all.

Until next time!

That's it for this post. I hope it helps you get set up with Locust, and ultimately helps you build better, more reliable software. As always, I love feedback, so if I've missed anything or you have any further questions or issues, let me know!