Source code for mass.visualization.phase_portraits

# -*- coding: utf-8 -*-
r"""Contains function for visualizing phase portraits of simulation results.

See the  :mod:`mass.visualization` documentation for general information
on :mod:`mass.visualization` functions.

This module contains the following functions for visualization of
time-dependent solutions returned in :class:`~.MassSolution`\ s after
simulation of models.

    * :func:`~.phase_portraits.plot_phase_portrait`
    * :func:`~.phase_portraits.plot_ensemble_phase_portrait`
    * :func:`~.phase_portraits.plot_tiled_phase_portrait`

"""
from warnings import warn

import numpy as np
from six import iteritems, iterkeys, itervalues

from mass.util.util import _check_kwargs
from mass.visualization import visualization_util as v_util


[docs]def plot_phase_portrait(mass_solution, x, y, ax=None, legend=None, **kwargs): """Plot phase portraits of solutions in a given :class:`~.MassSolution`. Accepted ``kwargs`` are passed onto various :mod:`matplotlib` methods utilized in the function. See the :mod:`~mass.visualization` module documentation for more detailed information about the possible ``kwargs``. Notes ----- * To prevent any changes to the original :class:`~.MassSolution`, a copy of the :class:`~.MassSolution` is created and used. Parameters ---------- mass_solution : MassSolution The :class:`~.MassSolution` containing the time-dependent solutions to be plotted. x : :mod:`mass` object or its string identifier The string identifier of a :mod:`mass` object or the object itself that corresponds to the key for the desired solution in the :class:`~.MassSolution` for the x-axis of the phase portrait. y : :mod:`mass` object or its string identifier The string identifier of a :mod:`mass` object or the object itself that corresponds to the key for the desired solution in the :class:`~.MassSolution` for the y-axis of the phase portrait. ax : matplotlib.axes.Axes, None An :class:`~matplotlib.axes.Axes` instance to plot the data on. If ``None`` then the current axes instance is used. legend : iterable, str, int There are three possible input formats for the legend: 1. An iterable of legend labels as strings. 2. A ``str`` representing the location of the legend, or an ``int`` between 0 and 14 (inclusive) corresponding to the legend location. 3. An iterable of the format ``(labels, loc)`` to set both the legend labels and location, where ``labels`` and ``loc`` follows the format specified in **1** and **2**. See the :mod:`~mass.visualization` documentation for more information about legend and valid legend locations. **kwargs * time_vector * plot_function * title * xlabel * ylabel * xlim * ylim * grid * grid_color * grid_linestyle * grid_linewidth * prop_cycle * color * linestyle * linewidth * marker * markersize * legend_ncol * annotate_time_points * annotate_time_points_color * annotate_time_points_marker * annotate_time_points_markersize * annotate_time_points_labels * annotate_time_points_legend * deviation * deviation_zero_centered * deviation_normalization See :mod:`~mass.visualization` documentation for more information on optional ``kwargs``. Returns ------- ax : matplotlib.axes.Axes The :class:`~matplotlib.axes.Axes` instance containing the newly created plot. """ # Validate whether necessary packages are installed. v_util._validate_visualization_packages("matplotlib") # Check kwargs kwargs = _check_kwargs( get_phase_portrait_default_kwargs("plot_phase_portrait"), kwargs ) # Get the axies instance ax = v_util._validate_axes_instance(ax) # Validate the MassSolution input, ensure it is not empty. mass_solution = v_util._validate_mass_solution(mass_solution) if not mass_solution: return ax # Get the solutions to be observed and validate time vector. xy = v_util._validate_plot_observables(mass_solution, (x, y), **kwargs) # Get the plotting function or raise an error if invalid. plot_function = v_util._get_plotting_function( ax, plot_function_str=kwargs.get("plot_function"), valid={"plot", "semilogx", "semilogy", "loglog"}, ) label = "{0} vs. {1}".format(*v_util._group_xy_items(xy, iterkeys)) sols = v_util._group_xy_items(xy, itervalues) observable = {label: sols} # Get the legend arguments if desired. if legend is not None: legend_labels, legend_kwargs = v_util._get_legend_args( ax, legend, observable, **kwargs ) observable = v_util._map_labels_to_solutions(observable, legend_labels) else: legend_kwargs = None # Set line colors and styles using a custom cycler prop_cycler = v_util._get_line_property_cycler( n_current=len(v_util._get_ax_current(ax)), n_new=len(observable), **kwargs ) if prop_cycler: ax.set_prop_cycle(prop_cycler) # Plot lines onto axes using legend entries as labels (if legend valid). for label, sols in iteritems(observable): plot_function(*sols, label=label, zorder=kwargs.get("zorder")) # Set the axes options including axis labels, limits, and gridlines. v_util._set_axes_labels(ax, **kwargs) v_util._set_axes_limits(ax, **kwargs) v_util._set_axes_margins(ax, **kwargs) v_util._set_axes_gridlines(ax, **kwargs) # Set the legend if desired. if legend is not None: lines_and_labels = v_util._get_handles_and_labels(ax, False) legend = ax.legend(*lines_and_labels, **legend_kwargs) if kwargs.get("annotate_time_points", None) is not None: ax = v_util._set_annotated_time_points( ax, observable=xy, type_of_plot="phase_portrait", first_legend=(legend, legend_kwargs), time_range=xy.time, **kwargs ) # Reset default prop_cycle ax.set_prop_cycle(v_util._get_default_cycler()) return ax
[docs]def plot_ensemble_phase_portrait( mass_solution_list, x, y, ax=None, legend=None, **kwargs ): """Plot a phase portrait for an ensemble of class:`~.MassSolution` objects. The plotted lines represent the mean for the values of a particular solution specified in ``observable``. Accepted ``kwargs`` are passed onto various :mod:`matplotlib` methods utilized in the function. See the :mod:`~mass.visualization` module documentation for more detailed information about the possible ``kwargs``. Notes ----- * To prevent any changes to the original :class:`~.MassSolution`, copies of :class:`~.MassSolution` objects are created and used. Parameters ---------- mass_solution_list : iterable An iterable of :class:`~.MassSolution` objects containing the time-dependent solutions to be plotted. x : :mod:`mass` object or its string identifier The string identifier of a :mod:`mass` object or the object itself that corresponds to the key for the desired solution in the :class:`~.MassSolution` for the x-axis of the phase portrait. y : :mod:`mass` object or its string identifier The string identifier of a :mod:`mass` object or the object itself that corresponds to the key for the desired solution in the :class:`~.MassSolution` for the y-axis of the phase portrait. ax : matplotlib.axes.Axes, None An :class:`~matplotlib.axes.Axes` instance to plot the data on. If ``None`` then the current axes instance is used. legend : iterable, str, int There are three possible input formats for the legend: 1. An iterable of legend labels as strings. 2. A ``str`` representing the location of the legend, or an ``int`` between 0 and 14 (inclusive) corresponding to the legend location. 3. An iterable of the format ``(labels, loc)`` to set both the legend labels and location, where ``labels`` and ``loc`` follows the format specified in **1** and **2**. See the :mod:`~mass.visualization` documentation for more information about legend and valid legend locations. **kwargs * time_vector * plot_function * title * xlabel * ylabel * xlim * ylim * grid * grid_color * grid_linestyle * grid_linewidth * prop_cycle * color * linestyle * linewidth * marker * markersize * legend_ncol * annotate_time_points * annotate_time_points_color * annotate_time_points_marker * annotate_time_points_markersize * annotate_time_points_labels * annotate_time_points_legend * deviation * deviation_zero_centered * deviation_normalization See :mod:`~mass.visualization` documentation for more information on optional ``kwargs``. Returns ------- ax : matplotlib.axes.Axes The :class:`~matplotlib.axes.Axes` instance containing the newly created plot. """ # Validate whether necessary packages are installed. v_util._validate_visualization_packages("matplotlib") # Check kwargs kwargs = _check_kwargs( get_phase_portrait_default_kwargs("plot_ensemble_phase_portrait"), kwargs ) # Get the axes instance ax = v_util._validate_axes_instance(ax) # Validate MassSolutions mass_solution_list = [ sol for sol in mass_solution_list if v_util._validate_mass_solution(sol) ] if not mass_solution_list: warn("No valid MassSolution objects given") return ax # Get the solutions to be observed and validate time vector. xy, time_vector = v_util._validate_ensemble_plot_observables( mass_solution_list, (x, y), **kwargs ) label = "{0} vs. {1}".format(*v_util._group_xy_items(xy, iterkeys)) sols = v_util._group_xy_items(xy, itervalues) observable = {label: sols} # Get the legend arguments if desired. if legend is not None: legend_labels, legend_kwargs = v_util._get_legend_args( ax, legend, observable, **kwargs ) observable = v_util._map_labels_to_solutions(observable, legend_labels) else: legend_kwargs = None # Set line colors and styles using a custom cycler prop_cycler = v_util._get_line_property_cycler( n_current=len(v_util._get_ax_current(ax)), n_new=len(observable), **kwargs ) if prop_cycler: ax.set_prop_cycle(prop_cycler) _plot_ensemble_lines(ax, observable, **kwargs) # Set the axes options including axis labels, limits, and gridlines. v_util._set_axes_labels(ax, **kwargs) v_util._set_axes_limits(ax, **kwargs) v_util._set_axes_margins(ax, **kwargs) v_util._set_axes_gridlines(ax, **kwargs) # Set the legend if desired. if legend is not None: lines_and_labels = v_util._get_handles_and_labels(ax, False) legend = ax.legend(*lines_and_labels, **legend_kwargs) if kwargs.get("annotate_time_points", None): ax = v_util._set_annotated_time_points( ax, observable=xy, type_of_plot="phase_portrait", first_legend=(legend, legend_kwargs), time_range=time_vector, **kwargs ) # Reset default prop_cycle ax.set_prop_cycle(v_util._get_default_cycler()) return ax
def _plot_ensemble_lines(ax, observable, **kwargs): """Plot the mean and interval of the ensemble solutions. Warnings -------- This method is intended for internal use only. """ # Get the plotting function or raise an error if invalid. plot_function = v_util._get_plotting_function( ax, plot_function_str=kwargs.get("plot_function"), valid={"plot", "semilogx", "semilogy", "loglog"}, ) for label, (x_sol_df, y_sol_df) in iteritems(observable): plot_function(x_sol_df.mean(axis=0), y_sol_df.mean(axis=0), label=label)
[docs]def plot_tiled_phase_portraits( mass_solution, observable=None, ax=None, plot_tile_placement="all", additional_data=None, **kwargs ): """Plot phase portraits of solutions in a given :class:`~.MassSolution`. Accepted ``kwargs`` are passed onto various matplotlib methods in utilized in the function. See the :mod:`~mass.visualization` module documentation for more detailed information about the possible ``kwargs``. Notes ----- * To prevent any changes to the original :class:`~.MassSolution`, a copy of the :class:`~.MassSolution` will be created and used. * ``i`` and ``j`` represent the number of rows and columns, respectively. Parameters ---------- mass_solution : MassSolution The :class:`~.MassSolution` containing the time-dependent solutions to be plotted. observable : iterable An iterable containing string identifiers of the :mod:`mass` objects or the objects themselves that correspond to the keys for the desired solutions in the :class:`~.MassSolution`. ax : matplotlib.axes.Axes, None An :class:`~matplotlib.axes.Axes` instance to plot the data on. If ``None`` then the current axes instance is used. plot_tile_placement : str A string representing the location to place the tiles containing phase portrait plots. Must be one of the following: * ``"lower"`` to place plot tiles on the lower left triangular section ``(i < j)`` on the figure tiles. * ``"upper"`` to place plot tiles on the upper right triangular section ``(i > j)`` on the figure tiles. * ``all`` to place plot tiles on the lower left triangular section ``(i < j)`` AND on the upper right triangular section ``(i > j)`` on the figure tiles. additional_data : array_like, None A matrix of shape ``(N, N)`` where ``N_obs`` is the number of observables provided, or the number of keys in the :class:`~.MassSolution` if ``observable=None``. The value at ``(i, j)`` of the matrix must correspond to the empty tile that the data should be displayed on. All other values are ignored. If ``None`` then no data will be displayed and tiles will be left empty. **kwargs * time_vector * plot_function * title * xlim * ylim * grid * grid_color * grid_linestyle * grid_linewidth * prop_cycle * color * linestyle * linewidth * marker * markersize * annotate_time_points * annotate_time_points_color * annotate_time_points_marker * annotate_time_points_markersize * annotate_time_points_legend * annotate_time_points_zorder * tile_ticks_on * tile_xlabel_fontdict * tile_ylabel_fontdict * data_tile_fontsize * data_tile_color * diag_tile_color * empty_tile_color See :mod:`~mass.visualization` documentation for more information on optional ``kwargs``. Returns ------- ax : matplotlib.axes.Axes The :class:`~matplotlib.axes.Axes` instance containing the newly created plot. """ # Validate whether necessary packages are installed. v_util._validate_visualization_packages("matplotlib") # Check kwargs kwargs = _check_kwargs( get_phase_portrait_default_kwargs("plot_tiled_phase_portraits"), kwargs ) # Get the axes instance ax = v_util._validate_axes_instance(ax) # Validate the MassSolution input, ensure it is not empty. mass_solution = v_util._validate_mass_solution(mass_solution) if not mass_solution: return ax # Get the solutions to be observed and validate time vector. observable = v_util._validate_plot_observables(mass_solution, observable, **kwargs) if ax.child_axes is not None and len(ax.child_axes) == len(observable) ** 2: subaxes = np.reshape( np.array(ax.child_axes), (len(observable), len(observable)) ) else: subaxes = None plot_tile_placement = v_util._validate_tile_placement( plot_tile_placement, prefix="plot" ) # Split the tiled kwargs and the phase portrait kwargs into seperate dicts tile_kwargs, pp_kwargs = _sep_kwargs_for_tiled_phase_portraits(**kwargs) # Remove axis lines ax.axis("off") # Create N x N subplots where N is the number of observable solutions # Fraction of larger figure to be utilized by the subplot (Inverse of N) sub_ax_placement_vals = [1 / len(observable)] * 2 # Get width and height. Alter if ticks will be included if tile_kwargs.get("tile_ticks_on"): sub_ax_placement_vals[1] = sub_ax_placement_vals[1] * 0.75 for j, y in enumerate(observable): for i, x in enumerate(observable): # [x0, y0, width, height] from lower left corner of inset axes if subaxes is not None: sub_ax = subaxes[j, i] else: sub_ax = ax.inset_axes( bounds=[ i * sub_ax_placement_vals[0], 1 - sub_ax_placement_vals[0] * (j + 1), sub_ax_placement_vals[1], sub_ax_placement_vals[1], ] ) # Create tile (either phase_portrait, data, or empty) sub_ax = _create_tiled_phase_portraits_tile( sub_ax, observable, i, j, x, y, plot_tile_placement, additional_data, tile_kwargs, pp_kwargs, ) # If tile_ticks is not True, remove them if not tile_kwargs.get("tile_ticks_on"): sub_ax.set_xticks([]) sub_ax.set_yticks([]) # Set xlabels only on the final row if j == len(observable) - 1: sub_ax.set_xlabel(x, kwargs.get("tile_xlabel_fontdict")) # Set ylabels only on the first column if i == 0: sub_ax.set_ylabel(y, kwargs.get("tile_ylabel_fontdict")) if kwargs.get("annotate_time_points_legend"): for sub_ax in ax.get_children(): if sub_ax.__class__.__name__ == "Axes" and v_util._get_ax_current( sub_ax, time_points=True ): leg_args = v_util._get_annotated_time_points_legend_args( sub_ax, kwargs.get("annotate_time_points_legend") ) break ax = v_util._set_additional_legend_box(ax, leg_args, first_legend=None) # Set the axes title. v_util._set_axes_labels(ax, **kwargs) # Reset default prop_cycle ax.set_prop_cycle(v_util._get_default_cycler()) return ax
[docs]def get_phase_portrait_default_kwargs(function_name): """Get default ``kwargs`` for plotting functions in :mod:`phase_portraits`. Parameters ---------- function_name : str The name of the plotting function to get the ``kwargs`` for. Valid values include the following: * ``"plot_phase_portrait"`` * ``"plot_tiled_phase_portraits"`` Returns ------- dict Default ``kwarg`` values for the given ``function_name``. """ if function_name not in __all__[:-1]: raise ValueError( "Invalid 'function_name'. Valid values include the following: " + str(__all__[:-1]) ) default_kwargs = { "time_vector": None, "plot_function": "plot", "title": None, "xlim": None, "ylim": None, "xmargin": None, "ymargin": None, "color": None, "linestyle": None, "linewidth": None, "marker": None, "markersize": None, "grid": None, "grid_color": None, "grid_linestyle": None, "grid_linewidth": None, "annotate_time_points": None, "annotate_time_points_color": None, "annotate_time_points_marker": None, "annotate_time_points_markersize": None, "annotate_time_points_labels": False, "annotate_time_points_legend": None, "annotate_time_points_zorder": None, "prop_cycle": None, "deviation": False, "deviation_zero_centered": False, "deviation_normalization": "initial value", "zorder": None, } if function_name == "plot_phase_portrait": default_kwargs.update( { "xlabel": None, "ylabel": None, "legend_ncol": None, } ) if function_name == "plot_ensemble_phase_portrait": default_kwargs.update( { "xlabel": None, "ylabel": None, "legend_ncol": None, } ) if function_name == "plot_tiled_phase_portraits": default_kwargs.update( { "tile_ticks_on": False, "tile_xlabel_fontdict": None, "tile_ylabel_fontdict": None, "data_tile_fontsize": "large", "data_tile_color": None, "diag_tile_color": None, "empty_tile_color": None, } ) return default_kwargs
def _sep_kwargs_for_tiled_phase_portraits(**kwargs): """Seperate kwargs for tile properties from kwargs for phase portraits. Warnings -------- This method is intended for internal use only. """ # Create dicts for each kwarg tile_kwargs = {} pp_kwargs = {} for key, value in iteritems(kwargs): # Get tile kwarg if "tile" in key: # Add to the tile kwargs tile_kwargs[key] = value # Get tile kwarge that is not exclusive to tiled phase portraits elif key in ["title", "annotate_time_points_legend"]: # Add to the tile kwargs tile_kwargs[key] = value else: # Otherwise kwarg belongs to the phase portrait and will be # validated in phase portrait function. if "margin" in key and value is None: value = 0.15 pp_kwargs[key] = value # Iterate through tile colors, setting defaults if no color provided. for key, default_color in zip( ["data", "diag", "empty"], ["lightgray", "black", "white"] ): color = tile_kwargs.get(key + "_tile_color") if color is None: color = default_color tile_kwargs[key + "_tile_color"] = color # Make endpoints default for tiled phase portrait plots. if pp_kwargs.get("annotate_time_points") is None: pp_kwargs["annotate_time_points"] = "endpoints" if pp_kwargs.get("annotate_time_points_color") is None: pp_kwargs["annotate_time_points_color"] = ["red", "blue"] return tile_kwargs, pp_kwargs def _create_tiled_phase_portraits_tile(ax, observable, *args): """Create a tile for the tiled phase portrait figure. Warnings -------- This method is intended for internal use only. """ def get_plot_tile_bool(i, j, plot_tile_placement): """Get a bool indicating if a plot should be made.""" return {"all": i < j or i > j, "lower": i < j, "upper": i > j}.get( plot_tile_placement ) i, j, x, y, plot_tile_placement, data_matrix, tile_kwargs, pp_kwargs = args plot_tile_bool = get_plot_tile_bool(i, j, plot_tile_placement) # Validate fontsize and set default data tile fontsize as large if needed. if not tile_kwargs.get("data_tile_fontsize"): tile_kwargs["data_tile_fontsize"] = "large" # Set diagonal tile color if i == j: ax.set_facecolor(tile_kwargs.get("diag_tile_color")) # Create a phase portrait for the tile elif plot_tile_bool: ax = plot_phase_portrait(observable, x=x, y=y, ax=ax, legend=None, **pp_kwargs) # Create the data tile elif data_matrix is not None and not plot_tile_bool: # Create the data tile only if there is information, # otherwise set the facecolor as an empty tile if data_matrix[j][i] == 0 or not data_matrix[j][i]: ax.set_facecolor(tile_kwargs.get("empty_tile_color")) else: # Place data onto tile and set facecolor ax.annotate( str(data_matrix[j][i]), xy=(0.5, 0.5), xycoords="axes fraction", va="center", ha="center", fontsize=tile_kwargs.get("data_tile_fontsize"), ) ax.set_facecolor(tile_kwargs.get("data_tile_color")) else: ax.set_facecolor(tile_kwargs.get("empty_tile_color")) return ax __all__ = ( "plot_phase_portrait", "plot_ensemble_phase_portrait", "plot_tiled_phase_portraits", "get_phase_portrait_default_kwargs", )