The Best Way to Plot a Precision-Recall Curve in Python

In this tutorial, I’ll show you how to plot a precision-recall curve in Python using Seaborn.

Specifically, I’ll show you how to use the new Seaborn Objects package to plot it.

The article is divided into several different sections to make it easier to navigate. So if you need something specific, just click on any of the links below.

Table of Contents:

I’ll show you the actual example further down in the tutorial, but we’ll start with a quick review of what a precision-recall curve is.

Let’s jump in.

A Quick Review of Precision-Recall Curves

In machine learning, one of the most important tasks is classification.

In classification, we categorize incoming data examples into one of several pre-defined groups, which are called “classes” or “labels.”

There are fairly simple classification systems, like binary classification, where there are only 2 classes. And there are more complicated forms of classification, like multi-class classification, where there are more than 2 classes.

But ultimately, classification is about predicting which class an example belongs to, from a pre-defined list of possible classes.

Classification Evaluation is Important When You Build a Classifier

Building a classifier can be somewhat challenging.

There are multiple steps, like getting your data, cleaning the data, analyzing the data, preprocessing the data, choosing the right classification algorithm (like logistic regression, decision trees, etc).

But once you have your classification system, you need to evaluate how it is performing.

Evaluation is critical.

You need to know how well your classifier is performing, and you need to be able to optimize it given the specific needs of your project.

That’s where the precision-recall curve comes in.

A Precision-Recall Curve is a Tool to Evaluate Classifiers

Precision-Recall curves are an important tool for classifier evaluation.

We typically use them to evaluate binary classifiers, although it’s possible to use precision-recall curves with multiclass classification too.

To explain them simply: precision-recall curves visualize the precision and recall for a classifier across a range of classification thresholds.

An image that shows a precision recall curve in Python, and explains that a precision recall curve plots precision vs recall for a range of classification thresholds.

It’s important to emphasize the part about the “range of thresholds.”

You need to remember that many classification algorithms, and in particular, probabilistic classification algorithms, produce what’s called a “probability score.” The probability score measures the confidence that the classifier has that an incoming example belongs to the positive class.

In classification systems that produce such a score, the system compares the score to a “threshold” … a cutoff point for the score.

If the probability score is above the threshold, then the incoming example is predicted as positive.

If the probability score is below the threshold, then the incoming example is predicted as negative.

So the threshold acts like a cutoff point for the system to decide when an incoming example should be predicted to be positive or negative.

And importantly, the threshold is something we can change. We can increase the threshold (which causes the classifier to become more conservative in making a positive prediction). Or we can lower the threshold (which makes the classifier more liberal in making a positive prediction).

Moving the threshold has implications for the number of True Positives, True Negatives, False Positives, and False Negatives.

And in turn, moving the threshold can change both precision and recall.

A Quick Review of Precision and Recall

Precision and recall are themselves classification metrics. We use them to evaluate classifier performance.

Precision measures the proportion of positive predictions that are actually positive.

Recall measures the proportion of actual positive examples that are correctly predicted as positive.

Both of these metrics are useful for classification evaluation, but they tend to have strengths and weaknesses.

Moreover, there’s a tradeoff between them.

Generally, as precision increases, recall decreases, and vise versa.

So when we build a classifier, we often need to select the right value for the classification threshold.

But changing the threshold changes precision and recall.

So we need to make sure to choose the “best” threshold for a given task that optimizes the tradeoff between precision and recall.

That’s where the precision-recall curve comes in.

A Precision-Recall Curve Visualizes the Tradeoff Between Precision and Recall

To put it very simply then:

A precision recall curve visualizes the tradeoff between precision and recall for a classifier, as we change the threshold.

As we increase the threshold, precision generally increases and recall decreases.

Conversely, as we decrease the threshold, precision generally decreases and recall increases.

An image that shows how precision and recall change as we change the classification threshold.

Again: there’s a tradeoff between precision and recall.

Furthermore, this tradeoff is important because if we optimize for precision, we’re likely to get more False Negatives, and when we optimize for recall, we’re likely to get more False Positives.

So, we often need to select the right threshold depending on the needs of the task, and also how many False Negatives or False Positives we’re willing to tolerate.

How do we choose the right threshold?

One good way to choose is by plotting the precision-recall curve.

The precision recall curve enables you to see the tradeoff between precision and recall, and therefore helps you decide exactly which threshold will be best, given the goals of your project.

With all that said, let me show you how to create a precision-recall curve in Python.

Plot a Precision-Recall Curve in Python using Seaborn Objects

In the rest of this tutorial, I’ll show you how to plot a precision-recall curve in Python.

Specifically, I’ll show you how to plot a precision-recall curve using the excellent Seaborn Objects data visualization package (which I love).

To do that though, we need to perform several steps.

We will need to:

  • import the necessary packages
  • create the data for the precision-recall curve
  • plot the curve

We’ll do each of these one at a time.

