install.packages("covid19.analytics")
11 Interpreting Model Results Through Visualisation
Learning Objectives
- Overview of the basic plots for visualising predictions and outcomes
- How to customise a plot using colours, palettes, legends, and guides
- Tell the story with data through plot layouts and saving as an image
This chapter is an essential part of the data visualisation section, it focuses on how to effectively interpret and use the results from machine learning models in the context of health metrics and infectious diseases. The aim is to provide actionable insights, enhance decision-making, and communicate findings to various stakeholders.
-
Visualising Predictions and Outcomes:
- Techniques for visualising predicted vs. actual values.
- Using scatter plots, line plots, and bar charts to compare predictions and observed data.
- Highlighting discrepancies and trends through residual plots and error distribution charts.
-
Case Studies and Applications:
- Real-world examples of model results application in public health.
- Case studies demonstrating the impact of machine learning and spatial modelling on health metrics.
- Lessons learned and best practices from successful implementations.
11.1 Practical Insights and Examples
11.1.1 Example: Visualising COVID-19 Data
Data is from the ECDC (European Centre for Disease Prevention and Control), data is updated daily. The data includes information on the number of deaths for each country. The dataset can be accessed through the covid19.analytics package.
library(tidyverse)
library(covid19.analytics)
covid19.ALL.agg.cases <- covid19.data("aggregated")
#> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#> ================================================================================
#> ================================================================================
#> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
data <- covid19.ALL.agg.cases %>%
select(Country_Region, Last_Update,
Confirmed, Deaths, Recovered) %>%
mutate(Last_Update = as.Date(Last_Update))
data %>%
head
#> Country_Region Last_Update Confirmed Deaths Recovered
#> 1 Afghanistan 2021-02-02 55059 2404 47723
#> 2 Albania 2021-02-02 78992 1393 47922
#> 3 Algeria 2021-02-02 107578 2894 73530
#> 4 Andorra 2021-02-02 9972 101 9206
#> 5 Angola 2021-02-02 19829 466 18180
#> 6 Antigua and Barbuda 2021-02-02 234 7 177
data %>%
group_by(Last_Update) %>%
summarise(across(Confirmed:Recovered, ~ mean(.x, na.rm = TRUE)))
#> # A tibble: 6 × 4
#> Last_Update Confirmed Deaths Recovered
#> <date> <dbl> <dbl> <dbl>
#> 1 2020-08-04 76 1.5 0
#> 2 2020-08-07 0 0 0
#> 3 2020-12-21 29.9 0.3 1.3
#> 4 2021-01-08 196 3 180
#> 5 2021-01-24 0 0 0
#> 6 2021-02-02 26088. 565. 14460.
data %>%
ggplot(aes(x = Confirmed, y = Deaths)) +
geom_point() +
geom_smooth() +
labs(title = "COVID-19 Deaths by Country",
x = "Confirmed", y = "Number of Deaths")
The relationship is linear between the number of confirmed cases and the number of deaths. The model can be fitted using linear regression to predict the number of deaths based on the number of confirmed cases.
mod <- lm(Deaths ~ Confirmed + Recovered, data = data)
summary(mod)$coefficients
#> Estimate Std. Error t value Pr(>|t|)
#> (Intercept) -34.040275885 1.989420e+01 -1.711066 8.714698e-02
#> Confirmed 0.025050716 2.300418e-04 108.896384 0.000000e+00
#> Recovered -0.003785725 3.075941e-04 -12.307533 3.432948e-34
Pred <- predict(mod, newdata = data)
data_pred <- cbind(data %>% select(Confirmed, Deaths),Pred) %>%
mutate(Residuals = Deaths - Pred)
data_pred %>%
head()
#> Confirmed Deaths Pred Residuals
#> 1 55059 2404 1164.56098 1239.43902
#> 2 78992 1393 1763.34641 -370.34641
#> 3 107578 2894 2382.50135 511.49865
#> 4 9972 101 180.91409 -79.91409
#> 5 19829 466 393.86590 72.13410
#> 6 234 7 -28.84848 35.84848
Let’s see the results in a plot:
data_pred %>%
ggplot(aes(Deaths, Pred)) +
geom_point() +
geom_abline(intercept = 0, slope = 1, color = "red") +
geom_segment(aes(xend = Deaths, yend = Pred), alpha = 0.5) +
labs(title = "Predicted vs. Actual Deaths",
x = "Actual Deaths", y = "Predicted Deaths")
data_pred %>%
ggplot(aes(Confirmed, Residuals)) +
geom_point() +
geom_hline(yintercept = 0, color = "red") +
labs(title = "Residuals vs. Confirmed Cases",
x = "Confirmed Cases", y = "Residuals")
data_pred %>%
ggplot() +
geom_point(aes(Confirmed, Deaths)) +
geom_line(aes(Confirmed, Pred), linetype = "dashed") +
labs(title = "Observed vs. Predicted",
x = "Confirmed Cases", y = "Deaths",
caption = "Source: ECDC")
data_pred %>%
ggplot(aes(x = Confirmed, y = Deaths)) +
geom_point() +
geom_line(aes(Confirmed, Pred), linetype = "dashed") +
geom_segment(aes(xend = Confirmed, yend = Pred)) +
labs(title = "Observed vs. Predicted",
subtitle = "Segments represent the residuals",
x = "Confirmed Cases", y = "Deaths",
caption = "Source: ECDC")
cfr <- covid19.ALL.agg.cases %>%
select(Country_Region, Last_Update, Lat, Long_, Case_Fatality_Ratio)
cfr %>%
head()
#> Country_Region Last_Update Lat Long_
#> 1 Afghanistan 2021-02-02 05:22:49 33.93911 67.70995
#> 2 Albania 2021-02-02 05:22:49 41.15330 20.16830
#> 3 Algeria 2021-02-02 05:22:49 28.03390 1.65960
#> 4 Andorra 2021-02-02 05:22:49 42.50630 1.52180
#> 5 Angola 2021-02-02 05:22:49 -11.20270 17.87390
#> 6 Antigua and Barbuda 2021-02-02 05:22:49 17.06080 -61.79640
#> Case_Fatality_Ratio
#> 1 4.366225
#> 2 1.763470
#> 3 2.690141
#> 4 1.012836
#> 5 2.350093
#> 6 2.991453
covid19.TS.deaths <- covid19.data("ts-deaths")
#> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#> --------------------------------------------------------------------------------
covid19.TS.deaths %>%
pivot_longer(cols = c(-Province.State,
-Country.Region,
-Lat, -Long),
names_to = "Date", values_to = "Deaths") %>%
head()
#> # A tibble: 6 × 6
#> Province.State Country.Region Lat Long Date Deaths
#> <chr> <chr> <dbl> <dbl> <chr> <int>
#> 1 "" Afghanistan 33.9 67.7 2020-01-22 0
#> 2 "" Afghanistan 33.9 67.7 2020-01-23 0
#> 3 "" Afghanistan 33.9 67.7 2020-01-24 0
#> 4 "" Afghanistan 33.9 67.7 2020-01-25 0
#> 5 "" Afghanistan 33.9 67.7 2020-01-26 0
#> 6 "" Afghanistan 33.9 67.7 2020-01-27 0
11.1.2 Example: Ischemic Stroke Decision Tree
In this example we have a look at how to visualise the results of a decision tree model for predicting Ischemic Stroke.
Load necessary libraries:
# Ischemic Stroke decision tree
library(tidymodels)
library(rpart)
library(rpart.plot)
Load the data for the Ischemic Stroke from the GitHub repository of the book: “Feature Engineering and Selection: A Practical Approach for Predictive Models” by Max Kuhn and Kjell Johnson.
Data is already split into training and test sets, we will combine them for the analysis. We just need to select the variables of interest for the analysis with any_of()
, a function to select the variables of interest for the analysis.
?any_of()
Set up the recipe for the data with the recipe()
function from the tidymodels
package. We will use the step_corr()
function to remove highly correlated predictors, step_center()
and step_scale()
to standardise the predictors, step_YeoJohnson()
to transform the predictors, and step_zv()
to remove zero variance predictors.
is_recipe <- recipe(Stroke ~ ., data = selected_train) %>%
#step_interact(int_form) %>%
step_corr(all_predictors(), threshold = 0.75) %>%
step_center(all_predictors()) %>%
step_scale(all_predictors()) %>%
step_YeoJohnson(all_predictors()) %>%
step_zv(all_predictors())
is_recipe%>%prep()%>%bake(new_data=NULL)%>%select(1:5)%>%head(5)
#> # A tibble: 5 × 5
#> CALCVolProp MATXVol MATXVolProp MaxCALCAreaProp MaxDilationByArea
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 -0.143 0.106 -0.161 0.0541 0.0249
#> 2 -1.35 -0.0530 0.858 -0.931 -0.748
#> 3 -0.784 1.03 0.218 0.284 -0.320
#> 4 1.06 0.587 -0.144 1.05 0.423
#> 5 -0.708 -0.107 -0.307 -0.203 -0.738
Set up the decision tree model with the decision_tree()
function from the tidymodels
package. We will use the rpart
engine for the decision tree model and set the mode to classification.
Finally, we will fit the model with the fit()
function from the tidymodels
package and visualise the results with the rpart.plot
package.