huggingface/plot.py
2025-08-18 13:22:25 +02:00

97 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
def plotVariables(x,y):
# scatter plot (matplotlib, single plot, no explicit colors)
#plt.ylim(0, 1)
plt.figure()
plt.scatter(x, y)
plt.axhline(0, linestyle="--") # reference line
plt.xlabel(x.name)
plt.ylabel(y.name)
plt.title("diff_price vs. confidence")
plt.grid(True)
plt.show()
def plotPriceConfidence(condensed):
# pick the right confidence column
conf_col = "confidence" if "confidence" in condensed.columns else (
"resp_confidence" if "resp_confidence" in condensed.columns else None
)
if conf_col is None:
raise KeyError("No 'confidence' or 'resp_confidence' column found in condensed.")
# keep only the needed columns and coerce to numeric
dfp = condensed[[conf_col, "diff_price"]].copy()
dfp[conf_col] = pd.to_numeric(dfp[conf_col], errors="coerce")
dfp["diff_price"] = pd.to_numeric(dfp["diff_price"], errors="coerce")
dfp = dfp.dropna(subset=[conf_col, "diff_price"])
# scatter plot (matplotlib, single plot, no explicit colors)
#plt.ylim(0, 1)
plt.figure()
plt.scatter(dfp[conf_col], dfp["diff_price"])
plt.axhline(0, linestyle="--") # reference line
plt.xlabel(conf_col)
plt.ylabel("diff_price")
plt.title("diff_price vs. confidence")
plt.grid(True)
plt.show()
def histPriceDiff(condensed):
conf_col = (
"confidence" if "confidence" in condensed.columns
else "resp_confidence" if "resp_confidence" in condensed.columns
else None
)
if conf_col is None:
raise KeyError("No 'confidence' or 'resp_confidence' column in condensed.")
# --- prepare data ---
df = condensed[[conf_col, "diff_price"]].copy()
df[conf_col] = pd.to_numeric(df[conf_col], errors="coerce")
df["diff_price"] = pd.to_numeric(df["diff_price"], errors="coerce")
df = df.dropna(subset=[conf_col, "diff_price"])
# scale confidence to 0100 if it looks like 01
if df[conf_col].max() <= 1.01:
df[conf_col] = df[conf_col] * 100
# --- define bands ---
bands = [
("confidence == 100", df[ df[conf_col] == 100 ]),
("100 > confidence ≥ 90", df[(df[conf_col] < 100) & (df[conf_col] >= 90)]),
("90 > confidence ≥ 80", df[(df[conf_col] < 90) & (df[conf_col] >= 80)]),
("80 > confidence ≥ 50", df[(df[conf_col] < 80) & (df[conf_col] >= 50)]),
("50 > confidence", df[ df[conf_col] < 50 ]),
]
# --- common bins across all groups for fair comparison ---
all_vals = df["diff_price"].values
if all_vals.size == 0:
raise ValueError("No numeric diff_price values to plot.")
xmin, xmax = np.nanmin(all_vals), np.nanmax(all_vals)
if xmin == xmax:
# degenerate case: make a tiny range around the single value
xmin, xmax = xmin - 0.5, xmax + 0.5
bins = np.linspace(xmin, xmax, 31) # 30 bins
# --- plot each histogram in its own figure (no subplots, no explicit colors) ---
for title, d in bands:
if d.empty:
print(f"[skip] {title}: no rows")
continue
plt.figure()
plt.hist(d["diff_price"].values, bins=bins)
plt.title(f"diff_price for {title}")
plt.xlabel("diff_price")
plt.ylabel("count")
plt.grid(True)
plt.show()
# (optional) quick counts per band
for title, d in bands:
print(f"{title}: {len(d)} rows")