library(testthat)
library(EpidigiR)

test_that("epi_model returns correct logistic regression output", {
  survey_data <- data.frame(
    respondent_id = 1:100,
    age_group = rep(c("18-29", "30-44", "45-59", "60+"), each = 25),
    vaccinated = sample(c(0,1), 100, replace = TRUE)
  )

  res_log <- epi_model(data = survey_data,
                       formula = vaccinated ~ age_group,
                       type = "logistic")

  expect_true(all(c("coefficients", "predictions")
                  %in% names(res_log)))
  expect_equal(length(res_log$predictions),
               nrow(survey_data))
})

test_that("epi_model performs k-means clustering", {
  km_data <- data.frame(x = rnorm(10),
                        y = rnorm(10))
  res_km <- epi_model(data = km_data,
                      type = "kmeans",
                      k = 2)

  expect_true(all(c("clusters", "centers")
                  %in% names(res_km)))
  expect_equal(length(res_km$clusters),
               nrow(km_data))
  expect_equal(nrow(res_km$centers), 2)
})

test_that("epi_model returns correct survival analysis output", {
  set.seed(123)
  n <- 100
  surv_data <- data.frame(
    time = rexp(n, rate = 0.1),
    status = sample(0:1, n, replace = TRUE)
  )

  res_surv <- epi_model(data = surv_data,
                        type = "survival")

  expect_s3_class(res_surv$survfit, "survfit")
  expect_true(all(c("time", "surv")
                  %in% names(res_surv$summary)))
  expect_equal(length(res_surv$summary$time),
               length(res_surv$summary$surv))
})

# Survival analysis

library(survival)

test_that("epi_model returns correct survival analysis output", {

  set.seed(123)
  n <- 100
  surv_data <- data.frame(
    time = rexp(n, rate = 0.1),        # survival times
    status = sample(0:1, n, replace = TRUE)  # 1 = event occurred, 0 = censored
  )

  res_surv <- epi_model(data = surv_data, type = "survival")

# Check that result contains survfit object and summary
  expect_true("survfit"
              %in% names(res_surv))
  expect_true("summary"
              %in% names(res_surv))

# Check class of survfit
  expect_s3_class(res_surv$survfit,
                  "survfit")

# Check summary structure
  sum_obj <- res_surv$summary
  expect_true(all(c("time", "surv")
                  %in% names(sum_obj)))
  expect_equal(length(sum_obj$time),
               length(sum_obj$surv))

})

