This lab on Splines and GAMs in R comes from p. 293-297 of "Introduction to Statistical Learning with Applications in R" by Gareth James, Daniela Witten, Trevor Hastie and Robert Tibshirani. It was re-implemented in Fall 2016 in tidyverse
format by Amelia McNamara and R. Jordan Crouser at Smith College.
Want to follow along on your own machine? Download the .Rmd or Jupyter Notebook version.
library(ISLR)
library(dplyr)
library(ggplot2)
In order to fit regression splines in R, we use the splines
library. In lecture, we saw that regression splines can be fit by constructing an appropriate matrix of basis functions. The bs()
function generates the entire matrix of basis functions for splines with the specified set of knots. By default, cubic
splines are produced. Fitting wage
to age
using a regression spline is simple:
library(splines)
# Get min/max values of age using the range() function
agelims = Wage %>%
select(age) %>%
range
# Generate a sequence of age values spanning the range
age_grid = seq(from = min(agelims), to = max(agelims))
# Fit a regression spline using basis functions
fit = lm(wage~bs(age, knots = c(25,40,60)), data = Wage)
# Predict the value of the generated ages,
# returning the standard error using se = TRUE
pred = predict(fit, newdata = list(age = age_grid), se = TRUE)
# Compute error bands (2*SE)
se_bands = with(pred, cbind("upper" = fit+2*se.fit,
"lower" = fit-2*se.fit))
# Plot the spline and error bands
ggplot() +
geom_point(data = Wage, aes(x = age, y = wage)) +
geom_line(aes(x = age_grid, y = pred$fit), color = "#0000FF") +
geom_ribbon(aes(x = age_grid,
ymin = se_bands[,"lower"],
ymax = se_bands[,"upper"]),
alpha = 0.3) +
xlim(agelims)
Here we have prespecified knots at ages 25, 40, and 60. This produces a
spline with six basis functions. (Recall that a cubic spline with three knots
has seven degrees of freedom; these degrees of freedom are used up by an
intercept, plus six basis functions.) We could also use the df
option to
produce a spline with knots at uniform quantiles of the data:
# Specifying knots directly: 6 basis functions
with(Wage, dim(bs(age, knots = c(25,40,60))))
# Specify desired degrees of freedom, select knots automatically:
# still 6 basis functions
with(Wage, dim(bs(age, df = 6)))
# Show me where the knots were placed
with(Wage, attr(bs(age, df = 6),"knots"))
In this case R chooses knots at ages 33.8, 42.0, and 51.0, which correspond
to the 25th, 50th, and 75th percentiles of age. The function bs()
also has
a degree
argument, so we can fit splines of any degree, rather than the
default degree of 3 (which yields a cubic spline).
In order to instead fit a natural spline, we use the ns()
function. Here
we fit a natural spline with four degrees of freedom:
fit2 = lm(wage~ns(age, df = 4), data = Wage)
pred2 = predict(fit2, newdata = list(age = age_grid), se = TRUE)
# Compute error bands (2*SE)
se_bands2 = with(pred, cbind("upper" = fit+2*se.fit,
"lower" = fit-2*se.fit))
# Plot the natural spline and error bands
ggplot() +
geom_point(data = Wage, aes(x = age, y = wage)) +
geom_line(aes(x = age_grid, y = pred2$fit), color = "#0000FF") +
geom_ribbon(aes(x = age_grid,
ymin = se_bands2[,"lower"],
ymax = se_bands2[,"upper"]),
alpha = 0.3) +
xlim(agelims)
As with the bs()
function, we could instead specify the knots directly using
the knots
option.
In order to fit a smoothing spline, we use the smooth.spline()
function. Here we'll reproduce the plot we saw in lecture showing a 16-degree and LOOCV smoothing spline on the Wage
data:
# Fit 2 smoothing splines
fit_smooth = with(Wage, smooth.spline(age, wage, df = 16))
fit_smooth_cv = with(Wage, smooth.spline(age, wage, cv = TRUE))
# Plot the smoothing splines
ggplot() +
geom_point(data = Wage, aes(x = age, y = wage)) +
geom_line(aes(x = fit_smooth$x, y = fit_smooth$y,
color = "16 degrees of freedom")) +
geom_line(aes(x = fit_smooth_cv$x, y = fit_smooth_cv$y,
color = "6.8 effective degrees of freedom")) +
theme(legend.position = 'bottom')+
labs(title = "Smoothing Splines", colour="")
Notice that in the first call to smooth.spline()
, we specified df=16
. The
function then determines which value of $\lambda$ leads to 16 degrees of freedom. In
the second call to smooth.spline()
, we select the smoothness level by cross-validation;
this results in a value of $\lambda$ that yields 6.8 degrees of freedom.
We now fit a really simple GAM to predict wage using natural spline functions of year
and age
, treating education
as a qualitative predictor. Since
this is just a big linear regression model using an appropriate choice of
basis functions, we can simply do this using the lm()
function:
gam1 = lm(wage ~ ns(year, 4) + ns(age, 5) + education, data = Wage)
What if we want to fit the model using smoothing splines rather than natural
splines? In order to fit more general sorts of GAMs, using smoothing splines
or other components that cannot be expressed in terms of basis functions
and then fit using least squares regression, we will need to use the gam
library in R. The s()
function, which is part of the gam
library, is used to indicate that
we would like to use a smoothing spline. We'll specify that the function of
year
should have 4 degrees of freedom, and that the function of age
will
have 5 degrees of freedom. Since education
is qualitative, we leave it as is,
and it is converted into four dummy variables.
We can use the gam()
function in order to fit a GAM using these components. All of the terms are
fit simultaneously, taking each other into account to explain the response:
library(gam)
gam2 = gam(wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
par(mfrow = c(1,3))
plot(gam2, se = TRUE, col = "blue")
The generic plot()
function recognizes that gam2 is an object of class gam,
and invokes the appropriate plot.gam()
method. Conveniently, even though our simple
gam1
is not of class gam
but rather of class lm
, we can still use plot.gam()
on it:
par(mfrow = c(1,3))
plot.gam(gam1, se = TRUE, col = "red")
Notice here we had to use plot.gam()
rather than the generic plot()
function.
In these plots, the function of year
looks rather linear. We can perform a
series of ANOVA tests in order to determine which of these three models is
best: a GAM that excludes year
($M_1$), a GAM that uses a linear function
of year
($M_2$), or a GAM that uses a spline function of year
($M_3$):
gam_no_year = gam(wage ~ s(age, 5) + education, data = Wage)
gam_linear_year = gam(wage ~ year + s(age, 5) + education, data = Wage)
print(anova(gam_no_year, gam_linear_year, gam2, test = "F"))
We find that there is compelling evidence that a GAM with a linear function
of year
is better than a GAM that does not include year
at all
($p$-value=0.00014). However, there is no evidence that a non-linear function
of year
is helpful ($p$-value=0.349). In other words, based on the results
of this ANOVA, $M_2$ is preferred.
The summary()
function produces a summary of the GAM fit:
summary(gam2)
The $p$-values for year
and age
correspond to a null hypothesis of a linear
relationship versus the alternative of a non-linear relationship. The large
$p$-value for year
reinforces our conclusion from the ANOVA test that a linear
function is adequate for this term. However, there is very clear evidence
that a non-linear term is required for age
.
We can make predictions from gam
objects, just like from lm
objects,
using the predict()
method for the class gam
. Here we make predictions on
the training set:
preds = predict(gam_linear_year, newdata = Wage)
In order to fit a logistic regression GAM, we once again use the I()
function
in constructing the binary response variable, and set family=binomial
:
gam_logistic = gam(I(wage>250) ~ year + s(age, df = 5) + education,
family = binomial, data = Wage)
par(mfrow=c(1,3))
plot(gam_logistic, se = TRUE, col = "green")
It is easy to see that there are no high earners in the <HS
category:
with(Wage, table(education, I(wage>250)))
Hence, we fit a logistic regression GAM using all but this category. This provides more sensible results:
college_educated = Wage %>%
filter(education != "1. < HS Grad")
gam_logistic_subset = gam(I(wage>250) ~ year + s(age, df = 5) + education,
family = binomial, data = college_educated)
par(mfrow=c(1,3))
plot(gam_logistic_subset, se = TRUE, col = "green")
To get credit for this lab, post your answer to thew following question:
to Moodle: https://moodle.smith.edu/mod/quiz/view.php?id=262963