## ----setup, include = FALSE-----------------------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 8,
  fig.height = 6,
  fig.align = "center",
  out.width = "95%",
  dpi = 90,
  message = FALSE,
  warning = FALSE
)
options(width = 100)

## -------------------------------------------------------------------------------------------------
# Load required packages
library(missoNet)
library(ggplot2)
library(reshape2)
library(gridExtra)

# Set ggplot theme
theme_set(theme_minimal(base_size = 11))

## -------------------------------------------------------------------------------------------------
# Set dimensions mimicking a genomic region
n <- 600   # Number of samples
p <- 80   # Number of SNPs in the region
q <- 20    # Number of CpG sites

# Create realistic correlation structure
# SNPs: Linkage disequilibrium (LD) blocks
create_ld_structure <- function(p, n_blocks = 4, within_block_cor = 0.7) {
  Sigma <- matrix(0, p, p)
  block_size <- p / n_blocks
  
  for (b in 1:n_blocks) {
    idx <- ((b-1)*block_size + 1):(b*block_size)
    for (i in idx) {
      for (j in idx) {
        Sigma[i, j] <- within_block_cor^abs(i - j)
      }
    }
  }
  diag(Sigma) <- 1
  return(Sigma)
}

Sigma.X <- create_ld_structure(p, n_blocks = 4)

# Variable missing rates (technical dropout patterns)
# Higher missingness for CpG sites with extreme GC content
gc_content <- runif(q, 0.3, 0.7)  # Simulated GC content
rho_vec <- 0.01 + 0.3 * (abs(gc_content - 0.5) / 0.2)^2  # U-shaped missing pattern

# Generate the data
sim <- generateData(
  n = n,
  p = p,
  q = q,
  rho = rho_vec,  # Missing rates
  missing.type = "MAR",  # Missing depends on technical factors
  Sigma.X = Sigma.X,
  Beta.row.sparsity = 0.15,  # 15% of SNPs are meQTLs
  Beta.elm.sparsity = 0.4,     # Each meQTL affects 40% of CpGs
  seed = 100
)

# Add meaningful variable names
colnames(sim$X) <- sprintf("rs%d", 1000000 + 1:p)  # SNP IDs
colnames(sim$Y) <- sprintf("cg%d", 2000000 + 1:q)  # CpG IDs
colnames(sim$Z) <- colnames(sim$Y)
rownames(sim$X) <- rownames(sim$Y) <- rownames(sim$Z) <- paste0("Sample", 1:n)

{
  cat("\nDataset Summary:\n")
  cat("================\n")
  cat("Samples:", n, "\n")
  cat("SNPs:", p, "\n")
  cat("CpG sites:", q, "\n")
  cat("Overall missing rate:", sprintf("%.1f%%", mean(is.na(sim$Z)) * 100), "\n")
  cat("True meQTLs:", sum(rowSums(abs(sim$Beta)) > 0), "\n")
}

## -------------------------------------------------------------------------------------------------
# Analyze missing patterns
miss_by_cpg <- colMeans(is.na(sim$Z))
miss_by_sample <- rowMeans(is.na(sim$Z))

# Create visualization
par(mfrow = c(2, 2))

# 1. Missing rate by CpG
plot(miss_by_cpg, type = "h", lwd = 2, col = "steelblue",
     xlab = "CpG Site", ylab = "Missing Rate",
     main = "Missing Data by CpG Site")
abline(h = mean(miss_by_cpg), col = "red", lty = 2)

# 2. Missing rate by sample
hist(miss_by_sample, breaks = 20, col = "lightblue",
     xlab = "Missing Rate", main = "Distribution of Missing Rates (Samples)")
abline(v = mean(miss_by_sample), col = "red", lwd = 2)

# 3. Heatmap of missingness
image(t(is.na(sim$Z[1:100, ])), col = c("white", "darkred"),
      xlab = "CpG Site", ylab = "Sample (first 100)",
      main = "Missing Data Pattern")

# 4. Correlation of missingness with GC content
plot(gc_content, miss_by_cpg, pch = 19, col = "darkblue",
     xlab = "GC Content", ylab = "Missing Rate",
     main = "Technical Dropout vs GC Content")
