Source code for sequana.viz.heatmap

#
#  This file is part of Sequana software
#
#  Copyright (c) 2016-2020 - Sequana Development Team
#
#  File author(s):
#      Thomas Cokelaer <thomas.cokelaer@pasteur.fr>
#      Etienne Kornobis <etienne.kornobis@pasteur.fr>
#
#  Distributed under the terms of the 3-clause BSD license.
#  The full license is in the LICENSE file, distributed with this software.
#
#  website: https://github.com/sequana/sequana
#  documentation: http://sequana.readthedocs.io
#
##############################################################################
"""Heatmap and dendograms"""

from sequana.lazy import numpy as np
from sequana.lazy import pandas as pd
from sequana.lazy import pylab
from sequana.viz.linkage import Linkage

__all__ = ["Heatmap", "Clustermap"]


try:
    import seaborn as sns
except ImportError:
    print("You should install seaborn to use sequana.viz.heatmap")


def get_heatmap_df():
    """a simple example to play with and perform test"""
    import pandas as pd

    df = pd.DataFrame(
        {
            "A": [1, 0, 1, 1],
            "B": [0.9, 0.1, 0.6, 1],
            "C": [0.5, 0.2, 0, 1],
            "D": [0.5, 0.2, 0, 1],
        }
    )
    return df


def get_clustermap_data():
    """a simple example to play with and perform test"""
    import pandas as pd

    df = pd.DataFrame(
        {
            "sample_1": [1, 0, 1, 1],
            "sample_2": [0.9, 0.1, 0.6, 1],
            "sample_3": [0.5, 0.2, 0, 1],
            "sample_4": [0.5, 0.2, 0, 1],
        },
        index=["gene_A", "gene_B", "gene_C", "gene_D"],
    )
    df_sample_groups = pd.DataFrame(
        {
            "sample_1": ["WT", "rep1"],
            "sample_2": ["WT", "rep2"],
            "sample_3": ["KO", "rep1"],
            "sample_4": ["KO", "rep2"],
        },
        index=["group1", "group2"],
    ).transpose()

    df_gene_groups = pd.DataFrame(
        {
            "gene_A": ["transcription", "GO1"],
            "gene_B": ["transcription", "GO2"],
            "gene_C": ["regulation", "GO1"],
            "gene_D": ["regulation", "GO2"],
        },
        index=["group1", "group2"],
    ).transpose()

    return df, df_sample_groups, df_gene_groups


