Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 146 additions & 146 deletions src/maxplotlib/canvas/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,141 @@
from maxplotlib.utils.options import Backends


def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None):
"""
Plot all nodes and paths on the provided axis using Matplotlib.

Parameters:
- ax (matplotlib.axes.Axes): Axis on which to plot the figure.
"""

# TODO: Specify which layers to retreive nodes from with layers=layers
nodes = tikzfigure.layers.get_nodes()
paths = tikzfigure.layers.get_paths()

for path in paths:
x_coords = [node.x for node in path.nodes]
y_coords = [node.y for node in path.nodes]

# Parse path color
path_color_spec = path.kwargs.get("color", "black")
try:
color = Color(path_color_spec).to_rgb()
except ValueError as e:
print(e)
color = "black"

# Parse line width
line_width_spec = path.kwargs.get("line_width", 1)
if isinstance(line_width_spec, str):
match = re.match(r"([\d.]+)(pt)?", line_width_spec)
if match:
line_width = float(match.group(1))
else:
print(
f"Invalid line width specification: '{line_width_spec}', defaulting to 1",
)
line_width = 1
else:
line_width = float(line_width_spec)

# Parse line style using Linestyle class
style_spec = path.kwargs.get("style", "solid")
linestyle = Linestyle(style_spec).to_matplotlib()

ax.plot(
x_coords,
y_coords,
color=color,
linewidth=line_width,
linestyle=linestyle,
zorder=1, # Lower z-order to place behind nodes
)

# Plot nodes after paths so they appear on top
for node in nodes:
# Determine shape and size
shape = node.kwargs.get("shape", "circle")
fill_color_spec = node.kwargs.get("fill", "white")
edge_color_spec = node.kwargs.get("draw", "black")
linewidth = float(node.kwargs.get("line_width", 1))
size = float(node.kwargs.get("size", 1))

# Parse colors using the Color class
try:
facecolor = Color(fill_color_spec).to_rgb()
except ValueError as e:
print(e)
facecolor = "white"

try:
edgecolor = Color(edge_color_spec).to_rgb()
except ValueError as e:
print(e)
edgecolor = "black"

# Plot shapes
if shape == "circle":
radius = size / 2
circle = patches.Circle(
(node.x, node.y),
radius,
facecolor=facecolor,
edgecolor=edgecolor,
linewidth=linewidth,
zorder=2, # Higher z-order to place on top of paths
)
ax.add_patch(circle)
elif shape == "rectangle":
width = height = size
rect = patches.Rectangle(
(node.x - width / 2, node.y - height / 2),
width,
height,
facecolor=facecolor,
edgecolor=edgecolor,
linewidth=linewidth,
zorder=2, # Higher z-order
)
ax.add_patch(rect)
else:
# Default to circle if shape is unknown
radius = size / 2
circle = patches.Circle(
(node.x, node.y),
radius,
facecolor=facecolor,
edgecolor=edgecolor,
linewidth=linewidth,
zorder=2,
)
ax.add_patch(circle)

# Add text inside the shape
if node.content:
ax.text(
node.x,
node.y,
node.content,
fontsize=10,
ha="center",
va="center",
wrap=True,
zorder=3, # Even higher z-order for text
)

# Remove axes, ticks, and legend
ax.axis("off")

# Adjust plot limits
all_x = [node.x for node in nodes]
all_y = [node.y for node in nodes]
padding = 1 # Adjust padding as needed
ax.set_xlim(min(all_x) - padding, max(all_x) + padding)
ax.set_ylim(min(all_y) - padding, max(all_y) + padding)
ax.set_aspect("equal", adjustable="datalim")


class Canvas:
def __init__(
self,
Expand All @@ -29,7 +164,7 @@ def __init__(
label: str | None = None,
fontsize: int = 14,
dpi: int = 300,
width: str = "17cm",
width: str = "5cm",
ratio: str = "golden", # TODO Add literal
gridspec_kw: Dict = {"wspace": 0.08, "hspace": 0.1},
):
Expand Down Expand Up @@ -62,6 +197,8 @@ def __init__(
self._ratio = ratio
self._gridspec_kw = gridspec_kw
self._plotted = False
self._matplotlib_fig = None
self._matplotlib_axes = None

# Dictionary to store lines for each subplot
# Key: (row, col), Value: list of lines with their data and kwargs
Expand Down Expand Up @@ -106,7 +243,6 @@ def add_line(
subplot: LinePlot | None = None,
row: int | None = None,
col: int | None = None,
plot_type="plot",
**kwargs,
):
if row is not None and col is not None:
Expand All @@ -126,7 +262,6 @@ def add_line(
x_data=x_data,
y_data=y_data,
layer=layer,
plot_type=plot_type,
**kwargs,
)

Expand Down Expand Up @@ -304,7 +439,7 @@ def show(
elif backend == "plotly":
self.plot_plotly(savefig=False)
elif backend == "tikzpics":
fig = self.plot_tikzpics(savefig=False)
fig = self.plot_tikzpics(savefig=False, verbose=verbose)
fig.show()
else:
raise ValueError("Invalid backend")
Expand Down Expand Up @@ -374,8 +509,8 @@ def plot_matplotlib(

def plot_tikzpics(
self,
savefig=None,
verbose=False,
savefig: str | None = None,
verbose: bool = False,
) -> TikzFigure:
if len(self.subplots) > 1:
raise NotImplementedError(
Expand Down Expand Up @@ -507,13 +642,6 @@ def label(self, value):
def figsize(self, value):
self._figsize = value

# Magic methods
def __str__(self):
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})"

def __repr__(self):
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})"

def __getitem__(self, key):
"""Allows accessing subplots by tuple index."""
row, col = key
Expand All @@ -528,140 +656,12 @@ def __setitem__(self, key, value):
raise IndexError("Subplot index out of range")
self._subplot_matrix[row][col] = value

def __repr__(self):
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})"

