11  Interpreting Model Results Through Visualisation

Learning Objectives

  • Overview of the basic plots for visualising predictions and outcomes
  • How to customise a plot using colours, palettes, legends, and guides
  • Tell the story with data through plot layouts and saving as an image


This chapter is an essential part of the data visualisation section, it focuses on how to effectively interpret and use the results from machine learning models in the context of health metrics and infectious diseases. The aim is to provide actionable insights, enhance decision-making, and communicate findings to various stakeholders.

  1. Visualising Predictions and Outcomes:
    • Techniques for visualising predicted vs. actual values.
    • Using scatter plots, line plots, and bar charts to compare predictions and observed data.
    • Highlighting discrepancies and trends through residual plots and error distribution charts.
  2. Case Studies and Applications:
    • Real-world examples of model results application in public health.
    • Case studies demonstrating the impact of machine learning and spatial modelling on health metrics.
    • Lessons learned and best practices from successful implementations.

11.1 Practical Insights and Examples

11.1.1 Example: Visualising COVID-19 Data

Data is from the ECDC (European Centre for Disease Prevention and Control), data is updated daily. The data includes information on the number of deaths for each country. The dataset can be accessed through the covid19.analytics package.

install.packages("covid19.analytics")
library(tidyverse)
library(covid19.analytics)

covid19.ALL.agg.cases <- covid19.data("aggregated")
#> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#> ================================================================================
#> ================================================================================
#> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
data <- covid19.ALL.agg.cases %>% 
  select(Country_Region, Last_Update, 
         Confirmed, Deaths, Recovered) %>%
  mutate(Last_Update = as.Date(Last_Update)) 

data %>% 
  head
#>        Country_Region Last_Update Confirmed Deaths Recovered
#> 1         Afghanistan  2021-02-02     55059   2404     47723
#> 2             Albania  2021-02-02     78992   1393     47922
#> 3             Algeria  2021-02-02    107578   2894     73530
#> 4             Andorra  2021-02-02      9972    101      9206
#> 5              Angola  2021-02-02     19829    466     18180
#> 6 Antigua and Barbuda  2021-02-02       234      7       177
data %>%
  group_by(Last_Update) %>%
  summarise(across(Confirmed:Recovered, ~ mean(.x, na.rm = TRUE)))
#> # A tibble: 6 × 4
#>   Last_Update Confirmed Deaths Recovered
#>   <date>          <dbl>  <dbl>     <dbl>
#> 1 2020-08-04       76      1.5       0  
#> 2 2020-08-07        0      0         0  
#> 3 2020-12-21       29.9    0.3       1.3
#> 4 2021-01-08      196      3       180  
#> 5 2021-01-24        0      0         0  
#> 6 2021-02-02    26088.   565.    14460.
data %>%
  ggplot(aes(x = Confirmed, y = Deaths)) +
  geom_point() +
  geom_smooth() +
  labs(title = "COVID-19 Deaths by Country",
       x = "Confirmed", y = "Number of Deaths") 
COVID-19 data
Figure 11.1: COVID-19 Deaths by Country

The relationship is linear between the number of confirmed cases and the number of deaths. The model can be fitted using linear regression to predict the number of deaths based on the number of confirmed cases.

mod <- lm(Deaths ~ Confirmed + Recovered, data = data) 

summary(mod)$coefficients
#>                  Estimate   Std. Error    t value     Pr(>|t|)
#> (Intercept) -34.040275885 1.989420e+01  -1.711066 8.714698e-02
#> Confirmed     0.025050716 2.300418e-04 108.896384 0.000000e+00
#> Recovered    -0.003785725 3.075941e-04 -12.307533 3.432948e-34
Pred <- predict(mod, newdata = data)

data_pred <- cbind(data %>% select(Confirmed, Deaths),Pred) %>%
  mutate(Residuals = Deaths - Pred) 

data_pred %>%
  head()
#>   Confirmed Deaths       Pred  Residuals
#> 1     55059   2404 1164.56098 1239.43902
#> 2     78992   1393 1763.34641 -370.34641
#> 3    107578   2894 2382.50135  511.49865
#> 4      9972    101  180.91409  -79.91409
#> 5     19829    466  393.86590   72.13410
#> 6       234      7  -28.84848   35.84848

Let’s see the results in a plot:

data_pred %>%
  ggplot(aes(Deaths, Pred)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, color = "red") +
  geom_segment(aes(xend = Deaths, yend = Pred), alpha = 0.5) +
  labs(title = "Predicted vs. Actual Deaths",
       x = "Actual Deaths", y = "Predicted Deaths") 

data_pred %>%
  ggplot(aes(Confirmed, Residuals)) +
  geom_point() +
  geom_hline(yintercept = 0, color = "red") +
  labs(title = "Residuals vs. Confirmed Cases",
       x = "Confirmed Cases", y = "Residuals") 
COVID-19 data
(a) COVID-19 Deaths by Country
COVID-19 data
(b) Predicted vs. Actual Deaths
Figure 11.2: COVID-19 data
data_pred %>%
  ggplot() +
  geom_point(aes(Confirmed, Deaths)) +
  geom_line(aes(Confirmed, Pred), linetype = "dashed") +
  labs(title = "Observed vs. Predicted",
       x = "Confirmed Cases", y = "Deaths",
       caption = "Source: ECDC") 
  
