diff options
Diffstat (limited to 'fg21sim/extragalactic/clusters/mergertree.py')
-rw-r--r-- | fg21sim/extragalactic/clusters/mergertree.py | 54 |
1 files changed, 52 insertions, 2 deletions
diff --git a/fg21sim/extragalactic/clusters/mergertree.py b/fg21sim/extragalactic/clusters/mergertree.py index 7c25430..61b9542 100644 --- a/fg21sim/extragalactic/clusters/mergertree.py +++ b/fg21sim/extragalactic/clusters/mergertree.py @@ -10,6 +10,9 @@ import os import pickle import logging +from matplotlib.figure import Figure +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + logger = logging.getLogger(__name__) @@ -66,8 +69,55 @@ def read_mtree(infile): return mtree -def plot_mtree(mtree): +def plot_mtree(mtree, outfile, figsize=(12, 8)): """ Plot the cluster merger tree. + + Parameters + ---------- + mtree : `~MergerTree` + The merger tree to be plotted + outfile : str + Output filename to save the plotted figure + figsize : tuple + The (width, height) of the plotting figure """ - raise NotImplementedError + def _plot(tree, ax): + if tree is None: + return + if tree.main is None: + # Only plot a point for current tree node + x = [tree.data["age"]] + y = [tree.data["mass"]] + ax.plot(x, y, marker="o", markersize=1.5, color="black", + linestyle=None) + return + # Plot a point for current tree node + x = [tree.data["age"]] + y = [tree.data["mass"]] + ax.plot(x, y, marker="o", markersize=1.5, color="black", + linestyle=None) + # Plot a line from current tree node to its main node + x = [tree.data["age"], tree.main.data["age"]] + y = [tree.data["mass"], tree.main.data["mass"]] + ax.plot(x, y, color="blue") + if tree.sub: + # Plot a line between main and sub nodes + x = [tree.main.data["age"], tree.sub.data["age"]] + y = [tree.main.data["mass"], tree.sub.data["mass"]] + ax.plot(x, y, color="green", linewidth=1, alpha=0.8) + # Recursively plot the descendant nodes + _plot(tree.main, ax) + _plot(tree.sub, ax) + + fig = Figure(figsize=figsize) + canvas = FigureCanvas(fig) + ax = fig.add_subplot(1, 1, 1) + ax.hold(True) + _plot(mtree, ax=ax) + ax.set_xlabel("Cosmic time [Gyr]") + ax.set_ylabel("Mass [Msun]") + ax.set_xlim((0, mtree.data["age"])) + ax.set_ylim((0, mtree.data["mass"])) + fig.tight_layout() + canvas.print_figure(outfile) |