Source code for statsmodels.graphics.correlation

'''correlation plots

Author: Josef Perktold
License: BSD-3

example for usage with different options in
statsmodels\sandbox\examples\thirdparty\ex_ratereturn.py

'''
import numpy as np

from . import utils


[docs]def plot_corr(dcorr, xnames=None, ynames=None, title=None, normcolor=False, ax=None, cmap='RdYlBu_r'): """Plot correlation of many variables in a tight color grid. Parameters ---------- dcorr : ndarray Correlation matrix, square 2-D array. xnames : list of str, optional Labels for the horizontal axis. If not given (None), then the matplotlib defaults (integers) are used. If it is an empty list, [], then no ticks and labels are added. ynames : list of str, optional Labels for the vertical axis. Works the same way as `xnames`. If not given, the same names as for `xnames` are re-used. title : str, optional The figure title. If None, the default ('Correlation Matrix') is used. If ``title=''``, then no title is added. normcolor : bool or tuple of scalars, optional If False (default), then the color coding range corresponds to the range of `dcorr`. If True, then the color range is normalized to (-1, 1). If this is a tuple of two numbers, then they define the range for the color bar. ax : Matplotlib AxesSubplot instance, optional If `ax` is None, then a figure is created. If an axis instance is given, then only the main plot but not the colorbar is created. cmap : str or Matplotlib Colormap instance, optional The colormap for the plot. Can be any valid Matplotlib Colormap instance or name. Returns ------- fig : Matplotlib figure instance If `ax` is None, the created figure. Otherwise the figure to which `ax` is connected. Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt >>> import statsmodels.graphics.api as smg >>> hie_data = sm.datasets.randhie.load_pandas() >>> corr_matrix = np.corrcoef(hie_data.data.T) >>> smg.plot_corr(corr_matrix, xnames=hie_data.names) >>> plt.show() """ if ax is None: create_colorbar = True else: create_colorbar = False fig, ax = utils.create_mpl_ax(ax) import matplotlib as mpl from matplotlib import cm nvars = dcorr.shape[0] if ynames is None: ynames = xnames if title is None: title = 'Correlation Matrix' if isinstance(normcolor, tuple): vmin, vmax = normcolor elif normcolor: vmin, vmax = -1.0, 1.0 else: vmin, vmax = None, None axim = ax.imshow(dcorr, cmap=cmap, interpolation='nearest', extent=(0,nvars,0,nvars), vmin=vmin, vmax=vmax) # create list of label positions labelPos = np.arange(0, nvars) + 0.5 if not ynames is None: ax.set_yticks(labelPos) ax.set_yticks(labelPos[:-1]+0.5, minor=True) ax.set_yticklabels(ynames[::-1], fontsize='small', horizontalalignment='right') elif ynames == []: ax.set_yticks([]) if not xnames is None: ax.set_xticks(labelPos) ax.set_xticks(labelPos[:-1]+0.5, minor=True) ax.set_xticklabels(xnames, fontsize='small', rotation=45, horizontalalignment='right') elif xnames == []: ax.set_xticks([]) if not title == '': ax.set_title(title) if create_colorbar: fig.colorbar(axim, use_gridspec=True) fig.tight_layout() ax.tick_params(which='minor', length=0) ax.tick_params(direction='out', top=False, right=False) try: ax.grid(True, which='minor', linestyle='-', color='w', lw=1) except AttributeError: # Seems to fail for axes created with AxesGrid. MPL bug? pass return fig
[docs]def plot_corr_grid(dcorrs, titles=None, ncols=None, normcolor=False, xnames=None, ynames=None, fig=None, cmap='RdYlBu_r'): """Create a grid of correlation plots. The individual correlation plots are assumed to all have the same variables, axis labels can be specified only once. Parameters ---------- dcorrs : list or iterable of ndarrays List of correlation matrices. titles : list of str, optional List of titles for the subplots. By default no title are shown. ncols : int, optional Number of columns in the subplot grid. If not given, the number of columns is determined automatically. normcolor : bool or tuple, optional If False (default), then the color coding range corresponds to the range of `dcorr`. If True, then the color range is normalized to (-1, 1). If this is a tuple of two numbers, then they define the range for the color bar. xnames : list of str, optional Labels for the horizontal axis. If not given (None), then the matplotlib defaults (integers) are used. If it is an empty list, [], then no ticks and labels are added. ynames : list of str, optional Labels for the vertical axis. Works the same way as `xnames`. If not given, the same names as for `xnames` are re-used. fig : Matplotlib figure instance, optional If given, this figure is simply returned. Otherwise a new figure is created. cmap : str or Matplotlib Colormap instance, optional The colormap for the plot. Can be any valid Matplotlib Colormap instance or name. Returns ------- fig : Matplotlib figure instance If `ax` is None, the created figure. Otherwise the figure to which `ax` is connected. Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt >>> import statsmodels.api as sm In this example we just reuse the same correlation matrix several times. Of course in reality one would show a different correlation (measuring a another type of correlation, for example Pearson (linear) and Spearman, Kendall (nonlinear) correlations) for the same variables. >>> hie_data = sm.datasets.randhie.load_pandas() >>> corr_matrix = np.corrcoef(hie_data.data.T) >>> sm.graphics.plot_corr_grid([corr_matrix] * 8, xnames=hie_data.names) >>> plt.show() """ if ynames is None: ynames = xnames if not titles: titles = ['']*len(dcorrs) n_plots = len(dcorrs) if ncols is not None: nrows = int(np.ceil(n_plots / float(ncols))) else: # Determine number of rows and columns, square if possible, otherwise # prefer a wide (more columns) over a high layout. if n_plots < 4: nrows, ncols = 1, n_plots else: nrows = int(np.sqrt(n_plots)) ncols = int(np.ceil(n_plots / float(nrows))) # Create a figure with the correct size aspect = min(ncols / float(nrows), 1.8) vsize = np.sqrt(nrows) * 5 fig = utils.create_mpl_fig(fig, figsize=(vsize * aspect + 1, vsize)) for i, c in enumerate(dcorrs): ax = fig.add_subplot(nrows, ncols, i+1) # Ensure to only plot labels on bottom row and left column _xnames = xnames if nrows * ncols - (i+1) < ncols else [] _ynames = ynames if (i+1) % ncols == 1 else [] plot_corr(c, xnames=_xnames, ynames=_ynames, title=titles[i], normcolor=normcolor, ax=ax, cmap=cmap) # Adjust figure margins and add a colorbar fig.subplots_adjust(bottom=0.1, left=0.09, right=0.9, top=0.9) cax = fig.add_axes([0.92, 0.1, 0.025, 0.8]) fig.colorbar(fig.axes[0].images[0], cax=cax) return fig