This lab on Splines and GAMs in R comes from p. 293-297 of "Introduction to Statistical Learning with Applications in R" by Gareth James, Daniela Witten, Trevor Hastie and Robert Tibshirani. It was re-implemented in Fall 2016 in tidyverse format by Amelia McNamara and R. Jordan Crouser at Smith College.

Want to follow along on your own machine? Download the .Rmd or Jupyter Notebook version.

library(ISLR)
library(dplyr)
library(ggplot2)

7.8.2 Splines

In order to fit regression splines in R, we use the splines library. In lecture, we saw that regression splines can be fit by constructing an appropriate matrix of basis functions. The bs() function generates the entire matrix of basis functions for splines with the specified set of knots. By default, cubic splines are produced. Fitting wage to age using a regression spline is simple:

library(splines)

# Get min/max values of age using the range() function
agelims = Wage %>%
    select(age) %>%
    range

# Generate a sequence of age values spanning the range
age_grid = seq(from = min(agelims), to = max(agelims))

# Fit a regression spline using basis functions
fit = lm(wage~bs(age, knots = c(25,40,60)), data = Wage)

# Predict the value of the generated ages, 
# returning the standard error using se = TRUE
pred = predict(fit, newdata = list(age = age_grid), se = TRUE)

# Compute error bands (2*SE)
se_bands = with(pred, cbind("upper" = fit+2*se.fit, 
                            "lower" = fit-2*se.fit))

# Plot the spline and error bands
ggplot() +
  geom_point(data = Wage, aes(x = age, y = wage)) +
  geom_line(aes(x = age_grid, y = pred$fit), color = "#0000FF") + 
  geom_ribbon(aes(x = age_grid, 
                  ymin = se_bands[,"lower"], 
                  ymax = se_bands[,"upper"]), 
              alpha = 0.3) +
  xlim(agelims)

Here we have prespecified knots at ages 25, 40, and 60. This produces a spline with six basis functions. (Recall that a cubic spline with three knots has seven degrees of freedom; these degrees of freedom are used up by an intercept, plus six basis functions.) We could also use the df option to produce a spline with knots at uniform quantiles of the data:

# Specifying knots directly: 6 basis functions
with(Wage, dim(bs(age, knots = c(25,40,60))))

# Specify desired degrees of freedom, select knots automatically: 
# still 6 basis functions
with(Wage, dim(bs(age, df = 6)))

# Show me where the knots were placed
with(Wage, attr(bs(age, df = 6),"knots"))

In this case R chooses knots at ages 33.8, 42.0, and 51.0, which correspond to the 25th, 50th, and 75th percentiles of age. The function bs() also has a degree argument, so we can fit splines of any degree, rather than the default degree of 3 (which yields a cubic spline).

In order to instead fit a natural spline, we use the ns() function. Here we fit a natural spline with four degrees of freedom:

fit2 = lm(wage~ns(age, df = 4), data = Wage)
pred2 = predict(fit2, newdata = list(age = age_grid), se = TRUE)

# Compute error bands (2*SE)
se_bands2 = with(pred, cbind("upper" = fit+2*se.fit, 
                            "lower" = fit-2*se.fit))

# Plot the natural spline and error bands
ggplot() +
  geom_point(data = Wage, aes(x = age, y = wage)) +
  geom_line(aes(x = age_grid, y = pred2$fit), color = "#0000FF") + 
  geom_ribbon(aes(x = age_grid, 
                  ymin = se_bands2[,"lower"], 
                  ymax = se_bands2[,"upper"]), 
              alpha = 0.3) +
  xlim(agelims)

As with the bs() function, we could instead specify the knots directly using the knots option.

In order to fit a smoothing spline, we use the smooth.spline() function. Here we'll reproduce the plot we saw in lecture showing a 16-degree and LOOCV smoothing spline on the Wage data:

# Fit 2 smoothing splines
fit_smooth = with(Wage, smooth.spline(age, wage, df = 16))
fit_smooth_cv = with(Wage, smooth.spline(age, wage, cv = TRUE))

# Plot the smoothing splines
ggplot() +
  geom_point(data = Wage, aes(x = age, y = wage)) +
  geom_line(aes(x = fit_smooth$x, y = fit_smooth$y, 
                color = "16 degrees of freedom"))  +
  geom_line(aes(x = fit_smooth_cv$x, y = fit_smooth_cv$y, 
                color = "6.8 effective degrees of freedom")) +
  theme(legend.position = 'bottom')+ 
  labs(title = "Smoothing Splines", colour="")

