# -*- coding: utf-8 -*-
"""Contains function for visually comparing values in various objects.
See the :mod:`mass.visualization` documentation for general information
on :mod:`mass.visualization` functions.
This module contains the following functions for visually comparing a set of
values in one object against a similar set of valeus in another object.
* :func:`~.comparison.plot_comparison`
"""
import math
from mass.util.util import _check_kwargs
from mass.visualization import visualization_util as v_util
[docs]def plot_comparison(
x, y, compare=None, observable=None, ax=None, legend=None, **kwargs
):
"""Plot values of two objects for comparision.
This function can take two :class:`.MassModel`, :class:`.ConcSolution`,
:class:`cobra.Solution <cobra.core.solution.Solution>`, or
:class:`pandas.Series` objects and plot them against one another in a
calibration plot.
Accepted ``kwargs`` are passed onto various :mod:`matplotlib` methods
utilized in the function. See the :mod:`~mass.visualization` module
documentation for more detailed information about the possible ``kwargs``.
Notes
-----
* If a :class:`pandas.Series`, the index must correspond to the identifier
of the assoicated object. (e.g. a metabolite identifier for
``compare="concentrations"``, or a reaction identifier for
``compare="Keqs"``)
Parameters
----------
x : MassModel, ConcSolution, ~cobra.core.solution.Solution, ~pandas.Series
The object to access for x-axis values.
y : MassModel, ConcSolution, ~cobra.core.solution.Solution, ~pandas.Series
The object to access for y-axis values.
compare : str
The values to be compared. Must be one of the following:
* ``"concentrations"`` for :class:`.MassModel` and
:class:`.ConcSolution` objects.
* ``"Keqs"`` for :class:`.MassModel` and
:class:`.ConcSolution` objects.
* ``"fluxes"`` for :class:`.MassModel` and
:class:`cobra.Solution <cobra.core.solution.Solution>` objects.
* ``"kfs"`` for :class:`.MassModel` objects.
Not required if both ``x`` and ``y`` are :class:`pandas.Series`.
observable : iterable
An iterable containing string identifiers of :mod:`mass` objects
or the objects themselves corresponding to the object or index where
the value is located.
ax : matplotlib.axes.Axes, None
An :class:`~matplotlib.axes.Axes` instance to plot the data on.
If ``None`` then the current axes instance is used.
legend : iterable, str, int
There are three possible input formats for the legend:
1. An iterable of legend labels as strings.
2. A ``str`` representing the location of the legend, or an
``int`` between 0 and 14 (inclusive) corresponding to the
legend location.
3. An iterable of the format ``(labels, loc)`` to set both
the legend labels and location, where ``labels`` and ``loc``
follows the labels specified in **1** and **2**.
See the :mod:`~mass.visualization` documentation for more information
about legend and valid legend locations.
**kwargs
* plot_function
* title
* xlabel
* ylabel
* xlim
* ylim
* grid
* grid_color
* grid_linestyle
* grid_linewidth
* prop_cycle
* color
* marker
* markersize
* legend_ncol
* xy_line
* xy_linecolor
* xy_linewidth
* xy_linestyle
* xy_legend
See :mod:`~mass.visualization` documentation for more information on
optional ``kwargs``.
Returns
-------
ax : matplotlib.axes.Axes
The :class:`~matplotlib.axes.Axes` instance containing the newly
created plot.
"""
# Validate whether necessary packages are installed.
v_util._validate_visualization_packages("matplotlib")
# Check kwargs
kwargs = _check_kwargs(get_comparison_default_kwargs("plot_comparison"), kwargs)
kwargs["linestyle"] = " "
# Get the axies instance
ax = v_util._validate_axes_instance(ax)
x = v_util._get_values_as_series(x, compare, name="x")
y = v_util._get_values_as_series(y, compare, name="y")
observable = v_util._get_dataframe_of_observables(x, y, compare, observable)
# Get the plotting function or raise an error if invalid.
plot_function = v_util._get_plotting_function(
ax, plot_function_str=kwargs.get("plot_function"), valid={"plot", "loglog"}
)
# Get the legend arguments if desired.
if legend is not None:
legend_labels, legend_kwargs = v_util._get_legend_args(
ax, legend, list(observable.index), **kwargs
)
observable = v_util._map_labels_to_solutions(observable, legend_labels)
else:
legend_kwargs = None
# Set line colors and styles using a custom cycler
prop_cycler = v_util._get_line_property_cycler(
n_current=len(v_util._get_ax_current(ax)), n_new=len(observable), **kwargs
)
if prop_cycler:
ax.set_prop_cycle(prop_cycler)
# Plot lines onto axes using legend entries as labels (if legend valid).
for sol in observable.itertuples():
plot_function(sol.x, sol.y, label=sol.Index)
# Set the axes options including axis labels, limits, and gridlines.
v_util._set_axes_labels(ax, **kwargs)
v_util._set_axes_limits(ax, **kwargs)
v_util._set_axes_margins(ax, **kwargs)
v_util._set_axes_gridlines(ax, **kwargs)
# Set the legend if desired.
if legend is not None:
lines_and_labels = v_util._get_handles_and_labels(ax, False)
legend = ax.legend(*lines_and_labels, **legend_kwargs)
if kwargs.get("xy_line"):
# Get min and max value, make line extend slightly past those values.
if kwargs.get("plot_function") == "plot":
limits = (
int(math.floor(observable.values.min()) / 1.05 - 1),
int(math.ceil(observable.values.max()) * 1.05 + 1),
)
else:
limits = (
int(math.floor(observable.values.min()) / 10),
int(math.ceil(observable.values.max()) * 10),
)
ax = _plot_xy_line(ax, limits, first_legend=(legend, legend_kwargs), **kwargs)
# Reset default prop_cycle
ax.set_prop_cycle(v_util._get_default_cycler())
return ax
def _plot_xy_line(ax, limits, first_legend=None, **kwargs):
"""Plot a a line for ``y=x`` on the comparison plot.
Warnings
--------
This method is intended for internal use only.
"""
# Validate color
linecolor = kwargs.get("xy_linecolor")
if linecolor is None:
linecolor = "grey"
# Validate linestyle
linestyle = kwargs.get("xy_linestyle")
if linestyle is None:
linestyle = "--"
# Validate linewidth
linewidth = kwargs.get("xy_linewidth")
plot_function = v_util._get_plotting_function(
ax, plot_function_str=kwargs.get("plot_function"), valid={"plot", "loglog"}
)
# Plot the line using the set kwarg options, set zorder to 1.9 so that
# it is below original points (default is 2.)
line = plot_function(
limits,
limits,
label="y=x",
color=linecolor,
linestyle=linestyle,
linewidth=linewidth,
marker="",
zorder=1.9,
)
if kwargs.get("xy_legend") is not None:
desired, taken = v_util._check_second_legend_location(
kwargs.get("xy_legend"), first_legend[1]
)
# Set default desired location
if desired is None:
desired = "best" if taken != "best" else "right outside"
# Get kwargs for legend location
anch = None
if desired in v_util.OUTSIDE_LEGEND_LOCATION_AND_ANCHORS:
desired, anch = v_util.OUTSIDE_LEGEND_LOCATION_AND_ANCHORS[desired]
legend_args = (line, ["y=x"], {"loc": desired, "bbox_to_anchor": anch})
ax = v_util._set_additional_legend_box(
ax, legend_args, first_legend=first_legend[0]
)
return ax
[docs]def get_comparison_default_kwargs(function_name):
"""Get default ``kwargs`` for plotting functions in :mod:`comparison`.
Parameters
----------
function_name : str
The name of the plotting function to get the ``kwargs`` for.
Valid values include the following:
* ``"plot_comparison"``
Returns
-------
dict
Default ``kwarg`` values for the given ``function_name``.
"""
if function_name not in __all__[:-1]:
raise ValueError(
"Invalid 'function_name'. Valid values include the following: "
+ str(__all__[:-1])
)
default_kwargs = {
"plot_function": "plot",
"title": None,
"xlabel": None,
"ylabel": None,
"xlim": None,
"ylim": None,
"xmargin": None,
"ymargin": None,
"color": None,
"marker": "o",
"markersize": None,
"grid": None,
"grid_color": None,
"grid_linestyle": None,
"grid_linewidth": None,
"legend_ncol": None,
"prop_cycle": None,
"xy_line": False,
"xy_linecolor": None,
"xy_linewidth": None,
"xy_linestyle": None,
"xy_legend": None,
}
return default_kwargs
__all__ = ("plot_comparison", "get_comparison_default_kwargs")