#' Asymptotic sensitivity analysis for weak nulls with continuous exposures.
#'
#' @param Z A length N vector of  observed doses.
#' @param R A length N vector of observed outcomes.
#' @param index A length N vector of indices indicating matched set membership.
#' @param gamma The nonnegative sensitivity parameter; gamma = 0 means no
#' unmeasured confounding.
#' @param theta The value at which to test the weak null.
#' @param X A matrix with I rows and less than I columns that contains covariate
#' information.
#' @param estimand_function A function that takes in values z and r and outputs
#' a scalar; this function governs the causal estimand to estimate
#' @param gamma_star_vec A vector that contains the minimum probability of a
#' permutation for each matched set; default is NULL.
#' @param kappa_inv_vec A vector that contains the ratio of the maximum probability
#' and minimum probability of a permutation for each matched set; default is NULL
#'
#' @return A list containing the deviate, one-sided p-value,
#' observed value of the test statistic in each matched set,
#' and conservative standard deviation estimate.
#' @import dplyr
#' @export
#'
#' @examples
#' # Load the data
#' data <- lead_bmd
#' # prepare data
#' threshold <- log(0.74675)
#' match_info = data |> dplyr::group_by(matched_sets) |>
#' dplyr::summarise(below = sum(log_lead < threshold) > 0, disc = var(log_lead) > 0,
#' above = sum(log_lead > threshold) > 0)
#' below_indices <- match_info$matched_sets[match_info$below]
#' disc_indices <- match_info$matched_sets[match_info$disc]
#' above_indices <- match_info$matched_sets[match_info$above]
#' # outcome analysis using the stochastic intervention statistic, weak null
#' below_nbp <- data |> dplyr::filter(matched_sets %in% below_indices & matched_sets
#' %in% disc_indices)
#' above_below <- below_nbp |> dplyr::filter(matched_sets %in% above_indices)
#' extract_below_threshold_vs_baseline_function <- function(z, r) {
#'   extract_below_threshold_vs_baseline(z, r, threshold)
#' }
#' # one-sided test that estimand defined by estimand_function is 0 at gamma = 0
#' result <- weak_null_test(Z = above_below$log_lead,
#' R = above_below$lumbar_spine_bmd,
#' index = above_below$matched_sets, gamma = 0, theta = 0,
#' estimand_function = extract_below_threshold_vs_baseline_function)
#'
weak_null_test <- function(Z, R, index, gamma = 0, theta = 0, X = NA,
                           estimand_function = extract_OLS,
                           gamma_star_vec = NULL, kappa_inv_vec = NULL) {
  # total number of subjects
  N <- length(R)

  # Matched set indices
  match_index = unique(index)

  # number of matched sets
  nostratum <- length(unique(index))

  # weights
  stratum_weights <- rep(NA, nostratum)

  # observed test stat
  obsT = rep(NA, nostratum)

  # compute test stat from each matched set
  for (j in 1:nostratum) {
    # doses in set j
    doses <- Z[which(index == match_index[j])]
    # response in set j.
    resp <- R[which(index == match_index[j])]
    # number of subjects
    ns <- length(doses)
    # weights
    stratum_weights[j] <- ns / N
    # compute smallest and largest probabilities
    if (is.null(gamma_star_vec) | is.null(kappa_inv_vec)) {
      probs <- prob_bounds(z = doses, gamma = gamma)
      kappa_inv <- probs$min_prob
      gamma_star <- probs$max_prob / probs$min_prob
    } else {
      kappa_inv <- kappa_inv_vec[j]
      gamma_star <- gamma_star_vec[j]
    }


    # unadjusted statistic
    V <- estimand_function(z = doses, r = resp)
    dhat <- V - theta -
      (gamma_star - 1) / (1 + gamma_star) * abs(V - theta)

    # statistic bounded above under weak null
    obsT[j] <- dhat * (1 + gamma_star) * (1/ kappa_inv) / (gamma_star * factorial(ns))

  }

  #  set up variance estimation
  Q = matrix(NA, nrow = nostratum, ncol = 2)
  Q[,1] <- 1
  if (stats::var(stratum_weights > 0)) {
    Q[,2] <- stratum_weights
  } else {
    Q = Q[,1]
  }
  if (all(!is.na(X))) {
    Q = cbind(Q, X)
  }

  # components for variance estimation
  H_Q <- Q %*% solve(t(Q)%*%Q) %*% t(Q)
  W <- diag(nostratum*stratum_weights)
  y <- obsT / sqrt(1 - diag(H_Q))

  # standard error estimate
  se <- var_est(y = y, W = W, H_Q = H_Q)

  # normal deviate
  deviate <- sum(obsT * stratum_weights) / se
  pval <- 1 - stats::pnorm(deviate)

  return(list(deviate = deviate, pval = pval, teststat = obsT, se = se))
}

