# Adaptive Hybrid Relaxed Lasso Regression (AHRLR) Simulation Study
# Purpose: This script evaluates AHRLR and other penalized regression methods via simulations
# ---------------------------------------------
# Load Required Libraries
# ---------------------------------------------
library(glmnet)   # For penalized regression models
library(MASS)     # For multivariate normal generation
library(dplyr)    # For data manipulation
# ---------------------------------------------
# Set random seed for reproducibility
# ---------------------------------------------
set.seed(999)
# ---------------------------------------------
# Simulate high-dimensional regression data
# ---------------------------------------------
simulate_data <- function(n, p, beta, sigma, cor, dist_type, grouped = FALSE) {
  # Create correlation structure
  if (!grouped) {
    Sigma <- cor ^ abs(outer(1:p, 1:p, "-"))  # AR(1)-type structure
  } else {
    Sigma <- diag(p)  # Block correlation for grouped predictors
    group_size <- 10
    for (g in seq(1, p, by = group_size)) {
      idx <- g:min(g + group_size - 1, p)
      Sigma[idx, idx] <- cor ^ abs(outer(idx, idx, "-"))
    }
  }
  # Simulate design matrix X
  X <- mvrnorm(n, mu = rep(0, p), Sigma = Sigma)
  # Simulate error based on distribution
  epsilon <- switch(dist_type,
                    "normal" = rnorm(n),
                    "t3" = rt(n, df = 3) / sqrt(3),
                    "exp" = rexp(n) - 1,
                    "unif" = runif(n, -sqrt(12)/2, sqrt(12)/2))
  
  # Generate response
  y <- X %*% beta + sigma * epsilon
  
  list(X = X, y = y)
}
# ---------------------------------------------
# Augment data for Relaxed/HRLR/AHRLR
# ---------------------------------------------
augment_data <- function(X, y, lambda2) {
  p <- ncol(X)
  X_aug <- rbind(X, sqrt(lambda2) * diag(p)) / sqrt(1 + lambda2)
  y_aug <- c(y, rep(0, p))
  list(X_aug = X_aug, y_aug = y_aug)
}
# ---------------------------------------------
# Compute adaptive weights using Ridge
# ---------------------------------------------
get_adaptive_weights <- function(X, y, gamma = 1, eps = 1e-4) {
  fit <- glmnet(X, y, alpha = 0, lambda = 1e-4, intercept = FALSE, standardize = FALSE)
  beta_ridge <- as.vector(coef(fit))[-1]
  beta_ridge[abs(beta_ridge) < eps] <- eps
  1 / (abs(beta_ridge)^gamma)
}
# ---------------------------------------------
# Run one replication for a given setting
# ---------------------------------------------
run_one <- function(n, p, beta, sigma, cor, dist_type, grouped, strong_signals = which(beta != 0)) {
  s <- length(strong_signals)
  data <- simulate_data(n, p, beta, sigma, cor, dist_type, grouped)
  X <- scale(data$X)
  y <- scale(data$y)
  # Train/test split
  train_idx <- 1:(n/2)
  test_idx <- (n/2 + 1):n
  X_train <- X[train_idx, ]; y_train <- y[train_idx]
  X_test  <- X[test_idx, ];  y_test  <- y[test_idx]
  
  lambda2 <- 1  # Relaxation parameter
  # Internal evaluation function
  method_eval <- function(model_fit, newx, true_beta) {
    pred <- predict(model_fit, newx = newx, s = "lambda.min")
    rmse <- sqrt(mean((y_test - pred)^2))
    coef_vec <- as.vector(coef(model_fit, s = "lambda.min"))[-1]
    selected <- which(coef_vec != 0)
    tp <- sum(selected %in% strong_signals)
    fp <- sum(!(selected %in% strong_signals))
    fn <- s - tp
    tn <- (p - s) - fp
    recall <- tp / s
    conf_acc <- (tp + tn) / p
    list(rmse = rmse, tp = tp, fp = fp, recall = recall, conf_acc = conf_acc)
  }
  # Fit models
  w <- get_adaptive_weights(X_train, y_train)
  aug <- augment_data(X_train, y_train, lambda2)
  w_aug <- get_adaptive_weights(X_train, y_train)
  fit_lasso     <- cv.glmnet(X_train, y_train, alpha = 1)
  fit_ridge     <- cv.glmnet(X_train, y_train, alpha = 0)
  fit_enet      <- cv.glmnet(X_train, y_train, alpha = 0.5)
  fit_adalasso  <- cv.glmnet(X_train, y_train, alpha = 1, penalty.factor = w)
  fit_hrlr      <- cv.glmnet(aug$X_aug, aug$y_aug, alpha = 1, intercept = FALSE, standardize = FALSE)
  fit_ahr       <- cv.glmnet(aug$X_aug, aug$y_aug, alpha = 1, penalty.factor = w_aug, intercept = FALSE, standardize = FALSE)
  # Store evaluations
  evals <- list(
    Lasso = method_eval(fit_lasso, X_test, beta),
    Ridge = method_eval(fit_ridge, X_test, beta),
    ElasticNet = method_eval(fit_enet, X_test, beta),
    AdaptiveLasso = method_eval(fit_adalasso, X_test, beta),
    HRLR = method_eval(fit_hrlr, X_test / sqrt(1 + lambda2), beta),
    AHRLR = method_eval(fit_ahr, X_test / sqrt(1 + lambda2), beta)
  )
  
  # Return one-row-per-method data frame
  do.call(rbind, lapply(names(evals), function(m) {
    cbind(Method = m, as.data.frame(evals[[m]]))
  }))
}
# ---------------------------------------------
# Run all simulation configurations
# ---------------------------------------------
run_all_conditions <- function(nrep = 1000) {
  settings <- expand.grid(
    dist_type = c("normal"),        # Can be extended to "t3", "exp", "unif"
    grouped   = c(FALSE, TRUE)      # Whether predictors are grouped
  )
  
  # Simulation parameters
  n <- 40
  p <- 500
  beta <- c(6, 4, 2, 0.5, rep(0, p - 4))  # True coefficient vector
  sigma <- 3
  s <- 4  # Number of non-zero signals
  results <- list()
  for (i in seq_len(nrow(settings))) {
    dist_type <- settings$dist_type[i]
    grouped   <- settings$grouped[i]
    # Repeat the simulation nrep times
    reps <- replicate(nrep, run_one(n, p, beta, sigma, cor = 0.7, dist_type, grouped), simplify = FALSE)
    df <- bind_rows(reps)
    df$Dist <- dist_type
    df$Grouped <- grouped
    results[[i]] <- df
  }
  # Combine and summarize results
  all_results <- bind_rows(results)
  summary <- all_results %>%
    group_by(Dist, Grouped, Method) %>%
    summarise(
      Median_RMSE   = median(rmse),
      SE_RMSE       = sd(rmse) / sqrt(n()),
      Avg_TP        = mean(tp),
      Avg_FP        = mean(fp),
      Recall        = mean(recall, na.rm = TRUE),
      Conf_Accuracy = mean(conf_acc, na.rm = TRUE),
      .groups = "drop"
    )
  # Save results to CSV
  write.csv(summary, "full_adaptive_sim_summary.csv", row.names = FALSE)
  write.csv(all_results, "full_adaptive_sim_raw.csv", row.names = FALSE)
  return(summary)
}
# ---------------------------------------------
# Run full simulation
# ---------------------------------------------
extended_summary <- run_all_conditions()
print(extended_summary)