lines(lowess(gc_content, miss_by_cpg), col = "red", lwd = 2)

## -------------------------------------------------------------------------------------------------
# Use complete data for visualization
Y_complete <- sim$Y
cor_cpg <- cor(Y_complete)

# Create enhanced heatmap
library(RColorBrewer)
colors <- colorRampPalette(brewer.pal(9, "RdBu"))(100)

heatmap(cor_cpg, 
        col = rev(colors),
        symm = TRUE,
        main = "CpG Methylation Correlation Structure",
        xlab = "CpG Sites", ylab = "CpG Sites",
        margins = c(8, 8))

# Identify CpG modules
hc <- hclust(as.dist(1 - abs(cor_cpg)))
modules <- cutree(hc, k = 3)

cat("\nCpG modules identified:", table(modules), "\n")

## -------------------------------------------------------------------------------------------------
fit_initial <- missoNet(
  X = sim$X,
  Y = sim$Z,
  GoF = "BIC",
  adaptive.search = TRUE,  # Fast exploration
  verbose = 1
)

# Examine initial selection
{
  cat("\nStep 1: Initial parameter exploration\n")
  cat("=====================================\n")
  cat("  Lambda.beta:", fit_initial$est.min$lambda.beta, "\n")
  cat("  Lambda.theta:", fit_initial$est.min$lambda.theta, "\n")
  cat("  Active SNPs:", sum(rowSums(abs(fit_initial$est.min$Beta)) > 1e-8), "\n")
  cat("  Network edges:", 
      sum(abs(fit_initial$est.min$Theta[upper.tri(fit_initial$est.min$Theta)]) > 1e-8), "\n")
}

## -------------------------------------------------------------------------------------------------
# Define refined grid based on initial exploration
lambda.beta.refined <- 10^(seq(
  log10(min(max(fit_initial$lambda.beta.seq) * 0.9, fit_initial$est.min$lambda.beta * 50)),
  log10(max(max(fit_initial$lambda.beta.seq) * 0.005, fit_initial$est.min$lambda.beta / 20)),
  length.out = 25
))

lambda.theta.refined <- 10^(seq(
  log10(min(max(fit_initial$lambda.theta.seq) * 0.9, fit_initial$est.min$lambda.theta * 50)),
  log10(max(max(fit_initial$lambda.theta.seq) * 0.005, fit_initial$est.min$lambda.theta / 20)),
  length.out = 25
))

# Perform 5-fold cross-validation
cvfit <- cv.missoNet(
  X = sim$X,
  Y = sim$Z,
  kfold = 5,
  lambda.beta = lambda.beta.refined,
  lambda.theta = lambda.theta.refined,
  compute.1se = TRUE,
  verbose = 0,
  seed = 1000
)

## -------------------------------------------------------------------------------------------------
# Compare different model choices
models <- list(
  "CV Minimum" = cvfit$est.min,
  "1SE Beta" = cvfit$est.1se.beta,
  "1SE Theta" = cvfit$est.1se.theta,
  "Initial BIC" = fit_initial$est.min
)

if (!is.null(models$`1SE Beta`) & !is.null(models$`1SE Theta`)) {  # Ensure models exist
  comparison <- data.frame(
    Model = names(models),
    Lambda.Beta = sapply(models, function(x) x$lambda.beta),
    Lambda.Theta = sapply(models, function(x) x$lambda.theta),
    Active.SNPs = sapply(models, function(x) 
      sum(rowSums(abs(x$Beta)) > 1e-8)),
    Total.Effects = sapply(models, function(x)
      sum(abs(x$Beta) > 1e-8)),
    Network.Edges = sapply(models, function(x)
      sum(abs(x$Theta[upper.tri(x$Theta)]) > 1e-8))
  )
  print(comparison, digits = 4)
}


# Select the more regularized model, fallback if NULL
if (!is.null(models$`1SE Beta`)) {
  final_model <- cvfit$est.1se.beta
} else final_model <- fit_initial$est.min

## -------------------------------------------------------------------------------------------------
# Extract coefficients
Beta <- final_model$Beta
rownames(Beta) <- colnames(sim$X)
colnames(Beta) <- colnames(sim$Z)

