# utils::globalVariables must be at the top of the .R file
utils::globalVariables(c("AgeGroup", "Category", "Total", "Value", "total"))


#' Plot Attribution or Decomposition Results Using Sullivan Method
#'
#' This function generates visualization plots for results from either
#' longitudinal attribution models (e.g., `Attribution_sullivan`) or
#' decomposition models (e.g., `Decomp_sullivan` ) applied to cohort health expectancy.
#'
#'
#' @name Plot.sullivan
#' @aliases Plot.sullivan
#' @param result A list object returned by either an attribution function (must contain
#'   `Absolute_Contributions_1` and `Absolute_Contributions_2`) or a decomposition function
#'   (must contain `total_effect`, `mortality_effect`, and `disability_effect`).
#' @param var_list A character vector specifying the names of variables (e.g., diseases or risk factors)
#'   to include in the plots.
#' @param colors Optional. A named character vector of colors for each category (including `"Background"`).
#'   If `NULL`, a color palette will be generated automatically.
#'
#' @return A `patchwork` object combining multiple `ggplot2` barplots. The layout depends on the result type:
#' \itemize{
#'   \item For attribution results, four plots are returned: absolute and relative contributions
#'     to disability and death.
#'   \item For decomposition results, three plots are returned: total, mortality, and disability effects.
#' }
#'
#' @details
#' This function supports two types of Sullivan-based outputs:
#' \enumerate{
#'   \item \strong{Attribution results:} The input should contain components named
#'     `Absolute_Contributions_1` (for disability) and `Absolute_Contributions_2` (for mortality),
#'     as returned by the `Attribution_sullivan()` function.
#'   \item \strong{Decomposition results:} The input should contain `total_effect`, `mortality_effect`,
#'     and `disability_effect` matrices, typically produced by the `Decomp_sullivan()` function.
#' }
#' For attribution results, the function internally computes relative contributions as a share
#' of total years lost or gained.
#'
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_bar
#' @importFrom ggplot2 geom_line
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 scale_fill_manual
#' @importFrom ggplot2 scale_y_continuous
#' @importFrom ggplot2 theme_minimal
#' @importFrom ggplot2 theme
#' @importFrom ggplot2 element_text
#' @importFrom ggplot2 element_blank
#' @importFrom ggplot2 element_rect
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 xlab
#' @importFrom ggplot2 ylab
#' @importFrom ggplot2 ylim
#'
#' @importFrom tidyr pivot_longer
#' @importFrom patchwork wrap_plots
#' @importFrom patchwork plot_layout
#' @importFrom grDevices rgb

#'
#' @examples
#' # For attribution result
#' data(attributionA)
#' Plot.sullivan(result=attributionA, var_list = c("Z1", "Z2", "Z3"))
#'
#' # For decomposition result
#' data(decom_results)
#' Plot.sullivan(result=decom_results, var_list = c("Z1", "Z2", "Z3"))
#'
#'
#' @export