Notice that in the first call to smooth.spline(), we specified df=16. The function then determines which value of $\lambda$ leads to 16 degrees of freedom. In the second call to smooth.spline(), we select the smoothness level by cross-validation; this results in a value of $\lambda$ that yields 6.8 degrees of freedom.

7.8.3 GAMs

We now fit a really simple GAM to predict wage using natural spline functions of year and age, treating education as a qualitative predictor. Since this is just a big linear regression model using an appropriate choice of basis functions, we can simply do this using the lm() function:

gam1 = lm(wage ~ ns(year, 4) + ns(age, 5) + education, data = Wage)

What if we want to fit the model using smoothing splines rather than natural splines? In order to fit more general sorts of GAMs, using smoothing splines or other components that cannot be expressed in terms of basis functions and then fit using least squares regression, we will need to use the gam library in R. The s() function, which is part of the gam library, is used to indicate that we would like to use a smoothing spline. We'll specify that the function of year should have 4 degrees of freedom, and that the function of age will have 5 degrees of freedom. Since education is qualitative, we leave it as is, and it is converted into four dummy variables.

We can use the gam() function in order to fit a GAM using these components. All of the terms are fit simultaneously, taking each other into account to explain the response:

library(gam)
gam2 = gam(wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
par(mfrow = c(1,3))
plot(gam2, se = TRUE, col = "blue")

The generic plot() function recognizes that gam2 is an object of class gam, and invokes the appropriate plot.gam() method. Conveniently, even though our simple gam1 is not of class gam but rather of class lm, we can still use plot.gam() on it:

par(mfrow = c(1,3))
plot.gam(gam1, se = TRUE, col = "red")

Notice here we had to use plot.gam() rather than the generic plot() function.

In these plots, the function of year looks rather linear. We can perform a series of ANOVA tests in order to determine which of these three models is best: a GAM that excludes year ($M_1$), a GAM that uses a linear function of year ($M_2$), or a GAM that uses a spline function of year ($M_3$):

gam_no_year = gam(wage ~ s(age, 5) + education, data = Wage)
gam_linear_year = gam(wage ~ year + s(age, 5) + education, data = Wage)
print(anova(gam_no_year, gam_linear_year, gam2, test = "F"))

We find that there is compelling evidence that a GAM with a linear function of year is better than a GAM that does not include year at all ($p$-value=0.00014). However, there is no evidence that a non-linear function of year is helpful ($p$-value=0.349). In other words, based on the results of this ANOVA, $M_2$ is preferred.

The summary() function produces a summary of the GAM fit:

summary(gam2)

The $p$-values for year and age correspond to a null hypothesis of a linear relationship versus the alternative of a non-linear relationship. The large $p$-value for year reinforces our conclusion from the ANOVA test that a linear function is adequate for this term. However, there is very clear evidence that a non-linear term is required for age.

We can make predictions from gam objects, just like from lm objects, using the predict() method for the class gam. Here we make predictions on the training set:

preds = predict(gam_linear_year, newdata = Wage)

Logistic Regression GAMs

In order to fit a logistic regression GAM, we once again use the I() function in constructing the binary response variable, and set family=binomial:

gam_logistic = gam(I(wage>250) ~ year + s(age, df = 5) + education, 
                   family = binomial, data = Wage)
par(mfrow=c(1,3))
plot(gam_logistic, se = TRUE, col = "green")

It is easy to see that there are no high earners in the <HS category:

with(Wage, table(education, I(wage>250)))

Hence, we fit a logistic regression GAM using all but this category. This provides more sensible results:

college_educated = Wage %>%
  filter(education != "1. < HS Grad")

gam_logistic_subset = gam(I(wage>250) ~ year + s(age, df = 5) + education, 
                          family = binomial, data = college_educated)
par(mfrow=c(1,3))
plot(gam_logistic_subset, se = TRUE, col = "green")

To get credit for this lab, post your answer to thew following question:

  • How would you choose whether to use a polynomial, step, or spline function for each predictor when building a GAM?

to Moodle: https://moodle.smith.edu/mod/quiz/view.php?id=262963