Check the documentation for ?sort
and ?order
and then write code to output the values of Y
sorted according to the order of X
X <- runif(10)
Y <- X + rnorm(10)
qplot(X, Y)
Base R solution
Y[order(X)]
## [1] 0.70076257 0.08613248 0.58063740 0.12880954 1.04591264 0.92221713
## [7] 0.04330536 0.96979525 0.87850142 1.48724097
egdf <- data.frame(X = X, Y = Y)
egdf %>%
arrange(X)
## X Y
## 1 0.1621587 0.70076257
## 2 0.1907793 0.08613248
## 3 0.3844635 0.58063740
## 4 0.4481603 0.12880954
## 5 0.5358620 1.04591264
## 6 0.7060425 0.92221713
## 7 0.7151095 0.04330536
## 8 0.7449972 0.96979525
## 9 0.8567907 0.87850142
## 10 0.9625128 1.48724097
Below is some code that computes the average values of Y
above and below a given split point
Base R
x_split <- 0.5
c(mean(Y[X <= x_split]),
mean(Y[X > x_split]))
## [1] 0.3740855 0.8911621
tidyverse
egdf %>%
group_by(X <= x_split) %>%
summarize(avg_Y = mean(Y))
## # A tibble: 2 x 2
## `X <= x_split` avg_Y
## <lgl> <dbl>
## 1 FALSE 0.891
## 2 TRUE 0.374
x_split
?Have to re-sort the data to find the indexes of Y
corresponding to X
values above/below the new split point
(Sorting a list of k
items requires–in the worst case–an order of k*log(k)
operations)
X
only once, and then, taking each X
value as a split point consecutively, computes the average Y
values above and below that split point while minimizing unnecessary computationx_order <- order(X)
Y_sorted <- Y[x_order]
n <- length(X)
for (i in 1:(n-1)) {
print(c(mean(Y_sorted[1:i]),
mean(Y_sorted[(i+1):n])))
}
## [1] 0.7007626 0.6825058
## [1] 0.3934475 0.7570525
## [1] 0.4558441 0.7822546
## [1] 0.3740855 0.8911621
## [1] 0.5084509 0.8602120
## [1] 0.5774120 0.8447108
## [1] 0.501111 1.111846
## [1] 0.5596965 1.1828712
## [1] 0.5951193 1.4872410
We can use the average Y
values as predictions within each leaf, compute the RSS, and choose the split point giving the lowest RSS
Instead of making the loop go from 1 to (n-1)
, we can make it start at min.obs
and end at (n - min.obs)
Write a function that inputs a single numeric predictor and outcome, and outputs a splitting point that achieves the lowest RSS
tree_split <- function(x, y, min.obs = 10) {
x_order <- order(x)
X <- x[x_order]
Y <- y[x_order]
n <- length(x)
RSSs <- numeric(length = n-1)
RSSs[1:length(RSSs)] <- Inf
for (i in min.obs:(n-min.obs)) {
Y_left <- Y[1:i]
Y_right <- Y[(i+1):n]
RSSs[i] <- sum((Y_left - mean(Y_left))^2) +
sum((Y_right - mean(Y_right))^2)
}
X[which.min(RSSs)]
}
n <- 1000
mixture_ids <- rbinom(n, 1, .5)
x <- rnorm(n) + 3*mixture_ids
y <- rnorm(n) + 3*mixture_ids
qplot(x,y)
tree_split(x, y, min.obs = 10)
## [1] 1.658915
n <- 1000
mixture_ids <- rbinom(n, 1, .5)
x <- rnorm(n) + 2*mixture_ids
y <- rnorm(n) + 1*mixture_ids
qplot(x,y)
tree_split(x, y, min.obs = 10)
## [1] 1.071319
n
and repeatgapminder
data, plot the initial split pointgm2007 <- gapminder %>% filter(year == 2007)
split2007 <- with(gm2007, tree_split(gdpPercap, lifeExp))
split2007
## [1] 2280.77
gm2007 %>%
ggplot(aes(gdpPercap, lifeExp)) +
geom_point() +
geom_vline(xintercept = split2007)
n <- 1000
mixture_ids <- rbinom(n, 1, .5)
x <- rnorm(n) + 3*mixture_ids
y <- rnorm(n) + 3*mixture_ids
x <- c(x, rnorm(n/2, mean = -2))
y <- c(y, rnorm(n/2, mean = 5))
egdf <- data.frame(x = x, y = y)
egplot <- egdf %>%
ggplot(aes(x, y)) +
geom_point()
egplot
split1 <- tree_split(x, y)
split1
## [1] -1.026118
inds1 <- which(x <= split1)
inds2 <- setdiff(1:length(x), inds1)
split2 <- tree_split(x[inds1], y[inds1])
split2
## [1] -2.30362
split3 <- tree_split(x[inds2], y[inds2])
split3
## [1] 1.894641
egplot +
geom_vline(xintercept = split1) +
geom_vline(xintercept = split2) +
geom_vline(xintercept = split3)