Skip to content

Commit

Permalink
Add orientation plot option
Browse files Browse the repository at this point in the history
  • Loading branch information
moshi4 committed Jan 29, 2024
1 parent 0697fbd commit 5504643
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/phytreeviz/treeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
format: str = "newick",
height: float = 0.5,
width: float = 8,
orientation: str = "right",
align_leaf_label: bool = False,
ignore_branch_length: bool = False,
leaf_label_size: float = 12,
Expand All @@ -52,6 +53,8 @@ def __init__(
Figure height per leaf node of tree
width : float, optional
Figure width
orientation : str, optional
Tree orientation (`right`|`left`)
align_leaf_label: bool, optional
If True, align leaf label.
ignore_branch_length : bool, optional
Expand Down Expand Up @@ -91,6 +94,10 @@ def __init__(
self._innode_label_xmargin_ratio = innode_label_xmargin_ratio
self._reverse = reverse
self._ax: Axes | None = None
if orientation in ("right", "left"):
self._orientation = orientation
else:
raise ValueError(f"{orientation=} is invalid (`right` or `left`).")

self._node2label_props: dict[str, dict[str, Any]] = defaultdict(lambda: {})
self._node2line_props: dict[str, dict[str, Any]] = defaultdict(lambda: {})
Expand Down Expand Up @@ -121,7 +128,10 @@ def figsize(self) -> tuple[float, float]:
@property
def xlim(self) -> tuple[float, float]:
"""Axes xlim"""
return (0, self.max_tree_depth)
if self._orientation == "left":
return (self.max_tree_depth, 0)
else:
return (0, self.max_tree_depth)

@property
def ylim(self) -> tuple[float, float]:
Expand Down Expand Up @@ -180,7 +190,7 @@ def ax(self) -> Axes:
Can't access `ax` property before calling `tv.plotfig()` method
"""
if self._ax is None:
err_msg = "Can't access ax property before calling `tv.plotfig() method"
err_msg = "Can't access ax property before calling `tv.plotfig()` method"
raise ValueError(err_msg)
return self._ax

Expand Down Expand Up @@ -464,6 +474,8 @@ def annotate(
line_kws.setdefault("lw", 1)
line_kws.update(dict(color=line_color, clip_on=False))
text_kws.update(dict(size=text_size, color=text_color, va="center", ha="left"))
if self._orientation == "left":
text_kws.update(ha="right")
if text_orientation == "horizontal":
text_kws.update(dict(rotation=0))
elif text_orientation == "vertical":
Expand Down Expand Up @@ -656,6 +668,9 @@ def text_on_branch(
ypos2va = dict(top="bottom", center="center", bottom="top")
ha, va = xpos, ypos2va[ypos]

if self._orientation == "left":
xpos = dict(left="right", right="left", center="center")[xpos]

# Get text plot target node & xy coordinate
target_node_name = self._search_target_node_name(query)
xpos2xy = dict(
Expand Down Expand Up @@ -1000,6 +1015,8 @@ def _plot_node_label(self, ax: Axes) -> None:
# Plot label
text_kws = dict(size=label_size, ha="left", va="center_baseline")
text_kws.update(self._node2label_props[str(node.name)])
if self._orientation == "left":
text_kws.update(ha="right")
ax.text(x, y, s=node.name, **text_kws) # type: ignore

def _load_tree(self, data: str | Path | Tree, format: str) -> Tree:
Expand Down

0 comments on commit 5504643

Please sign in to comment.