[docs]class Clustermap: """Heatmap and dendrograms based on seaborn Clustermap .. plot:: :include-source: :width: 80% from sequana.viz.heatmap import Clustermap, get_clustermap_data df, sample_groups_df, gene_groups_df = get_clustermap_data() h = Clustermap(df, sample_groups_df=sample_groups_df, gene_groups_df=gene_groups_df) h.plot() """ def __init__( self, data_df, sample_groups_df=None, sample_groups_sel=[], sample_groups_palette=sns.color_palette(desat=0.6), gene_groups_df=None, gene_groups_sel=[], gene_groups_palette=sns.color_palette(desat=0.6), yticklabels="auto", **kwargs ): # annot): """.. rubric:: Constructor :param data_df: a dataframe. :param sample_groups_df: a dataframe with sample id as index (same as in data_df columns) and a group definition per column. Use to produce the x axis color groups. :param sample_group_sel: a list of the columns to select from the sample_groups_df. :param sample_groups_palette: the palette to use for sample color groups. :param gene_groups_df: a dataframe with gene id as index (same as in data_df columns) and a group definition per column. Use to produce the y axis color groups. :param gene_group_sel: a list of the columns to select from the gene_groups_df. :param gene_groups_palette: the palette to use for gene color groups. :param ytickslabels: "auto" for classical heatmap behaviour, [] for no ticks or a pandas Series giving the mapping between the index (gene names in data_df) and the gene names to be used for the heatmap :param kwargs: All other kwargs are passed to seaborn.Clustermap. """ self.data_df = data_df self.sample_groups_df = sample_groups_df self.sample_groups_sel = sample_groups_sel self.sample_groups_palette = sample_groups_palette self.gene_groups_df = gene_groups_df self.gene_groups_sel = gene_groups_sel self.gene_groups_palette = gene_groups_palette self.kwargs = kwargs self.sample_groups_col_df, self.sample_color_dict = self._get_group_colors( sample_groups_df, sample_groups_sel, sample_groups_palette ) self.gene_groups_col_df, self.gene_color_dict = self._get_group_colors( gene_groups_df, gene_groups_sel, gene_groups_palette ) self.yticklabels = self._convert_gene_names(yticklabels) self.params = {"legend.fontsize": 16, "legend.sample": True, "legend.gene": True} def _convert_gene_names(self, yticklabels): if isinstance(yticklabels, pd.Series): return yticklabels.loc[self.data_df.index] else: return yticklabels def _get_group_colors(self, groups_df, selection, palette): if isinstance(groups_df, pd.DataFrame): if selection: groups_df = groups_df.loc[:, selection] groups = [x for y in groups_df for x in groups_df[y].unique()] col_map = dict(zip(groups, palette)) groups_col_df = groups_df.apply(lambda x: x.map(col_map)) return groups_col_df, col_map else: return None, None def _do_legend(self, figure, color_dict, bbox_to_anchor): if color_dict: import matplotlib.patches as mpatches patches = [mpatches.Patch(color=c, label=l) for l, c in color_dict.items()] legend = pylab.legend( loc="upper center", handles=patches, bbox_to_anchor=bbox_to_anchor, frameon=True, title="Sample groups", fontsize=self.params["legend.fontsize"], ) figure.add_artist(legend)
[docs] def plot(self, cmap=None): if cmap is None: cmap = sns.diverging_palette(220, 10, as_cmap=True) cmap.set_bad("grey", 1.0) p = sns.clustermap( self.data_df, cmap=cmap, col_colors=self.sample_groups_col_df, yticklabels=self.yticklabels, row_colors=self.gene_groups_col_df, **self.kwargs, ) f = pylab.gca() if self.params["legend.sample"]: self._do_legend(f, self.sample_color_dict, (12, 2)) if self.params["legend.gene"]: self._do_legend(f, self.gene_color_dict, (-2, -2)) return p
[docs]class Heatmap(Linkage): """Heatmap and dendograms of an input matrix A heat map is an image representation of a matrix with a dendrogram added to the left side and to the top. Typically, reordering of the rows and columns according to some set of values (row or column means) within the restrictions imposed by the dendrogram is carried out. .. plot:: :include-source: :width: 80% from sequana.viz import heatmap df = heatmap.get_heatmap_df() h = heatmap.Heatmap(df) h.plot() side colors can be added: h = viz.Heatmap(df, col_side_colors=['r', 'g', 'b', 'y', 'k']); h.category_column = category; h.category_row = category where category is a dictionary with keys as df.columns and values as category defined by you. The number of colors in col_side_colors and row_side_colors should match the number of category """ def __init__( self, data=None, row_method="complete", column_method="complete", row_metric="euclidean", column_metric="euclidean", cmap="yellow_black_blue", col_side_colors=None, row_side_colors=None, verbose=True, ): """.. rubric:: constructor :param data: a dataframe or possibly a numpy matrix. :param row_method: *complete* by default :param column_method: *complete* by default. See linkage module for details :param row_metric: *euclidean* by default :param column_metric: *euclidean* by default :param cmap: colormap. any matplotlib accepted or combo of colors as defined in colormap package (pypi) :param col_side_colors: :param row_side_colors: """ # should be a copy since it may be reshuffled ? try: if data is None and verbose is True: print("No data provided, please fill the `df` attribute manually") elif data is None: pass else: self._df = data.copy() except AttributeError as err: print("input must be a pandas data frame or numpy matrix") raise (err) self._row_method = row_method self._column_method = column_method self._column_metric = column_metric self._row_metric = row_metric # some default parameters self.cluster_criterion = "distance" class Params: pass self.params = Params() self.params.col_side_colors = ["r", "g", "b", "y", "w", "k", "m"] self.params.row_side_colors = ["r", "g", "b", "y", "w", "k", "m"] self.params.cmap = cmap self.category_row = {} self.category_column = {} if col_side_colors: self.params.col_side_colors = col_side_colors if row_side_colors: self.params.row_side_colors = row_side_colors def _get_df(self): return self._df def _set_df(self, data): self._df = data.copy() df = property(_get_df, _set_df) frame = property(_get_df, _set_df) def _get_row_method(self): return self._row_method def _set_row_method(self, value): self.check_method(value) self._row_method = value row_method = property(_get_row_method, _set_row_method) def _get_col_method(self): return self._column_method def _set_col_method(self, value): self.check_method(value) self._column_method = value column_method = property(_get_col_method, _set_col_method) def _get_col_metric(self): return self._column_metric def _set_col_metric(self, value): self.check_metric(value) self._column_metric = value column_metric = property(_get_col_metric, _set_col_metric) def _get_row_metric(self): return self._row_metric def _set_row_metric(self, value): self.check_metric(value) self._row_metric = value row_metric = property(_get_row_metric, _set_row_metric)
[docs] def plot( self, num=1, cmap=None, colorbar=True, vmin=None, vmax=None, colorbar_position="right", gradient_span="None", figsize=(12, 8), fontsize=None, ): """ Using as input:: df = pd.DataFrame({'A':[1,0,1,1], 'B':[.9,0.1,.6,1], 'C':[.5,.2,0,1], 'D':[.5,.2,0,1]}) we can plot the heatmap + dendogram as follows:: h = Heatmap(df) h.plot(vmin=0, vmax=1.1) .. plot:: :include-source: :width: 80% from sequana.viz import heatmap df = heatmap.get_heatmap_df() h = heatmap.Heatmap(df) h.category_column['A'] = 1 h.category_column['C'] = 1 h.category_column['D'] = 2 h.category_column['B'] = 2 h.plot() """ import scipy.cluster.hierarchy as hierarchy # save all parameters in a dict layout = {} if cmap is None: cmap = self.params.cmap try: import colormap cmap = colormap.cmap_builder(cmap) except Exception: # pragma: no cover pass # keep track of row and column names for later. row_header = self.frame.index column_header = self.frame.columns import matplotlib # FIXME something clever for the fontsize if len(row_header) > 100 or len(column_header) > 100: matplotlib.rcParams["font.size"] = 6 if len(row_header) > 50 or len(column_header) > 50: matplotlib.rcParams["font.size"] = 7 if len(row_header) > 30 or len(column_header) > 30: matplotlib.rcParams["font.size"] = 8 else: matplotlib.rcParams["font.size"] = 12 if fontsize: matplotlib.rcParams["font.size"] = fontsize # scaling min/max range self.gradient_span = gradient_span #'only_max' # min_to_max, min_to_max_centered, only_max, only_min if self.gradient_span == "min_to_max_centered": vmax = self.frame.max().max() vmin = self.frame.min().min() vmax = max([vmax, abs(vmin)]) vmin = vmax * -1 if self.gradient_span == "only_max": vmin = 0 vmax = self.frame.max().max() if self.gradient_span == "only_min": vmin = self.frame.min().min() vmax = 0 norm = matplotlib.colors.Normalize(vmin, vmax) # Scale the figure window size # fig = pylab.figure(num=num, figsize=figsize) fig.clf() # LAYOUT -------------------------------------------------- # ax1 (dendrogram 1) on the left of the heatmap [ax1_x, ax1_y, ax1_w, ax1_h] = [0.05, 0.22, 0.2, 0.6] width_between_ax1_axr = 0.004 # distance between the top color bar axis and the matrix height_between_ax1_axc = 0.004 # Sufficient size to show color_bar_w = 0.015 # axr, placement of row side colorbar # second to last controls the width of the side color bar - 0.015 when showing [axr_x, axr_y, axr_w, axr_h] = [0.31, 0.1, color_bar_w, 0.6] axr_x = ax1_x + ax1_w + width_between_ax1_axr axr_y = ax1_y axr_h = ax1_h width_between_axr_axm = 0.004 # axc, placement of column side colorbar # # last one controls the hight of the top color bar - 0.015 when showing [axc_x, axc_y, axc_w, axc_h] = [0.4, 0.63, 0.5, color_bar_w] axc_x = axr_x + axr_w + width_between_axr_axm axc_y = ax1_y + ax1_h + height_between_ax1_axc height_between_axc_ax2 = 0.004 # axm, placement of heatmap for the data matrix # why larger than 1? [axm_x, axm_y, axm_w, axm_h] = [0.4, 0.9, 2.5, 0.5] axm_x = axr_x + axr_w + width_between_axr_axm axm_y = ax1_y axm_h = ax1_h axm_w = axc_w # ax2 (dendrogram 2), on the top of the heatmap # [ax2_x, ax2_y, ax2_w, ax2_h] = [0.3, 0.72, 0.6, 0.15] ax2_x = axr_x + axr_w + width_between_axr_axm ax2_y = ax1_y + ax1_h + height_between_ax1_axc + axc_h + height_between_axc_ax2 ax2_w = axc_w # axcb - placement of the color legend # if colorbar_position == "top left": [axcb_x, axcb_y, axcb_w, axcb_h] = [0.07, 0.88, 0.18, 0.09] elif colorbar_position == "right": [axcb_x, axcb_y, axcb_w, axcb_h] = [0.85, 0.2, 0.08, 0.6] else: raise ValueError("'top left' or 'right' accepted for now") # COMPUTATION DENDOGRAM 1 ------------------------------------- if self.column_method: Y = self.linkage(self.frame.transpose(), self.column_method, self.column_metric) ax2 = fig.add_axes([ax2_x, ax2_y, ax2_w, ax2_h], frame_on=True) # p=30, truncate_mode=None, color_threshold=None, get_leaves=True, # orientation='top labels=None, count_sort=False, distance_sort=False, # show_leaf_counts=True, no_plot=False, no_labels=False, leaf_font_size=None, # leaf_rotation=None, leaf_label_func=None, show_contracted=False, # link_color_func=None, ax=None, above_threshold_color='b', # # color_threshold=0 and above_threshold_color='k' colors all # dendogram into black Z = hierarchy.dendrogram( Y, color_threshold=0, above_threshold_color="k", distance_sort="descending", ) ind2 = hierarchy.fcluster(Y, 0.7 * max(Y[:, 2]), self.cluster_criterion) ax2.set_xticks([]) ax2.set_yticks([]) # apply the clustering for the array-dendrograms to the actual matrix data idx2 = Z["leaves"] self.frame = self.frame.iloc[:, idx2] # reorder the flat cluster to match the order of the leaves the dendrogram ind2 = ind2[idx2] layout["dendogram2"] = ax2 else: idx2 = range(self.frame.shape[1]) # COMPUTATION DENDOGRAM 2 --------------------------------- if self.row_method: Y = self.linkage(self.frame, self.row_method, self.row_metric) ax1 = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True) Z = hierarchy.dendrogram( Y, orientation="left", color_threshold=0, above_threshold_color="k", distance_sort="descending", ) ind1 = hierarchy.fcluster(Y, 0.7 * max(Y[:, 2]), self.cluster_criterion) ax1.set_xticks([]) ax1.set_yticks([]) # apply the clustering for the array-dendrograms to the actual matrix data idx1 = Z["leaves"] self.frame = self.frame.iloc[idx1, :] # reorder the flat cluster to match the order of the leaves the dendrogram ind1 = ind1[idx1] layout["dendogram1"] = ax1 else: idx1 = range(self.frame.shape[0]) # HEATMAP itself axm = fig.add_axes([axm_x, axm_y, axm_w, axm_h]) axm.imshow( self.frame, aspect="auto", origin="lower", interpolation="None", cmap=cmap, norm=norm, ) axm.set_xticks([]) axm.set_yticks([]) layout["heatmap"] = axm # TEXT new_row_header = [] new_column_header = [] for i in range(self.frame.shape[0]): axm.text( self.frame.shape[1] - 0.5, i, " " + str(row_header[idx1[i]]), verticalalignment="center", ) new_row_header.append(row_header[idx1[i]] if self.row_method else row_header[i]) for i in range(self.frame.shape[1]): axm.text( i, -0.9, " " + str(column_header[idx2[i]]), rotation=90, verticalalignment="top", horizontalalignment="center", ) new_column_header.append(column_header[idx2[i]] if self.column_method else column_header[i]) # CATEGORY column ------------------------------ if self.category_column: axc = fig.add_axes([axc_x, axc_y, axc_w, axc_h]) category_col = [self.category_column[x] for x in self.frame.columns] dc = np.array(category_col, dtype=int) dc.shape = (1, len(ind2)) cmap_c = matplotlib.colors.ListedColormap(self.params.col_side_colors) axc.matshow(dc, aspect="auto", origin="lower", cmap=cmap_c) axc.set_xticks([]) axc.set_yticks([]) layout["category_column"] = axc # CATEGORY row ------------------------------- if self.category_row: axr = fig.add_axes([axr_x, axr_y, axr_w, axr_h]) # self.category_row must be a dictionary with names as found in the columns # of the dataframe. category_row = [self.category_row[x] for x in self.frame.index] dr = np.array(category_row, dtype=int) dr.shape = (len(category_row), 1) cmap_r = matplotlib.colors.ListedColormap(self.params.row_side_colors) axr.matshow(dr, aspect="auto", origin="lower", cmap=cmap_r) axr.set_xticks([]) axr.set_yticks([]) layout["category_row"] = axr # COLORBAR ---------------------- if colorbar == True: axcb = fig.add_axes([axcb_x, axcb_y, axcb_w, axcb_h], frame_on=False) if colorbar_position == "right": orientation = "vertical" else: orientation = "horizontal" cb = matplotlib.colorbar.ColorbarBase(ax=axcb, cmap=cmap, norm=norm, orientation=orientation) # axcb.set_title("whatever") # max_cb_ticks = 5 # axcb.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(max_cb_ticks)) layout["colorbar"] = cb layout["colorbar_scalablemap"] = axcb # could be useful self.d = {"ordered": self.frame.copy(), "rorder": idx1, "corder": idx2} return layout