data_pred %>%
  ggplot(aes(x = Confirmed, y = Deaths)) +
  geom_point() +
  geom_line(aes(Confirmed, Pred), linetype = "dashed") +
  geom_segment(aes(xend = Confirmed, yend = Pred)) +
  labs(title = "Observed vs. Predicted",
       subtitle = "Segments represent the residuals",
       x = "Confirmed Cases", y = "Deaths",
       caption = "Source: ECDC")   
COVID-19 Model Results
(a) Observed vs. Predicted
COVID-19 Model Results
(b) Segments represent the residuals
Figure 11.3: COVID-19 Model Results
cfr <- covid19.ALL.agg.cases %>% 
  select(Country_Region, Last_Update, Lat, Long_, Case_Fatality_Ratio)

cfr %>%
  head()
#>        Country_Region         Last_Update       Lat     Long_
#> 1         Afghanistan 2021-02-02 05:22:49  33.93911  67.70995
#> 2             Albania 2021-02-02 05:22:49  41.15330  20.16830
#> 3             Algeria 2021-02-02 05:22:49  28.03390   1.65960
#> 4             Andorra 2021-02-02 05:22:49  42.50630   1.52180
#> 5              Angola 2021-02-02 05:22:49 -11.20270  17.87390
#> 6 Antigua and Barbuda 2021-02-02 05:22:49  17.06080 -61.79640
#>   Case_Fatality_Ratio
#> 1            4.366225
#> 2            1.763470
#> 3            2.690141
#> 4            1.012836
#> 5            2.350093
#> 6            2.991453
covid19.TS.deaths <- covid19.data("ts-deaths")
#> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#> --------------------------------------------------------------------------------
covid19.TS.deaths %>% 
    pivot_longer(cols = c(-Province.State, 
                          -Country.Region, 
                          -Lat, -Long),
               names_to = "Date", values_to = "Deaths") %>%
  head()
#> # A tibble: 6 × 6
#>   Province.State Country.Region   Lat  Long Date       Deaths
#>   <chr>          <chr>          <dbl> <dbl> <chr>       <int>
#> 1 ""             Afghanistan     33.9  67.7 2020-01-22      0
#> 2 ""             Afghanistan     33.9  67.7 2020-01-23      0
#> 3 ""             Afghanistan     33.9  67.7 2020-01-24      0
#> 4 ""             Afghanistan     33.9  67.7 2020-01-25      0
#> 5 ""             Afghanistan     33.9  67.7 2020-01-26      0
#> 6 ""             Afghanistan     33.9  67.7 2020-01-27      0

11.1.2 Example: Ischemic Stroke Decision Tree

In this example we have a look at how to visualise the results of a decision tree model for predicting Ischemic Stroke.

Load necessary libraries:

# Ischemic Stroke decision tree
library(tidymodels)
library(rpart)
library(rpart.plot)

Load the data for the Ischemic Stroke from the GitHub repository of the book: “Feature Engineering and Selection: A Practical Approach for Predictive Models” by Max Kuhn and Kjell Johnson.

load(url("https://github.com/topepo/FES/blob/master/Data_Sets/Ischemic_Stroke/stroke_data.RData?raw=true"))

Data is already split into training and test sets, we will combine them for the analysis. We just need to select the variables of interest for the analysis with any_of(), a function to select the variables of interest for the analysis.

?any_of()

selected_train <- 
  stroke_train %>%
  dplyr::select(any_of(VC_preds), Stroke)

Set up the recipe for the data with the recipe() function from the tidymodels package. We will use the step_corr() function to remove highly correlated predictors, step_center() and step_scale() to standardise the predictors, step_YeoJohnson() to transform the predictors, and step_zv() to remove zero variance predictors.

is_recipe <- recipe(Stroke ~ ., data = selected_train) %>%
  #step_interact(int_form)  %>%
  step_corr(all_predictors(), threshold = 0.75) %>%
  step_center(all_predictors()) %>%
  step_scale(all_predictors()) %>%
  step_YeoJohnson(all_predictors()) %>%
  step_zv(all_predictors())

is_recipe%>%prep()%>%bake(new_data=NULL)%>%select(1:5)%>%head(5)
#> # A tibble: 5 × 5
#>   CALCVolProp MATXVol MATXVolProp MaxCALCAreaProp MaxDilationByArea
#>         <dbl>   <dbl>       <dbl>           <dbl>             <dbl>
#> 1      -0.143  0.106       -0.161          0.0541            0.0249
#> 2      -1.35  -0.0530       0.858         -0.931            -0.748 
#> 3      -0.784  1.03         0.218          0.284            -0.320 
#> 4       1.06   0.587       -0.144          1.05              0.423 
#> 5      -0.708 -0.107       -0.307         -0.203            -0.738

Set up the decision tree model with the decision_tree() function from the tidymodels package. We will use the rpart engine for the decision tree model and set the mode to classification.

class_tree_spec <- decision_tree() %>%
  set_engine("rpart") %>%
  set_mode("classification")

Finally, we will fit the model with the fit() function from the tidymodels package and visualise the results with the rpart.plot package.

is_wfl <- workflow() %>%
  add_model(class_tree_spec) %>%
  add_recipe(is_recipe)

is_dt_fit_wfl <- is_wfl %>%
  fit(data = selected_train,
      control = control_workflow())

is_dt_fit_wfl%>%
  extract_fit_engine() %>%
  rpart.plot::rpart.plot(roundint = FALSE)
Decision Tree for Ischemic Stroke
Figure 11.4: Decision Tree for Ischemic Stroke