(Note: if you already have some precision/recall data that you want to plot, then feel free to skip to the section where we plot the precision-recall curve with Seaborn.

Import Packages

First, we need to import the necessary packages.

Since we’re going to plot the curve with the Seaborn Objects package, we need to import Seaborn Objects.

But remember: we’ll also need to generate some data to plot. That means that we’ll need to:

  • create a classification dataset
  • split the data
  • build a classification model (we’re going to build a logistic regression model)
  • generate the precision/recall data for the curve itself

All that being the case, we will need to import make_classification, train_test_split, LogisticRegression, and precision_recall_curve from Scikit Learn in execute all of those steps.

# seaborn objects
import seaborn.objects as so

# scikit learn tools to create model and data
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve

Make Data for the Precision-Recall Curve

Now that we have the right packages and tools, we’ll make the data for the precision recall curve.

Remember that when we plot a precision-recall curve, we’re essentially plotting the precision and recall over a range of classification thresholds for a particular model.

To get this data, we need to:

Since I want to keep this brief, I’m going to move quickly through most of these data creation steps, although you can click on the links above to the related tutorials that will explain how each of these tools/functions work, if you want more details.

# GENERATE CLASSIFICATIONDATASET WITH 2 CLASSES
X, y = make_classification(n_samples=1000, n_classes=2, random_state=1)


# SPLIT DATA INTO TRAIN/TEST
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2)


# BUILD & FIT CLASSIFICATION MODEL
my_logistic_reg = LogisticRegression()
my_logistic_reg.fit(X_train, y_train)


# PREDICT PROBABILITIES FOR TEST SET
probabilities_logistic_reg = my_logistic_reg.predict_proba(X_test)


# KEEP PROBABILITIES FOR ONLY THE POSITIVE OUTCOME
probabilities_logistic_posclass = probabilities_logistic_reg[:, 1]


# CALCULATE PRECISION AND RECALL FOR EACH THRESHOLD
precisions, recalls, thresholds = precision_recall_curve(y_test, probabilities_logistic_posclass)

After you run all of the data generation code above to get the precisions and recalls for the logistic regression model, you’ll be ready to plot the precision recall curve.

Plot the Precision-Recall Curve with Seaborn Objects

In this final step, we’re going to plot the precision-recall curve.

To do this, we’ll use the Seaborn Objects visualization package.

Let’s run the code, and then I’ll explain.

(so.Plot()
   .add(so.Path(color = 'red', marker = '.')
        ,x = recalls
        ,y = precisions
        )
   .layout(size = (8,5))
)

OUT:

An image of a precision recall curve made in Python with the Seaborn Objects package.

Explanation

Here, we used Seaborn Objects to create our precision-recall curve.

But syntactically, what did we do?

We used the so.Plot() function to initiate the plot.

We used the .add() method to “add” a layer to the plot. Specifically, we added a “path” layer. This is like a line plot, but it keeps the data in the original order (as opposed to so.Line() which re-sorts the data).

As you can see, we also needed to specify to put the recalls on the x-axis and precisions on the y-axis. We also changed the color to ‘red’ and added a dot marker over the line with marker = '.'.

It’s fairly straight-forward.

Closing Note: Use Seaborn for ML Visualization

One final note before I finish the tutorial.

I recommend that you use Seaborn for your Python data visualization.

It’s great for reporting and analytics, but it’s also great for machine learning visualizations.

But beyond the traditional Seaborn package, I strongly recommend that you start using the new Seaborn Objects visualization system.

Why?

Seaborn Objects is simple, powerful, and easy to use. If you’re familiar with R, it’s like a ggplot2 or Python.

I’ve started to use Seaborn Objects for the vast majority of my Python visualizations.

And for machine learning, as you’ve seen here, you can use it for ML visualizations like precision recall curves, and also for ROC curves (which I show in this tutorial about making ROC curves with Seaborn Objects).

Leave your other questions in the comments below

Do you still have questions about how to make a precision-recall curve in Python?

Leave your questions in the comments section below.

For more machine learning tutorials, sign up for our email list

In this tutorial, I showed you how to make a precision-recall curve in Python with Seaborn.

But if you want to master machine learning in Python, there’s a lot more to learn.

If you want free tutorials about machine learning in Python, then sign up for our email list.

When you sign up, you’ll get free tutorials on:

  • Scikit learn
  • Machine learning
  • Deep learning
  • NumPy
  • Pandas
  • … and more.

So if you want to master machine learning, deep learning, and AI in Python, then sign up now.

Joshua Ebner

Joshua Ebner is the founder, CEO, and Chief Data Scientist of Sharp Sight.   Prior to founding the company, Josh worked as a Data Scientist at Apple.   He has a degree in Physics from Cornell University.   For more daily data science advice, follow Josh on LinkedIn.

2 thoughts on “The Best Way to Plot a Precision-Recall Curve in Python”

  1. Great post again!

    Two remarks:
    1)
    You write:
    “Precision measures the proportion of positive examples that are actually positive.”
    Should it be:
    “Precision measures the proportion of positive predictions that are actually positive.”

    2)
    I didn’t find the answer for the most important question:
    How to choose the proper threshold based on the precision-recall curve?
    In the previous posts you mentioned the F1-score that balances between these two metrics. Could it be somehow incorporated into the diagram? For example as a color argument?

    Reply

Leave a Comment