Files
gem/plot_lib.py
2026-02-02 16:52:12 +01:00

403 lines
10 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
import os
def plot_xy(x, y, xlabel="x", ylabel="y", title="Plot of y over x"):
"""
Plots y over x using matplotlib with:
- X limits set exactly to the first and last x values
- Y limits padded by ±10% of the data range
- Custom x/y labels and title
"""
if len(x) == 0 or len(y) == 0:
raise ValueError("x and y must not be empty")
if len(x) != len(y):
raise ValueError("x and y must have the same length")
plt.figure(figsize=(8, 5))
plt.plot(x, y)
# Set x-axis exactly to first and last value
plt.xlim(x[0], x[-1])
# Compute y padding
ymin, ymax = min(y), max(y)
yrange = ymax - ymin
pad = yrange * 0.10 if yrange != 0 else 1 # avoid zero-range issue
plt.ylim(ymin - pad, ymax + pad)
# Labels and title
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True)
plt.show()
def plot_xy_points(x, y, xlabel="x", ylabel="y", title="Plot of y over x"):
"""
Plots y over x using matplotlib with:
- A line plot + dots on each measurement point
- X limits set exactly to the first and last x values
- Y limits padded by ±10% of the data range
- Custom x/y labels and title
"""
if len(x) == 0 or len(y) == 0:
raise ValueError("x and y must not be empty")
if len(x) != len(y):
raise ValueError("x and y must have the same length")
plt.figure(figsize=(8, 5))
# Plot line
plt.plot(x, y)
# Plot dots at measurement values
plt.scatter(x, y)
# Set x limits exactly to first and last value
plt.xlim(x[0], x[-1])
# Compute y padding
ymin, ymax = min(y), max(y)
yrange = ymax - ymin
pad = yrange * 0.10 if yrange != 0 else 1 # avoid zero-range issue
plt.ylim(ymin - pad, ymax + pad)
# Labels and title
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True)
plt.show()
def plot_mult_xy(x_list, y_list, legends, xlabel="x", ylabel="y", title="Plot"):
"""
Plots multiple y-over-x curves using matplotlib.
Parameters:
x_list : list of x arrays
y_list : list of y arrays
legends : list of legend labels
xlabel : x axis label
ylabel : y axis label
title : plot title
"""
if len(x_list) != len(y_list):
raise ValueError("x_list and y_list must have the same number of elements")
if len(legends) != len(x_list):
raise ValueError("Number of legends must match number of curves")
plt.figure(figsize=(8, 5))
# Track global y range
all_y = []
for x, y, label in zip(x_list, y_list, legends):
if len(x) != len(y):
raise ValueError(f"Length mismatch in dataset '{label}'")
plt.plot(x, y, label=label)
all_y.extend(y)
# Global y-axis padding
ymin, ymax = min(all_y), max(all_y)
yrange = ymax - ymin
pad = yrange * 0.10 if yrange != 0 else 1
plt.ylim(ymin - pad, ymax + pad)
# Global x-axis: from smallest first x to largest last x
x_start = min(x[0] for x in x_list)
x_end = max(x[-1] for x in x_list)
plt.xlim(x_start, x_end)
# Labels, grid, legend
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True)
plt.legend()
plt.show()
def plot_mult_xy_points(x_list, y_list, legends, xlabel="x", ylabel="y", title="Plot"):
"""
Same as plot_xy, but also places dots at each measured point.
"""
if len(x_list) != len(y_list):
raise ValueError("x_list and y_list must have the same number of elements")
if len(legends) != len(x_list):
raise ValueError("Number of legends must match number of curves")
plt.figure(figsize=(8, 5))
all_y = []
for x, y, label in zip(x_list, y_list, legends):
if len(x) != len(y):
raise ValueError(f"Length mismatch in dataset '{label}'")
# Line + dots
plt.plot(x, y, label=label)
plt.scatter(x, y)
all_y.extend(y)
# Y-axis padding
ymin, ymax = min(all_y), max(all_y)
yrange = ymax - ymin
pad = yrange * 0.10 if yrange != 0 else 1
plt.ylim(ymin - pad, ymax + pad)
# X-axis range
x_start = min(x[0] for x in x_list)
x_end = max(x[-1] for x in x_list)
plt.xlim(x_start, x_end)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True)
plt.legend()
plt.show()
def plot_shared_xy(
x,
y_list,
labels=None,
xlabel="X",
ylabel="Y",
title="Plot",
show_points=False,
figsize=(10,6),
linewidth=2,
marker_size=60,
log_x=False,
log_y=False,
save_path=None # <-- NEW
):
"""
Plot multiple y-arrays over the same x-array using shared axes.
Supports optional logarithmic axes, padding, and SVG saving.
"""
x = np.asarray(x)
# Validate y lengths
for idx, y in enumerate(y_list):
if len(y) != len(x):
raise ValueError(f"y_list[{idx}] does not match length of x")
plt.figure(figsize=figsize)
# Plot each curve
for idx, y in enumerate(y_list):
y = np.asarray(y)
label = labels[idx] if labels else None
plt.plot(x, y, linewidth=linewidth, label=label)
if show_points:
plt.scatter(x, y, s=marker_size)
# X-axis range
plt.xlim(x[0], x[-1])
# Y-axis padding
all_y = np.concatenate([np.asarray(y) for y in y_list])
ymin, ymax = np.min(all_y), np.max(all_y)
yrange = ymax - ymin
pad = yrange * 0.10 if yrange != 0 else 1
plt.ylim(ymin - pad, ymax + pad)
# Axis labels
plt.xlabel(xlabel)
plt.ylabel(ylabel)
# Optional log scaling
if log_x:
plt.xscale("log")
if log_y:
plt.yscale("log")
# Title / legend / grid
plt.title(title)
if labels:
plt.legend()
plt.grid(True, which="both")
plt.tight_layout()
# === SVG Saving Logic ===============================
if save_path is not None and str(save_path).strip() != "":
save_path = str(save_path)
# Ensure .svg extension
if not save_path.lower().endswith(".svg"):
save_path += ".svg"
# If file exists, ask for overwrite or rename
if os.path.exists(save_path):
print(f"File '{save_path}' already exists.")
choice = input("Overwrite? (y/n): ").strip().lower()
if choice != "y":
new_path = input("Enter new filename (with or without .svg): ").strip()
if not new_path.lower().endswith(".svg"):
new_path += ".svg"
save_path = new_path
# Save the file now
plt.savefig(save_path, format="svg")
print(f"Saved SVG to: {save_path}")
# =====================================================
plt.show()
def plot_difference(
x,
y1,
y2,
absolute=False,
return_sum=False, # Return sum instead of array
xlabel="X",
ylabel="Difference",
title="Difference Plot",
show_points=False,
figsize=(10,6),
linewidth=2,
marker_size=60,
log_x=False,
log_y=False,
save_path=None
):
"""
Plot the difference between two y-series over the same x.
Optionally compute absolute difference and optionally return the sum
of the difference values instead of the full array.
NOTE: This version intentionally provides NO LEGEND.
"""
# Convert to arrays
x = np.asarray(x)
y1 = np.asarray(y1)
y2 = np.asarray(y2)
# Validate size
if len(y1) != len(x) or len(y2) != len(x):
raise ValueError("y1 and y2 must match the length of x")
# Compute difference
diff = y2 - y1
if absolute:
diff = np.abs(diff)
# === Plotting ===
plt.figure(figsize=figsize)
plt.plot(
x,
diff,
linewidth=linewidth
)
if show_points:
plt.scatter(x, diff, s=marker_size)
# X limits
plt.xlim(x[0], x[-1])
# Y padding
ymin, ymax = np.min(diff), np.max(diff)
yrange = ymax - ymin
pad = yrange * 0.10 if yrange != 0 else 1
plt.ylim(ymin - pad, ymax + pad)
# Labels & styling
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True, which="both")
# Apply log scales
if log_x:
plt.xscale("log")
if log_y:
plt.yscale("log")
plt.tight_layout()
# === SVG Saving Logic ===
if save_path is not None and str(save_path).strip() != "":
save_path = str(save_path)
if not save_path.lower().endswith(".svg"):
save_path += ".svg"
if os.path.exists(save_path):
print(f"File '{save_path}' already exists.")
choice = input("Overwrite? (y/n): ").strip().lower()
if choice != "y":
new_path = input(
"Enter new filename (with or without .svg): "
).strip()
if not new_path.lower().endswith(".svg"):
new_path += ".svg"
save_path = new_path
plt.savefig(save_path, format="svg")
print(f"Saved SVG to: {save_path}")
# ===========================
plt.show()
# === Return requested output ===
if return_sum:
return np.sum(diff)
else:
return diff
if __name__ == "__main__":
# Example usage
# x = [1, 2, 3, 4, 5]
# y = [2, 3, 5, 7, 11]
# plot_xy_points(x, y, "X-axis", "Y-axis", "Sample Plot")
# x1 = [0, 1, 2, 3]
# y1 = [2, 5, 3, 6]
# x2 = [0, 1, 2, 3]
# y2 = [1, 4, 2, 7]
# plot_mult_xy_points(
# [x1, x2],
# [y1, y2],
# legends=["Sensor A", "Sensor B"],
# xlabel="Time (s)",
# ylabel="Value",
# title="Line Plot"
# )
x = [0, 1, 2, 3]
y1 = [10, 20, 18, 30]
y2 = [12, 22, 19, 29]
y3 = [8, 18, 17, 27]
plot_shared_xy(
x=[0,1,2,3],
y_list=[[10,20,18,30], [12,21,17,28]],
labels=["A","B"],
show_points=True,
save_path="my_plot.svg"
)