Revision c2f5ae1aad3831db93d91c932c79387063e08d93 authored by Paul Hoffman on 04 September 2020, 20:04:01 UTC, committed by GitHub on 04 September 2020, 20:04:01 UTC
1 parent c680e35
Raw File
plotting.old
#' @include seurat.R
NULL

globalVariables(names = c('cell', 'gene'), package = 'Seurat', add = TRUE)
#' Gene expression heatmap
#'
#' Draws a heatmap of single cell gene expression using ggplot2.
#'
#' @param object Seurat object
#' @param data.use Option to pass in data to use in the heatmap. Default will pick from either
#' object@@data or object@@scale.data depending on use.scaled parameter. Should have cells as columns
#' and genes as rows.
#' @param use.scaled Whether to use the data or scaled data if data.use is NULL
#' @param cells.use Cells to include in the heatmap (default is all cells)
#' @param genes.use Genes to include in the heatmap (ordered)
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param group.by Groups cells by this variable. Default is object@@ident
#' @param group.order Order of groups from left to right in heatmap.
#' @param draw.line Draw vertical lines delineating different groups
#' @param col.low Color for lowest expression value
#' @param col.mid Color for mid expression value
#' @param col.high Color for highest expression value
#' @param slim.col.label display only the identity class name once for each group
#' @param remove.key Removes the color key from the plot.
#' @param rotate.key Rotate color scale horizantally
#' @param title Title for plot
#' @param cex.col Controls size of column labels (cells)
#' @param cex.row Controls size of row labels (genes)
#' @param group.label.loc Place group labels on bottom or top of plot.
#' @param group.label.rot Whether to rotate the group label.
#' @param group.cex Size of group label text
#' @param group.spacing Controls amount of space between columns.
#' @param assay.type to plot heatmap for (default is RNA)
#' @param do.plot Whether to display the plot.
#'
#' @return Returns a ggplot2 plot object
#'
#' @importFrom dplyr %>%
#' @importFrom reshape2 melt
#'
#' @export
#'
#' @examples
#' DoHeatmap(object = pbmc_small)
#'
DoHeatmap <- function(
  object,
  data.use = NULL,
  use.scaled = TRUE,
  cells.use = NULL,
  genes.use = NULL,
  disp.min = -2.5,
  disp.max = 2.5,
  group.by = "ident",
  group.order = NULL,
  draw.line = TRUE,
  col.low = "#FF00FF",
  col.mid = "#000000",
  col.high = "#FFFF00",
  slim.col.label = FALSE,
  remove.key = FALSE,
  rotate.key = FALSE,
  title = NULL,
  cex.col = 10,
  cex.row = 10,
  group.label.loc = "bottom",
  group.label.rot = FALSE,
  group.cex = 15,
  group.spacing = 0.15,
  assay.type = "RNA",
  do.plot = TRUE
) {
  if (is.null(x = data.use)) {
    if (use.scaled) {
      data.use <- GetAssayData(object,assay.type = assay.type,slot = "scale.data")
    } else {
      data.use <- GetAssayData(object,assay.type = assay.type,slot = "data")
    }
  }
  cells.use <- SetIfNull(x = cells.use, default = object@cell.names)
  cells.use <- intersect(x = cells.use, y = colnames(x = data.use))
  if (length(x = cells.use) == 0) {
    stop("No cells given to cells.use present in object")
  }
  genes.use <- SetIfNull(x = genes.use, default = rownames(x = data.use))
  genes.use <- intersect(x = genes.use, y = rownames(x = data.use))
  if (length(x = genes.use) == 0) {
    stop("No genes given to genes.use present in object")
  }
  if (is.null(x = group.by) || group.by == "ident") {
    cells.ident <- object@ident[cells.use]
  } else {
    cells.ident <- factor(x = FetchData(
      object = object,
      cells.use = cells.use,
      vars.all = group.by
    )[, 1])
    names(x = cells.ident) <- cells.use
  }
  cells.ident <- factor(
    x = cells.ident,
    labels = intersect(x = levels(x = cells.ident), y = cells.ident)
  )
  data.use <- data.use[genes.use, cells.use, drop = FALSE]
  if (!use.scaled) {
    data.use <- as.matrix(x = data.use)
    disp.max <- ifelse(test = disp.max == 2.5, yes = 10, no = disp.max)
  }
  data.use <- MinMax(data = data.use, min = disp.min, max = disp.max)
  data.use <- as.data.frame(x = t(x = data.use))
  data.use$cell <- rownames(x = data.use)
  colnames(x = data.use) <- make.unique(names = colnames(x = data.use))
  data.use %>% melt(id.vars = "cell") -> data.use
  names(x = data.use)[names(x = data.use) == 'variable'] <- 'gene'
  names(x = data.use)[names(x = data.use) == 'value'] <- 'expression'
  data.use$ident <- cells.ident[data.use$cell]
  if (!is.null(x = group.order)) {
    if (length(group.order) == length(levels(data.use$ident)) && all(group.order %in% levels(data.use$ident))) {
      data.use$ident <- factor(data.use$ident, levels = group.order)
    }
    else {
      stop("Invalid group.order")
    }
  }
  data.use$gene <- with(
    data = data.use,
    expr = factor(x = gene, levels = rev(x = unique(x = data.use$gene)))
  )
  data.use$cell <- with(
    data = data.use,
    expr = factor(x = cell, levels = cells.use)
  )
  if (rotate.key) {
    key.direction <- "horizontal"
    key.title.pos <- "top"
  } else {
    key.direction <- "vertical"
    key.title.pos <- "left"
  }
  heatmap <- ggplot(
    data = data.use,
    mapping = aes(x = cell, y = gene, fill = expression)
  ) +
    geom_tile() +
    scale_fill_gradient2(
      low = col.low,
      mid = col.mid,
      high = col.high,
      name = "Expression",
      guide = guide_colorbar(
        direction = key.direction,
        title.position = key.title.pos
      )
    ) +
    scale_y_discrete(position = "right", labels = rev(genes.use)) +
    theme(
      axis.line = element_blank(),
      axis.title.y = element_blank(),
      axis.ticks.y = element_blank(),
      strip.text.x = element_text(size = group.cex),
      axis.text.y = element_text(size = cex.row),
      axis.text.x = element_text(size = cex.col),
      axis.title.x = element_blank()
    )
  if (slim.col.label) {
    heatmap <- heatmap +
      theme(
        axis.title.x = element_blank(),
        axis.text.x = element_blank(),
        axis.ticks.x = element_blank(),
        axis.line = element_blank(),
        axis.title.y = element_blank(),
        axis.ticks.y = element_blank()
      )
  } else {
    heatmap <- heatmap + theme(axis.text.x = element_text(angle = 90))
  }
  if (!is.null(x = group.by)) {
    if (group.label.loc == "top") {
      switch <- NULL
    } else {
      switch <- 'x'
    }
    heatmap <- heatmap +
      facet_grid(
        facets = ~ident,
        drop = TRUE,
        space = "free",
        scales = "free",
        switch = switch
      ) +
      scale_x_discrete(expand = c(0, 0), drop = TRUE)
    if (draw.line) {
      panel.spacing <- unit(x = group.spacing, units = 'lines')
    } else {
      panel.spacing <- unit(x = 0, units = 'lines')
    }
    heatmap <- heatmap +
      theme(strip.background = element_blank(), panel.spacing = panel.spacing)
    if (group.label.rot) {
      heatmap <- heatmap + theme(strip.text.x = element_text(angle = 90))
    }
  }
  if (remove.key) {
    heatmap <- heatmap + theme(legend.position = "none")
  }
  if (!is.null(x = title)) {
    heatmap <- heatmap + labs(title = title)
  }
  return(heatmap)
}

#' Single cell violin plot
#'
#' Draws a violin plot of single cell data (gene expression, metrics, PC
#' scores, etc.)
#'
#' @param object Seurat object
#' @param features.plot Features to plot (gene expression, metrics, PC scores,
#' anything that can be retreived by FetchData)
#' @param ident.include Which classes to include in the plot (default is all)
#' @param nCol Number of columns if multiple plots are displayed
#' @param do.sort Sort identity classes (on the x-axis) by the average
#' expression of the attribute being potted
#' @param y.max Maximum y axis value
#' @param same.y.lims Set all the y-axis limits to the same values
#' @param size.x.use X axis title font size
#' @param size.y.use Y axis title font size
#' @param size.title.use Main title font size
#' @param adjust.use Adjust parameter for geom_violin
#' @param point.size.use Point size for geom_violin
#' @param cols.use Colors to use for plotting
#' @param group.by Group (color) cells in different ways (for example, orig.ident)
#' @param y.log plot Y axis on log scale
#' @param x.lab.rot Rotate x-axis labels
#' @param y.lab.rot Rotate y-axis labels
#' @param legend.position Position the legend for the plot
#' @param single.legend Consolidate legend the legend for all plots
#' @param remove.legend Remove the legend from the plot
#' @param do.return Return a ggplot2 object (default : FALSE)
#' @param return.plotlist Return the list of individual plots instead of compiled plot.
#' @param \dots additional parameters to pass to FetchData (for example, use.imputed, use.scaled, use.raw)
#'
#' @import ggplot2
#' @importFrom cowplot plot_grid get_legend
#'
#' @return By default, no return, only graphical output. If do.return=TRUE,
#' returns a list of ggplot objects.
#'
#' @export
#'
#' @examples
#' VlnPlot(object = pbmc_small, features.plot = 'PC1')
#'
VlnPlot <- function(
  object,
  features.plot,
  ident.include = NULL,
  nCol = NULL,
  do.sort = FALSE,
  y.max = NULL,
  same.y.lims = FALSE,
  size.x.use = 16,
  size.y.use = 16,
  size.title.use = 20,
  adjust.use = 1,
  point.size.use = 1,
  cols.use = NULL,
  group.by = NULL,
  y.log = FALSE,
  x.lab.rot = FALSE,
  y.lab.rot = FALSE,
  legend.position = "right",
  single.legend = TRUE,
  remove.legend = FALSE,
  do.return = FALSE,
  return.plotlist = FALSE,
  ...
) {
  if (is.null(x = nCol)) {
    if (length(x = features.plot) > 9) {
      nCol <- 4
    } else {
      nCol <- min(length(x = features.plot), 3)
    }
  }
  data.use <- data.frame(FetchData(object = object, vars.all = features.plot, ...), check.names = F)
  if (is.null(x = ident.include)) {
    cells.to.include <- object@cell.names
  } else {
    cells.to.include <- WhichCells(object = object, ident = ident.include)
  }
  data.use <- data.use[cells.to.include, ,drop = FALSE]
  if (!is.null(x = group.by)) {
    ident.use <- as.factor(x = FetchData(
      object = object,
      vars.all = group.by
    )[cells.to.include, 1])
  } else {
    ident.use <- object@ident[cells.to.include]
  }
  gene.names <- colnames(x = data.use)[colnames(x = data.use) %in% rownames(x = object@data)]
  if (single.legend) {
    remove.legend <- TRUE
  }
  if (same.y.lims && is.null(x = y.max)) {
    y.max <- max(data.use)
  }
  plots <- lapply(
    X = features.plot,
    FUN = function(x) {
      return(SingleVlnPlot(
        feature = x,
        data = data.use[, x, drop = FALSE],
        cell.ident = ident.use,
        do.sort = do.sort, y.max = y.max,
        size.x.use = size.x.use,
        size.y.use = size.y.use,
        size.title.use = size.title.use,
        adjust.use = adjust.use,
        point.size.use = point.size.use,
        cols.use = cols.use,
        gene.names = gene.names,
        y.log = y.log,
        x.lab.rot = x.lab.rot,
        y.lab.rot = y.lab.rot,
        legend.position = legend.position,
        remove.legend = remove.legend
      ))
    }
  )
  if (length(x = features.plot) > 1) {
    plots.combined <- plot_grid(plotlist = plots, ncol = nCol)
    if (single.legend && !remove.legend) {
      legend <- get_legend(
        plot = plots[[1]] + theme(legend.position = legend.position)
      )
      if (legend.position == "bottom") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          ncol = 1,
          rel_heights = c(1, .2)
        )
      } else if (legend.position == "right") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          rel_widths = c(3, .3)
        )
      } else {
        warning("Shared legends must be at the bottom or right of the plot")
      }
    }
  } else {
    plots.combined <- plots[[1]]
  }
  if (do.return) {
    if (return.plotlist) {
      return(plots)
    } else {
      return(plots.combined)
    }
  } else {
    if (length(x = plots.combined) > 1) {
      plots.combined
    }
    else {
      invisible(x = lapply(X = plots.combined, FUN = print))
    }
  }
}

