If you’ve been using R for a while, and you’ve been working with basic data visualization and data exploration techniques, the next logical step is to start learning some machine learning.
To help you begin learning about machine learning in R, I’m going to introduce you to an R package: the caret package. We’ll build a very simple machine learning model as a way to learn some of caret’s basic syntax and functionality.
But before diving into caret, let’s quickly discuss what machine learning is and why we use it.
What is machine learning?
Machine learning is the study of data-driven, computational methods for making inferences and predictions.
Without going into extreme depth here, let’s unpack that by looking at an example.
A simple example
Imagine that you want to understand the relationship between car weight and car fuel efficiency (I.e., miles per gallon); how is fuel efficiency effected by a car’s weight?
To answer this question, you could obtain a dataset with several different car models, and attempt to identify a relationship between weight (which we’ll call
A good starting point would be to simply plot the data, so first, we’ll create a scatterplot using R’s ggplot:
require(ggplot2) ggplot(data = mtcars, aes(x = wt, y = mpg)) + geom_point()
Just examining the data visually, it’s pretty clear that there’s some relationship. But if we want to make more precise claims about the relationship between
Assuming this type of mathematical relationship, machine learning provides a set of methods for identifying that relationship. Said differently, machine learning provides a set of computational methods that accept data observations as inputs, and subsequently estimate that mathematical function, ; machine learning methods learn the relationship by being trained with an input dataset.
Ultimately, once we have this mathematical function (a model), we can use that model to make predictions and inferences.
How much math you really need to know
What I just wrote in the last few paragraphs about “estimating functions” and “mathematical relationships” might cause you to ask a question: “how much math do I need to know to do machine learning?”
Ok, here is some good news: to implement basic machine learning techniques, you don’t need to know much math.
To be clear, there is quite a bit of math involved in machine learning, but most of that math is taken care of for you. What I mean, is that for the most part, R libraries and functions perform the mathematical calculations for you. You just need to know which functions to use, and when to use them.
Here’s an analogy: if you were a carpenter, you wouldn’t need to build your own power tools. You wouldn’t need to build your own drill and power saw. Therefore, you wouldn’t need to understand the mathematics, physics, and electrical engineering principles that would be required to construct those tools from scratch. You could just go and buy them “off the shelf.” To be clear, you’d still need to learn how to use those tools, but you wouldn’t need a deep understanding of math and electrical engineering to operate them.
When you’re first getting started with machine learning, the situation is very similar: you can learn to use some of the tools, without knowing the deep mathematics that makes those tools work.
Having said that, the above analogy is somewhat imperfect. At some point, as you progress to more advanced topics, it will be very beneficial to know the underlying mathematics. My argument, however, is that in the beginning, you can still be productive without a deep understanding of calculus, linear algebra, etc. In any case, I’ll be writing more about “how much math you need” in another blog post.
Ok, so you don’t need to know that much math to get stared, but you’re not entirely off the hook. As I noted above, you still need to know how to use the tools properly.
In some sense, this is one of the challenges of using machine learning tools in R: many of them be difficult to use.
R has many packages for implementing various machine learning methods, but unfortunately many of these tools were designed separately, and they are not always consistent in how they work. The syntax for some of the machine learning tools is very awkward, and syntax from one tool to the next is not always the same. If you don’t know where to start, machine learning in R can become very confusing.
This is why I recommend using the caret package to do machine learning in R.
A quick introduction to caret
For starters, let’s discuss what caret is.
The caret package is a set of tools for building machine learning models in R. The name “caret” stands for Classification And REgression Training. As the name implies, the caret package gives you a toolkit for building classification models and regression models. Moreover, caret provides you with essential tools for:
– Data preparation, including: imputation, centering/scaling data, removing correlated predictors, reducing skewness
– Data splitting
– Model evaluation
– Variable selection
Caret simplifies machine learning in R
While caret has broad functionality, the real reason to use caret is that it’s simple and easy to use.
As noted above, one of the major problems with machine learning in R is that most of R’s different machine learning tools have different interfaces. They almost all “work” a little differently from one another: the syntax is slightly different from one modeling tool to the next; tools for different parts of the machine learning workflow don’t always “work well” together; tools for fine tuning models or performing critical functions may be awkward or difficult to work with. Said succinctly, R has many machine learning tools, but they can be extremely clumsy to work with.
Caret solves this problem. To simplify the process, caret provides tools for almost every part of the model building process, and moreover, provides a common interface to these different machine learning methods.
For example, caret provides a simple, common interface to almost every machine learning algorithm in R. When using caret, different learning methods like linear regression, neural networks, and support vector machines, all share a common syntax (the syntax is basically identical, except for a few minor changes).
Moreover, additional parts of the machine learning workflow – like cross validation and parameter tuning – are built directly into this common interface.
To say that more simply, caret provides you with an easy-to-use toolkit for building many different model types and executing critical parts of the ML workflow. This simple interface enables rapid, iterative modeling. In turn, this iterative workflow will allow you to develop good models faster, with less effort, and with less frustration.
Now that you’ve been introduced to caret, let’s return to the example above (of
Again, imagine you want to learn the relationship between
Here in this example, we’re going to make an additional assumption that will simplify the process somewhat: we’re going to assume that the relationship is linear; we’ll assume that that it can be described by a straight line of the form .
In terms of our modeling effort, this means that we’ll be using linear regression to build our machine learning model.
Without going into the details of linear regression (I’ll save that for another blog post), let’s look at how we implement linear regression with caret.
The core of caret’s functionality is the
Let’s take a look at this syntactically.
Here is the syntax for a linear regression model, regressing
#~~~~~~~~~~~~~~~~~~~~~~~~~~ # Build model using train() #~~~~~~~~~~~~~~~~~~~~~~~~~~ require(caret) model.mtcars_lm <- train(mpg ~ wt ,data = mtcars ,method = "lm" )
That's it. The syntax for building a linear regression is extremely simple with caret.
Now that we have a simple model, let's quickly extract the regression coefficients and plot the model (i.e., plot the linear function that describes the relationship between
#~~~~~~~~~~~~~~~~~~~~~~~~~~ # Retrieve coefficients for # - slope # - intercept #~~~~~~~~~~~~~~~~~~~~~~~~~~ coef.icept <- coef(model.mtcars_lm$finalModel) coef.slope <- coef(model.mtcars_lm$finalModel) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Plot scatterplot and regression line # using ggplot() #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ggplot(data = mtcars, aes(x = wt, y = mpg)) + geom_point() + geom_abline(slope = coef.slope, intercept = coef.icept, color = "red")
Now, let's look more closely at the syntax and how it works.
When training a model using
- The dataset you're working with
- The target variable you're trying to predict (e.g., the
- The input variable (e.g., the
- The machine learning method you want to use (in this case "linear regression")
In caret's syntax, you identify the target variable and input variables using the "formula notation."
The basic syntax for formula notation is
Now, with this knowledge about caret's formula syntax, let's reexamine the above code. Because we want to predict
The data = parameter
method = parameter
Finally, we see the
Keep in mind, however, we could select a different learning method. Although it's beyond the scope of this blog post to discuss all of the possible learning methods that we could use here, there are many different methods we could use. For example, if we wanted to use the k-nearest neighbor technique, we could use the
Again, it's beyond the scope of this post to discuss all of the different model types. However, as you learn more about machine learning, and want to try out more advanced machine learning techniques, this is how you can implement them. You simply change the learning method by changing the argument of the
This is a good place to reiterate one of caret's primary advantages: switching between model types is extremely easy when we use caret's
Caret's syntax allows you to very easily change the learning method. In turn, this allows you to "try out" and evaluate many different learning methods rapidly and iteratively. You can just re-run your code with different values for the
Now that you have a high-level understanding of caret, you're ready to dive deeper into machine learning in R.
Keep in mind though, if you're new to machine learning, there's still lots more to learn. Machine learning is intricate and fascinatingly complex. Moreover, caret has a variety of additional tools for model building. We've just scratched the surface here.
At Sharp Sight Labs, we'll be creating blog posts to teach you more of the basics, including tutorials about:
- the bias-variance tradeoff
- deeper looks at regression
- data preparation
- and more.
In the mean time, send me your questions.
If you have questions about machine learning, or topics you're struggling with, sign up for the email list. Once you're on the email list, reply directly to any of the emails and send in your questions.
9 thoughts on “A quick introduction to machine learning in R with caret”
Thank you for this article, this is a nice easy intro into ML and the CARET package. I’m excited to start to play around with it and learn how to leverage ML.
It’s a really exciting time in ML and data science right now, so great to hear that you’re diving in.
Great introduction, I am eager to swim in data science using R programming language,
Thanks for the article. I have base knowledge in linear regression but i didn’t know caret.
For linear regression, is caret restricted to a unique independant variable ?
Well, I got to install caret recently (I believe on your recommendation). I’m yet to use it, but I must say that after a couple of years of using R, I’ve never quite run into a massive package like this one! Hmm, I’m really looking forward to working with it…
I am learning data analysis with R through most of your material, and just finished the last video on your course. For machine learning, what type of modeling would likely work best where I want to predict number of cases shipped based on what month/year/ # of working days in the month? I’ve been trying to build a model to do our forecasting based on historical input, however the amount of data we have is pretty thin when going by monthly, as it only goes back about 3 years. So far I’ve not had much luck finding modeling tutorials that point me in the right direction.
Great introductory article ! Thank you for the example code and explanations ! Please follow up with different methods soon !
The cool thing is simple swap of model choice. Admittedly caret isn’t needed for this simple example at all. Basically we could have just put the regression line geometry in directly with ggplot:
ggplot(data=mtcars,aes(wt,mpg))+geom_point() + geom_smooth(formula= y~x,method = ‘lm’,se=FALSE)
Nevertheless, the example was helpful for me to get a simple overview.
The example isn’t to say that caret is the best tool for building a simple regression line, but rather to show how caret works with a very simple example.
My philosophy with teaching complicated things (like ML) is to start with ultra simple example to help people develop intuition, and then increase complexity.