aboutsummaryrefslogtreecommitdiffstats
path: root/fg21sim
diff options
context:
space:
mode:
Diffstat (limited to 'fg21sim')
-rw-r--r--fg21sim/extragalactic/clusters/mergertree.py54
1 files changed, 52 insertions, 2 deletions
diff --git a/fg21sim/extragalactic/clusters/mergertree.py b/fg21sim/extragalactic/clusters/mergertree.py
index 7c25430..61b9542 100644
--- a/fg21sim/extragalactic/clusters/mergertree.py
+++ b/fg21sim/extragalactic/clusters/mergertree.py
@@ -10,6 +10,9 @@ import os
import pickle
import logging
+from matplotlib.figure import Figure
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+
logger = logging.getLogger(__name__)
@@ -66,8 +69,55 @@ def read_mtree(infile):
return mtree
-def plot_mtree(mtree):
+def plot_mtree(mtree, outfile, figsize=(12, 8)):
"""
Plot the cluster merger tree.
+
+ Parameters
+ ----------
+ mtree : `~MergerTree`
+ The merger tree to be plotted
+ outfile : str
+ Output filename to save the plotted figure
+ figsize : tuple
+ The (width, height) of the plotting figure
"""
- raise NotImplementedError
+ def _plot(tree, ax):
+ if tree is None:
+ return
+ if tree.main is None:
+ # Only plot a point for current tree node
+ x = [tree.data["age"]]
+ y = [tree.data["mass"]]
+ ax.plot(x, y, marker="o", markersize=1.5, color="black",
+ linestyle=None)
+ return
+ # Plot a point for current tree node
+ x = [tree.data["age"]]
+ y = [tree.data["mass"]]
+ ax.plot(x, y, marker="o", markersize=1.5, color="black",
+ linestyle=None)
+ # Plot a line from current tree node to its main node
+ x = [tree.data["age"], tree.main.data["age"]]
+ y = [tree.data["mass"], tree.main.data["mass"]]
+ ax.plot(x, y, color="blue")
+ if tree.sub:
+ # Plot a line between main and sub nodes
+ x = [tree.main.data["age"], tree.sub.data["age"]]
+ y = [tree.main.data["mass"], tree.sub.data["mass"]]
+ ax.plot(x, y, color="green", linewidth=1, alpha=0.8)
+ # Recursively plot the descendant nodes
+ _plot(tree.main, ax)
+ _plot(tree.sub, ax)
+
+ fig = Figure(figsize=figsize)
+ canvas = FigureCanvas(fig)
+ ax = fig.add_subplot(1, 1, 1)
+ ax.hold(True)
+ _plot(mtree, ax=ax)
+ ax.set_xlabel("Cosmic time [Gyr]")
+ ax.set_ylabel("Mass [Msun]")
+ ax.set_xlim((0, mtree.data["age"]))
+ ax.set_ylim((0, mtree.data["mass"]))
+ fig.tight_layout()
+ canvas.print_figure(outfile)