#' Single cell ridge plot
#'
#' Draws a ridge plot of single cell data (gene expression, metrics, PC
#' scores, etc.)
#'
#' @param object Seurat object
#' @param features.plot Features to plot (gene expression, metrics, PC scores,
#' anything that can be retreived by FetchData)
#' @param ident.include Which classes to include in the plot (default is all)
#' @param nCol Number of columns if multiple plots are displayed
#' @param do.sort Sort identity classes (on the x-axis) by the average
#' expression of the attribute being potted
#' @param y.max Maximum y axis value
#' @param same.y.lims Set all the y-axis limits to the same values
#' @param size.x.use X axis title font size
#' @param size.y.use Y axis title font size
#' @param size.title.use Main title font size
#' @param cols.use Colors to use for plotting
#' @param group.by Group (color) cells in different ways (for example, orig.ident)
#' @param y.log plot Y axis on log scale
#' @param x.lab.rot Rotate x-axis labels
#' @param y.lab.rot Rotate y-axis labels
#' @param legend.position Position the legend for the plot
#' @param single.legend Consolidate legend the legend for all plots
#' @param remove.legend Remove the legend from the plot
#' @param do.return Return a ggplot2 object (default : FALSE)
#' @param return.plotlist Return the list of individual plots instead of compiled plot.
#' @param \dots additional parameters to pass to FetchData (for example, use.imputed, use.scaled, use.raw)
#'
#' @import ggplot2
#' @importFrom cowplot get_legend plot_grid
#' @importFrom ggridges geom_density_ridges theme_ridges
#'
#' @return By default, no return, only graphical output. If do.return=TRUE,
#' returns a list of ggplot objects.
#'
#' @export
#'
#' @examples
#' RidgePlot(object = pbmc_small, features.plot = 'PC1')
#'
RidgePlot <- function(
  object,
  features.plot,
  ident.include = NULL,
  nCol = NULL,
  do.sort = FALSE,
  y.max = NULL,
  same.y.lims = FALSE,
  size.x.use = 16,
  size.y.use = 16,
  size.title.use = 20,
  cols.use = NULL,
  group.by = NULL,
  y.log = FALSE,
  x.lab.rot = FALSE,
  y.lab.rot = FALSE,
  legend.position = "right",
  single.legend = TRUE,
  remove.legend = FALSE,
  do.return = FALSE,
  return.plotlist = FALSE,
  ...
) {
  if (is.null(x = nCol)) {
    if (length(x = features.plot) > 9) {
      nCol <- 4
    } else {
      nCol <- min(length(x = features.plot), 3)
    }
  }
  data.use <- data.frame(
    FetchData(
      object = object,
      vars.all = features.plot,
      ...
    ),
    check.names = F
  )
  if (is.null(x = ident.include)) {
    cells.to.include <- object@cell.names
  } else {
    cells.to.include <- WhichCells(object = object, ident = ident.include)
  }
  data.use <- data.use[cells.to.include, ,drop = FALSE]
  if (!is.null(x = group.by)) {
    ident.use <- as.factor(x = FetchData(
      object = object,
      vars.all = group.by
    )[cells.to.include, 1])
  } else {
    ident.use <- object@ident[cells.to.include]
  }
  gene.names <- colnames(x = data.use)[colnames(x = data.use) %in% rownames(x = object@data)]
  if (single.legend) {
    remove.legend <- TRUE
  }
  if (same.y.lims && is.null(x = y.max)) {
    y.max <- max(data.use)
  }
  plots <- lapply(
    X = features.plot,
    FUN = function(x) {
      return(SingleRidgePlot(
        feature = x,
        data = data.use[, x, drop = FALSE],
        cell.ident = ident.use,
        do.sort = do.sort, y.max = y.max,
        size.x.use = size.x.use,
        size.y.use = size.y.use,
        size.title.use = size.title.use,
        cols.use = cols.use,
        gene.names = gene.names,
        y.log = y.log,
        x.lab.rot = x.lab.rot,
        y.lab.rot = y.lab.rot,
        legend.position = legend.position,
        remove.legend = remove.legend
      ))
    }
  )
  if (length(x = features.plot) > 1) {
    plots.combined <- plot_grid(plotlist = plots, ncol = nCol)
    if (single.legend && !remove.legend) {
      legend <- get_legend(
        plot = plots[[1]] + theme(legend.position = legend.position)
      )
      if (legend.position == "bottom") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          ncol = 1,
          rel_heights = c(1, .2)
        )
      } else if (legend.position == "right") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          rel_widths = c(3, .3)
        )
      } else {
        warning("Shared legends must be at the bottom or right of the plot")
      }
    }
  } else {
    plots.combined <- plots[[1]]
  }
  if (do.return) {
    if (return.plotlist) {
      return(plots)
    } else {
      return(plots.combined)
    }
  } else {
    if (length(x = plots.combined) > 1) {
      plots.combined
    }
    else {
      invisible(x = lapply(X = plots.combined, FUN = print))
    }
  }
}