#' Asymptotic sensitivity analysis for weak nulls with continuous exposures
#' assuming constant effects across matched sets.
#'
#' @param Z A length N vector of  observed doses.
#' @param R A length N vector of observed outcomes.
#' @param index A length N vector of indices indicating matched set membership.
#' @param gamma The nonnegative sensitivity parameter; gamma = 0 means no
#' unmeasured confounding.
#' @param theta The value at which to test the weak null.
#' @param X A matrix with I rows and less than I columns that contains covariate
#' information.
#' @param estimand_function A function that takes in values z and r and outputs
#' a scalar; this function governs the causal estimand to estimate
#' @param gamma_star_vec that contains the maximum ratio of any two probabilities
#' of permutations for each matched set.
#'
#' @return A list containing the deviate, one-sided p-value,
#' observed value of the test statistic in each matched set,
#' and conservative standard deviation estimate.
#' @import dplyr
#' @export
#'
#' @examples
#'  # Load the data
#' data <- lead_bmd
#' # prepare data
#' threshold <- log(0.74675)
#' match_info = data |> dplyr::group_by(matched_sets) |>
#' dplyr::summarise(below = sum(log_lead < threshold) > 0, disc = var(log_lead) > 0,
#' above = sum(log_lead > threshold) > 0)
#' below_indices <- match_info$matched_sets[match_info$below]
#' disc_indices <- match_info$matched_sets[match_info$disc]
#' above_indices <- match_info$matched_sets[match_info$above]
#' # outcome analysis using the stochastic intervention statistic, weak null
#' below_nbp <- data |> dplyr::filter(matched_sets %in% below_indices & matched_sets
#' %in% disc_indices)
#' above_below <- below_nbp |> dplyr::filter(matched_sets %in% above_indices)
#' extract_below_threshold_vs_baseline_function <- function(z, r) {
#'   extract_below_threshold_vs_baseline(z, r, threshold)
#' }
#' # one-sided test that estimand defined by estimand_function is 0 at gamma = 0
#' result <- constant_effects_test(Z = above_below$log_lead,
#' R = above_below$lumbar_spine_bmd,
#' index = above_below$matched_sets, gamma = 0, theta = 0,
#' estimand_function = extract_below_threshold_vs_baseline_function)
#'
constant_effects_test <- function(Z, R, index, gamma = 0, theta = 0, X = NA,
                                  estimand_function = extract_OLS,
                                  gamma_star_vec = NULL) {
  # total number of subjects
  N <- length(R)

  # Matched set indices
  match_index = unique(index)

  # number of matched sets
  nostratum <- length(unique(index))

  # weights
  stratum_weights <- rep(NA, nostratum)

  # observed test stat
  obsT = rep(NA, nostratum)

  # compute test stat from each matched set
  for (j in 1:nostratum) {
    # doses in set j
    doses <- Z[which(index == match_index[j])]
    # response in set j.
    resp <- R[which(index == match_index[j])]
    # number of subjects
    ns <- length(doses)
    # weights
    stratum_weights[j] <- ns / N
    # compute smallest and largest probabilities
    if (is.null(gamma_star_vec)) {
      gamma_star <- max_ratio(z = doses, gamma = gamma)
    } else {
      gamma_star <- gamma_star_vec[j]
    }

    # unadjusted statistic
    V <- estimand_function(z = doses, r = resp)
    dhat <- V - theta -
      (gamma_star - 1) / (1 + gamma_star) * abs(V - theta)

    # statistic bounded above under weak null
    obsT[j] <- dhat

  }

  #  set up variance estimation
  Q = matrix(NA, nrow = nostratum, ncol = 2)
  Q[,1] <- 1
  if (stats::var(stratum_weights > 0)) {
    Q[,2] <- stratum_weights
  } else {
    Q = Q[,1]
  }
  if (all(!is.na(X))) {
    Q = cbind(Q, X)
  }

  # components for variance estimation
  H_Q <- Q %*% solve(t(Q)%*%Q) %*% t(Q)
  W <- diag(nostratum*stratum_weights)
  y <- obsT / sqrt(1 - diag(H_Q))

  # standard error estimate
  se <- var_est(y = y, W = W, H_Q = H_Q)

  # normal deviate
  deviate <- sum(obsT * stratum_weights) / se
  pval <- 1 - stats::pnorm(deviate)

  return(list(deviate = deviate, pval = pval, teststat = obsT, se = se))
}

