#!/usr/bin/env python
# @Author: Kelvin
# @Date: 2020-05-18 00:15:00
# @Last Modified by: Kelvin
# @Last Modified time: 2020-12-30 01:46:37
import seaborn as sns
import pandas as pd
import numpy as np
from ..utilities._utilities import *
from ..tools._diversity import rarefun
from scanpy.plotting._tools.scatterplots import embedding
import matplotlib.pyplot as plt
from anndata import AnnData
import random
from adjustText import adjust_text
from plotnine import ggplot, theme_classic, aes, geom_line, xlab, ylab, options, ggtitle, labs, scale_color_manual
from scanpy.plotting import palettes
from time import sleep
import matplotlib.pyplot as plt
from itertools import combinations
[docs]def clone_rarefaction(self, groupby, clone_key=None, palette=None, figsize=(6,4), save=None):
"""
Plots rarefaction curve for cell numbers vs clone size.
Parameters
----------
self : AnnData
`AnnData` object.
groupby : str
Column name to split the calculation of clone numbers for a given number of cells for e.g. sample, patient etc.
clone_key : str, optional
Column name specifying the clone_id column in metadata/obs.
palette : sequence, optional
Color mapping for unique elements in groupby. Will try to retrieve from AnnData `.uns` slot if present.
figsize : tuple[float, float]
Size of plot.
save : str, optional
Save path.
Returns
-------
rarefaction curve plot.
"""
if self.__class__ == AnnData:
metadata = self.obs.copy()
if clone_key is None:
clonekey = 'clone_id'
else:
clonekey = clone_key
groups = list(set(metadata[groupby]))
metadata = metadata[metadata['bcr_QC_pass'].isin([True, 'True'])]
metadata[clonekey] = metadata[clonekey].cat.remove_unused_categories()
res = {}
for g in groups:
_metadata = metadata[metadata[groupby]==g]
res[g] = _metadata[clonekey].value_counts()
res_ = pd.DataFrame.from_dict(res, orient = 'index')
# remove those with no counts
rowsum = res_.sum(axis = 1)
print('removing due to zero counts:', ', '.join([res_.index[i] for i, x in enumerate(res_.sum(axis = 1) == 0) if x]))
sleep(0.5)
res_ = res_[~(res_.sum(axis = 1) == 0)]
# set up for calculating rarefaction
tot = res_.apply(sum, axis = 1)
S = res_.apply(lambda x: x[x > 0].shape[0], axis = 1)
nr = res_.shape[0]
# append the results to a dictionary
rarecurve = {}
for i in tqdm(range(0, nr), desc = 'Calculating rarefaction curve '):
n = np.arange(1, tot[i], step = 10)
if (n[-1:] != tot[i]):
n = np.append(n, tot[i])
rarecurve[res_.index[i]] = [rarefun(np.array(res_.iloc[i,]), z) for z in n]
y = pd.DataFrame([rarecurve[c] for c in rarecurve]).T
pred = pd.DataFrame([np.append(np.arange(1, s, 10),s) for s in res_.sum(axis = 1)], index = res_.index).T
y = y.melt()
pred = pred.melt()
pred['yhat'] = y['value']
options.figure_size = figsize
if palette is None:
if self.__class__ == AnnData:
try:
pal = self.uns[str(groupby)+'_colors']
except:
if len(list(set((pred.variable)))) <= 20:
pal = palettes.default_20
elif len(list(set((pred.variable)))) <= 28:
pal = palettes.default_28
elif len(list(set((pred.variable)))) <= 102:
pal = palettes.default_102
else:
pal = None
if pal is not None:
p = (ggplot(pred, aes(x = "value", y = "yhat", color = "variable"))
+ theme_classic()
+ xlab('number of cells')
+ ylab('number of clones')
+ ggtitle('rarefaction curve')
+ labs(color = groupby)
+ scale_color_manual(values=(pal))
+ geom_line())
else:
p = (ggplot(pred, aes(x = "value", y = "yhat", color = "variable"))
+ theme_classic()
+ xlab('number of cells')
+ ylab('number of clones')
+ ggtitle('rarefaction curve')
+ labs(color = groupby)
+ geom_line())
else:
if len(list(set((pred.variable)))) <= 20:
pal = palettes.default_20
elif len(list(set((pred.variable)))) <= 28:
pal = palettes.default_28
elif len(list(set((pred.variable)))) <= 102:
pal = palettes.default_102
else:
pal = None
if pal is not None:
p = (ggplot(pred, aes(x = "value", y = "yhat", color = "variable"))
+ theme_classic()
+ xlab('number of cells')
+ ylab('number of clones')
+ ggtitle('rarefaction curve')
+ labs(color = groupby)
+ scale_color_manual(values=(pal))
+ geom_line())
else:
p = (ggplot(pred, aes(x = "value", y = "yhat", color = "variable"))
+ theme_classic()
+ xlab('number of cells')
+ ylab('number of clones')
+ ggtitle('rarefaction curve')
+ labs(color = groupby)
+ geom_line())
else:
p = (ggplot(pred, aes(x = "value", y = "yhat", color = "variable"))
+ theme_classic()
+ xlab('number of cells')
+ ylab('number of clones')
+ ggtitle('rarefaction curve')
+ labs(color = groupby)
+ geom_line())
if save:
p.save(filename = 'figures/rarefaction'+str(save), height= plt.rcParams['figure.figsize'][0], width= plt.rcParams['figure.figsize'][1], units = 'in', dpi= plt.rcParams["savefig.dpi"])
return(p)
def random_palette(n):
# a list of 900+colours
cols = list(sns.xkcd_rgb.keys())
# if max_colors_needed1 > len(cols):
cols2 = list(sns.color_palette('husl', n))
palette = random.sample(sns.xkcd_palette(cols) + cols2, n)
return(palette)
[docs]def clone_network(adata, basis = 'bcr', edges = True, **kwargs):
"""
Using scanpy's plotting module to plot the network. Only thing that is changed is the dfault options: `basis = 'bcr'` and `edges = True`.
Parameters
----------
adata : AnnData
AnnData object.
basis : str
key for embedding. Default is 'bcr'.
edges : bool
whether or not to plot edges. Default is True.
**kwargs
passed `sc.pl.embedding`.
"""
embedding(adata, basis = basis, edges = edges, **kwargs)
[docs]def barplot(self, variable, palette = 'Set1', figsize = (12, 4), normalize = True, sort_descending = True, title = None, xtick_rotation = None, min_clone_size = None, clone_key = None, **kwargs):
"""
A barplot function to plot usage of V/J genes in the data.
Parameters
----------
self : Dandelion, AnnData
`Dandelion` or `AnnData` object.
variable : str
column name in metadata for plotting in bar plot.
palette : str
Colors to use for the different levels of the variable. Should be something that can be interpreted by [color_palette](https://seaborn.pydata.org/generated/seaborn.color_palette.html#seaborn.color_palette), or a dictionary mapping hue levels to matplotlib colors. See [seaborn.barplot](https://seaborn.pydata.org/generated/seaborn.barplot.html).
figsize : tuple[float, float]
figure size. Default is (12, 4).
normalize : bool
if True, will return as proportion out of 1, otherwise False will return counts. Default is True.
sort_descending : bool
whether or not to sort the order of the plot. Default is True.
title : str, optional
title of plot.
xtick_rotation : int, optional
rotation of x tick labels.
min_clone_size : int, optional
minimum clone size to keep. Defaults to 1 if left as None.
clone_key : str, optional
column name for clones. None defaults to 'clone_id'.
**kwargs
passed to `sns.barplot`.
Returns
-------
a seaborn barplot.
"""
if self.__class__ == Dandelion:
data = self.metadata.copy()
elif self.__class__ == AnnData:
data = self.obs.copy()
if min_clone_size is None:
min_size = 1
else:
min_size = int(min_clone_size)
if clone_key is None:
clone_ = 'clone_id'
else:
clone_ = clone_key
size = data[clone_].value_counts()
keep = list(size[size >= min_size].index)
data_ = data[data[clone_].isin(keep)]
sns.set_style('whitegrid', {'axes.grid' : False})
res = pd.DataFrame(data_[variable].value_counts(normalize=normalize))
if not sort_descending:
res = res.sort_index()
res.reset_index(drop = False, inplace = True)
# Initialize the matplotlib figure
fig, ax = plt.subplots(figsize=figsize)
# plot
sns.barplot(x='index', y = variable, data=res, palette = palette, **kwargs)
# change some parts
if title is None:
ax.set_title(variable.replace('_', ' ')+' usage')
else:
ax.set_title(title)
if normalize:
ax.set_ylabel('proportion')
else:
ax.set_ylabel('count')
ax.set_xlabel('')
if xtick_rotation is None:
plt.xticks(rotation=90)
else:
plt.xticks(rotation=xtick_rotation)
return fig, ax
[docs]def stackedbarplot(self, variable, groupby, figsize = (12, 4), normalize = False, title = None, sort_descending=True, xtick_rotation=None, hide_legend=True, legend_options = None, labels=None, min_clone_size = None, clone_key = None, **kwargs):
"""
A stackedbarplot function to plot usage of V/J genes in the data split by groups.
Parameters
----------
self : Dandelion, AnnData
`Dandelion` or `AnnData` object.
variable : str
column name in metadata for plotting in bar plot.
groupby : str
column name in metadata to split by during plotting.
figsize : tuple[float, float]
figure size. Default is (12, 4).
normalize : bool
if True, will return as proportion out of 1, otherwise False will return counts. Default is True.
sort_descending : bool
whether or not to sort the order of the plot. Default is True.
title : str, optional
title of plot.
xtick_rotation : int, optional
rotation of x tick labels.
hide_legend : bool
whether or not to hide the legend.
legend_options : tuple[str, tuple[float, float], int]
a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int).
labels : list
Names of objects will be used for the legend if list of multiple dataframes supplied.
min_clone_size : int, optional
minimum clone size to keep. Defaults to 1 if left as None.
clone_key : str, optional
column name for clones. None defaults to 'clone_id'.
**kwargs
other kwargs passed to `matplotlib.plt`.
Returns
-------
stacked bar plot.
"""
if self.__class__ == Dandelion:
data = self.metadata.copy()
elif self.__class__ == AnnData:
data = self.obs.copy()
data[groupby] = [str(l) for l in data[groupby]] # quick fix to prevent dropping of nan
if min_clone_size is None:
min_size = 1
else:
min_size = int(min_clone_size)
if clone_key is None:
clone_ = 'clone_id'
else:
clone_ = clone_key
size = data[clone_].value_counts()
keep = list(size[size >= min_size].index)
data_ = data[data[clone_].isin(keep)]
dat_ = pd.DataFrame(data_.groupby(variable)[groupby].value_counts(normalize=normalize).unstack(fill_value=0).stack(), columns = ['value'])
dat_.reset_index(drop = False, inplace = True)
dat_order = pd.DataFrame(data[variable].value_counts(normalize=normalize))
dat_ = dat_.pivot(index=variable, columns=groupby, values='value')
if sort_descending is True:
dat_ = dat_.reindex(dat_order.index)
elif sort_descending is False:
dat_ = dat_.reindex(dat_order.index[::-1])
elif sort_descending is None:
dat_ = dat_.sort_index()
def _plot_bar_stacked(dfall, labels=None, figsize = (12, 4), title="multiple stacked bar plot", xtick_rotation=None, legend_options = None, hide_legend=True, H="/", **kwargs):
"""
Given a list of dataframes, with identical columns and index, create a clustered stacked bar plot.
Parameters
----------
labels
a list of the dataframe objects. Names of objects will be used for the legend.
title
string for the title of the plot
H
is the hatch used for identification of the different dataframes
**kwargs
other kwargs passed to matplotlib.plt
"""
if type(dfall) is not list:
dfall = [dfall]
n_df = len(dfall)
n_col = len(dfall[0].columns)
n_ind = len(dfall[0].index)
# Initialize the matplotlib figure
fig, ax = plt.subplots(figsize=figsize)
for df in dfall : # for each data frame
ax = df.plot(kind="bar",
linewidth=0,
stacked=True,
ax=ax,
legend=False,
grid=False,
**kwargs) # make bar plots
h,l = ax.get_legend_handles_labels() # get the handles we want to modify
for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df
for j, pa in enumerate(h[i:i+n_col]):
for rect in pa.patches: # for each index
rect.set_x(rect.get_x() + 1 / float(n_df + 1) * i / float(n_col))
rect.set_hatch(H * int(i / n_col)) #edited part
rect.set_width(1 / float(n_df + 1))
ax.set_xticks((np.arange(0, 2 * n_ind, 2) + 1 / float(n_df + 1)) / 2.)
ax.set_xticklabels(df.index, rotation = 0)
ax.set_title(title)
if normalize:
ax.set_ylabel('proportion')
else:
ax.set_ylabel('count')
# Add invisible data to add another legend
n=[]
for i in range(n_df):
n.append(ax.bar(0, 0, color="grey", hatch=H * i))
if legend_options is None:
Legend = ('center right', (1.15, 0.5), 1)
else:
Legend = legend_options
if hide_legend is False:
l1 = ax.legend(h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol = Legend[2], frameon=False)
if labels is not None:
l2 = plt.legend(n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol = Legend[2], frameon=False)
ax.add_artist(l1)
if xtick_rotation is None:
plt.xticks(rotation=90)
else:
plt.xticks(rotation=xtick_rotation)
return fig, ax
if title is None:
title = "multiple stacked bar plot : " + variable.replace('_', ' ') +' usage'
else:
title = title
return _plot_bar_stacked(dat_, labels = labels, figsize = figsize, title = title, xtick_rotation = xtick_rotation, legend_options = legend_options, hide_legend = hide_legend, **kwargs)
[docs]def spectratype(self, variable, groupby, locus, clone_key = None, figsize = (6, 4), width = None, title = None, xtick_rotation=None, hide_legend=True, legend_options = None, labels=None, **kwargs):
"""
A stackedbarplot function to plot usage of V/J genes in the data split by groups.
Parameters
----------
self : Dandelion, AnnData
`Dandelion` or `AnnData` object.
variable : str
column name in metadata for plotting in bar plot.
groupby : str
column name in metadata to split by during plotting.
locus : str
either IGH or IGL.
figsize : tuple[float, float]
figure size. Default is (6, 4).
width : float, optional
width of bars.
title : str, optional
title of plot.
xtick_rotation : int, optional
rotation of x tick labels.
hide_legend : bool
whether or not to hide the legend.
legend_options : tuple[str, tuple[float, float], int]
a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int).
labels : list
Names of objects will be used for the legend if list of multiple dataframes supplied.
**kwargs
other kwargs passed to matplotlib.pyplot.plot
Returns
-------
sectratype plot
"""
if clone_key is None:
clonekey = 'clone_id'
else:
clonekey = clone_key
if self.__class__ == Dandelion:
data = self.data.copy()
else:
try:
data = self.copy()
except:
AttributeError("Please provide a <class 'Dandelion'> class object or a pandas dataframe instead of %s." % self.__class__)
if 'locus' not in data.columns:
raise AttributeError("Please ensure dataframe contains 'locus' column")
if type(locus) is not list:
locus = [locus]
data = data[data['locus'].isin(locus)]
data[groupby] = [str(l) for l in data[groupby]]
dat_ = pd.DataFrame(data.groupby(variable)[groupby].value_counts(normalize=False).unstack(fill_value=0).stack(), columns = ['value'])
dat_.reset_index(drop = False, inplace = True)
dat_[variable] = pd.to_numeric(dat_[variable], errors='coerce')
dat_.sort_values(by = variable)
dat_2 = dat_.pivot(index=variable, columns=groupby, values='value')
new_index = range(0, int(dat_[variable].max())+1)
dat_2 = dat_2.reindex(new_index, fill_value=0)
def _plot_spectra_stacked(dfall, labels=None, figsize = (6, 4), title="multiple stacked bar plot", width = None, xtick_rotation=None, legend_options = None, hide_legend=True, H="/", **kwargs):
if type(dfall) is not list:
dfall = [dfall]
n_df = len(dfall)
n_col = len(dfall[0].columns)
n_ind = len(dfall[0].index)
if width is None:
wdth = 0.1 * n_ind/60+0.8
else:
wdth = width
# Initialize the matplotlib figure
fig, ax = plt.subplots(figsize=figsize)
for df in dfall : # for each data frame
ax = df.plot(kind="bar",
linewidth=0,
stacked=True,
ax=ax,
legend=False,
grid=False,
**kwargs) # make bar plots
h,l = ax.get_legend_handles_labels() # get the handles we want to modify
for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df
for j, pa in enumerate(h[i:i+n_col]):
for rect in pa.patches: # for each index
rect.set_x(rect.get_x() + 1 / float(n_df + 1) * i / float(n_col))
rect.set_hatch(H * int(i / n_col)) #edited part
rect.set_width(wdth) # need to see if there's a better way to toggle this.
n = 5 # Keeps every 5th label visible and hides the rest
[l.set_visible(False) for (i,l) in enumerate(ax.xaxis.get_ticklabels()) if i % n != 0]
ax.set_title(title)
ax.set_ylabel('count')
# Add invisible data to add another legend
n=[]
for i in range(n_df):
n.append(ax.bar(0, 0, color="gray", hatch=H * i))
if legend_options is None:
Legend = ('center right', (1.25, 0.5), 1)
else:
Legend = legend_options
if hide_legend is False:
l1 = ax.legend(h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol = Legend[2], frameon=False)
if labels is not None:
l2 = plt.legend(n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol = Legend[2], frameon=False)
ax.add_artist(l1)
if xtick_rotation is None:
plt.xticks(rotation=0)
else:
plt.xticks(rotation=xtick_rotation)
return fig, ax
return _plot_spectra_stacked(dat_2, labels = labels, figsize = figsize, title = title, width = width, xtick_rotation = xtick_rotation, legend_options = legend_options, hide_legend =hide_legend, **kwargs)
[docs]def clone_overlap(self, groupby, colorby, min_clone_size = None, clone_key = None, color_mapping = None, node_labels = True, node_label_layout = 'rotation', group_label_position = 'middle', group_label_offset = 8, figsize = (8, 8), return_graph = False, save = None, **kwargs):
"""
A plot function to visualise clonal overlap as a circos-style plot.
Parameters
----------
self : Dandelion, AnnData
`Dandelion` or `AnnData` object.
groupby : str
column name in obs/metadata for collapsing to nodes in circos plot.
colorby : str
column name in obs/metadata for grouping and color of nodes in circos plot.
min_clone_size : int, optional
minimum size of clone for plotting connections. Defaults to 2 if left as None.
clone_key : str, optional
column name for clones. None defaults to 'clone_id'.
color_maopping : dict, sequence, optional
custom color mapping provided as a sequence (correpsonding to order of categories or alpha-numeric order if dtype is not category), or dictionary containing custom {category:color} mapping.
node_labels : bool, optional
whether to use node objects as labels or not
node_label_layout : bool, optional
which/whether (a) node layout is used. One of 'rotation', 'numbers' or None.
group_label_position : str
The position of the group label. One of 'beginning', 'middle' or 'end'.
group_label_offset : int, float
how much to offset the group labels, so that they are not overlapping with node labels.
figsize : tuple[float, float]
figure size. Default is (8, 8).
return_graph : bool
whether or not to return the graph for fine tuning. Default is False.
**kwargs
passed to `matplotlib.pyplot.savefig`.
Returns
-------
a `nxviz.CircosPlot`.
"""
import networkx as nx
try:
import nxviz as nxv
except:
raise(ImportError("Unable to import module `nxviz`. Have you done install nxviz? Try pip install git+https://github.com/zktuong/nxviz.git"))
if min_clone_size is None:
min_size = 2
else:
min_size = int(min_clone_size)
if clone_key is None:
clone_ = 'clone_id'
else:
clone_ = clone_key
if self.__class__ == AnnData:
data = self.obs.copy()
# get rid of problematic rows that appear because of category conversion?
data = data[~(data[clone_].isin([np.nan, 'nan', 'NaN', None]))]
if 'clone_overlap' in self.uns:
overlap = self.uns['clone_overlap'].copy()
else:
# prepare a summary table
datc_ = data[clone_].str.split('|', expand = True).stack()
datc_ = pd.DataFrame(datc_)
datc_.reset_index(drop = False, inplace = True)
datc_.columns = ['cell_id', 'tmp', clone_]
datc_.drop('tmp', inplace = True, axis = 1)
dictg_ = dict(data[groupby])
datc_[groupby] = [dictg_[l] for l in datc_['cell_id']]
overlap = pd.crosstab(data[clone_], data[groupby])
if min_size == 0:
raise ValueError('min_size must be greater than 0.')
elif min_size > 2:
overlap[overlap < min_size] = 0
overlap[overlap >= min_size] = 1
elif min_size == 2:
overlap[overlap >= min_size] = 1
overlap.index.name = None
overlap.columns.name = None
elif self.__class__ == Dandelion:
data = self.metadata.copy()
# get rid of problematic rows that appear because of category conversion?
data = data[~(data[clone_].isin([np.nan, 'nan', 'NaN', None]))]
# prepare a summary table
datc_ = data[clone_].str.split('|', expand = True).stack()
datc_ = pd.DataFrame(datc_)
datc_.reset_index(drop = False, inplace = True)
datc_.columns = ['cell_id', 'tmp', clone_]
datc_.drop('tmp', inplace = True, axis = 1)
dictg_ = dict(data[groupby])
datc_[groupby] = [dictg_[l] for l in datc_['cell_id']]
overlap = pd.crosstab(data[clone_], data[groupby])
if min_size == 0:
raise ValueError('min_size must be greater than 0.')
elif min_size > 2:
overlap[overlap < min_size] = 0
overlap[overlap >= min_size] = 1
elif min_size == 2:
overlap[overlap >= min_size] = 1
overlap.index.name = None
overlap.columns.name = None
edges = {}
for x in overlap.index:
if overlap.loc[x].sum() > 1:
edges[x] = [y + ({str(clone_):x},) for y in list(combinations([i for i in overlap.loc[x][overlap.loc[x] == 1].index], 2))]
# create graph
G = nx.Graph()
# add in the nodes
G.add_nodes_from([(p, {str(colorby): d}) for p,d in zip(data[groupby], data[colorby])])
# unpack the edgelist and add to the graph
for edge in edges:
G.add_edges_from(edges[edge])
groupby_dict = dict(zip(data[groupby], data[colorby]))
if color_mapping is None:
if self.__class__ == AnnData:
if pd.api.types.is_categorical_dtype(self.obs[groupby]):
try:
colorby_dict = dict(zip(list(self.obs[str(colorby)].cat.categories), self.uns[str(colorby)+'_colors']))
except:
pass
else:
if type(color_mapping) is dict:
colorby_dict = color_mapping
else:
if pd.api.types.is_categorical_dtype(data[groupby]):
colorby_dict = dict(zip(list(data[str(colorby)].cat.categories), color_mapping))
else:
colorby_dict = dict(zip(sorted(list(set(data[str(colorby)]))), color_mapping))
df = data[[groupby, colorby]]
if groupby == colorby:
df = data[[groupby]]
df = df.sort_values(groupby).drop_duplicates(subset=groupby, keep="first").reset_index(drop = True)
else:
df = df.sort_values(colorby).drop_duplicates(subset=groupby, keep="first").reset_index(drop = True)
c = nxv.CircosPlot(G,
node_color=colorby,
node_grouping=colorby,
node_labels=node_labels,
node_label_layout=node_label_layout,
group_label_position=group_label_position,
group_label_offset=group_label_offset,
figsize=figsize)
c.nodes = list(df[groupby])
if 'colorby_dict' in locals():
c.node_colors = [colorby_dict[groupby_dict[c]] for c in c.nodes]
c.compute_group_label_positions()
c.compute_group_colors()
c.draw()
if save is not None:
plt.savefig(save, bbox_inches = 'tight', **kwargs)
if return_graph:
return(c)