Plot.sullivan <- function(result, var_list = NULL, colors = NULL) {
  # detect type
  is_attribution <- all(c("Absolute_Contributions_1", "Absolute_Contributions_2") %in% names(result))
  is_decomposition <- all(c("total_effect", "mortality_effect", "disability_effect") %in% names(result))

  if (!is_attribution && !is_decomposition) {
    stop("The input object is neither a valid attribution nor decomposition result.")
  }

  if (is.null(var_list)) {
    stop("Please provide var_list for consistent color assignment.")
  }

  # set colors
  if (is.null(colors)) {
    all_vars <- c("Background", var_list)
    colors <- grDevices::rgb(runif(length(all_vars)), runif(length(all_vars)), runif(length(all_vars)))
    names(colors) <- all_vars
  }

  # helper for grouping by age (assumes age_width = 1)
  group_by_age <- function(mat) {
    df <- as.data.frame(t(mat))
    df$AgeGroup <- as.factor(1:nrow(df))
    return(df)
  }

  if (is_attribution) {
    # Attribution plotting ---------------------------------------------------
    # Disability
    df_abs1 <- group_by_age(result$Absolute_Contributions_1)
    colnames(df_abs1) <- gsub("^Absolute_", "", colnames(df_abs1))
    df_abs1_long <- tidyr::pivot_longer(df_abs1, cols = -c(AgeGroup, Total),
                                        names_to = "Category", values_to = "Value")
    df_rel1 <- df_abs1
    df_rel1$Total <- as.numeric(df_rel1$Total)
    df_rel1[,-c(1,ncol(df_rel1))]<- df_rel1[,-c(1,ncol(df_rel1))] / df_rel1$Total
    df_rel1_long <- tidyr::pivot_longer(df_rel1, cols = -c(AgeGroup, Total),
                                        names_to = "Category", values_to = "Value")

    # Mortality
    df_abs2 <- group_by_age(result$Absolute_Contributions_2)
    colnames(df_abs2) <- gsub("^Absolute_", "", colnames(df_abs2))
    df_abs2_long <- tidyr::pivot_longer(df_abs2, cols = -c(AgeGroup, Total),
                                        names_to = "Category", values_to = "Value")
    df_rel2 <- df_abs2
    df_rel2[,-c(1,ncol(df_rel1))] <- df_rel2[,-c(1,ncol(df_rel1))] / df_rel2$Total
    df_rel2_long <- tidyr::pivot_longer(df_rel2, cols = -c(AgeGroup, Total),
                                        names_to = "Category", values_to = "Value")

    ymax_abs <- max(c(df_abs1_long$Total, df_abs2_long$Total), na.rm = TRUE)

    # Plot panels
    p1 <- ggplot(df_abs1_long, aes(x = AgeGroup, y = Value, fill = Category)) +
      geom_bar(stat = "identity", na.rm = TRUE) +
      ylim(0, ymax_abs * 1.1) +
      scale_fill_manual(values = colors) +
      labs(title = "(a) Absolute Attribution (Disability)", x = "Age", y = "Years") +
      theme_minimal(base_size = 14) +
      theme(legend.position = "bottom", panel.border = element_rect(color = "black", fill = NA),
            plot.title = element_text(hjust = 0.5))

    p2 <- ggplot(df_rel1_long, aes(x = AgeGroup, y = Value, fill = Category)) +
      geom_bar(stat = "identity", na.rm = TRUE, position = "stack") +
      scale_fill_manual(values = colors) +
      labs(title = "(b) Relative Attribution (Disability)", x = "Age", y = "Proportion") +
      theme_minimal(base_size = 14) +
      theme(legend.position = "bottom", panel.border = element_rect(color = "black", fill = NA),
            plot.title = element_text(hjust = 0.5))

    p3 <- ggplot(df_abs2_long, aes(x = AgeGroup, y = Value, fill = Category)) +
      geom_bar(stat = "identity", na.rm = TRUE) +
      ylim(0, ymax_abs * 1.1) +
      scale_fill_manual(values = colors) +
      labs(title = "(c) Absolute Attribution (Death)", x = "Age", y = "Years") +
      theme_minimal(base_size = 14) +
      theme(legend.position = "bottom", panel.border = element_rect(color = "black", fill = NA),
            plot.title = element_text(hjust = 0.5))

    p4 <- ggplot(df_rel2_long, aes(x = AgeGroup, y = Value, fill = Category)) +
      geom_bar(stat = "identity", na.rm = TRUE, position = "stack") +
      scale_fill_manual(values = colors) +
      labs(title = "(d) Relative Attribution (Death)", x = "Age", y = "Proportion") +
      theme_minimal(base_size = 14) +
      theme(legend.position = "bottom", panel.border = element_rect(color = "black", fill = NA),
            plot.title = element_text(hjust = 0.5))

    return(patchwork::wrap_plots(p1, p2, p3, p4, ncol = 2, guides = "collect") &
             theme(legend.position = "bottom"))
  }

  if (is_decomposition) {
    process_effect <- function(mat) {
      df <- group_by_age(mat)
      df_long <- tidyr::pivot_longer(df, cols = -c(AgeGroup, total),
                                     names_to = "Category", values_to = "Value")
      df_long$Category[df_long$Category == "backgroud"] <- "Background"
      list(df = df, df_long = df_long)
    }

    res_total <- process_effect(result$total_effect)
    res_mort  <- process_effect(result$mortality_effect)
    res_disab <- process_effect(result$disability_effect)

    y_vals <- c(res_total$df_long$Value, res_total$df_long$total,
                res_mort$df_long$Value, res_mort$df_long$total,
                res_disab$df_long$Value, res_disab$df_long$total)
    y_limits <- c( min(y_vals, na.rm = TRUE) - 0.05,
                  max(y_vals, na.rm = TRUE) + 0.05)

    make_plot <- function(res, title_text) {
      ggplot(res$df_long, aes(x = AgeGroup, y = Value, fill = Category)) +
        geom_bar(stat = "identity", na.rm = TRUE) +
        geom_line(data = res$df, aes(x = AgeGroup, y = total, group = 1),
                  color = "black", inherit.aes = FALSE, na.rm = TRUE) +
        geom_point(data = res$df, aes(x = AgeGroup, y = total),
                   color = "black", size = 2, inherit.aes = FALSE, na.rm = TRUE) +
        scale_fill_manual(values = colors) +
        xlab("Age") + ylab("Years") +
        scale_y_continuous(limits = y_limits) +
        labs(title = title_text) +
        theme_minimal(base_size = 14) +
        theme(legend.position = "bottom", panel.border = element_rect(color = "black", fill = NA),
              plot.title = element_text(hjust = 0.5))
    }

    p1 <- make_plot(res_total, paste0("(a) Total Effect = ", round(sum(res_total$df$total), 2), " years"))
    p2 <- make_plot(res_mort,  paste0("(b) Mortality Effect = ", round(sum(res_mort$df$total), 2), " years"))
    p3 <- make_plot(res_disab, paste0("(c) Disability Effect = ", round(sum(res_disab$df$total), 2), " years"))

    return((p1 + p2 + p3) + patchwork::plot_layout(nrow = 1, guides = "collect") &
             theme(legend.position = "bottom"))
  }
}