# Identify significant associations
threshold <- 1e-3
sig_meqtls <- which(abs(Beta) > threshold, arr.ind = TRUE)

if (nrow(sig_meqtls) > 0) {
  meqtl_df <- data.frame(
    SNP = rownames(Beta)[sig_meqtls[,1]],
    CpG = colnames(Beta)[sig_meqtls[,2]],
    Effect = Beta[sig_meqtls],
    AbsEffect = abs(Beta[sig_meqtls])
  )
  meqtl_df <- meqtl_df[order(meqtl_df$AbsEffect, decreasing = TRUE), ]
  
  cat("Top 15 meQTL associations:\n")
  print(head(meqtl_df, 15), digits = 3)
  
  # Visualization
  top_snps <- unique(meqtl_df$SNP[1:min(30, nrow(meqtl_df))])
  Beta_subset <- Beta[top_snps, , drop = FALSE]
  
  # Create heatmap
  par(mfrow = c(1, 1))
  colors <- colorRampPalette(c("blue", "white", "red"))(100)
  heatmap(as.matrix(Beta_subset), 
          col = colors,
          scale = "none",
          main = "Top meQTL Effects",
          xlab = "CpG Sites", 
          ylab = "SNPs",
          margins = c(8, 8))
}

## -------------------------------------------------------------------------------------------------
# Extract precision matrix and convert to partial correlations
Theta <- final_model$Theta
rownames(Theta) <- colnames(Theta) <- colnames(sim$Z)

# Compute partial correlations
partial_cor <- -cov2cor(Theta)
diag(partial_cor) <- 0

# Network statistics
edge_threshold <- 0.1
n_edges <- sum(abs(partial_cor[upper.tri(partial_cor)]) > edge_threshold)
{
  cat("\nNetwork Statistics:\n")
  cat("  Total possible edges:", q * (q-1) / 2, "\n")
  cat("  Selected edges (|r| >", edge_threshold, "):", n_edges, "\n")
  cat("  Network density:", sprintf("%.1f%%", 100 * n_edges / (q * (q-1) / 2)), "\n")
}

# Identify hub CpGs
degree <- colSums(abs(partial_cor) > edge_threshold)
hub_cpgs <- names(sort(degree, decreasing = TRUE)[1:5])
cat("\nHub CpG sites (highest connectivity):\n")
for (cpg in hub_cpgs) {
  cat("  ", cpg, ": degree =", degree[cpg], "\n")
}

# Visualize network
if (requireNamespace("igraph", quietly = TRUE)) {
  library(igraph)
  
  # Create network from significant edges
  edges <- which(abs(partial_cor) > 0.15 & upper.tri(partial_cor), arr.ind = TRUE)
  if (nrow(edges) > 0) {
    edge_list <- data.frame(
      from = rownames(partial_cor)[edges[,1]],
      to = colnames(partial_cor)[edges[,2]],
      weight = abs(partial_cor[edges])
    )
    
    g <- graph_from_data_frame(edge_list, directed = FALSE)
    
    # Node properties
    V(g)$size <- 5 + sqrt(degree[V(g)$name]) * 3
    V(g)$color <- ifelse(V(g)$name %in% hub_cpgs, "red", "lightblue")
    
    # Plot
    par(mfrow = c(1, 1))
    plot(g, 
         layout = layout_with_fr(g),
         vertex.label.cex = 0.7,
         edge.width = E(g)$weight * 3,
         main = "CpG Conditional Dependency Network")
    legend("topright", legend = c("Hub CpG", "Regular CpG"),
           pch = 21, pt.bg = c("red", "lightblue"), pt.cex = 2)
  }
}

## -------------------------------------------------------------------------------------------------
# Analyze how meQTLs relate to network structure
active_snps <- which(rowSums(abs(Beta)) > threshold)

