diff options
Diffstat (limited to 'fg21sim/extragalactic')
| -rw-r--r-- | fg21sim/extragalactic/clusters/mergertree.py | 54 | 
1 files changed, 52 insertions, 2 deletions
diff --git a/fg21sim/extragalactic/clusters/mergertree.py b/fg21sim/extragalactic/clusters/mergertree.py index 7c25430..61b9542 100644 --- a/fg21sim/extragalactic/clusters/mergertree.py +++ b/fg21sim/extragalactic/clusters/mergertree.py @@ -10,6 +10,9 @@ import os  import pickle  import logging +from matplotlib.figure import Figure +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +  logger = logging.getLogger(__name__) @@ -66,8 +69,55 @@ def read_mtree(infile):      return mtree -def plot_mtree(mtree): +def plot_mtree(mtree, outfile, figsize=(12, 8)):      """      Plot the cluster merger tree. + +    Parameters +    ---------- +    mtree : `~MergerTree` +        The merger tree to be plotted +    outfile : str +        Output filename to save the plotted figure +    figsize : tuple +        The (width, height) of the plotting figure      """ -    raise NotImplementedError +    def _plot(tree, ax): +        if tree is None: +            return +        if tree.main is None: +            # Only plot a point for current tree node +            x = [tree.data["age"]] +            y = [tree.data["mass"]] +            ax.plot(x, y, marker="o", markersize=1.5, color="black", +                    linestyle=None) +            return +        # Plot a point for current tree node +        x = [tree.data["age"]] +        y = [tree.data["mass"]] +        ax.plot(x, y, marker="o", markersize=1.5, color="black", +                linestyle=None) +        # Plot a line from current tree node to its main node +        x = [tree.data["age"], tree.main.data["age"]] +        y = [tree.data["mass"], tree.main.data["mass"]] +        ax.plot(x, y, color="blue") +        if tree.sub: +            # Plot a line between main and sub nodes +            x = [tree.main.data["age"], tree.sub.data["age"]] +            y = [tree.main.data["mass"], tree.sub.data["mass"]] +            ax.plot(x, y, color="green", linewidth=1, alpha=0.8) +        # Recursively plot the descendant nodes +        _plot(tree.main, ax) +        _plot(tree.sub, ax) + +    fig = Figure(figsize=figsize) +    canvas = FigureCanvas(fig) +    ax = fig.add_subplot(1, 1, 1) +    ax.hold(True) +    _plot(mtree, ax=ax) +    ax.set_xlabel("Cosmic time [Gyr]") +    ax.set_ylabel("Mass [Msun]") +    ax.set_xlim((0, mtree.data["age"])) +    ax.set_ylim((0, mtree.data["mass"])) +    fig.tight_layout() +    canvas.print_figure(outfile)  | 