def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None):
"""
Plot all nodes and paths on the provided axis using Matplotlib.

Parameters:
- ax (matplotlib.axes.Axes): Axis on which to plot the figure.
"""

# TODO: Specify which layers to retreive nodes from with layers=layers
nodes = tikzfigure.layers.get_nodes()
paths = tikzfigure.layers.get_paths()

for path in paths:
x_coords = [node.x for node in path.nodes]
y_coords = [node.y for node in path.nodes]

# Parse path color
path_color_spec = path.kwargs.get("color", "black")
try:
color = Color(path_color_spec).to_rgb()
except ValueError as e:
print(e)
color = "black"

# Parse line width
line_width_spec = path.kwargs.get("line_width", 1)
if isinstance(line_width_spec, str):
match = re.match(r"([\d.]+)(pt)?", line_width_spec)
if match:
line_width = float(match.group(1))
else:
print(
f"Invalid line width specification: '{line_width_spec}', defaulting to 1",
)
line_width = 1
else:
line_width = float(line_width_spec)

# Parse line style using Linestyle class
style_spec = path.kwargs.get("style", "solid")
linestyle = Linestyle(style_spec).to_matplotlib()

ax.plot(
x_coords,
y_coords,
color=color,
linewidth=line_width,
linestyle=linestyle,
zorder=1, # Lower z-order to place behind nodes
)

# Plot nodes after paths so they appear on top
for node in nodes:
# Determine shape and size
shape = node.kwargs.get("shape", "circle")
fill_color_spec = node.kwargs.get("fill", "white")
edge_color_spec = node.kwargs.get("draw", "black")
linewidth = float(node.kwargs.get("line_width", 1))
size = float(node.kwargs.get("size", 1))

# Parse colors using the Color class
try:
facecolor = Color(fill_color_spec).to_rgb()
except ValueError as e:
print(e)
facecolor = "white"

try:
edgecolor = Color(edge_color_spec).to_rgb()
except ValueError as e:
print(e)
edgecolor = "black"

# Plot shapes
if shape == "circle":
radius = size / 2
circle = patches.Circle(
(node.x, node.y),
radius,
facecolor=facecolor,
edgecolor=edgecolor,
linewidth=linewidth,
zorder=2, # Higher z-order to place on top of paths
)
ax.add_patch(circle)
elif shape == "rectangle":
width = height = size
rect = patches.Rectangle(
(node.x - width / 2, node.y - height / 2),
width,
height,
facecolor=facecolor,
edgecolor=edgecolor,
linewidth=linewidth,
zorder=2, # Higher z-order
)
ax.add_patch(rect)
else:
# Default to circle if shape is unknown
radius = size / 2
circle = patches.Circle(
(node.x, node.y),
radius,
facecolor=facecolor,
edgecolor=edgecolor,
linewidth=linewidth,
zorder=2,
)
ax.add_patch(circle)

# Add text inside the shape
if node.content:
ax.text(
node.x,
node.y,
node.content,
fontsize=10,
ha="center",
va="center",
wrap=True,
zorder=3, # Even higher z-order for text
)

# Remove axes, ticks, and legend
ax.axis("off")

# Adjust plot limits
all_x = [node.x for node in nodes]
all_y = [node.y for node in nodes]
padding = 1 # Adjust padding as needed
ax.set_xlim(min(all_x) - padding, max(all_x) + padding)
ax.set_ylim(min(all_y) - padding, max(all_y) + padding)
ax.set_aspect("equal", adjustable="datalim")
# Magic methods
def __str__(self):
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})"


if __name__ == "__main__":
Expand Down
21 changes: 11 additions & 10 deletions src/maxplotlib/colors/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@


class Color:
def __init__(self, color_spec):
"""
Initialize the Color object by parsing the color specification.

Parameters:
- color_spec: Can be a TikZ color string (e.g., 'blue!20'), a standard color name,
an RGB tuple, a hex code, etc.
"""
self.color_spec = color_spec
self.rgb = self._parse_color(color_spec)

def _parse_color(self, color_spec):
"""
Expand Down Expand Up @@ -53,6 +43,17 @@ def _parse_color(self, color_spec):
except ValueError:
raise ValueError(f"Invalid color specification: '{color_spec}'")

def __init__(self, color_spec):
"""
Initialize the Color object by parsing the color specification.

Parameters:
- color_spec: Can be a TikZ color string (e.g., 'blue!20'), a standard color name,
an RGB tuple, a hex code, etc.
"""
self.color_spec = color_spec
self.rgb = self._parse_color(color_spec)

def to_rgb(self):
"""
Return the color as an RGB tuple.
Expand Down
Loading
Loading