if (length(active_snps) > 0) {
  cat("\nIntegration Analysis:\n")
  cat("=====================\n")
  # For each active SNP, check which CpGs it affects
  for (i in active_snps[1:min(5, length(active_snps))]) {
    affected_cpgs <- which(abs(Beta[i,]) > threshold)
    if (length(affected_cpgs) > 1) {
      # Check if affected CpGs are connected in the network
      subnet_partial <- partial_cor[affected_cpgs, affected_cpgs]
      mean_connection <- mean(abs(subnet_partial[upper.tri(subnet_partial)]))
      
      cat("\n", rownames(Beta)[i], "affects", length(affected_cpgs), "CpGs\n")
      cat("Mean network connection among affected CpGs:", 
          round(mean_connection, 3), "\n")
      cat("Affected CpGs:", paste(colnames(Beta)[affected_cpgs], collapse = ", "), "\n")
    }
  }
}

## -------------------------------------------------------------------------------------------------
# Split data for validation
n_train <- round(0.75 * n)
train_idx <- sample(n, n_train)
test_idx <- setdiff(1:n, train_idx)

# Refit on training data
cvfit_train <- cv.missoNet(
  X = sim$X[train_idx, ],
  Y = sim$Z[train_idx, ],
  kfold = 5,
  lambda.beta = lambda.beta.refined,
  lambda.theta = lambda.theta.refined,
  verbose = 0
)

# Predictions
Y_pred <- predict(cvfit_train, newx = sim$X[test_idx, ])
Y_test <- sim$Y[test_idx, ]  # True complete values

# Calculate performance metrics
mse_per_cpg <- colMeans((Y_pred - Y_test)^2)
cor_per_cpg <- sapply(1:q, function(j) cor(Y_pred[,j], Y_test[,j]))

# Visualization
par(mfrow = c(1, 2))

# MSE vs missing rate
plot(miss_by_cpg, mse_per_cpg, pch = 19, col = "darkblue",
     xlab = "Missing Rate", ylab = "Prediction MSE",
     main = "Prediction Error vs Missing Rate")
lines(lowess(miss_by_cpg, mse_per_cpg), col = "red", lwd = 2)

# Correlation vs missing rate
plot(miss_by_cpg, cor_per_cpg, pch = 19, col = "darkgreen",
     xlab = "Missing Rate", ylab = "Prediction Correlation",
     main = "Prediction Accuracy vs Missing Rate")
lines(lowess(miss_by_cpg, cor_per_cpg), col = "red", lwd = 2)

{
  cat("\nPrediction Performance:\n")
  cat("  Mean MSE:", round(mean(mse_per_cpg), 4), "\n")
  cat("  Mean correlation:", round(mean(cor_per_cpg), 3), "\n")
  cat("  Worst CpG MSE:", round(max(mse_per_cpg), 4), "\n")
  cat("  Best CpG correlation:", round(max(cor_per_cpg), 3), "\n")
}

## -------------------------------------------------------------------------------------------------
# Bootstrap stability (simplified for demonstration)
n_boot <- 10
selection_freq_beta <- matrix(0, p, q)
selection_freq_theta <- matrix(0, q, q)

for (b in 1:n_boot) {
  # Bootstrap sample
  boot_idx <- sample(n, replace = TRUE)
  
  # Fit model
  fit_boot <- missoNet(
    X = sim$X[boot_idx, ],
    Y = sim$Z[boot_idx, ],
    lambda.beta = final_model$lambda.beta,
    lambda.theta = final_model$lambda.theta,
    verbose = 0
  )
  
  # Track selections
  selection_freq_beta <- selection_freq_beta + (abs(fit_boot$est.min$Beta) > threshold)
  selection_freq_theta <- selection_freq_theta + 
    (abs(fit_boot$est.min$Theta) > edge_threshold)
}

# Normalize
selection_freq_beta <- selection_freq_beta / n_boot
selection_freq_theta <- selection_freq_theta / n_boot

# Identify stable features
stable_meqtls <- which(selection_freq_beta > 0.8, arr.ind = TRUE)
stable_edges <- which(selection_freq_theta > 0.8 & upper.tri(selection_freq_theta), 
                     arr.ind = TRUE)

{
  cat("\nStability Results:\n")
  cat("  Stable meQTLs (>80% selection):", nrow(stable_meqtls), "\n")
  cat("  Stable network edges (>80% selection):", nrow(stable_edges), "\n")
}