#' Old Dot plot visualization (pre-ggplot implementation)
#
#' Intuitive way of visualizing how gene expression changes across different identity classes (clusters).
#' The size of the dot encodes the percentage of cells within a class, while the color encodes the
#' AverageExpression level of 'expressing' cells (green is high).
#'
#' @param object Seurat object
#' @param genes.plot Input vector of genes
#' @param cex.use Scaling factor for the dots (scales all dot sizes)
#' @param cols.use colors to plot
#' @param thresh.col The raw data value which corresponds to a red dot (lowest expression)
#' @param dot.min The fraction of cells at which to draw the smallest dot (default is 0.05)
#' @param group.by Factor to group the cells by
#'
#' @return Only graphical output
#'
#' @importFrom graphics axis plot
#'
#' @export
#'
#' @examples
#' cd_genes <- c("CD247", "CD3E", "CD9")
#' DotPlotOld(object = pbmc_small, genes.plot = cd_genes)
#'
DotPlotOld <- function(
  object,
  genes.plot,
  cex.use = 2,
  cols.use = NULL,
  thresh.col = 2.5,
  dot.min = 0.05,
  group.by = NULL
) {
  if (! is.null(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  #object@data=object@data[genes.plot,]
  object@data <- data.frame(t(x = FetchData(object = object, vars.all = genes.plot)))
  #this line is in case there is a '-' in the cell name
  colnames(x = object@data) <- object@cell.names
  avg.exp <- AverageExpression(object = object)
  avg.alpha <- AverageDetectionRate(object = object)
  cols.use <- SetIfNull(x = cols.use, default = CustomPalette(low = "red", high = "green"))
  exp.scale <- t(x = scale(x = t(x = avg.exp)))
  exp.scale <- MinMax(data = exp.scale, max = thresh.col, min = (-1) * thresh.col)
  n.col <- length(x = cols.use)
  data.y <- rep(x = 1:ncol(x = avg.exp), nrow(x = avg.exp))
  data.x <- unlist(x = lapply(X = 1:nrow(x = avg.exp), FUN = rep, ncol(x = avg.exp)))
  data.avg <- unlist(x = lapply(
    X = 1:length(x = data.y),
    FUN = function(x) {
      return(exp.scale[data.x[x], data.y[x]])
    }
  ))
  exp.col <- cols.use[floor(
    x = n.col * (data.avg + thresh.col) / (2 * thresh.col) + .5
  )]
  data.cex <- unlist(x = lapply(
    X = 1:length(x = data.y),
    FUN = function(x) {
      return(avg.alpha[data.x[x], data.y[x]])
    }
  )) * cex.use + dot.min
  plot(
    x = data.x,
    y = data.y,
    cex = data.cex,
    pch = 16,
    col = exp.col,
    xaxt = "n",
    xlab = "",
    ylab = "",
    yaxt = "n"
  )
  axis(side = 1, at = 1:length(x = genes.plot), labels = genes.plot)
  axis(side = 2, at = 1:ncol(x = avg.alpha), colnames(x = avg.alpha), las = 1)
}

globalVariables(
  names = c('cell', 'id', 'avg.exp', 'avg.exp.scale', 'pct.exp'),
  package = 'Seurat',
  add = TRUE
)
#' Dot plot visualization
#'
#' Intuitive way of visualizing how gene expression changes across different
#' identity classes (clusters). The size of the dot encodes the percentage of
#' cells within a class, while the color encodes the AverageExpression level of
#' cells within a class (blue is high).
#'
#' @param object Seurat object
#' @param genes.plot Input vector of genes
#' @param cols.use Colors to plot, can pass a single character giving the name of
#' a palette from \code{RColorBrewer::brewer.pal.info}
#' @param col.min Minimum scaled average expression threshold (everything smaller
#'  will be set to this)
#' @param col.max Maximum scaled average expression threshold (everything larger
#' will be set to this)
#' @param dot.min The fraction of cells at which to draw the smallest dot
#' (default is 0). All cell groups with less than this expressing the given
#' gene will have no dot drawn.
#' @param dot.scale Scale the size of the points, similar to cex
#' @param scale.by Scale the size of the points by 'size' or by 'radius'
#' @param scale.min Set lower limit for scaling, use NA for default
#' @param scale.max Set upper limit for scaling, use NA for default
#' @param group.by Factor to group the cells by
#' @param plot.legend plots the legends
#' @param x.lab.rot Rotate x-axis labels
#' @param do.return Return ggplot2 object
#'
#' @return default, no return, only graphical output. If do.return=TRUE, returns a ggplot2 object
#'
#' @importFrom tidyr gather
#' @importFrom dplyr %>% group_by summarize_each mutate ungroup
#'
#' @export
#' @seealso \code{RColorBrewer::brewer.pal.info}
#'
#' @examples
#' cd_genes <- c("CD247", "CD3E", "CD9")
#' DotPlot(object = pbmc_small, genes.plot = cd_genes)
#'
DotPlot <- function(
  object,
  genes.plot,
  cols.use = c("lightgrey", "blue"),
  col.min = -2.5,
  col.max = 2.5,
  dot.min = 0,
  dot.scale = 6,
  scale.by = 'radius',
  scale.min = NA,
  scale.max = NA,
  group.by,
  plot.legend = FALSE,
  do.return = FALSE,
  x.lab.rot = FALSE
) {
  scale.func <- switch(
    EXPR = scale.by,
    'size' = scale_size,
    'radius' = scale_radius,
    stop("'scale.by' must be either 'size' or 'radius'")
  )
  if (!missing(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  data.to.plot <- data.frame(FetchData(object = object, vars.all = genes.plot))
  colnames(x = data.to.plot) <- genes.plot
  data.to.plot$cell <- rownames(x = data.to.plot)
  data.to.plot$id <- object@ident
  data.to.plot %>% gather(
    key = genes.plot,
    value = expression,
    -c(cell, id)
  ) -> data.to.plot
  data.to.plot %>%
    group_by(id, genes.plot) %>%
    summarize(
      avg.exp = mean(expm1(x = expression)),
      pct.exp = PercentAbove(x = expression, threshold = 0)
    ) -> data.to.plot
  data.to.plot %>%
    ungroup() %>%
    group_by(genes.plot) %>%
    mutate(avg.exp.scale = scale(x = avg.exp)) %>%
    mutate(avg.exp.scale = MinMax(
      data = avg.exp.scale,
      max = col.max,
      min = col.min
    )) ->  data.to.plot
  data.to.plot$genes.plot <- factor(
    x = data.to.plot$genes.plot,
    levels = rev(x = genes.plot)
  )
  # data.to.plot$genes.plot <- factor(
  #   x = data.to.plot$genes.plot,
  #   levels = rev(x = sub(pattern = "-", replacement = ".", x = genes.plot))
  # )
  data.to.plot$pct.exp[data.to.plot$pct.exp < dot.min] <- NA
  data.to.plot$pct.exp <- data.to.plot$pct.exp * 100
  p <- ggplot(data = data.to.plot, mapping = aes(x = genes.plot, y = id)) +
    geom_point(mapping = aes(size = pct.exp, color = avg.exp.scale)) +
    scale.func(range = c(0, dot.scale), limits = c(scale.min, scale.max)) +
    theme(axis.title.x = element_blank(), axis.title.y = element_blank())
  if (length(x = cols.use) == 1) {
    p <- p + scale_color_distiller(palette = cols.use)
  } else {
    p <- p + scale_color_gradient(low = cols.use[1], high = cols.use[2])
  }
  if (!plot.legend) {
    p <- p + theme(legend.position = "none")
  }
  if (x.lab.rot) {
    p <- p + theme(axis.text.x = element_text(angle = 90, vjust = 0.5))
  }
  suppressWarnings(print(p))
  if (do.return) {
    return(p)
  }
}

globalVariables(
  names = c('cell', 'id', 'avg.exp', 'pct.exp', 'ptcolor', 'ident2'),
  package = 'Seurat',
  add = TRUE
)
#' Split Dot plot visualization
#'
#' Intuitive way of visualizing how gene expression changes across different identity classes (clusters).
#' The size of the dot encodes the percentage of cells within a class, while the color encodes the
#' AverageExpression level of 'expressing' cells. Splits the cells into groups based on a
#' grouping variable.
#' Still in BETA
#'
#' @param object Seurat object
#' @param grouping.var Grouping variable for splitting the dataset
#' @param genes.plot Input vector of genes
#' @param cols.use colors to plot
#' @param col.min Minimum scaled average expression threshold (everything smaller will be set to this)
#' @param col.max Maximum scaled average expression threshold (everything larger will be set to this)
#' @param dot.min The fraction of cells at which to draw the smallest dot (default is 0.05).
#' @param dot.scale Scale the size of the points, similar to cex
#' @param group.by Factor to group the cells by
#' @param plot.legend plots the legends
#' @param x.lab.rot Rotate x-axis labels
#' @param do.return Return ggplot2 object
#' @param gene.groups Add labeling bars to the top of the plot
#'
#' @return default, no return, only graphical output. If do.return=TRUE, returns a ggplot2 object
#'
#' @importFrom grDevices colorRampPalette
#' @importFrom tidyr gather separate unite
#' @importFrom dplyr %>% group_by summarize_each mutate ungroup rowwise
#'
#' @export
#'
#' @examples
#' # Create a simulated grouping variable
#' pbmc_small@meta.data$groups <- sample(
#'   x = c("g1", "g2"),
#'   size = length(x = pbmc_small@cell.names),
#'   replace = TRUE
#' )
#' SplitDotPlotGG(pbmc_small, grouping.var = "groups", genes.plot = pbmc_small@var.genes[1:5])
#'
SplitDotPlotGG <- function(
  object,
  grouping.var,
  genes.plot,
  gene.groups,
  cols.use = c("blue", "red"),
  col.min = -2.5,
  col.max = 2.5,
  dot.min = 0,
  dot.scale = 6,
  group.by,
  plot.legend = FALSE,
  do.return = FALSE,
  x.lab.rot = FALSE
) {
  if (!missing(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  grouping.data <- FetchData(
    object = object,
    vars.all = grouping.var
  )[names(x = object@ident), 1]
  ncolor <- length(x = cols.use)
  ngroups <- length(x = unique(x = grouping.data))
  if (ncolor < ngroups) {
    stop(
      paste(
        "Not enough colors supplied for number of grouping variables. Need",
        ngroups,
        "got",
        ncolor,
        "colors"
      )
    )
  } else if (ncolor > ngroups) {
    cols.use <- cols.use[1:ngroups]
  }
  idents.old <- levels(x = object@ident)
  idents.new <- paste(object@ident, grouping.data, sep = "_")
  colorlist <- cols.use
  names(x = colorlist) <- levels(x = grouping.data)
  object@ident <- factor(
    x = idents.new,
    levels = unlist(x = lapply(
      X = idents.old,
      FUN = function(x) {
        lvls <- list()
        for (i in seq_along(along.with = levels(x = grouping.data))) {
          lvls[[i]] <- paste(x, levels(x = grouping.data)[i], sep = "_")
        }
        return(unlist(x = lvls))
      }
    )),
    ordered = TRUE
  )
  data.to.plot <- data.frame(FetchData(object = object, vars.all = genes.plot))
  data.to.plot$cell <- rownames(x = data.to.plot)
  data.to.plot$id <- object@ident
  data.to.plot %>%
    gather(key = genes.plot, value = expression, -c(cell, id)) -> data.to.plot
  data.to.plot %>%
    group_by(id, genes.plot) %>%
    summarize(
      avg.exp = ExpMean(x = expression),
      pct.exp = PercentAbove(x = expression, threshold = 0)
    ) -> data.to.plot
  data.to.plot %>%
    ungroup() %>%
    group_by(genes.plot) %>%
    mutate(avg.exp = scale(x = avg.exp)) %>%
    mutate(avg.exp.scale = as.numeric(x = cut(
      x = MinMax(data = avg.exp, max = col.max, min = col.min),
      breaks = 20
    ))) ->  data.to.plot
  data.to.plot %>%
    separate(col = id, into = c('ident1', 'ident2'), sep = "_") %>%
    rowwise() %>%
    mutate(
      palette.use = colorlist[[ident2]],
      ptcolor = colorRampPalette(colors = c("grey", palette.use))(20)[avg.exp.scale]
    ) %>%
    unite('id', c('ident1', 'ident2'), sep = '_') -> data.to.plot
  data.to.plot$genes.plot <- factor(
    x = data.to.plot$genes.plot,
    levels = rev(x = sub(pattern = "-", replacement = ".", x = genes.plot))
  )
  data.to.plot$pct.exp[data.to.plot$pct.exp < dot.min] <- NA
  data.to.plot$id <- factor(x = data.to.plot$id, levels = levels(object@ident))
  palette.use <- unique(x = data.to.plot$palette.use)
  if (!missing(x = gene.groups)) {
    names(x = gene.groups) <- genes.plot
    data.to.plot %>%
      mutate(gene.groups = gene.groups[genes.plot]) -> data.to.plot
  }
  data.to.plot$pct.exp <- data.to.plot$pct.exp * 100
  p <- ggplot(data = data.to.plot, mapping = aes(x = genes.plot, y = id)) +
    geom_point(mapping = aes(size = pct.exp, color = ptcolor)) +
    scale_radius(range = c(0, dot.scale)) +
    scale_color_identity() +
    theme(axis.title.x = element_blank(), axis.title.y = element_blank())
  if (!missing(x = gene.groups)) {
    p <- p +
      facet_grid(
        facets = ~gene.groups,
        scales = "free_x",
        space = "free_x",
        switch = "y"
      ) +
      theme(
        panel.spacing = unit(x = 1, units = "lines"),
        strip.background = element_blank(),
        strip.placement = "outside"
      )
  }
  if (x.lab.rot) {
    p <- p + theme(axis.text.x = element_text(angle = 90, vjust = 0.5))
  }
  if (!plot.legend) {
    p <- p + theme(legend.position = "none")
  } else if (plot.legend) {
    # Get legend from plot
    plot.legend <- cowplot::get_legend(plot = p)
    # Get gradient legends from both palettes
    palettes <- list()
    for (i in seq_along(along.with = colorlist)) {
      palettes[[names(colorlist[i])]] <- colorRampPalette(colors = c("grey", colorlist[[i]]))(20)
    }
    gradient.legends <- mapply(
      FUN = GetGradientLegend,
      palette = palettes,
      group = names(x = palettes),
      SIMPLIFY = FALSE,
      USE.NAMES = FALSE
    )
    # Remove legend from p
    p <- p + theme(legend.position = "none")
    # Arrange legends using plot_grid
    legends <- cowplot::plot_grid(
      plotlist = gradient.legends,
      plot.legend,
      ncol = 1,
      rel_heights = c(1, rep.int(x = 0.5, times = length(x = gradient.legends))),
      scale = rep(0.5, length(gradient.legends)), align = "hv"
    )
    # Arrange plot and legends using plot_grid
    p <- cowplot::plot_grid(
      p, legends,
      ncol = 2,
      rel_widths = c(1, 0.3),
      scale = c(1, 0.8)
    )
  }
  suppressWarnings(print(p))
  if (do.return) {
    return(p)
  }
}

#' Visualize 'features' on a dimensional reduction plot
#'
#' Colors single cells on a dimensional reduction plot according to a 'feature'
#' (i.e. gene expression, PC scores, number of genes detected, etc.)
#'
#' @param object Seurat object
#' @param features.plot Vector of features to plot
#' @param min.cutoff Vector of minimum cutoff values for each feature, may specify quantile in the form of 'q##' where '##' is the quantile (eg, 1, 10)
#' @param max.cutoff Vector of maximum cutoff values for each feature, may specify quantile in the form of 'q##' where '##' is the quantile (eg, 1, 10)
#' @param dim.1 Dimension for x-axis (default 1)
#' @param dim.2 Dimension for y-axis (default 2)
#' @param cells.use Vector of cells to plot (default is all cells)
#' @param pt.size Adjust point size for plotting
#' @param cols.use The two colors to form the gradient over. Provide as string vector with
#' the first color corresponding to low values, the second to high. Also accepts a Brewer
#' color scale or vector of colors. Note: this will bin the data into number of colors provided.
#' @param pch.use Pch for plotting
#' @param overlay Plot two features overlayed one on top of the other
#' @param do.hover Enable hovering over points to view information
#' @param data.hover Data to add to the hover, pass a character vector of features to add. Defaults to cell name and identity. Pass 'NULL' to remove extra data.
#' @param do.identify Opens a locator session to identify clusters of cells
#' @param reduction.use Which dimensionality reduction to use. Default is
#' "tsne", can also be "pca", or "ica", assuming these are precomputed.
#' @param use.imputed Use imputed values for gene expression (default is FALSE)
#' @param nCol Number of columns to use when plotting multiple features.
#' @param no.axes Remove axis labels
#' @param no.legend Remove legend from the graph. Default is TRUE.
#' @param coord.fixed Use a fixed scale coordinate system (for spatial coordinates). Default is FALSE.
#' @param dark.theme Plot in a dark theme
#' @param do.return return the ggplot2 object
#' @param vector.friendly FALSE by default. If TRUE, points are flattened into a PNG, while axes/labels retain full vector resolution. Useful for producing AI-friendly plots with large numbers of cells.
#' @param png.file Use specific name for temporary png file
#' @param png.arguments Set width, height, and DPI for ggsave
#'
#' @importFrom RColorBrewer brewer.pal.info
#'
#' @return No return value, only a graphical output
#'
#' @export
#'
#' @examples
#' FeaturePlot(object = pbmc_small, features.plot = 'PC1')
#'
FeaturePlot <- function(
  object,
  features.plot,
  min.cutoff = NA,
  max.cutoff = NA,
  dim.1 = 1,
  dim.2 = 2,
  cells.use = NULL,
  pt.size = 1,
  cols.use = c("yellow", "red"),
  pch.use = 16,
  overlay = FALSE,
  do.hover = FALSE,
  data.hover = 'ident',
  do.identify = FALSE,
  reduction.use = "tsne",
  use.imputed = FALSE,
  nCol = NULL,
  no.axes = FALSE,
  no.legend = TRUE,
  coord.fixed = FALSE,
  dark.theme = FALSE,
  do.return = FALSE,
  vector.friendly=FALSE,
  png.file = NULL,
  png.arguments = c(10,10, 100)
) {
  cells.use <- SetIfNull(x = cells.use, default = colnames(x = object@data))
  if (is.null(x = nCol)) {
    nCol <- 2
    if (length(x = features.plot) == 1) {
      nCol <- 1
    }
    if (length(x = features.plot) > 6) {
      nCol <- 3
    }
    if (length(x = features.plot) > 9) {
      nCol <- 4
    }
  }
  num.row <- floor(x = length(x = features.plot) / nCol - 1e-5) + 1
  if (overlay | do.hover) {
    num.row <- 1
    nCol <- 1
  }
  par(mfrow = c(num.row, nCol))
  dim.code <- GetDimReduction(
    object = object,
    reduction.type = reduction.use,
    slot = 'key'
  )
  dim.codes <- paste0(dim.code, c(dim.1, dim.2))
  data.plot <- as.data.frame(GetCellEmbeddings(
    object = object,
    reduction.type = reduction.use,
    dims.use = c(dim.1, dim.2),
    cells.use = cells.use
  ))
  x1 <- paste0(dim.code, dim.1)
  x2 <- paste0(dim.code, dim.2)
  data.plot$x <- data.plot[, x1]
  data.plot$y <- data.plot[, x2]
  data.plot$pt.size <- pt.size
  names(x = data.plot) <- c('x', 'y')
  data.use <- t(x = FetchData(
    object = object,
    vars.all = features.plot,
    cells.use = cells.use,
    use.imputed = use.imputed
  ))
  #   Check mins and maxes
  min.cutoff <- mapply(
    FUN = function(cutoff, feature) {
      ifelse(
        test = is.na(x = cutoff),
        yes = min(data.use[feature, ]),
        no = cutoff
      )
    },
    cutoff = min.cutoff,
    feature = features.plot
  )
  max.cutoff <- mapply(
    FUN = function(cutoff, feature) {
      ifelse(
        test = is.na(x = cutoff),
        yes = max(data.use[feature, ]),
        no = cutoff
      )
    },
    cutoff = max.cutoff,
    feature = features.plot
  )
  check_lengths = unique(x = vapply(
    X = list(features.plot, min.cutoff, max.cutoff),
    FUN = length,
    FUN.VALUE = numeric(length = 1)
  ))
  if (length(x = check_lengths) != 1) {
    stop('There must be the same number of minimum and maximum cuttoffs as there are features')
  }
  if (overlay) {
    #   Wrap as a list for MutiPlotList
    pList <- list(
      BlendPlot(
        data.use = data.use,
        features.plot = features.plot,
        data.plot = data.plot,
        pt.size = pt.size,
        pch.use = pch.use,
        cols.use = cols.use,
        dim.codes = dim.codes,
        min.cutoff = min.cutoff,
        max.cutoff = max.cutoff,
        coord.fixed = coord.fixed,
        no.axes = no.axes,
        no.legend = no.legend,
        dark.theme = dark.theme
      )
    )
  } else {
    #   Use mapply instead of lapply for multiple iterative variables.
    pList <- mapply(
      FUN = SingleFeaturePlot,
      feature = features.plot,
      min.cutoff = min.cutoff,
      max.cutoff = max.cutoff,
      coord.fixed = coord.fixed,
      MoreArgs = list( # Arguments that are not being repeated
        data.use = data.use,
        data.plot = data.plot,
        pt.size = pt.size,
        pch.use = pch.use,
        cols.use = cols.use,
        dim.codes = dim.codes,
        no.axes = no.axes,
        no.legend = no.legend,
        dark.theme = dark.theme,
        vector.friendly = vector.friendly,
        png.file = png.file,
        png.arguments = png.arguments
      ),
      SIMPLIFY = FALSE # Get list, not matrix
    )
  }
  if (do.hover) {
    if (length(x = pList) != 1) {
      stop("'do.hover' only works on a single feature or an overlayed FeaturePlot")
    }
    if (is.null(x = data.hover)) {
      features.info <- NULL
    } else {
      features.info <- FetchData(object = object, vars.all = data.hover)
    }
    #   Use pList[[1]] to properly extract the ggplot out of the plot list
    return(HoverLocator(
      plot = pList[[1]],
      data.plot = data.plot,
      features.info = features.info,
      dark.theme = dark.theme,
      title = features.plot
    ))
    # invisible(readline(prompt = 'Press <Enter> to continue\n'))
  } else if (do.identify) {
    if (length(x = pList) != 1) {
      stop("'do.identify' only works on a single feature or an overlayed FeaturePlot")
    }
    #   Use pList[[1]] to properly extract the ggplot out of the plot list
    return(FeatureLocator(
      plot = pList[[1]],
      data.plot = data.plot,
      dark.theme = dark.theme
    ))
  } else {
    print(x = cowplot::plot_grid(plotlist = pList, ncol = nCol))
  }
  ResetPar()
  if (do.return){
    return(pList)
  }
}

globalVariables(
  names = c('gene', 'dim1', 'dim2', 'ident', 'cell', 'scaled.expression'),
  package = 'Seurat',
  add = TRUE
)


#' Vizualization of multiple features
#'
#' Similar to FeaturePlot, however, also splits the plot by visualizing each
#' identity class separately.
#'
#' Particularly useful for seeing if the same groups of cells co-exhibit a
#' common feature (i.e. co-express a gene), even within an identity class. Best
#' understood by example.
#'
#' @param object Seurat object
#' @param features.plot Vector of features to plot
#' @param dim.1 Dimension for x-axis (default 1)
#' @param dim.2 Dimension for y-axis (default 2)
#' @param idents.use Which identity classes to display (default is all identity
#' classes)
#' @param pt.size Adjust point size for plotting
#' @param cols.use Ordered vector of colors to use for plotting. Default is
#' heat.colors(10).
#' @param pch.use Pch for plotting
#' @param reduction.use Which dimensionality reduction to use. Default is
#' "tsne", can also be "pca", or "ica", assuming these are precomputed.
#' @param group.by Group cells in different ways (for example, orig.ident)
#' @param data.use Dataset to use for plotting, choose from 'data', 'scale.data', or 'imputed'
#' @param sep.scale Scale each group separately. Default is FALSE.
#' @param max.exp Max cutoff for scaled expression value, supports quantiles in the form of 'q##' (see FeaturePlot)
#' @param min.exp Min cutoff for scaled expression value, supports quantiles in the form of 'q##' (see FeaturePlot)
#' @param rotate.key rotate the legend
#' @param plot.horiz rotate the plot such that the features are columns, groups are the rows
#' @param key.position position of the legend ("top", "right", "bottom", "left")
#' @param do.return Return the ggplot2 object
#'
#' @return No return value, only a graphical output
#'
#' @importFrom tidyr gather
#' @importFrom dplyr %>% mutate_each group_by select ungroup
#'
#' @seealso \code{\link{FeaturePlot}}
#'
#' @export
#'
#' @examples
#' pbmc_small
#' FeatureHeatmap(object = pbmc_small, features.plot = "PC1")
#'
FeatureHeatmap <- function(
  object,
  features.plot,
  dim.1 = 1,
  dim.2 = 2,
  idents.use = NULL,
  pt.size = 2,
  cols.use = c("grey", "red"),
  pch.use = 16,
  reduction.use = "tsne",
  group.by = NULL,
  data.use = 'data',
  sep.scale = FALSE,
  do.return = FALSE,
  min.exp = -Inf,
  max.exp = Inf,
  rotate.key = FALSE,
  plot.horiz = FALSE,
  key.position = "right"
) {
  switch(
    EXPR = data.use,
    'data' = {
      use.imputed <- FALSE
      use.scaled <- FALSE
    },
    'scale.data' = {
      use.imputed <- FALSE
      use.scaled <- TRUE
    },
    'imputed' = {
      use.imputed <- TRUE
      use.scaled <- FALSE
    },
    stop("Invalid dataset to use")
  )
  if (!is.null(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  idents.use <- SetIfNull(x = idents.use, default = sort(x = unique(x = object@ident)))
  par(mfrow = c(length(x = features.plot), length(x = idents.use)))
  dim.code <- GetDimReduction(
    object = object,
    reduction.type = reduction.use,
    slot = 'key'
  )
  dim.codes <- paste0(dim.code, c(dim.1, dim.2))
  data.plot <- data.frame(FetchData(
    object = object,
    vars.all = c(dim.codes, features.plot),
    use.imputed = use.imputed,
    use.scaled = use.scaled
  ))
  colnames(x = data.plot)[1:2] <- c("dim1", "dim2")
  data.plot$ident <- as.character(x = object@ident)
  data.plot <- data.plot[data.plot$ident %in% idents.use,] # keep only identities defined in idents.use
  data.plot$cell <- rownames(x = data.plot)
  features.plot <- gsub('-', '\\.', features.plot)
  data.plot  %>% gather(key = "gene", value = "expression", -dim1, -dim2, -ident, -cell) -> data.plot
  if (sep.scale) {
    data.plot %>% group_by(ident, gene) %>% mutate(scaled.expression = scale(expression)) -> data.plot
  } else {
    data.plot %>%  group_by(gene) %>% mutate(scaled.expression = scale(expression)) -> data.plot
  }
  min.exp <- SetQuantile(cutoff = min.exp, data = data.plot$scaled.expression)
  max.exp <- SetQuantile(cutoff = max.exp, data = data.plot$scaled.expression)
  data.plot$gene <- factor(x = data.plot$gene, levels = features.plot)
  data.plot$scaled.expression <- MinMax(
    data = data.plot$scaled.expression,
    min = min.exp,
    max = max.exp
  )
  if (rotate.key) {
    key.direction <- "horizontal"
    key.title.pos <- "top"
  } else {
    key.direction <- "vertical"
    key.title.pos <- "left"
  }
  p <- ggplot(data = data.plot, mapping = aes(x = dim1, y = dim2)) +
    geom_point(mapping = aes(colour = scaled.expression), size = pt.size, shape = pch.use)
  if (rotate.key) {
    p <- p + scale_colour_gradient(
      low = cols.use[1],
      high = cols.use[2],
      guide = guide_colorbar(
        direction = key.direction,
        title.position = key.title.pos,
        title = "Scaled Expression"
      )
    )
  } else {
    p <- p + scale_colour_gradient(
      low = cols.use[1],
      high = cols.use[2],
      guide = guide_colorbar(title = "Scaled Expression")
    )
  }
  if(plot.horiz){
    p <- p + facet_grid(ident ~ gene)
  }
  else{
    p <- p + facet_grid(gene ~ ident)
  }
  p2 <- p +
    theme_bw() +
    NoGrid() +
    ylab(label = dim.codes[2]) +
    xlab(label = dim.codes[1])
  p2 <- p2 + theme(legend.position = key.position)
  if (do.return) {
    return(p2)
  }
  print(p2)
}

#' Gene expression heatmap
#'
#' Draws a heatmap of single cell gene expression using the heatmap.2 function. Has been replaced by the ggplot2
#' version (now in DoHeatmap), but kept for legacy
#'
#' @param object Seurat object
#' @param cells.use Cells to include in the heatmap (default is all cells)
#' @param genes.use Genes to include in the heatmap (ordered)
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param draw.line Draw vertical lines delineating cells in different identity
#' classes.
#' @param do.return Default is FALSE. If TRUE, return a matrix of scaled values
#' which would be passed to heatmap.2
#' @param order.by.ident Order cells in the heatmap by identity class (default
#' is TRUE). If FALSE, cells are ordered based on their order in cells.use
#' @param col.use Color palette to use
#' @param slim.col.label if (order.by.ident==TRUE) then instead of displaying
#' every cell name on the heatmap, display only the identity class name once
#' for each group
#' @param group.by If (order.by.ident==TRUE) default,  you can group cells in
#' different ways (for example, orig.ident)
#' @param remove.key Removes the color key from the plot.
#' @param cex.col positive numbers, used as cex.axis in for the column axis labeling.
#' The defaults currently only use number of columns
#' @param do.scale whether to use the data or scaled data
#' @param ... Additional parameters to heatmap.2. Common examples are cexRow
#' and cexCol, which set row and column text sizes
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @importFrom gplots heatmap.2
#'
#' @export
#'
#' @examples
#' pbmc_small
#' OldDoHeatmap(object = pbmc_small, genes.use = pbmc_small@var.genes)
#'
OldDoHeatmap <- function(
  object,
  cells.use = NULL,
  genes.use = NULL,
  disp.min = NULL,
  disp.max = NULL,
  draw.line = TRUE,
  do.return = FALSE,
  order.by.ident = TRUE,
  col.use = PurpleAndYellow(),
  slim.col.label = FALSE,
  group.by = NULL,
  remove.key = FALSE,
  cex.col = NULL,
  do.scale = TRUE,
  ...
) {
  cells.use <- SetIfNull(x = cells.use, default = object@cell.names)
  cells.use <- intersect(x = cells.use, y = object@cell.names)
  cells.ident <- object@ident[cells.use]
  if (! is.null(x = group.by)) {
    cells.ident <- factor(x = FetchData(
      object = object,
      vars.all = group.by
    )[, 1])
  }
  cells.ident <- factor(
    x = cells.ident,
    labels = intersect(x = levels(x = cells.ident), y = cells.ident)
  )
  if (order.by.ident) {
    cells.use <- cells.use[order(cells.ident)]
  } else {
    cells.ident <- factor(
      x = cells.ident,
      levels = as.vector(x = unique(x = cells.ident))
    )
  }
  #determine assay type
  data.use <- NULL
  assays.use <- c("RNA", names(x = object@assay))
  if (do.scale) {
    slot.use <- "scale.data"
    if ((is.null(x = disp.min) || is.null(x = disp.max))) {
      disp.min <- -2.5
      disp.max <- 2.5
    }
  } else {
    slot.use <- "data"
    if ((is.null(x = disp.min) || is.null(x = disp.max))) {
      disp.min <- -Inf
      disp.max <- Inf
    }
  }
  for (assay.check in assays.use) {
    data.assay <- GetAssayData(
      object = object,
      assay.type = assay.check,
      slot = slot.use
    )
    genes.intersect <- intersect(x = genes.use, y = rownames(x = data.assay))
    new.data <- data.assay[genes.intersect, cells.use, drop = FALSE]
    if (! (is.matrix(x = new.data))) {
      new.data <- as.matrix(x = new.data)
    }
    data.use <- rbind(data.use, new.data)
  }
  data.use <- MinMax(data = data.use, min = disp.min, max = disp.max)
  vline.use <- NULL
  colsep.use <- NULL
  if (remove.key) {
    hmFunction <- heatmap2NoKey
  } else {
    hmFunction <- heatmap.2
  }
  if (draw.line) {
    colsep.use <- cumsum(x = table(cells.ident))
  }
  if (slim.col.label && order.by.ident) {
    col.lab <- rep("", length(x = cells.use))
    col.lab[round(x = cumsum(x = table(cells.ident)) - table(cells.ident) / 2) + 1] <- levels(x = cells.ident)
    cex.col <- SetIfNull(
      x = cex.col,
      default = 0.2 + 1 / log10(x = length(x = unique(x = cells.ident)))
    )
    hmFunction(
      data.use,
      Rowv = NA,
      Colv = NA,
      trace = "none",
      col = col.use,
      colsep = colsep.use,
      labCol = col.lab,
      cexCol = cex.col,
      ...
    )
  } else if (slim.col.label) {
    col.lab = rep("", length(x = cells.use))
    cex.col <- SetIfNull(
      x = cex.col,
      default = 0.2 + 1 / log10(x = length(x = unique(x = cells.ident)))
    )
    hmFunction(
      data.use,
      Rowv = NA,
      Colv = NA,
      trace = "none",
      col = col.use,
      colsep = colsep.use,
      labCol = col.lab,
      cexCol = cex.col,
      ...
    )
  } else {
    hmFunction(
      data.use,
      Rowv = NA,
      Colv = NA,
      trace = "none",
      col = col.use,
      colsep = colsep.use,
      ...
    )
  }
  if (do.return) {
    return(data.use)
  }
}

globalVariables(names = c('x', 'y'), package = 'Seurat', add = TRUE)
#' Scatter plot of single cell data
#'
#' Creates a scatter plot of two features (typically gene expression), across a
#' set of single cells. Cells are colored by their identity class. Pearson
#' correlation between the two features is displayed above the plot.
#'
#' @param object Seurat object
#' @inheritParams FetchData
#' @param gene1 First feature to plot. Typically gene expression but can also
#' be metrics, PC scores, etc. - anything that can be retreived with FetchData
#' @param gene2 Second feature to plot.
#' @param cell.ids Cells to include on the scatter plot.
#' @param col.use Colors to use for identity class plotting.
#' @param pch.use Pch argument for plotting
#' @param cex.use Cex argument for plotting
#' @param use.imputed Use imputed values for gene expression (Default is FALSE)
#' @param use.scaled Use scaled data
#' @param use.raw Use raw data
#' @param do.hover Enable hovering over points to view information
#' @param data.hover Data to add to the hover, pass a character vector of
#' features to add. Defaults to cell name and ident. Pass 'NULL' to clear extra
#' information.
#' @param do.identify Opens a locator session to identify clusters of cells.
#' @param dark.theme Use a dark theme for the plot
#' @param do.spline Add a spline (currently hardwired to df=4, to be improved)
#' @param spline.span spline span in loess function call
#' @param \dots Additional arguments to be passed to plot.
#'
#' @return No return, only graphical output
#'
#' @importFrom graphics plot
#'
#' @export
#'
#' @examples
#' GenePlot(object = pbmc_small, gene1 = 'CD9', gene2 = 'CD3E')
#'
GenePlot <- function(
  object,
  gene1,
  gene2,
  cell.ids = NULL,
  col.use = NULL,
  pch.use = 16,
  cex.use = 1.5,
  use.imputed = FALSE,
  use.scaled = FALSE,
  use.raw = FALSE,
  do.hover = FALSE,
  data.hover = 'ident',
  do.identify = FALSE,
  dark.theme = FALSE,
  do.spline = FALSE,
  spline.span = 0.75,
  ...
) {
  cell.ids <- SetIfNull(x = cell.ids, default = object@cell.names)
  #   Don't transpose the data.frame for better compatability with FeatureLocator and the rest of Seurat
  data.use <- as.data.frame(
    x = FetchData(
      object = object,
      vars.all = c(gene1, gene2),
      cells.use = cell.ids,
      use.imputed = use.imputed,
      use.scaled = use.scaled,
      use.raw = use.raw
    )
  )
  #   Ensure that our data is only the cells we're working with and
  #   the genes we want. This step seems kind of redundant though...
  data.plot <- data.use[cell.ids, c(gene1, gene2)]
  #   Set names to 'x' and 'y' for easy calling later on
  names(x = data.plot) <- c('x', 'y')
  ident.use <- as.factor(x = object@ident[cell.ids])
  if (length(x = col.use) > 1) {
    col.use <- col.use[as.numeric(x = ident.use)]
  } else {
    col.use <- SetIfNull(x = col.use, default = as.numeric(x = ident.use))
  }
  gene.cor <- round(x = cor(x = data.plot$x, y = data.plot$y), digits = 2)
  if (dark.theme) {
    par(bg = 'black')
    col.use <- sapply(
      X = col.use,
      FUN = function(color) ifelse(
        test = all(col2rgb(color) == 0),
        yes = 'white',
        no = color
      )
    )
    axes = FALSE
    col.lab = 'white'
  } else {
    axes = TRUE
    col.lab = 'black'
  }
  #   Plot the data
  plot(
    x = data.plot$x,
    y = data.plot$y,
    xlab = gene1,
    ylab = gene2,
    col = col.use,
    cex = cex.use,
    main = gene.cor,
    pch = pch.use,
    axes = axes,
    col.lab = col.lab,
    col.main = col.lab,
    ...
  )
  if (dark.theme) {
    axis(
      side = 1,
      at = NULL,
      labels = TRUE,
      col.axis = col.lab,
      col = col.lab
    )
    axis(
      side = 2,
      at = NULL,
      labels = TRUE,
      col.axis = col.lab,
      col = col.lab
    )
  }
  if (do.spline) {
    # spline.fit <- smooth.spline(x = g1, y = g2, df = 4)
    spline.fit <- smooth.spline(x = data.plot$x, y = data.plot$y, df = 4)
    #lines(spline.fit$x,spline.fit$y,lwd=3)
    #spline.fit=smooth.spline(g1,g2,df = 4)
    # loess.fit <- loess(formula = g2 ~ g1, span=spline.span)
    loess.fit <- loess(formula = y ~ x, data = data.plot, span = spline.span)
    #lines(spline.fit$x,spline.fit$y,lwd=3)
    # points(x = g1, y = loess.fit$fitted, col="darkblue")
    points(x = data.plot$x, y = loess.fit$fitted, col = 'darkblue')
  }
  if (do.identify | do.hover) {
    #   This is where that untransposed renamed data.frame comes in handy
    p <- ggplot2::ggplot(data = data.plot, mapping = aes(x = x, y = y))
    p <- p + geom_point(
      mapping = aes(color = colors),
      size = cex.use,
      shape = pch.use,
      color = col.use
    )
    p <- p + labs(title = gene.cor, x = gene1, y = gene2)
    if (do.hover) {
      names(x = data.plot) <- c(gene1, gene2)
      if (is.null(x = data.hover)) {
        features.info <- NULL
      } else {
        features.info <- FetchData(object = object, vars.all = data.hover)
      }
      return(HoverLocator(
        plot = p,
        data.plot = data.plot,
        features.info = features.info,
        dark.theme = dark.theme,
        title = gene.cor
      ))
    } else if (do.identify) {
      return(FeatureLocator(
        plot = p,
        data.plot = data.plot,
        dark.theme = dark.theme
      ))
    }
  }
}

globalVariables(names = c('x', 'y'), package = 'Seurat', add = TRUE)
#' Cell-cell scatter plot
#'
#' Creates a plot of scatter plot of genes across two single cells. Pearson
#' correlation between the two cells is displayed above the plot.
#'
#' @param object Seurat object
#' @param cell1 Cell 1 name (can also be a number, representing the position in
#' object@@cell.names)
#' @param cell2 Cell 2 name (can also be a number, representing the position in
#' object@@cell.names)
#' @param gene.ids Genes to plot (default, all genes)
#' @param col.use Colors to use for the points
#' @param nrpoints.use Parameter for smoothScatter
#' @param pch.use Point symbol to use
#' @param cex.use Point size
#' @param do.hover Enable hovering over points to view information
#' @param do.identify Opens a locator session to identify clusters of cells.
#' points to reveal gene names (hit ESC to stop)
#' @param \dots Additional arguments to pass to smoothScatter
#'
#' @return No return value (plots a scatter plot)
#'
#' @importFrom stats cor
#' @importFrom graphics smoothScatter
#'
#' @export
#'
#' @examples
#' CellPlot(object = pbmc_small, cell1 = 'ATAGGAGAAACAGA', cell2 = 'CATCAGGATGCACA')
#'
CellPlot <- function(
  object,
  cell1,
  cell2,
  gene.ids = NULL,
  col.use = "black",
  nrpoints.use = Inf,
  pch.use = 16,
  cex.use = 0.5,
  do.hover = FALSE,
  do.identify = FALSE,
  ...
) {
  gene.ids <- SetIfNull(x = gene.ids, default = rownames(x = object@data))
  #   Transpose this data.frame so that the genes are in the row for
  #   easy selecting with do.identify
  data.plot <- as.data.frame(
    x = t(
      x = FetchData(
        object = object,
        vars.all = gene.ids,
        cells.use = c(cell1, cell2)
      )
    )
  )
  #   Set names for easy calling with ggplot
  names(x = data.plot) <- c('x', 'y')
  gene.cor <- round(x = cor(x = data.plot$x, y = data.plot$y), digits = 2)
  smoothScatter(
    x = data.plot$x,
    y = data.plot$y,
    xlab = cell1,
    ylab = cell2,
    col = col.use,
    nrpoints = nrpoints.use,
    pch = pch.use,
    cex = cex.use,
    main = gene.cor
  )
  if (do.identify | do.hover) {
    #   This is where that untransposed renamed data.frame comes in handy
    p <- ggplot2::ggplot(data = data.plot, mapping = aes(x = x, y = y))
    p <- p + geom_point(
      mapping = aes(color = colors),
      size = cex.use,
      shape = pch.use,
      color = col.use
    )
    p <- p + labs(title = gene.cor, x = cell1, y = cell2)
    if (do.hover) {
      names(x = data.plot) <- c(cell1, cell2)
      return(HoverLocator(plot = p, data.plot = data.plot, title = gene.cor))
    } else if (do.identify) {
      return(FeatureLocator(plot = p, data.plot = data.plot, ...))
    }
  }
}

#' Dimensional reduction heatmap
#'
#' Draws a heatmap focusing on a principal component. Both cells and genes are sorted by their
#' principal component scores. Allows for nice visualization of sources of heterogeneity in the dataset.
#'
#' @param object Seurat object.
#' @param assay.use Assay to pull from - default is RNA
#' @param reduction.type Which dimmensional reduction t use
#' @param dim.use Dimensions to plot
#' @param cells.use A list of cells to plot. If numeric, just plots the top cells.
#' @param num.genes NUmber of genes to plot
#' @param use.full Use the full PCA (projected PCA). Default is FALSE
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param do.return If TRUE, returns plot object, otherwise plots plot object
#' @param col.use Color to plot.
#' @param use.scale Default is TRUE: plot scaled data. If FALSE, plot raw data on the heatmap.
#' @param do.balanced Plot an equal number of genes with both + and - scores.
#' @param remove.key Removes the color key from the plot.
#' @param label.columns Labels for columns
#' @param check.plot Check that plotting will finish in a reasonable amount of time
#' @param ... Extra parameters for heatmap plotting.
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @importFrom graphics par
#' @importFrom utils menu
#'
#' @export
#'
#' @examples
#' DimHeatmap(object = pbmc_small)
#'
DimHeatmap <- function(
  object,
  assay.use = "RNA",
  reduction.type = "pca",
  dim.use = 1,
  cells.use = NULL,
  num.genes = 30,
  use.full = FALSE,
  disp.min = -2.5,
  disp.max = 2.5,
  do.return = FALSE,
  col.use = PurpleAndYellow(),
  use.scale = TRUE,
  do.balanced = FALSE,
  remove.key = FALSE,
  label.columns = NULL,
  check.plot = TRUE,
  ...
) {
  num.row <- floor(x = length(x = dim.use) / 3.01) + 1
  orig_par <- par()$mfrow
  par(mfrow = c(num.row, min(length(x = dim.use), 3)))
  cells <- cells.use
  plots <- c()

  if (is.null(x = label.columns)) {
    label.columns <- ! (length(x = dim.use) > 1)
  }
  for (ndim in dim.use) {
    if (is.numeric(x = (cells))) {
      cells.use <- DimTopCells(
        object = object,
        dim.use = ndim,
        reduction.type = reduction.type,
        num.cells = cells,
        do.balanced = do.balanced
      )
    } else {
      cells.use <- SetIfNull(x = cells, default = object@cell.names)
    }
    genes.use <- rev(x = DimTopGenes(
      object = object,
      dim.use = ndim,
      reduction.type = reduction.type,
      num.genes = num.genes,
      use.full = use.full,
      do.balanced = do.balanced
    ))
    dim.scores <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "cell.embeddings"
    )
    dim.key <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "key"
    )
    cells.ordered <- cells.use[order(dim.scores[cells.use, paste0(dim.key, ndim)])]
    data.use <- NULL
    if (! use.scale) {
      slot.use="data"
    } else {
      slot.use <- "scale.data"
    }
    for (assay.check in assay.use) {
      data.assay <- GetAssayData(
        object = object,
        assay.type = assay.check,
        slot = slot.use
      )
      genes.intersect <- intersect(x = genes.use, y = rownames(x = data.assay))
      new.data <- data.assay[genes.intersect, cells.ordered]
      if (! is.matrix(x = new.data)) {
        new.data <- as.matrix(x = new.data)
      }
      data.use <- rbind(data.use, new.data)
    }
    if(check.plot & any(dim(data.use) > 700) & (remove.key == FALSE & length(dim.use) == 1)) {
      choice <- menu(c("Continue with plotting", "Quit"), title = "Plot(s) requested will likely take a while to plot.")
      if(choice == 1){
        check.plot = FALSE
      } else {
        return()
      }
    }
    #data.use <- object@scale.data[genes.use, cells.ordered]
    data.use <- MinMax(data = data.use, min = disp.min, max = disp.max)
    #if (!(use.scale)) data.use <- as.matrix(object@data[genes.use, cells.ordered])
    vline.use <- NULL
    hmTitle <- paste(dim.key, ndim)
    if (remove.key || length(dim.use) > 1) {
      hmFunction <- "heatmap2NoKey(data.use, Rowv = NA, Colv = NA, trace = \"none\", col = col.use, dimTitle = hmTitle, "
    } else {
      hmFunction <- "heatmap.2(data.use,Rowv=NA,Colv=NA,trace = \"none\",col=col.use, dimTitle = hmTitle, "
    }
    if (! label.columns) {
      hmFunction <- paste0(hmFunction, "labCol='', ")
    }
    hmFunction <- paste0(hmFunction, "...)")
    #print(hmFunction)
    eval(expr = parse(text = hmFunction))
  }
  if (do.return) {
    return(data.use)
  }
  # reset graphics parameters
  par(mfrow = orig_par)
}

#' Principal component heatmap
#'
#' Draws a heatmap focusing on a principal component. Both cells and genes are sorted by their principal component scores.
#' Allows for nice visualization of sources of heterogeneity in the dataset.
#'
#' @param object Seurat object.
#' @param pc.use PCs to plot
#' @param cells.use A list of cells to plot. If numeric, just plots the top cells.
#' @param num.genes Number of genes to plot
#' @param use.full Use the full PCA (projected PCA). Default is FALSE
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param do.return If TRUE, returns plot object, otherwise plots plot object
#' @param col.use Color to plot.
#' @param use.scale Default is TRUE: plot scaled data. If FALSE, plot raw data on the heatmap.
#' @param do.balanced Plot an equal number of genes with both + and - scores.
#' @param remove.key Removes the color key from the plot.
#' @param label.columns Whether to label the columns. Default is TRUE for 1 PC, FALSE for > 1 PC
#' @param ... Extra parameters for DimHeatmap
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @export
#'
#' @examples
#' PCHeatmap(object = pbmc_small)
#'
PCHeatmap <- function(
  object,
  pc.use = 1,
  cells.use = NULL,
  num.genes = 30,
  use.full = FALSE,
  disp.min = -2.5,
  disp.max = 2.5,
  do.return = FALSE,
  col.use = PurpleAndYellow(),
  use.scale = TRUE,
  do.balanced = FALSE,
  remove.key = FALSE,
  label.columns = NULL,
  ...
) {
  return(DimHeatmap(
    object,
    reduction.type = "pca",
    dim.use = pc.use,
    cells.use = cells.use,
    num.genes = num.genes,
    use.full = use.full,
    disp.min = disp.min,
    disp.max = disp.max,
    do.return = do.return,
    col.use = col.use,
    use.scale = use.scale,
    do.balanced = do.balanced,
    remove.key = remove.key,
    label.columns = label.columns,
    ...
  ))
}

#' Independent component heatmap
#'
#' Draws a heatmap focusing on a principal component. Both cells and genes are sorted by their
#' principal component scores. Allows for nice visualization of sources of heterogeneity
#' in the dataset."()
#'
#' @param object Seurat object
#' @param ic.use Components to use
#' @param cells.use A list of cells to plot. If numeric, just plots the top cells.
#' @param num.genes NUmber of genes to plot
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param do.return If TRUE, returns plot object, otherwise plots plot object
#' @param col.use Colors to plot.
#' @param use.scale Default is TRUE: plot scaled data. If FALSE, plot raw data on the heatmap.
#' @param do.balanced Plot an equal number of genes with both + and - scores.
#' @param remove.key Removes the color key from the plot.
#' @param label.columns Labels for columns
#' @param ... Extra parameters passed to DimHeatmap
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @export
#'
#' @examples
#' pbmc_small <- RunICA(object = pbmc_small, ics.compute = 25, print.results = FALSE)
#' ICHeatmap(object = pbmc_small)
#'
ICHeatmap <- function(
  object,
  ic.use = 1,
  cells.use = NULL,
  num.genes = 30,
  disp.min = -2.5,
  disp.max = 2.5,
  do.return = FALSE,
  col.use = PurpleAndYellow(),
  use.scale = TRUE,
  do.balanced = FALSE,
  remove.key = FALSE,
  label.columns = NULL,
  ...
) {
  return(DimHeatmap(
    object = object,
    reduction.type = "ica",
    dim.use = ic.use,
    cells.use = cells.use,
    num.genes = num.genes,
    disp.min = disp.min,
    disp.max = disp.max,
    do.return = do.return,
    col.use = col.use,
    use.scale = use.scale,
    do.balanced = do.balanced,
    remove.key = remove.key,
    label.columns = label.columns,
    ...
  ))
}


#' Visualize Dimensional Reduction genes
#'
#' Visualize top genes associated with reduction components
#'
#' @param object Seurat object
#' @param reduction.type Reduction technique to visualize results for
#' @param dims.use Number of dimensions to display
#' @param num.genes Number of genes to display
#' @param use.full Use reduction values for full dataset (i.e. projected dimensional reduction values)
#' @param font.size Font size
#' @param nCol Number of columns to display
#' @param do.balanced Return an equal number of genes with + and - scores. If FALSE (default), returns
#' the top genes ranked by the scores absolute values
#'
#' @return Graphical, no return value
#'
#' @importFrom graphics axis plot
#'
#' @export
#'
#' @examples
#' VizDimReduction(object = pbmc_small)
#'
VizDimReduction <- function(
  object,
  reduction.type = "pca",
  dims.use = 1:5,
  num.genes = 30,
  use.full = FALSE,
  font.size = 0.5,
  nCol = NULL,
  do.balanced = FALSE
) {
  if (use.full) {
    dim.scores <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "gene.loadings.full"
    )
  } else {
    dim.scores <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "gene.loadings"
    )
  }
  if (is.null(x = nCol)) {
    if (length(x = dims.use) > 6) {
      nCol <- 3
    } else if (length(x = dims.use) > 9) {
      nCol <- 4
    } else {
      nCol <- 2
    }
  }
  num.row <- floor(x = length(x = dims.use) / nCol - 1e-5) + 1
  par(mfrow = c(num.row, nCol))
  for (i in dims.use) {
    subset.use <- dim.scores[DimTopGenes(
      object = object,
      dim.use = i,
      reduction.type = reduction.type,
      num.genes = num.genes,
      use.full = use.full,
      do.balanced = do.balanced
    ), ]
    plot(
      x = subset.use[, i],
      y = 1:nrow(x = subset.use),
      pch = 16,
      col = "blue",
      xlab = paste0("PC", i),
      yaxt="n",
      ylab=""
    )
    axis(
      side = 2,
      at = 1:nrow(x = subset.use),
      labels = rownames(x = subset.use),
      las = 1,
      cex.axis = font.size
    )
  }
  ResetPar()
}

#' Visualize PCA genes
#'
#' Visualize top genes associated with principal components
#'
#' @param object Seurat object
#' @param pcs.use Number of PCs to display
#' @param num.genes Number of genes to display
#' @param use.full Use full PCA (i.e. the projected PCA, by default FALSE)
#' @param font.size Font size
#' @param nCol Number of columns to display
#' @param do.balanced Return an equal number of genes with both + and - PC scores.
#' If FALSE (by default), returns the top genes ranked by the score's absolute values
#'
#' @return Graphical, no return value
#'
#' @export
#'
#' @examples
#' VizPCA(object = pbmc_small)
#'
VizPCA <- function(
  object,
  pcs.use = 1:5,
  num.genes = 30,
  use.full = FALSE,
  font.size = 0.5,
  nCol = NULL,
  do.balanced = FALSE
) {
  VizDimReduction(
    object = object,
    reduction.type = "pca",
    dims.use = pcs.use,
    num.genes = num.genes,
    use.full = use.full,
    font.size = font.size,
    nCol = nCol,
    do.balanced = do.balanced
  )
}

#' Visualize ICA genes
#'
#' Visualize top genes associated with principal components
#'
#' @param object Seurat object
#' @param ics.use Number of ICs to display
#' @param num.genes Number of genes to display
#' @param use.full Use full ICA (i.e. the projected ICA, by default FALSE)
#' @param font.size Font size
#' @param nCol Number of columns to display
#' @param do.balanced Return an equal number of genes with both + and - IC scores.
#' If FALSE (by default), returns the top genes ranked by the score's absolute values
#'
#' @return Graphical, no return value
#'
#' @export
#'
#' @examples
#' pbmc_small <- RunICA(object = pbmc_small, ics.compute = 25, print.results = FALSE)
#' VizICA(object = pbmc_small)
#'
VizICA <- function(
  object,
  ics.use = 1:5,
  num.genes = 30,
  use.full = FALSE,
  font.size = 0.5,
  nCol = NULL,
  do.balanced = FALSE
) {
  VizDimReduction(
    object = object,
    reduction.type = "ica",
    dims.use = ics.use,
    num.genes = num.genes,
    use.full = use.full,
    font.size = font.size,
    nCol = nCol,
    do.balanced = do.balanced
  )
}

globalVariables(names = c('x', 'y', 'ident'), package = 'Seurat', add = TRUE)
#' Dimensional reduction plot
#'
#' Graphs the output of a dimensional reduction technique (PCA by default).
#' Cells are colored by their identity class.
#'
#' @param object Seurat object
#' @param reduction.use Which dimensionality reduction to use. Default is
#' "pca", can also be "tsne", or "ica", assuming these are precomputed.
#' @param dim.1 Dimension for x-axis (default 1)
#' @param dim.2 Dimension for y-axis (default 2)
#' @param cells.use Vector of cells to plot (default is all cells)
#' @param pt.size Adjust point size for plotting
#' @param do.return Return a ggplot2 object (default : FALSE)
#' @param do.bare Do only minimal formatting (default : FALSE)
#' @param cols.use Vector of colors, each color corresponds to an identity
#' class. By default, ggplot assigns colors.
#' @param group.by Group (color) cells in different ways (for example, orig.ident)
#' @param pt.shape If NULL, all points are circles (default). You can specify any
#' cell attribute (that can be pulled with FetchData) allowing for both
#' different colors and different shapes on cells.
#' @param do.hover Enable hovering over points to view information
#' @param data.hover Data to add to the hover, pass a character vector of
#' features to add. Defaults to cell name and ident. Pass 'NULL' to clear extra
#' information.
#' @param do.identify Opens a locator session to identify clusters of cells.
#' @param do.label Whether to label the clusters
#' @param label.size Sets size of labels
#' @param no.legend Setting to TRUE will remove the legend
#' @param coord.fixed Use a fixed scale coordinate system (for spatial coordinates). Default is FALSE.
#' @param no.axes Setting to TRUE will remove the axes
#' @param dark.theme Use a dark theme for the plot
#' @param plot.order Specify the order of plotting for the idents. This can be
#' useful for crowded plots if points of interest are being buried. Provide
#' either a full list of valid idents or a subset to be plotted last (on top).
#' @param cells.highlight A list of character or numeric vectors of cells to
#' highlight. If only one group of cells desired, can simply
#' pass a vector instead of a list. If set, colors selected cells to the color(s)
#' in \code{cols.highlight} and other cells black (white if dark.theme = TRUE);
#'  will also resize to the size(s) passed to \code{sizes.highlight}
#' @param cols.highlight A vector of colors to highlight the cells as; will
#' repeat to the length groups in cells.highlight
#' @param sizes.highlight Size of highlighted cells; will repeat to the length
#' groups in cells.highlight
#' @param plot.title Title for plot
#' @param vector.friendly FALSE by default. If TRUE, points are flattened into
#' a PNG, while axes/labels retain full vector resolution. Useful for producing
#' AI-friendly plots with large numbers of cells.
#' @param png.file Used only if vector.friendly is TRUE. Location for temporary
#' PNG file.
#' @param png.arguments Used only if vector.friendly is TRUE. Vector of three
#' elements (PNG width, PNG height, PNG DPI) to be used for temporary PNG.
#' Default is c(10,10,100)
#' @param na.value Color value for NA points when using custom scale.
#' @param ... Extra parameters to FeatureLocator for do.identify = TRUE
#'
#' @return If do.return==TRUE, returns a ggplot2 object. Otherwise, only
#' graphical output.
#'
#' @seealso \code{FeatureLocator}
#'
#' @import SDMTools
#' @importFrom stats median
#' @importFrom dplyr summarize group_by
#' @importFrom png readPNG
#'
#' @export
#'
#' @examples
#' DimPlot(object = pbmc_small)
#'
DimPlot <- function(
  object,
  reduction.use = "pca",
  dim.1 = 1,
  dim.2 = 2,
  cells.use = NULL,
  pt.size = 1,
  do.return = FALSE,
  do.bare = FALSE,
  cols.use = NULL,
  group.by = "ident",
  pt.shape = NULL,
  do.hover = FALSE,
  data.hover = 'ident',
  do.identify = FALSE,
  do.label = FALSE,
  label.size = 4,
  no.legend = FALSE,
  coord.fixed = FALSE,
  no.axes = FALSE,
  dark.theme = FALSE,
  plot.order = NULL,
  cells.highlight = NULL,
  cols.highlight = 'red',
  sizes.highlight = 1,
  plot.title = NULL,
  vector.friendly = FALSE,
  png.file = NULL,
  png.arguments = c(10,10, 100),
  na.value = 'grey50',
  ...
) {
  #first, consider vector friendly case
  if (vector.friendly) {
    previous_call <- blank_call <- png_call <-  match.call()
    blank_call$pt.size <- -1
    blank_call$do.return <- TRUE
    blank_call$vector.friendly <- FALSE
    png_call$no.axes <- TRUE
    png_call$no.legend <- TRUE
    png_call$do.return <- TRUE
    png_call$vector.friendly <- FALSE
    png_call$plot.title <- NULL
    blank_plot <- eval(blank_call, sys.frame(sys.parent()))
    png_plot <- eval(png_call, sys.frame(sys.parent()))
    png.file <- SetIfNull(x = png.file, default = paste0(tempfile(), ".png"))
    ggsave(
      filename = png.file,
      plot = png_plot,
      width = png.arguments[1],
      height = png.arguments[2],
      dpi = png.arguments[3]
    )
    to_return <- AugmentPlot(plot1 = blank_plot, imgFile = png.file)
    file.remove(png.file)
    if (do.return) {
      return(to_return)
    } else {
      print(to_return)
    }
  }
  embeddings.use <- GetDimReduction(
    object = object,
    reduction.type = reduction.use,
    slot = "cell.embeddings"
  )
  if (length(x = embeddings.use) == 0) {
    stop(paste(reduction.use, "has not been run for this object yet."))
  }
  cells.use <- SetIfNull(x = cells.use, default = colnames(x = object@data))
  dim.code <- GetDimReduction(
    object = object,
    reduction.type = reduction.use,
    slot = "key"
  )
  dim.codes <- paste0(dim.code, c(dim.1, dim.2))
  data.plot <- as.data.frame(x = embeddings.use)
  # data.plot <- as.data.frame(GetDimReduction(object, reduction.type = reduction.use, slot = ""))
  cells.use <- intersect(x = cells.use, y = rownames(x = data.plot))
  data.plot <- data.plot[cells.use, dim.codes]
  ident.use <- as.factor(x = object@ident[cells.use])
  if (group.by != "ident") {
    ident.use <- as.factor(x = FetchData(
      object = object,
      vars.fetch = group.by
    )[cells.use, 1])
  }
  data.plot$ident <- ident.use
  data.plot$x <- data.plot[, dim.codes[1]]
  data.plot$y <- data.plot[, dim.codes[2]]
  data.plot$pt.size <- pt.size
  if (!is.null(x = cells.highlight)) {
    # Ensure that cells.highlight are in our data.frame
    if (is.character(x = cells.highlight)) {
      cells.highlight <- list(cells.highlight)
    } else if (is.data.frame(x = cells.highlight) || !is.list(x = cells.highlight)) {
      cells.highlight <- as.list(x = cells.highlight)
    }
    cells.highlight <- lapply(
      X = cells.highlight,
      FUN = function(cells) {
        cells.return <- if (is.character(x = cells)) {
          cells[cells %in% rownames(x = data.plot)]
        } else {
          cells <- as.numeric(x = cells)
          cells <- cells[cells <= nrow(x = data.plot)]
          rownames(x = data.plot)[cells]
        }
        return(cells.return)
      }
    )
    # Remove groups that had no cells in our dataframe
    cells.highlight <- Filter(f = length, x = cells.highlight)
    if (length(x = cells.highlight) > 0) {
      if (!no.legend) {
        no.legend <- is.null(x = names(x = cells.highlight))
      }
      names.highlight <- if (is.null(x = names(x = cells.highlight))) {
        paste0('Group_', 1L:length(x = cells.highlight))
      } else {
        names(x = cells.highlight)
      }
      sizes.highlight <- rep_len(
        x = sizes.highlight,
        length.out = length(x = cells.highlight)
      )
      cols.highlight <- rep_len(
        x = cols.highlight,
        length.out = length(x = cells.highlight)
      )
      highlight <- rep_len(x = NA_character_, length.out = nrow(x = data.plot))
      if (is.null(x = cols.use)) {
        cols.use <- 'black'
      }
      cols.use <- c(cols.use[1], cols.highlight)
      size <- rep_len(x = pt.size, length.out = nrow(x = data.plot))
      for (i in 1:length(x = cells.highlight)) {
        cells.check <- cells.highlight[[i]]
        index.check <- match(x = cells.check, rownames(x = data.plot))
        highlight[index.check] <- names.highlight[i]
        size[index.check] <- sizes.highlight[i]
      }
      plot.order <- sort(x = unique(x = highlight), na.last = TRUE)
      plot.order[is.na(x = plot.order)] <- 'Unselected'
      highlight[is.na(x = highlight)] <- 'Unselected'
      highlight <- as.factor(x = highlight)
      data.plot$ident <- highlight
      data.plot$pt.size <- size
      if (dark.theme) {
        cols.use[1] <- 'white'
      }
    }
  }
  if (!is.null(x = plot.order)) {
    if (any(!plot.order %in% data.plot$ident)) {
      stop("invalid ident in plot.order")
    }
    plot.order <- rev(x = c(
      plot.order,
      setdiff(x = unique(x = data.plot$ident), y = plot.order)
    ))
    data.plot$ident <- factor(x = data.plot$ident, levels = plot.order)
    data.plot <- data.plot[order(data.plot$ident), ]
  }
  p <- ggplot(data = data.plot, mapping = aes(x = x, y = y)) +
    geom_point(mapping = aes(colour = factor(x = ident), size = pt.size))
  if (!is.null(x = pt.shape)) {
    shape.val <- FetchData(object = object, vars.all = pt.shape)[cells.use, 1]
    if (is.numeric(shape.val)) {
      shape.val <- cut(x = shape.val, breaks = 5)
    }
    data.plot[, "pt.shape"] <- shape.val
    p <- ggplot(data = data.plot, mapping = aes(x = x, y = y)) +
      geom_point(mapping = aes(
        colour = factor(x = ident),
        shape = factor(x = pt.shape),
        size = pt.size
      ))
  }
  if (!is.null(x = cols.use)) {
    p <- p + scale_colour_manual(values = cols.use, na.value=na.value)
  }
  if (coord.fixed) {
    p <- p + coord_fixed()
  }
  p <- p + guides(size = FALSE)
  p2 <- p +
    xlab(label = dim.codes[[1]]) +
    ylab(label = dim.codes[[2]]) +
    scale_size(range = c(min(data.plot$pt.size), max(data.plot$pt.size)))
  p3 <- p2 +
    SetXAxisGG() +
    SetYAxisGG() +
    SetLegendPointsGG(x = 6) +
    SetLegendTextGG(x = 12) +
    no.legend.title +
    theme_bw() +
    NoGrid()
  if (dark.theme) {
    p <- p + DarkTheme()
    p3 <- p3 + DarkTheme()
  }
  p3 <- p3 + theme(legend.title = element_blank())
  if (!is.null(plot.title)) {
    p3 <- p3 + ggtitle(plot.title) + theme(plot.title = element_text(hjust = 0.5))
  }
  if (do.label) {
    data.plot %>%
      dplyr::group_by(ident) %>%
      summarize(x = median(x = x), y = median(x = y)) -> centers
    p3 <- p3 +
      geom_point(data = centers, mapping = aes(x = x, y = y), size = 0, alpha = 0) +
      geom_text(data = centers, mapping = aes(label = ident), size = label.size)
  }
  if (no.legend) {
    p3 <- p3 + theme(legend.position = "none")
  }
  if (no.axes) {
    p3 <- p3 + theme(
      axis.line = element_blank(),
      axis.text.x = element_blank(),
      axis.text.y = element_blank(),
      axis.ticks = element_blank(),
      axis.title.x = element_blank(),
      axis.title.y = element_blank(),
      panel.background = element_blank(),
      panel.border = element_blank(),
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      plot.background = element_blank()
    )
  }
  if (do.identify || do.hover) {
    if (do.bare) {
      plot.use <- p
    } else {
      plot.use <- p3
    }
    if (do.hover) {
      if (is.null(x = data.hover)) {
        features.info <- NULL
      } else {
        features.info <- FetchData(object = object, vars.all = data.hover)
      }
      return(HoverLocator(
        plot = plot.use,
        data.plot = data.plot,
        features.info = features.info,
        dark.theme = dark.theme
      ))
    } else if (do.identify) {
      return(FeatureLocator(
        plot = plot.use,
        data.plot = data.plot,
        dark.theme = dark.theme,
        ...
      ))
    }
  }
  if (do.return) {
    if (do.bare) {
      return(p)
    } else {
      return(p3)
    }
  }
  if (do.bare) {
    print(p)
  } else {
    print(p3)
  }
}

#' Plot PCA map
#'
#' Graphs the output of a PCA analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @export
#'
#' @examples
#' PCAPlot(object = pbmc_small)
#'
PCAPlot <- function(object, ...) {
  return(DimPlot(object = object, reduction.use = "pca", label.size = 4, ...))
}

#' Plot Diffusion map
#'
#' Graphs the output of a Diffusion map analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @export
#'
#' @examples
#' pbmc_small <- RunDiffusion(object = pbmc_small)
#' DMPlot(object = pbmc_small)
#'
DMPlot <- function(object, ...) {
  return(DimPlot(object = object, reduction.use = "dm", label.size = 4, ...))
}

#' Plot ICA map
#'
#' Graphs the output of a ICA analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @export
#'
#' @examples
#' pbmc_small <- RunICA(object = pbmc_small, ics.compute = 25, print.results = FALSE)
#' ICAPlot(object = pbmc_small)
#'
ICAPlot <- function(object, ...) {
  return(DimPlot(object = object, reduction.use = "ica", ...))
}

#' Plot tSNE map
#'
#' Graphs the output of a tSNE analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param do.label FALSE by default. If TRUE, plots an alternate view where the center of each
#' cluster is labeled
#' @param pt.size Set the point size
#' @param label.size Set the size of the text labels
#' @param cells.use Vector of cell names to use in the plot.
#' @param colors.use Manually set the color palette to use for the points
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @seealso DimPlot
#'
#' @export
#'
#' @examples
#' TSNEPlot(object = pbmc_small)
#'
TSNEPlot <- function(
  object,
  do.label = FALSE,
  pt.size=1,
  label.size=4,
  cells.use = NULL,
  colors.use = NULL,
  ...
) {
  return(DimPlot(
    object = object,
    reduction.use = "tsne",
    cells.use = cells.use,
    pt.size = pt.size,
    do.label = do.label,
    label.size = label.size,
    cols.use = colors.use,
    ...
  ))
}

#' Quickly Pick Relevant Dimensions
#'
#' Plots the standard deviations (or approximate singular values if running PCAFast)
#' of the principle components for easy identification of an elbow in the graph.
#' This elbow often corresponds well with the significant dims and is much faster to run than
#' Jackstraw
#'
#'
#' @param object Seurat object
#' @param reduction.type  Type of dimensional reduction to plot data for
#' @param dims.plot Number of dimensions to plot sd for
#' @param xlab X axis label
#' @param ylab Y axis label
#' @param title Plot title
#'
#' @return Returns ggplot object
#'
#' @export
#'
#' @examples
#' DimElbowPlot(object = pbmc_small)
#'
DimElbowPlot <- function(
  object,
  reduction.type = "pca",
  dims.plot = 20,
  xlab = "",
  ylab = "",
  title = ""
) {
  data.use <- GetDimReduction(
    object = object,
    reduction.type = reduction.type,
    slot = "sdev"
  )
  if (length(data.use) == 0) {
    stop(paste("No standard deviation info stored for", reduction.type))
  }
  if (length(x = data.use) < dims.plot) {
    warning(paste(
      "The object only has information for",
      length(x = data.use),
      "PCs."
    ))
    dims.plot <- length(x = data.use)
  }
  data.use <- data.use[1:dims.plot]
  dims <- 1:length(x = data.use)
  data.plot <- data.frame(dims, data.use)
  plot <- ggplot(data = data.plot, mapping = aes(x = dims, y = data.use)) +
    geom_point()
  if (reduction.type == "pca") {
    plot <- plot +
      labs(y = "Standard Deviation of PC", x = "PC", title = title)
  } else if(reduction.type == "ica"){
    plot <- plot +
      labs(y = "Standard Deviation of IC", x = "IC", title = title)
  } else {
    plot <- plot +
      labs(y = ylab, x = xlab, title = title)
  }
  return(plot)
}

#' Quickly Pick Relevant PCs
#'
#' Plots the standard deviations (or approximate singular values if running PCAFast)
#' of the principle components for easy identification of an elbow in the graph.
#' This elbow often corresponds well with the significant PCs and is much faster to run.
#'
#' @param object Seurat object
#' @param num.pc Number of PCs to plot
#'
#' @return Returns ggplot object
#'
#' @export
#'
#' @examples
#' PCElbowPlot(object = pbmc_small)
#'
PCElbowPlot <- function(object, num.pc = 20) {
  return(DimElbowPlot(
    object = object,
    reduction.type = "pca",
    dims.plot = num.pc
  ))
}

#' View variable genes
#'
#' @param object Seurat object
#' @param do.text Add text names of variable genes to plot (default is TRUE)
#' @param cex.use Point size
#' @param cex.text.use Text size
#' @param do.spike FALSE by default. If TRUE, color all genes starting with ^ERCC a different color
#' @param pch.use Pch value for points
#' @param col.use Color to use
#' @param spike.col.use if do.spike, color for spike-in genes
#' @param plot.both Plot both the scaled and non-scaled graphs.
#' @param do.contour Draw contour lines calculated based on all genes
#' @param contour.lwd Contour line width
#' @param contour.col Contour line color
#' @param contour.lty Contour line type
#' @param x.low.cutoff Bottom cutoff on x-axis for identifying variable genes
#' @param x.high.cutoff Top cutoff on x-axis for identifying variable genes
#' @param y.cutoff Bottom cutoff on y-axis for identifying variable genes
#' @param y.high.cutoff Top cutoff on y-axis for identifying variable genes
#'
#' @importFrom stats cor loess smooth.spline
#' @importFrom grDevices col2rgb
#' @importFrom graphics axis points smoothScatter contour points text
#'
#' @export
#'
#' @examples
#' VariableGenePlot(object = pbmc_small)
#'
VariableGenePlot <- function(
  object,
  do.text = TRUE,
  cex.use = 0.5,
  cex.text.use = 0.5,
  do.spike = FALSE,
  pch.use = 16,
  col.use = "black",
  spike.col.use = "red",
  plot.both = FALSE,
  do.contour = TRUE,
  contour.lwd = 3,
  contour.col = "white",
  contour.lty = 2,
  x.low.cutoff = 0.1,
  x.high.cutoff = 8,
  y.cutoff = 1,
  y.high.cutoff = Inf
) {
  gene.mean <- object@hvg.info[, 1]
  gene.dispersion <- object@hvg.info[, 2]
  gene.dispersion.scaled <- object@hvg.info[, 3]
  names(x = gene.mean) <- names(x = gene.dispersion) <- names(x = gene.dispersion.scaled) <- rownames(x = object@hvg.info)
  pass.cutoff <- names(x = gene.mean)[which(
    x = (
      (gene.mean > x.low.cutoff) & (gene.mean < x.high.cutoff)
    ) &
      (gene.dispersion.scaled > y.cutoff) &
      (gene.dispersion.scaled < y.high.cutoff)
  )]
  if (do.spike) {
    spike.genes <- rownames(x = SubsetRow(data = object@data, code = "^ERCC"))
  }
  if (plot.both) {
    par(mfrow = c(1, 2))
    smoothScatter(
      x = gene.mean,
      y = gene.dispersion,
      pch = pch.use,
      cex = cex.use,
      col = col.use,
      xlab = "Average expression",
      ylab = "Dispersion",
      nrpoints = Inf
    )
    if (do.contour) {
      data.kde <- kde2d(x = gene.mean, y = gene.dispersion)
      contour(
        x = data.kde,
        add = TRUE,
        lwd = contour.lwd,
        col = contour.col,
        lty = contour.lty
      )
    }
    if (do.spike) {
      points(
        x = gene.mean[spike.genes],
        y = gene.dispersion[spike.genes],
        pch = 16,
        cex = cex.use,
        col = spike.col.use
      )
    }
    if (do.text) {
      text(
        x = gene.mean[pass.cutoff],
        y = gene.dispersion[pass.cutoff],
        labels = pass.cutoff,
        cex = cex.text.use
      )
    }
  }
  smoothScatter(
    x = gene.mean,
    y = gene.dispersion.scaled,
    pch = pch.use,
    cex = cex.use,
    col = col.use,
    xlab = "Average expression",
    ylab = "Dispersion",
    nrpoints = Inf
  )
  if (do.contour) {
    data.kde <- kde2d(x = gene.mean, y = gene.dispersion.scaled)
    contour(
      x = data.kde,
      add = TRUE,
      lwd = contour.lwd,
      col = contour.col,
      lty = contour.lty
    )
  }
  if (do.spike) {
    points(
      x = gene.mean[spike.genes],
      y = gene.dispersion.scaled[spike.genes],
      pch = 16,
      cex = cex.use,
      col = spike.col.use,
      nrpoints = Inf
    )
  }
  if (do.text) {
    text(
      x = gene.mean[pass.cutoff],
      y = gene.dispersion.scaled[pass.cutoff],
      labels = pass.cutoff,
      cex = cex.text.use
    )
  }
}

# #' Highlight classification results
# #'
# #' This function is useful to view where proportionally the clusters returned from
# #' classification map to the clusters present in the given object. Utilizes the FeaturePlot()
# #' function to color clusters in object.
# #'
# #' @param object Seurat object on which the classifier was trained and
# #' onto which the classification results will be highlighted
# #' @param clusters vector of cluster ids (output of ClassifyCells)
# #' @param ... additional parameters to pass to FeaturePlot()
# #'
# #' @return Returns a feature plot with clusters highlighted by proportion of cells
# #' mapping to that cluster
# #'
# #' @export
# #'
# VizClassification <- function(object, clusters, ...) {
#   cluster.dist <- prop.table(x = table(out)) # What is out?
#   object@meta.data$Classification <- numeric(nrow(x = object@meta.data))
#   for (cluster in 1:length(x = cluster.dist)) {
#     cells.to.highlight <- WhichCells(object, names(cluster.dist[cluster]))
#     if (length(x = cells.to.highlight) > 0) {
#       object@meta.data[cells.to.highlight, ]$Classification <- cluster.dist[cluster]
#     }
#   }
#   if (any(grepl(pattern = "cols.use", x = deparse(match.call())))) {
#     return(FeaturePlot(object, "Classification", ...))
#   }
#   cols.use = c("#f6f6f6", "black")
#   return(FeaturePlot(object, "Classification", cols.use = cols.use, ...))
# }

#' Plot phylogenetic tree
#'
#' Plots previously computed phylogenetic tree (from BuildClusterTree)
#'
#' @param object Seurat object
#' @param \dots Additional arguments for plotting the phylogeny
#'
#' @return Plots dendogram (must be precomputed using BuildClusterTree), returns no value
#'
#' @importFrom ape plot.phylo
#' @importFrom ape nodelabels
#'
#' @export
#'
#' @examples
#' PlotClusterTree(object = pbmc_small)
#'
PlotClusterTree <- function(object, ...) {
  if (length(x = object@cluster.tree) == 0) {
    stop("Phylogenetic tree does not exist, build using BuildClusterTree")
  }
  data.tree <- object@cluster.tree[[1]]
  plot.phylo(x = data.tree, direction = "downwards", ...)
  nodelabels()
}

#' Color tSNE Plot Based on Split
#'
#' Returns a tSNE plot colored based on whether the cells fall in clusters
#' to the left or to the right of a node split in the cluster tree.
#'
#' @param object Seurat object
#' @param node Node in cluster tree on which to base the split
#' @param color1 Color for the left side of the split
#' @param color2 Color for the right side of the split
#' @param color3 Color for all other cells
#' @inheritDotParams TSNEPlot -object
#'
#' @return Returns a tSNE plot
#'
#' @export
#'
#' @examples
#' pbmc_small
#' PlotClusterTree(pbmc_small)
#' ColorTSNESplit(pbmc_small, node = 6)
#'
ColorTSNESplit <- function(
  object,
  node,
  color1 = "red",
  color2 = "blue",
  color3 = "gray",
  ...
) {
  tree <- object@cluster.tree[[1]]
  split <- tree$edge[which(x = tree$edge[,1] == node), ][, 2]
  all.children <- sort(x = tree$edge[,2][!tree$edge[,2] %in% tree$edge[,1]])
  left.group <- DFT(tree = tree, node = split[1], only.children = TRUE)
  right.group <- DFT(tree = tree, node = split[2], only.children = TRUE)
  if (any(is.na(x = left.group))) {
    left.group <- split[1]
  }
  if (any(is.na(x = right.group))) {
    right.group <- split[2]
  }
  left.group <- MapVals(v = left.group, from = all.children, to = tree$tip.label)
  right.group <- MapVals(v = right.group, from = all.children, to = tree$tip.label)
  remaining.group <- setdiff(x = tree$tip.label, y = c(left.group, right.group))
  left.cells <- WhichCells(object = object, ident = left.group)
  right.cells <- WhichCells(object = object, ident = right.group)
  remaining.cells <- WhichCells(object = object, ident = remaining.group)
  object <- SetIdent(
    object = object,
    cells.use = left.cells,
    ident.use = "Left Split"
  )
  object <- SetIdent(
    object = object,
    cells.use = right.cells,
    ident.use = "Right Split"
  )
  object <- SetIdent(
    object = object,
    cells.use = remaining.cells,
    ident.use = "Not in Split"
  )
  colors.use = c(color1, color3, color2)
  return(TSNEPlot(object = object, colors.use = colors.use, ...))
}

#' Plot k-means clusters
#'
#' @param object A Seurat object
#' @param cells.use Cells to include in the heatmap
#' @param genes.cluster Clusters to include in heatmap
#' @param max.genes Maximum number of genes to include in the heatmap
#' @param slim.col.label Instead of displaying every cell name on the heatmap,
#' display only the identity class name once for each group
#' @param remove.key Removes teh color key from the plot
#' @param row.lines Color separations of clusters
#' @param ... Extra parameters to DoHeatmap
#'
#' @seealso \code{DoHeatmap}
#'
#' @export
#'
#' @examples
#' pbmc_small <- DoKMeans(object = pbmc_small, k.genes = 3)
#' KMeansHeatmap(object = pbmc_small)
#'
KMeansHeatmap <- function(
  object,
  cells.use = object@cell.names,
  genes.cluster = NULL,
  max.genes = 1e6,
  slim.col.label = TRUE,
  remove.key = TRUE,
  row.lines = TRUE,
  ...
) {
  genes.cluster <- SetIfNull(
    x = genes.cluster,
    default = unique(x = object@kmeans@gene.kmeans.obj$cluster)
  )
  genes.use <- GenesInCluster(
    object = object,
    cluster.num = genes.cluster,
    max.genes = max.genes
  )
  cluster.lengths <- sapply(
    X = genes.cluster,
    FUN = function(x) {
      return(length(x = GenesInCluster(object = object, cluster.num = x)))
    }
  )
  #print(cluster.lengths)
  # if (row.lines) {
  #   rowsep.use <- cumsum(x = cluster.lengths)
  # } else {
  #   rowsep.use <- NA
  # }
  DoHeatmap(
    object = object,
    cells.use = cells.use,
    genes.use = genes.use,
    slim.col.label = slim.col.label,
    remove.key = remove.key,
    # rowsep = rowsep.use,
    ...
  )
}

globalVariables(
  names = c('cluster', 'avg_diff', 'gene'),
  package = 'Seurat',
  add = TRUE
)
# Node Heatmap
#
# Takes an object, a marker list (output of FindAllMarkers), and a node
# and plots a heatmap where genes are ordered vertically by the splits present
# in the object@@cluster.tree slot.
#
# @param object Seurat object. Must have the cluster.tree slot filled (use BuildClusterTree)
# @param marker.list List of marker genes given from the FindAllMarkersNode function
# @param node Node in the cluster tree from which to start the plot, defaults to highest node in marker list
# @param max.genes Maximum number of genes to keep for each division
# @param ... Additional parameters to pass to DoHeatmap
#
#' @importFrom dplyr %>% group_by filter top_n select
#
# @return Plots heatmap. No return value.
#
# @export
#
NodeHeatmap <- function(object, marker.list, node = NULL, max.genes = 10, ...) {
  tree <- object@cluster.tree[[1]]
  node <- SetIfNull(x = node, default = min(marker.list$cluster))
  node.order <- c(node, DFT(tree = tree, node = node))
  marker.list$rank <- seq(1:nrow(x = marker.list))
  marker.list %>% group_by(cluster) %>% filter(avg_diff > 0) %>%
    top_n(max.genes, -rank) %>% select(gene, cluster) -> pos.genes
  marker.list %>% group_by(cluster) %>% filter(avg_diff < 0) %>%
    top_n(max.genes, -rank) %>% select(gene, cluster) -> neg.genes
  gene.list <- vector()
  node.stack <- vector()
  for (n in node.order) {
    if (NodeHasChild(tree = tree, node = n)) {
      gene.list <- c(
        gene.list,
        c(
          subset(x = pos.genes, subset = cluster == n)$gene,
          subset(x = neg.genes, subset = cluster == n)$gene
        )
      )
      if (NodeHasOnlyChildren(tree = tree, node = n)) {
        gene.list <- c(
          gene.list,
          subset(x = neg.genes, subset = cluster == node.stack[length(node.stack)])$gene
        )
        node.stack <- node.stack[-length(x = node.stack)]
      }
    }
    else {
      gene.list <- c(gene.list, subset(x = pos.genes, subset = cluster == n)$gene)
      node.stack <- append(x = node.stack, values = n)
    }
  }
  #gene.list <- rev(unique(rev(gene.list)))
  descendants <- GetDescendants(tree = tree, node = node)
  children <- descendants[!descendants %in% tree$edge[, 1]]
  all.children <- tree$edge[,2][!tree$edge[,2] %in% tree$edge[, 1]]
  DoHeatmap(
    object = object,
    cells.use = WhichCells(object = object, ident = children),
    genes.use = gene.list,
    slim.col.label = TRUE,
    remove.key = TRUE,
    ...
  )
}

globalVariables(
  names = c('cc', 'bicor', "Group"),
  package = 'Seurat',
  add = TRUE
)
#' Plot CC bicor saturation plot
#'
#' The function provides a useful plot for evaluating the number of CCs to
#' proceed with in the Seurat alignment workflow. Here we look at the biweight
#' midcorrelation (bicor) of the Xth gene ranked by minimum bicor across the
#' specified CCs for each group in the grouping.var. For alignment of more than
#' two groups, we average the bicor results for the reference group across the
#' pairwise alignments.
#'
#' @param object A Seurat object
#' @param bicor.data Optionally provide data.frame returned by function to avoid
#' recalculation
#' @param grouping.var Grouping variable specified in alignment procedure
#' @param dims.eval dimensions to evalutate the bicor for
#' @param gene.num Xth gene to look at bicor for
#' @param num.possible.genes Number of possible genes to search when choosing
#' genes for the metagene. Set to 2000 by default. Lowering will decrease runtime
#' but may result in metagenes constructed on fewer than num.genes genes.
#' @param smooth Smooth curves
#' @param return.mat Return data.matrix instead of ggplot2 object
#' @param display.progress Show progress bar
#'
#' @import ggplot2
#' @export
#'
#' @examples
#' pbmc_small <- DoKMeans(object = pbmc_small, k.genes = 3)
#' KMeansHeatmap(object = pbmc_small)
#'

MetageneBicorPlot <- function(
  object,
  bicor.data,
  grouping.var,
  dims.eval,
  gene.num = 30,
  num.possible.genes = 2000,
  return.mat = FALSE,
  smooth = TRUE,
  display.progress = TRUE
) {
  if (missing(x = bicor.data)) {
    bicor.data <- EvaluateCCs(
      object = object,
      grouping.var = grouping.var,
      dims.eval = dims.eval,
      gene.num = gene.num,
      num.possible.genes = num.possible.genes,
      display.progress = display.progress
    )
  }
  if (length(x = dims.eval) < 10 | !smooth) {
    if (!missing(x = smooth) & smooth) {
      warning("Curves not smoothed. Falling back to line plot")
    }
    p <- ggplot(bicor.data, aes(x = cc, y = abs(bicor))) +
      geom_line(aes(col = Group)) +
      ylab(paste0("Shared Correlation Strength")) + xlab("CC")
  } else {
    p <- ggplot(bicor.data, aes(x = cc, y = abs(bicor))) +
      geom_smooth(aes(col = Group), se = FALSE) +
      ylab(paste0("Shared Correlation Strength")) + xlab("CC")
  }
  print(p)
  if (return.mat) {
    return(bicor.data)
  } else {
    return(p)
  }
}
back to top