#' Asymptotic sharp null sensitivity analysis for a class of test statistics
#' accommodating continuous exposures and any scalar outcome.
#'
#' @param Z A length N vector of  observed doses.
#' @param R A length N vector of observed outcomes.
#' @param index A length N vector of indices indicating matched set membership.
#' @param gamma The nonnegative sensitivity parameter; gamma = 0 means no
#' unmeasured confounding.
#' @param q1 A transformation to apply to the doses.
#' @param q2 A transformation to apply to the outcomes
#' @param X A matrix with I rows and less than I columns that contains covariate
#' information.
#' @param stratum_weights A weight vector.
#' @param conservative_variance Whether to use the conservative variance or not;
#' default is TRUE.
#' @param double_rank Whether to use the ranks of the transformed doses and
#' outcomes; default is TRUE.
#'
#' @return A list containing the deviate, one-sided p-value,
#' observed value of the test statistic in each matched set,
#' and conservative standard deviation estimate.
#' @export
#'
#' @examples
#' # Load the data
#' data <- lead_bmd
#' # conduct sharp null test at gamma = 0.
#' result <- sharp_null_double_test(Z = data$log_lead,
#' R = -data$lumbar_spine_bmd, index = data$matched_sets, gamma = 0)
#'
sharp_null_double_test <- function(Z, R, index, gamma = 0, q1 = NA, q2 = NA, X = NA,
                            stratum_weights = rep(NA, nostratum),
                            conservative_variance = TRUE, double_rank = TRUE) {
  # total number of subjects
  N <- length(R)

  # Matched set indices
  match_index = unique(index)

  # number of matched sets
  nostratum <- length(unique(index))

  # observed test stat
  obsT = rep(NA, nostratum)

  # observed test stat
  varT = rep(NA, nostratum)

  if (double_rank) {
    ordered_doses <- sort(Z)
    ordered_pos <- sort(R)

    rank_doses <- rank(ordered_doses, ties.method = "min")
    rank_pos <- rank(ordered_pos, ties.method = "min")
    q1 <- function(z) {
      force(rank_doses)
      # Match the number in the vector and return the rank
      if (z %in% ordered_doses) {
        # Get the smallest rank of the number in the vector
        return(min(rank_doses[ordered_doses == z]))
      } else {
        return(NA)  # Return NA if the number is not in the vector
      }
    }
    q2 <- function(r) {
      force(rank_pos)
      # Match the number in the vector and return the rank
      if (r %in% ordered_pos) {
        # Get the smallest rank of the number in the vector
        return(min(rank_pos[ordered_pos == r]))
      } else {
        return(NA)  # Return NA if the number is not in the vector
      }
    }

  }

  # compute test stat from each matched set
  for (j in 1:nostratum) {
    # doses in set j
    doses <- Z[which(index == match_index[j])]
    # response in set j.
    resp <- R[which(index == match_index[j])]
    # number of subjects
    ns <- length(doses)
    # weights
    if (double_rank) {
      stratum_weights[j] <- 1
    } else {
      stratum_weights[j] <- ns / N
    }



    # unadjusted statistic
    T_i <- sum(vapply(doses, q1, FUN.VALUE = numeric(1)) * vapply(resp, q2, FUN.VALUE = numeric(1)))
    q_pi <- sharp_double_statistic(z = doses, r = resp, q1, q2)



    # maximum exp
    # compute stat for all permutations

    if (!conservative_variance) {
      max_exp_dummy <- max_expectation(z = doses, gamma = gamma, f_pi = q_pi,
                                       with_variance = TRUE)
      max_exp <- max_exp_dummy$max_exp
      varT[j] <- max_exp_dummy$variance
    } else {
      max_exp <- max_expectation(z = doses, gamma = gamma, f_pi = q_pi)
    }

    # statistic bounded above under sharp null
    obsT[j] <- T_i - max_exp

  }
  if (double_rank) {
    obsT <- obsT / N^2
  }

  #  set up variance estimation
  if (conservative_variance) {
    Q = matrix(NA, nrow = nostratum, ncol = 2)
    Q[,1] <- 1
    if (stats::var(stratum_weights > 0)) {
      Q[,2] <- stratum_weights
    } else {
      Q = Q[,1]
    }
    if (all(!is.na(X))) {
      Q = cbind(Q, X)
    }

    # components for variance estimation
    H_Q <- Q %*% solve(t(Q)%*%Q) %*% t(Q)
    W <- diag(nostratum*stratum_weights)
    y <- obsT / sqrt(1 - diag(H_Q))

    # standard error estimate
    se <- var_est(y = y, W = W, H_Q = H_Q)
  } else {
    se <- sqrt(sum(varT * stratum_weights^2))
  }


  # normal deviate
  deviate <- sum(obsT * stratum_weights) / se
  pval <- 1 - stats::pnorm(deviate)

  return(list(deviate = deviate, pval = pval, teststat = obsT, se = se))
}