if (nrow(stable_meqtls) > 0) {
  cat("\nMost stable meQTL associations:\n")
  stable_df <- data.frame(
    SNP = colnames(sim$X)[stable_meqtls[,1]],
    CpG = colnames(sim$Z)[stable_meqtls[,2]],
    Frequency = selection_freq_beta[stable_meqtls]
  )
  print(head(stable_df[order(stable_df$Frequency, decreasing = TRUE), ], 10))
}

## -------------------------------------------------------------------------------------------------
# Annotate SNPs with genes (simulated)
active_snp_ids <- which(rowSums(abs(Beta)) > threshold)
if (length(active_snp_ids) > 0) {
  gene_names <- paste0("GENE", sample(1:50, length(active_snp_ids), replace = TRUE))
  
  snp_annotation <- data.frame(
    SNP = colnames(sim$X)[active_snp_ids],
    Gene = gene_names,
    Effect_Size = rowSums(abs(Beta[active_snp_ids, ]))
  )
  
  cat("Genes with meQTLs:\n")
  gene_summary <- aggregate(Effect_Size ~ Gene, snp_annotation, sum)
  gene_summary <- gene_summary[order(gene_summary$Effect_Size, decreasing = TRUE), ]
  print(head(gene_summary, 10))
}

# CpG annotation (simulated)
cpg_annotation <- data.frame(
  CpG = colnames(sim$Z),
  Region = sample(c("Promoter", "Gene Body", "Enhancer", "Intergenic"), 
                  q, replace = TRUE, prob = c(0.4, 0.3, 0.2, 0.1)),
  Island = sample(c("Island", "Shore", "Shelf", "Open Sea"),
                  q, replace = TRUE, prob = c(0.3, 0.2, 0.2, 0.3))
)

{
  cat("\nCpG distribution by genomic region:\n")
  print(table(cpg_annotation$Region))
}

{
  cat("\nCpG distribution by island status:\n")
  print(table(cpg_annotation$Island))
}

## ----echo = FALSE, include = TRUE-----------------------------------------------------------------
cat("\n========================================\n")
cat("       ANALYSIS SUMMARY REPORT\n")
cat("========================================\n\n")

cat("\nDATA CHARACTERISTICS:\n")
cat("--------------------\n")
cat("• Samples analyzed:", n, "\n")
cat("• SNPs tested:", p, "\n")
cat("• CpG sites measured:", q, "\n")
cat("• Missing data rate:", sprintf("%.1f%%", mean(is.na(sim$Z)) * 100), "\n")
cat("• Missing pattern: MAR (technical dropout)\n\n")

cat("\nMODEL SELECTION:\n")
cat("---------------\n")
cat("• Method: 5-fold cross-validation\n")
cat("• Selection criterion: 1-SE.Beta CV error\n")
cat("• Lambda (Beta):", sprintf("%.4f", final_model$lambda.beta), "\n")
cat("• Lambda (Theta):", sprintf("%.4f", final_model$lambda.theta), "\n\n")

cat("\nKEY FINDINGS:\n")
cat("------------\n")
cat("• meQTLs identified:", sum(rowSums(abs(Beta)) > threshold), "/", p, "SNPs\n")
cat("• SNP-CpG associations:", sum(abs(Beta) > threshold), "\n")
cat("• CpG network edges:", n_edges, "/", q*(q-1)/2, "possible\n")
cat("• Hub CpGs identified:", length(hub_cpgs), "\n\n")

cat("\nMODEL PERFORMANCE:\n")
cat("-----------------\n")
cat("• Mean prediction correlation:", sprintf("%.3f", mean(cor_per_cpg)), "\n")
cat("• Mean prediction MSE:", sprintf("%.4f", mean(mse_per_cpg)), "\n")
cat("• Stability (bootstrap):", sprintf("%.0f%%", 
    100 * nrow(stable_meqtls) / max(sum(abs(Beta) > threshold), 1)), 
    "of associations stable\n\n")

cat("\nBIOLOGICAL INSIGHTS (SIMULATED):\n")
cat("-------------------\n")
cat("• Primary affected regions: Promoters and gene bodies\n")
cat("• Network structure suggests co-regulated CpG modules\n")
cat("• Hub CpGs may represent key regulatory sites\n")

