#!/usr/bin/env Rscript

argv <- commandArgs(TRUE)
if (length(argv) < 6) {
  cat("Usage: rsem-for-ebseq-find-DE path ngvector_file data_matrix_file output_file number_of_replicate_for_condition_1 number_of_replicate_for_condition_2 ...\n")
  q(status = 1)
}

path <- argv[1]
ngvector_file <- argv[2]
data_matrix_file <- argv[3]
output_file <- argv[4]
norm_out_file <- paste0(output_file, ".normalized_data_matrix")

nc <- length(argv) - 4;
num_reps <- as.numeric(argv[5:(5+nc-1)])

.libPaths(c(.libPaths(), path))
library(EBSeq)

DataMat <- data.matrix(read.table(data_matrix_file))
n <- dim(DataMat)[2]
if (sum(num_reps) != n) stop("Total number of replicates given does not match the number of columns from the data matrix!")

conditions <- as.factor(rep(paste("C", 1:nc, sep=""), times = num_reps))
Sizes <- MedianNorm(DataMat)
NormMat <- GetNormalizedMat(DataMat, Sizes)
ngvector <- NULL
if (ngvector_file != "#") {
  ngvector <- as.vector(data.matrix(read.table(ngvector_file)))
  stopifnot(!is.null(ngvector))
}

if (nc == 2) {
  EBOut <- NULL
  EBOut <- EBTest(Data = DataMat, NgVector = ngvector, Conditions = conditions, sizeFactors = Sizes, maxround = 5)
  stopifnot(!is.null(EBOut))

  PP <- as.data.frame(GetPPMat(EBOut))
  fc_res <- PostFC(EBOut)

  results <- cbind(PP, fc_res$PostFC, fc_res$RealFC,unlist(EBOut$C1Mean)[rownames(PP)], unlist(EBOut$C2Mean)[rownames(PP)])
  colnames(results) <- c("PPEE", "PPDE", "PostFC", "RealFC","C1Mean","C2Mean")
  results <- results[order(results[,"PPDE"], decreasing = TRUE),]
  write.table(results, file = output_file, sep = "\t")
  
} else {
  patterns <- GetPatterns(conditions)
  eename <- rownames(patterns)[which(rowSums(patterns) == nc)]
  stopifnot(length(eename) == 1)

  MultiOut <- NULL
  MultiOut <- EBMultiTest(Data = DataMat, NgVector = ngvector, Conditions = conditions, AllParti = patterns, sizeFactors = Sizes, maxround = 5)
  stopifnot(!is.null(MultiOut))

  MultiPP <- GetMultiPP(MultiOut)

  PP <- as.data.frame(MultiPP$PP)
  pos <- which(names(PP) == eename)
  probs <- rowSums(PP[,-pos])

  results <- cbind(PP, MultiPP$MAP[rownames(PP)], probs)
  colnames(results) <- c(colnames(PP), "MAP", "PPDE")  
  ord <- order(results[,"PPDE"], decreasing = TRUE)
  results <- results[ord,]
  write.table(results, file = output_file, sep = "\t")

  write.table(MultiPP$Patterns, file = paste(output_file, ".pattern", sep = ""), sep = "\t")

  MultiFC <- GetMultiFC(MultiOut)
  write.table(MultiFC$CondMeans[ord,], file = paste(output_file, ".condmeans", sep = ""), sep = "\t")
}
write.table(NormMat, file = norm_out_file, sep = "\t")
