KNN

App to demonstrate how the KNN classifier works
Classification
Author

Haley Grant

#| '!! shinylive warning !!': |
#|   shinylive does not work in self-contained HTML documents.
#|   Please set `embed-resources: false` in your metadata.
#| standalone: true
#| viewerHeight: 600
library(shiny)
library(plotly)
library(FNN)
library(dplyr)
library(tidyverse)
library(caret)
library(kableExtra)

dir.create("./Data")

url.csv <- "https://haleykgrant.github.io/tutorial_data/data/trainingData_diabetes.csv"
download.file(url.csv, "./Data/trainingData_diabetes.csv")

df <- read_csv("./Data/trainingData_diabetes.csv") %>%
  arrange(class) %>%
  mutate(id = row_number(),
         class = factor(class)) 

ui <- fluidPage(
  titlePanel("KNN Visualization"),
  
  sidebarLayout(
    sidebarPanel(
      sliderInput("k", 
                  "Number of Neighbors (K):", 
                  min = 1, max = 11, value = 3, step = 2),
      
      checkboxInput("scale_data",
                    "Scale Predictors",
                    value = FALSE),
      selectInput("x_var", "First predictor:", 
                  choices = setdiff(names(df)[sapply(df, is.numeric)],"id"), 
                  selected = "bmi"),
      
      selectInput("y_var", "Second predictor:", 
                  choices = setdiff(names(df)[sapply(df, is.numeric)],c("id","bmi")), 
                  selected = "tg")
    ),
    
    mainPanel(
      plotlyOutput("plot"),
      br(),
      uiOutput("knn_table")
    )
  )
)

server <- function(input, output, session) {
  
  selected_index <- reactiveVal(NULL)
  observeEvent(input$x_var, {
    # Get all numeric columns except the one selected for x
    y_choices <- setdiff(names(df)[sapply(df, is.numeric)], c(input$x_var,"id"))
    
    # Update the y_var selectInput
    updateSelectInput(session, "y_var",
                      choices = y_choices,
                      # If the current selection is now invalid, pick the first remaining
                      selected = if (input$y_var %in% y_choices) input$y_var else y_choices[1])
  })
  # Handle plot click
  observeEvent(event_data("plotly_click", source = "knnplot"), {
    click_data <- event_data("plotly_click", source = "knnplot")
    clicked_id <- click_data$key
    
    if (!is.null(selected_index()) && selected_index() == clicked_id) {
      selected_index(NULL)
    } else {
      selected_index(clicked_id)
    }
  })
  
  # Reactive predictors (scaled or not)
  predictors_reactive <- reactive({
    req(input$x_var, input$y_var)
    
    predictors <- df[, c(input$x_var, input$y_var)]
    
    if (input$scale_data) {
      predictors <- scale(predictors)
    }
    
    predictors
  })
  
  output$plot <- renderPlotly({
    req(input$k)
    
    k <- input$k
    plot_data <- df
    if(input$scale_data){plot_data[, c(input$x_var, input$y_var)] <- scale(df[, c(input$x_var, input$y_var)])}
    names(plot_data)[which(names(plot_data)==input$x_var)] = "x_var"
    names(plot_data)[which(names(plot_data)==input$y_var)] = "y_var"
    predictors <- predictors_reactive()
    sel <- selected_index()
    
    knn_result <- get.knn(predictors, k = k)
    
    if (!is.null(sel)) {
      row_idx <- which(plot_data$id == sel)
      neighbors <- knn_result$nn.index[row_idx, ]
      highlight_rows <- c(row_idx, neighbors)
      
      plot_data <- plot_data %>%
        mutate(
          opacity = ifelse(row_number() %in% highlight_rows, 1, 0.2),
          size = ifelse(row_number() %in% highlight_rows, 12, 6),
          symbol = ifelse(row_number() == row_idx, "x", "circle")
        )
    } else {
      plot_data <- plot_data %>%
        mutate(opacity = 1, size = 6, symbol = "circle")
    }
    
    plty <- plot_ly(
      data = plot_data,
      x = ~x_var,
      y = ~y_var,
      color = ~class,
      key = ~id,
      text = ~paste0("ID: ", id,
                     "<br>", input$x_var, ": ", round(plot_data[["x_var"]], 2),
                     "<br>", input$y_var, ": ", round(plot_data[["y_var"]], 2),
                     "<br>Class: ", class),
      hoverinfo = "text",
      type = "scatter",
      mode = "markers",
      marker = list(
        size = ~size,
        opacity = ~opacity,
        symbol = ~symbol
      ),
      source = "knnplot",
      colors = c("Non.Diabetic" = "#F8766D",
                 "Pre.Diabetic" = "#00BA38",
                 "Diabetic" = "#619CFF")
    )
    
    # Apply layout separately
    if (input$scale_data) {
      plty <- plty %>%
        layout(
          xaxis = list(title = paste(input$x_var, "(Scaled)"),
                       scaleanchor = "y",
                       scaleratio = 1),
          yaxis = list(title = paste(input$y_var, "(Scaled)"))
        )
    } else {
      plty <- plty %>%
        layout(
          xaxis = list(title = input$x_var,
                       scaleanchor = "y",
                       scaleratio = 1),
          yaxis = list(title = input$y_var)
        )
    }
  })
  
  output$knn_table <- renderUI({
    req(selected_index(), input$k)
    
    k <- input$k
    plot_data <- df
    predictors <- predictors_reactive()
    
    row_idx <- which(plot_data$id == selected_index())
    knn_result <- get.knn(predictors, k = k)
    neighbors <- knn_result$nn.index[row_idx, ]
    neighbor_classes <- plot_data$class[neighbors]
    
    tbl <- plot_data %>%
      slice(neighbors) %>%
      count(class, name = "n") %>%
      mutate(p = n / sum(n),
             assigned = names(which.max(table(neighbor_classes))))
    
    HTML(tbl %>%
           kable(digits = 3, format = "html") %>%
           kable_styling("striped", full_width = FALSE) %>%
           collapse_rows(columns = 4, valign = "middle"))
  })
}

shinyApp(ui, server)