Added done Work
This commit is contained in:
402
plot_lib.py
Normal file
402
plot_lib.py
Normal file
@@ -0,0 +1,402 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
def plot_xy(x, y, xlabel="x", ylabel="y", title="Plot of y over x"):
|
||||
"""
|
||||
Plots y over x using matplotlib with:
|
||||
- X limits set exactly to the first and last x values
|
||||
- Y limits padded by ±10% of the data range
|
||||
- Custom x/y labels and title
|
||||
"""
|
||||
if len(x) == 0 or len(y) == 0:
|
||||
raise ValueError("x and y must not be empty")
|
||||
if len(x) != len(y):
|
||||
raise ValueError("x and y must have the same length")
|
||||
|
||||
plt.figure(figsize=(8, 5))
|
||||
plt.plot(x, y)
|
||||
|
||||
# Set x-axis exactly to first and last value
|
||||
plt.xlim(x[0], x[-1])
|
||||
|
||||
# Compute y padding
|
||||
ymin, ymax = min(y), max(y)
|
||||
yrange = ymax - ymin
|
||||
pad = yrange * 0.10 if yrange != 0 else 1 # avoid zero-range issue
|
||||
|
||||
plt.ylim(ymin - pad, ymax + pad)
|
||||
|
||||
# Labels and title
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(title)
|
||||
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
def plot_xy_points(x, y, xlabel="x", ylabel="y", title="Plot of y over x"):
|
||||
"""
|
||||
Plots y over x using matplotlib with:
|
||||
- A line plot + dots on each measurement point
|
||||
- X limits set exactly to the first and last x values
|
||||
- Y limits padded by ±10% of the data range
|
||||
- Custom x/y labels and title
|
||||
"""
|
||||
if len(x) == 0 or len(y) == 0:
|
||||
raise ValueError("x and y must not be empty")
|
||||
if len(x) != len(y):
|
||||
raise ValueError("x and y must have the same length")
|
||||
|
||||
plt.figure(figsize=(8, 5))
|
||||
|
||||
# Plot line
|
||||
plt.plot(x, y)
|
||||
|
||||
# Plot dots at measurement values
|
||||
plt.scatter(x, y)
|
||||
|
||||
# Set x limits exactly to first and last value
|
||||
plt.xlim(x[0], x[-1])
|
||||
|
||||
# Compute y padding
|
||||
ymin, ymax = min(y), max(y)
|
||||
yrange = ymax - ymin
|
||||
pad = yrange * 0.10 if yrange != 0 else 1 # avoid zero-range issue
|
||||
|
||||
plt.ylim(ymin - pad, ymax + pad)
|
||||
|
||||
# Labels and title
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(title)
|
||||
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
def plot_mult_xy(x_list, y_list, legends, xlabel="x", ylabel="y", title="Plot"):
|
||||
"""
|
||||
Plots multiple y-over-x curves using matplotlib.
|
||||
|
||||
Parameters:
|
||||
x_list : list of x arrays
|
||||
y_list : list of y arrays
|
||||
legends : list of legend labels
|
||||
xlabel : x axis label
|
||||
ylabel : y axis label
|
||||
title : plot title
|
||||
"""
|
||||
|
||||
if len(x_list) != len(y_list):
|
||||
raise ValueError("x_list and y_list must have the same number of elements")
|
||||
if len(legends) != len(x_list):
|
||||
raise ValueError("Number of legends must match number of curves")
|
||||
|
||||
plt.figure(figsize=(8, 5))
|
||||
|
||||
# Track global y range
|
||||
all_y = []
|
||||
|
||||
for x, y, label in zip(x_list, y_list, legends):
|
||||
if len(x) != len(y):
|
||||
raise ValueError(f"Length mismatch in dataset '{label}'")
|
||||
|
||||
plt.plot(x, y, label=label)
|
||||
all_y.extend(y)
|
||||
|
||||
# Global y-axis padding
|
||||
ymin, ymax = min(all_y), max(all_y)
|
||||
yrange = ymax - ymin
|
||||
pad = yrange * 0.10 if yrange != 0 else 1
|
||||
plt.ylim(ymin - pad, ymax + pad)
|
||||
|
||||
# Global x-axis: from smallest first x to largest last x
|
||||
x_start = min(x[0] for x in x_list)
|
||||
x_end = max(x[-1] for x in x_list)
|
||||
plt.xlim(x_start, x_end)
|
||||
|
||||
# Labels, grid, legend
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(title)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
def plot_mult_xy_points(x_list, y_list, legends, xlabel="x", ylabel="y", title="Plot"):
|
||||
"""
|
||||
Same as plot_xy, but also places dots at each measured point.
|
||||
"""
|
||||
|
||||
if len(x_list) != len(y_list):
|
||||
raise ValueError("x_list and y_list must have the same number of elements")
|
||||
if len(legends) != len(x_list):
|
||||
raise ValueError("Number of legends must match number of curves")
|
||||
|
||||
plt.figure(figsize=(8, 5))
|
||||
|
||||
all_y = []
|
||||
|
||||
for x, y, label in zip(x_list, y_list, legends):
|
||||
if len(x) != len(y):
|
||||
raise ValueError(f"Length mismatch in dataset '{label}'")
|
||||
|
||||
# Line + dots
|
||||
plt.plot(x, y, label=label)
|
||||
plt.scatter(x, y)
|
||||
|
||||
all_y.extend(y)
|
||||
|
||||
# Y-axis padding
|
||||
ymin, ymax = min(all_y), max(all_y)
|
||||
yrange = ymax - ymin
|
||||
pad = yrange * 0.10 if yrange != 0 else 1
|
||||
plt.ylim(ymin - pad, ymax + pad)
|
||||
|
||||
# X-axis range
|
||||
x_start = min(x[0] for x in x_list)
|
||||
x_end = max(x[-1] for x in x_list)
|
||||
plt.xlim(x_start, x_end)
|
||||
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(title)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_shared_xy(
|
||||
x,
|
||||
y_list,
|
||||
labels=None,
|
||||
xlabel="X",
|
||||
ylabel="Y",
|
||||
title="Plot",
|
||||
show_points=False,
|
||||
figsize=(10,6),
|
||||
linewidth=2,
|
||||
marker_size=60,
|
||||
log_x=False,
|
||||
log_y=False,
|
||||
save_path=None # <-- NEW
|
||||
):
|
||||
"""
|
||||
Plot multiple y-arrays over the same x-array using shared axes.
|
||||
Supports optional logarithmic axes, padding, and SVG saving.
|
||||
"""
|
||||
|
||||
x = np.asarray(x)
|
||||
|
||||
# Validate y lengths
|
||||
for idx, y in enumerate(y_list):
|
||||
if len(y) != len(x):
|
||||
raise ValueError(f"y_list[{idx}] does not match length of x")
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
|
||||
# Plot each curve
|
||||
for idx, y in enumerate(y_list):
|
||||
y = np.asarray(y)
|
||||
label = labels[idx] if labels else None
|
||||
plt.plot(x, y, linewidth=linewidth, label=label)
|
||||
if show_points:
|
||||
plt.scatter(x, y, s=marker_size)
|
||||
|
||||
# X-axis range
|
||||
plt.xlim(x[0], x[-1])
|
||||
|
||||
# Y-axis padding
|
||||
all_y = np.concatenate([np.asarray(y) for y in y_list])
|
||||
ymin, ymax = np.min(all_y), np.max(all_y)
|
||||
yrange = ymax - ymin
|
||||
pad = yrange * 0.10 if yrange != 0 else 1
|
||||
plt.ylim(ymin - pad, ymax + pad)
|
||||
|
||||
# Axis labels
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
|
||||
# Optional log scaling
|
||||
if log_x:
|
||||
plt.xscale("log")
|
||||
if log_y:
|
||||
plt.yscale("log")
|
||||
|
||||
# Title / legend / grid
|
||||
plt.title(title)
|
||||
if labels:
|
||||
plt.legend()
|
||||
plt.grid(True, which="both")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# === SVG Saving Logic ===============================
|
||||
if save_path is not None and str(save_path).strip() != "":
|
||||
save_path = str(save_path)
|
||||
|
||||
# Ensure .svg extension
|
||||
if not save_path.lower().endswith(".svg"):
|
||||
save_path += ".svg"
|
||||
|
||||
# If file exists, ask for overwrite or rename
|
||||
if os.path.exists(save_path):
|
||||
print(f"File '{save_path}' already exists.")
|
||||
choice = input("Overwrite? (y/n): ").strip().lower()
|
||||
|
||||
if choice != "y":
|
||||
new_path = input("Enter new filename (with or without .svg): ").strip()
|
||||
if not new_path.lower().endswith(".svg"):
|
||||
new_path += ".svg"
|
||||
save_path = new_path
|
||||
|
||||
# Save the file now
|
||||
plt.savefig(save_path, format="svg")
|
||||
print(f"Saved SVG to: {save_path}")
|
||||
# =====================================================
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_difference(
|
||||
x,
|
||||
y1,
|
||||
y2,
|
||||
absolute=False,
|
||||
return_sum=False, # Return sum instead of array
|
||||
xlabel="X",
|
||||
ylabel="Difference",
|
||||
title="Difference Plot",
|
||||
show_points=False,
|
||||
figsize=(10,6),
|
||||
linewidth=2,
|
||||
marker_size=60,
|
||||
log_x=False,
|
||||
log_y=False,
|
||||
save_path=None
|
||||
):
|
||||
"""
|
||||
Plot the difference between two y-series over the same x.
|
||||
Optionally compute absolute difference and optionally return the sum
|
||||
of the difference values instead of the full array.
|
||||
|
||||
NOTE: This version intentionally provides NO LEGEND.
|
||||
"""
|
||||
|
||||
# Convert to arrays
|
||||
x = np.asarray(x)
|
||||
y1 = np.asarray(y1)
|
||||
y2 = np.asarray(y2)
|
||||
|
||||
# Validate size
|
||||
if len(y1) != len(x) or len(y2) != len(x):
|
||||
raise ValueError("y1 and y2 must match the length of x")
|
||||
|
||||
# Compute difference
|
||||
diff = y2 - y1
|
||||
if absolute:
|
||||
diff = np.abs(diff)
|
||||
|
||||
# === Plotting ===
|
||||
plt.figure(figsize=figsize)
|
||||
|
||||
plt.plot(
|
||||
x,
|
||||
diff,
|
||||
linewidth=linewidth
|
||||
)
|
||||
|
||||
if show_points:
|
||||
plt.scatter(x, diff, s=marker_size)
|
||||
|
||||
# X limits
|
||||
plt.xlim(x[0], x[-1])
|
||||
|
||||
# Y padding
|
||||
ymin, ymax = np.min(diff), np.max(diff)
|
||||
yrange = ymax - ymin
|
||||
pad = yrange * 0.10 if yrange != 0 else 1
|
||||
plt.ylim(ymin - pad, ymax + pad)
|
||||
|
||||
# Labels & styling
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(title)
|
||||
plt.grid(True, which="both")
|
||||
|
||||
# Apply log scales
|
||||
if log_x:
|
||||
plt.xscale("log")
|
||||
if log_y:
|
||||
plt.yscale("log")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# === SVG Saving Logic ===
|
||||
if save_path is not None and str(save_path).strip() != "":
|
||||
save_path = str(save_path)
|
||||
|
||||
if not save_path.lower().endswith(".svg"):
|
||||
save_path += ".svg"
|
||||
|
||||
if os.path.exists(save_path):
|
||||
print(f"File '{save_path}' already exists.")
|
||||
choice = input("Overwrite? (y/n): ").strip().lower()
|
||||
if choice != "y":
|
||||
new_path = input(
|
||||
"Enter new filename (with or without .svg): "
|
||||
).strip()
|
||||
if not new_path.lower().endswith(".svg"):
|
||||
new_path += ".svg"
|
||||
save_path = new_path
|
||||
|
||||
plt.savefig(save_path, format="svg")
|
||||
print(f"Saved SVG to: {save_path}")
|
||||
# ===========================
|
||||
|
||||
plt.show()
|
||||
|
||||
# === Return requested output ===
|
||||
if return_sum:
|
||||
return np.sum(diff)
|
||||
else:
|
||||
return diff
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
# x = [1, 2, 3, 4, 5]
|
||||
# y = [2, 3, 5, 7, 11]
|
||||
# plot_xy_points(x, y, "X-axis", "Y-axis", "Sample Plot")
|
||||
|
||||
# x1 = [0, 1, 2, 3]
|
||||
# y1 = [2, 5, 3, 6]
|
||||
|
||||
# x2 = [0, 1, 2, 3]
|
||||
# y2 = [1, 4, 2, 7]
|
||||
|
||||
# plot_mult_xy_points(
|
||||
# [x1, x2],
|
||||
# [y1, y2],
|
||||
# legends=["Sensor A", "Sensor B"],
|
||||
# xlabel="Time (s)",
|
||||
# ylabel="Value",
|
||||
# title="Line Plot"
|
||||
# )
|
||||
|
||||
|
||||
x = [0, 1, 2, 3]
|
||||
y1 = [10, 20, 18, 30]
|
||||
y2 = [12, 22, 19, 29]
|
||||
y3 = [8, 18, 17, 27]
|
||||
|
||||
plot_shared_xy(
|
||||
x=[0,1,2,3],
|
||||
y_list=[[10,20,18,30], [12,21,17,28]],
|
||||
labels=["A","B"],
|
||||
show_points=True,
|
||||
save_path="my_plot.svg"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user