#!/usr/bin/env Rscript

### Some constants


nrow_per_page = 3 # if input_list is composed of transcript ids
ncol_per_page = 2 # if input_list is composed of transcript ids
num_plots_per_page = nrow_per_page * ncol_per_page # if input_list is composed of transcript/allele ids


### Load program arguments


assert = function(expr, errmsg) {
  if (!expr) {      
    cat(errmsg, "\n", sep = "", file = stderr())
    quit(save = "no", status = 1)
  }
}

args = commandArgs(TRUE)
assert(length(args) == 6, "Usage: rsem-gen-transcript-plots sample_name input_list is_allele_specific id_type<0,allele;1,isoform;2,gene> show_uniq output_plot_file")

sample_name = args[1]
input_list = args[2]
alleleS = as.numeric(args[3])
id_type = as.numeric(args[4])
show_uniq = as.numeric(args[5])
output_plot_file = args[6]


### Load read depth files


load_read_depth = function(file) {
  depth = read.table(file, sep = "\t", stringsAsFactors = FALSE)
  rownames(depth) = depth[,1]
  return (depth)
}

readdepth = load_read_depth(sprintf("%s.transcript.readdepth", sample_name))
M = dim(readdepth)[1]
ord_depth = order(readdepth[,1])

all2uniq = c()
if (show_uniq) {
  readdepth_uniq = load_read_depth(sprintf("%s.uniq.transcript.readdepth", sample_name))
  ord_uniq_depth = order(readdepth_uniq[,1])
  assert(sum(readdepth[ord_depth,1] != readdepth_uniq[ord_uniq_depth,1]) == 0, "transcript/allele IDS in read depth and unique read depth files are not the same!")
  assert(sum(readdepth[ord_depth,2] != readdepth_uniq[ord_uniq_depth,2]) == 0, "transcript lengths in read depth and unique read depth files are not the same!")
  all2uniq[ord_depth] = ord_uniq_depth
}

cat("Loading read depth files is done!\n")


### Build Gene-Isoform/Gene-Allele map and maps between IDs and ID_NAMEs


id_equal = function(a, b) {
  a == substr(b, 1, nchar(a))
}


expr_data = read.delim(sprintf("%s.%s.results", sample_name, ifelse(alleleS, "alleles", "isoforms")), stringsAsFactors = FALSE)
assert(M == dim(expr_data)[1], "The number of transcripts/alleles contained in the expression file is not equal to the number in the readdepth file!")
ord_expr = order(expr_data[,1])

assert(sum(sapply(1:M, function(i) { !id_equal(readdepth[ord_depth[i], 1], expr_data[ord_expr[i], 1]) })) == 0, "Transcript/Allele IDs in the expression file is not exactly the same as the ones in the readdepth file!")

expr2depth = c() # from id_name to pos
expr2depth[ord_expr] = ord_depth
names(expr2depth) = expr_data[,1]

is_composite = (!alleleS && (id_type == 2)) || (alleleS && (id_type > 0))

if (is_composite) {
   tmp_df = data.frame(expr2depth, expr_data[,ifelse(alleleS && id_type == 2, 3, 2)], stringsAsFactors = F)
   tmp_agg = aggregate(tmp_df[1], tmp_df[2], function(x) { x })
}

cat("Building transcript to gene map is done!\n")

  
### Load and transfer IDs


ids = scan(file = input_list, what = "", sep = "\n", strip.white = T)
assert(length(ids) > 0, "You should provide at least one ID.")
poses = c()

if (is_composite) {
  poses = charmatch(ids, tmp_agg[,1], nomatch = -1)
} else {
  poses = match(ids, expr_data[,1])
  idx = !is.na(poses)
  poses[idx] = expr2depth[poses[idx]]
  poses[!idx] = match(ids[!idx], readdepth[,1], nomatch = -1)
}

err_idx = poses < 1
if (sum(err_idx) > 0) {
   cat("Warning: The following IDs are not in the RSEM indices and thus ignored: ")
   cat(ids[err_idx], sep = ", ")
   cat("\n")
}

ids = ids[!err_idx]
poses = poses[!err_idx]

assert(length(poses) > 0, "There is no valid ID. Stopped.")


### Generate plots

# pos is a number indexing the position in readdepth/readdepth_uniq
make_a_plot = function(pos) {
  len = readdepth[pos, 2]
  depths = readdepth[pos, 3]

  if (is.na(depths)) wiggle = rep(0, len) else wiggle = as.numeric(unlist(strsplit(depths, split = " ")))

  if (!show_uniq) {
    plot(wiggle, type = "h")
  } else {
    depths = readdepth_uniq[all2uniq[pos], 3]
    if (is.na(depths)) wiggle_uniq = rep(0, len) else wiggle_uniq = as.numeric(unlist(strsplit(depths, split = " ")))
    if (len != sum(wiggle >= wiggle_uniq)) {
      cat("Warning: ", ifelse(alleleS, "allele-specific transcript", "transcript"), " ", id, " has position(s) that read covarege with multireads is smaller than read covarge without multireads.\n", "         The 1-based position(s) is(are) : ", which(wiggle < wiggle_uniq), ".\n", "         This may be due to floating point arithmetics.\n", sep = "") 
    }
    heights = rbind(wiggle_uniq, wiggle - wiggle_uniq)	
    barplot(heights, space = 0, border = NA, names.arg = 1:len, col = c("black", "red")) 
  }
  title(main = readdepth[pos, 1])
}

# poses is a vector of numbers
generate_a_page = function(poses, title = NULL) {
  n = length(poses)
  ncol = ifelse(is_composite, floor(sqrt(n)), ncol_per_page)
  nrow = ifelse(is_composite, ceiling(n / ncol), nrow_per_page)

  par(mfrow = c(nrow, ncol), mar = c(2, 2, 2, 2))
  if (is_composite) par(oma = c(0, 0, 3, 0)) 
  sapply(poses, make_a_plot)
  if (is_composite) mtext(title, outer = TRUE, line = 1)
}

plot_individual = function(i) {
  fr = (i - 1) * num_plots_per_page + 1
  to = min(i * num_plots_per_page, n)
  generate_a_page(poses[fr:to])
}

# cid, composite id, can be either a gene id or transcript id (for allele-specific expression only)
plot_composite = function(pos) {
  generate_a_page(tmp_agg[pos, 2][[1]], tmp_agg[pos, 1])
}


pdf(output_plot_file)

if (!is_composite) {	
  n = length(ids)
  ub = (n - 1) %/% num_plots_per_page + 1
  dumbvar = sapply(1:ub, plot_individual)
} else {
  dumbvar = sapply(poses, plot_composite)
}

cat("Plots are generated!\n")

dev.off.output = dev.off()
