403 lines
10 KiB
Python
403 lines
10 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import os
|
|
|
|
|
|
def plot_xy(x, y, xlabel="x", ylabel="y", title="Plot of y over x"):
|
|
"""
|
|
Plots y over x using matplotlib with:
|
|
- X limits set exactly to the first and last x values
|
|
- Y limits padded by ±10% of the data range
|
|
- Custom x/y labels and title
|
|
"""
|
|
if len(x) == 0 or len(y) == 0:
|
|
raise ValueError("x and y must not be empty")
|
|
if len(x) != len(y):
|
|
raise ValueError("x and y must have the same length")
|
|
|
|
plt.figure(figsize=(8, 5))
|
|
plt.plot(x, y)
|
|
|
|
# Set x-axis exactly to first and last value
|
|
plt.xlim(x[0], x[-1])
|
|
|
|
# Compute y padding
|
|
ymin, ymax = min(y), max(y)
|
|
yrange = ymax - ymin
|
|
pad = yrange * 0.10 if yrange != 0 else 1 # avoid zero-range issue
|
|
|
|
plt.ylim(ymin - pad, ymax + pad)
|
|
|
|
# Labels and title
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
plt.title(title)
|
|
|
|
plt.grid(True)
|
|
plt.show()
|
|
|
|
def plot_xy_points(x, y, xlabel="x", ylabel="y", title="Plot of y over x"):
|
|
"""
|
|
Plots y over x using matplotlib with:
|
|
- A line plot + dots on each measurement point
|
|
- X limits set exactly to the first and last x values
|
|
- Y limits padded by ±10% of the data range
|
|
- Custom x/y labels and title
|
|
"""
|
|
if len(x) == 0 or len(y) == 0:
|
|
raise ValueError("x and y must not be empty")
|
|
if len(x) != len(y):
|
|
raise ValueError("x and y must have the same length")
|
|
|
|
plt.figure(figsize=(8, 5))
|
|
|
|
# Plot line
|
|
plt.plot(x, y)
|
|
|
|
# Plot dots at measurement values
|
|
plt.scatter(x, y)
|
|
|
|
# Set x limits exactly to first and last value
|
|
plt.xlim(x[0], x[-1])
|
|
|
|
# Compute y padding
|
|
ymin, ymax = min(y), max(y)
|
|
yrange = ymax - ymin
|
|
pad = yrange * 0.10 if yrange != 0 else 1 # avoid zero-range issue
|
|
|
|
plt.ylim(ymin - pad, ymax + pad)
|
|
|
|
# Labels and title
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
plt.title(title)
|
|
|
|
plt.grid(True)
|
|
plt.show()
|
|
|
|
def plot_mult_xy(x_list, y_list, legends, xlabel="x", ylabel="y", title="Plot"):
|
|
"""
|
|
Plots multiple y-over-x curves using matplotlib.
|
|
|
|
Parameters:
|
|
x_list : list of x arrays
|
|
y_list : list of y arrays
|
|
legends : list of legend labels
|
|
xlabel : x axis label
|
|
ylabel : y axis label
|
|
title : plot title
|
|
"""
|
|
|
|
if len(x_list) != len(y_list):
|
|
raise ValueError("x_list and y_list must have the same number of elements")
|
|
if len(legends) != len(x_list):
|
|
raise ValueError("Number of legends must match number of curves")
|
|
|
|
plt.figure(figsize=(8, 5))
|
|
|
|
# Track global y range
|
|
all_y = []
|
|
|
|
for x, y, label in zip(x_list, y_list, legends):
|
|
if len(x) != len(y):
|
|
raise ValueError(f"Length mismatch in dataset '{label}'")
|
|
|
|
plt.plot(x, y, label=label)
|
|
all_y.extend(y)
|
|
|
|
# Global y-axis padding
|
|
ymin, ymax = min(all_y), max(all_y)
|
|
yrange = ymax - ymin
|
|
pad = yrange * 0.10 if yrange != 0 else 1
|
|
plt.ylim(ymin - pad, ymax + pad)
|
|
|
|
# Global x-axis: from smallest first x to largest last x
|
|
x_start = min(x[0] for x in x_list)
|
|
x_end = max(x[-1] for x in x_list)
|
|
plt.xlim(x_start, x_end)
|
|
|
|
# Labels, grid, legend
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
plt.title(title)
|
|
plt.grid(True)
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
|
|
|
|
def plot_mult_xy_points(x_list, y_list, legends, xlabel="x", ylabel="y", title="Plot"):
|
|
"""
|
|
Same as plot_xy, but also places dots at each measured point.
|
|
"""
|
|
|
|
if len(x_list) != len(y_list):
|
|
raise ValueError("x_list and y_list must have the same number of elements")
|
|
if len(legends) != len(x_list):
|
|
raise ValueError("Number of legends must match number of curves")
|
|
|
|
plt.figure(figsize=(8, 5))
|
|
|
|
all_y = []
|
|
|
|
for x, y, label in zip(x_list, y_list, legends):
|
|
if len(x) != len(y):
|
|
raise ValueError(f"Length mismatch in dataset '{label}'")
|
|
|
|
# Line + dots
|
|
plt.plot(x, y, label=label)
|
|
plt.scatter(x, y)
|
|
|
|
all_y.extend(y)
|
|
|
|
# Y-axis padding
|
|
ymin, ymax = min(all_y), max(all_y)
|
|
yrange = ymax - ymin
|
|
pad = yrange * 0.10 if yrange != 0 else 1
|
|
plt.ylim(ymin - pad, ymax + pad)
|
|
|
|
# X-axis range
|
|
x_start = min(x[0] for x in x_list)
|
|
x_end = max(x[-1] for x in x_list)
|
|
plt.xlim(x_start, x_end)
|
|
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
plt.title(title)
|
|
plt.grid(True)
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
|
|
def plot_shared_xy(
|
|
x,
|
|
y_list,
|
|
labels=None,
|
|
xlabel="X",
|
|
ylabel="Y",
|
|
title="Plot",
|
|
show_points=False,
|
|
figsize=(10,6),
|
|
linewidth=2,
|
|
marker_size=60,
|
|
log_x=False,
|
|
log_y=False,
|
|
save_path=None # <-- NEW
|
|
):
|
|
"""
|
|
Plot multiple y-arrays over the same x-array using shared axes.
|
|
Supports optional logarithmic axes, padding, and SVG saving.
|
|
"""
|
|
|
|
x = np.asarray(x)
|
|
|
|
# Validate y lengths
|
|
for idx, y in enumerate(y_list):
|
|
if len(y) != len(x):
|
|
raise ValueError(f"y_list[{idx}] does not match length of x")
|
|
|
|
plt.figure(figsize=figsize)
|
|
|
|
# Plot each curve
|
|
for idx, y in enumerate(y_list):
|
|
y = np.asarray(y)
|
|
label = labels[idx] if labels else None
|
|
plt.plot(x, y, linewidth=linewidth, label=label)
|
|
if show_points:
|
|
plt.scatter(x, y, s=marker_size)
|
|
|
|
# X-axis range
|
|
plt.xlim(x[0], x[-1])
|
|
|
|
# Y-axis padding
|
|
all_y = np.concatenate([np.asarray(y) for y in y_list])
|
|
ymin, ymax = np.min(all_y), np.max(all_y)
|
|
yrange = ymax - ymin
|
|
pad = yrange * 0.10 if yrange != 0 else 1
|
|
plt.ylim(ymin - pad, ymax + pad)
|
|
|
|
# Axis labels
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
|
|
# Optional log scaling
|
|
if log_x:
|
|
plt.xscale("log")
|
|
if log_y:
|
|
plt.yscale("log")
|
|
|
|
# Title / legend / grid
|
|
plt.title(title)
|
|
if labels:
|
|
plt.legend()
|
|
plt.grid(True, which="both")
|
|
|
|
plt.tight_layout()
|
|
|
|
# === SVG Saving Logic ===============================
|
|
if save_path is not None and str(save_path).strip() != "":
|
|
save_path = str(save_path)
|
|
|
|
# Ensure .svg extension
|
|
if not save_path.lower().endswith(".svg"):
|
|
save_path += ".svg"
|
|
|
|
# If file exists, ask for overwrite or rename
|
|
if os.path.exists(save_path):
|
|
print(f"File '{save_path}' already exists.")
|
|
choice = input("Overwrite? (y/n): ").strip().lower()
|
|
|
|
if choice != "y":
|
|
new_path = input("Enter new filename (with or without .svg): ").strip()
|
|
if not new_path.lower().endswith(".svg"):
|
|
new_path += ".svg"
|
|
save_path = new_path
|
|
|
|
# Save the file now
|
|
plt.savefig(save_path, format="svg")
|
|
print(f"Saved SVG to: {save_path}")
|
|
# =====================================================
|
|
|
|
plt.show()
|
|
|
|
def plot_difference(
|
|
x,
|
|
y1,
|
|
y2,
|
|
absolute=False,
|
|
return_sum=False, # Return sum instead of array
|
|
xlabel="X",
|
|
ylabel="Difference",
|
|
title="Difference Plot",
|
|
show_points=False,
|
|
figsize=(10,6),
|
|
linewidth=2,
|
|
marker_size=60,
|
|
log_x=False,
|
|
log_y=False,
|
|
save_path=None
|
|
):
|
|
"""
|
|
Plot the difference between two y-series over the same x.
|
|
Optionally compute absolute difference and optionally return the sum
|
|
of the difference values instead of the full array.
|
|
|
|
NOTE: This version intentionally provides NO LEGEND.
|
|
"""
|
|
|
|
# Convert to arrays
|
|
x = np.asarray(x)
|
|
y1 = np.asarray(y1)
|
|
y2 = np.asarray(y2)
|
|
|
|
# Validate size
|
|
if len(y1) != len(x) or len(y2) != len(x):
|
|
raise ValueError("y1 and y2 must match the length of x")
|
|
|
|
# Compute difference
|
|
diff = y2 - y1
|
|
if absolute:
|
|
diff = np.abs(diff)
|
|
|
|
# === Plotting ===
|
|
plt.figure(figsize=figsize)
|
|
|
|
plt.plot(
|
|
x,
|
|
diff,
|
|
linewidth=linewidth
|
|
)
|
|
|
|
if show_points:
|
|
plt.scatter(x, diff, s=marker_size)
|
|
|
|
# X limits
|
|
plt.xlim(x[0], x[-1])
|
|
|
|
# Y padding
|
|
ymin, ymax = np.min(diff), np.max(diff)
|
|
yrange = ymax - ymin
|
|
pad = yrange * 0.10 if yrange != 0 else 1
|
|
plt.ylim(ymin - pad, ymax + pad)
|
|
|
|
# Labels & styling
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
plt.title(title)
|
|
plt.grid(True, which="both")
|
|
|
|
# Apply log scales
|
|
if log_x:
|
|
plt.xscale("log")
|
|
if log_y:
|
|
plt.yscale("log")
|
|
|
|
plt.tight_layout()
|
|
|
|
# === SVG Saving Logic ===
|
|
if save_path is not None and str(save_path).strip() != "":
|
|
save_path = str(save_path)
|
|
|
|
if not save_path.lower().endswith(".svg"):
|
|
save_path += ".svg"
|
|
|
|
if os.path.exists(save_path):
|
|
print(f"File '{save_path}' already exists.")
|
|
choice = input("Overwrite? (y/n): ").strip().lower()
|
|
if choice != "y":
|
|
new_path = input(
|
|
"Enter new filename (with or without .svg): "
|
|
).strip()
|
|
if not new_path.lower().endswith(".svg"):
|
|
new_path += ".svg"
|
|
save_path = new_path
|
|
|
|
plt.savefig(save_path, format="svg")
|
|
print(f"Saved SVG to: {save_path}")
|
|
# ===========================
|
|
|
|
plt.show()
|
|
|
|
# === Return requested output ===
|
|
if return_sum:
|
|
return np.sum(diff)
|
|
else:
|
|
return diff
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
# x = [1, 2, 3, 4, 5]
|
|
# y = [2, 3, 5, 7, 11]
|
|
# plot_xy_points(x, y, "X-axis", "Y-axis", "Sample Plot")
|
|
|
|
# x1 = [0, 1, 2, 3]
|
|
# y1 = [2, 5, 3, 6]
|
|
|
|
# x2 = [0, 1, 2, 3]
|
|
# y2 = [1, 4, 2, 7]
|
|
|
|
# plot_mult_xy_points(
|
|
# [x1, x2],
|
|
# [y1, y2],
|
|
# legends=["Sensor A", "Sensor B"],
|
|
# xlabel="Time (s)",
|
|
# ylabel="Value",
|
|
# title="Line Plot"
|
|
# )
|
|
|
|
|
|
x = [0, 1, 2, 3]
|
|
y1 = [10, 20, 18, 30]
|
|
y2 = [12, 22, 19, 29]
|
|
y3 = [8, 18, 17, 27]
|
|
|
|
plot_shared_xy(
|
|
x=[0,1,2,3],
|
|
y_list=[[10,20,18,30], [12,21,17,28]],
|
|
labels=["A","B"],
|
|
show_points=True,
|
|
save_path="my_plot.svg"
|
|
)
|
|
|