Notice

This is a raw view of the Python source code due to an error in generating the documentation.

Date of Conversion: 2025-10-11 23:32:44

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug  6 18:30:27 2025

@author: olivier.vitrac@gmail.com

    Full P' model

        E = A * (P')² + B * P' + C
        P' = (-B - sqrt(B² - 4A(C - E))) / (2A)

        where:
        - E = logP * ln(10) - S = Xw - Xo
        - S = entropy contribution = - (V/Vw - V/Vo)
"""

# %% 1. Set Up & Imports
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from figprint import print_figure
# Constants
Vw=19.588376948550433  # migrant("water").molarvolumeMiller
Vo=150.26143432234372  # migrant("octanol").molarvolumeMiller

def n_r(Vi, Vj, rcritical = 3):
    """
    Compressibility correction n(r) = r/5 - 1 for r >= 1, else 0,
    where r = Vi / Vj.

    Parameters
    ----------
    Vi : array_like
        Molar volumes of solutes i (cm^3/mol).
    Vj : float or array_like
        Molar volume(s) of phase j (cm^3/mol), e.g., Vw or Vo.

    Returns
    -------
    n : ndarray
        Compressibility correction(s), same shape as broadcast(Vi, Vj).
    """
    Vi = np.asarray(Vi, dtype=float)
    Vj = np.asarray(Vj, dtype=float)
    r = Vi / Vj
    n = np.zeros_like(r)
    mask = r >= rcritical
    n[mask] = r[mask] / rcritical - 1.0
    return n


def S(Vi, Vj):
    """
    Entropic contribution S(Vi, Vj) = 1 - [ r_ij - n(r_ij) ].

    Parameters
    ----------
    Vi, Vj : array_like
        Molar volumes (cm^3/mol) of solute (i) and phase (j).

    Returns
    -------
    S_val : ndarray
        Entropic contribution(s), same shape as broadcast(Vi, Vj).
    """
    r = np.asarray(Vi, dtype=float) / np.asarray(Vj, dtype=float)
    return 1.0 - (r - n_r(Vi, Vj))

# %% 2. Tuned Reference Dataset (8 solvents)
solvents = [
    "Water", "Methanol", "Ethanol", "Acetone", "Acetonitrile",
    "Dichloromethane", "Toluene", "n-Hexane"
]

# Polarity Index (P') with manual "tweaks" to get a smooth scale
polarity_index = [10.2,          # Water
                  8.1,           # Methanol (5.1 + 3.0)
                  6.0,           # Ethanol (4.3 + 1.7)
                  5.6,           # Acetone (5.1 + 0.5)
                  6.8,           # Acetonitrile (5.8 + 1.0)
                  3.1,           # Dichloromethane
                  1.8,           # Toluene
                  0.0            # n-Hexane
                 ]

# logP data from literature
logP_values = [-1.38,  # Water
               -0.77,  # Methanol
               -0.24,  # Ethanol
               -0.21,  # Acetone
               -0.22,  # Acetonitrile
                1.25,  # Dichloromethane
                2.73,  # Toluene
                3.90   # n-Hexane
              ]

# molar volume values
# [float(migrant(s).molarvolumeMiller) for s in solvents]
vol_values = [
    19.588376948550433,
     35.44756135405305,
     51.52475281040704,
     65.40976129570748,
     45.75175043917132,
     96.7449080359383,
     105.21479478698275,
     98.21183889765543
    ]

# Entropy correction
ref_S = S(vol_values,Vw) - S(vol_values,Vo) #- (1/Vw - 1/Vo) * np.array(vol_values)
ref_E = np.array(logP_values) * 2.302585092994046 - ref_S

# Put it into a small DataFrame for easy viewing
df = pd.DataFrame({
    "Solvent": solvents,
    "Polarity Index (P')": polarity_index,
    "logP": logP_values,
    "Miller's molar volume": vol_values,
    'E: Xw-Xo': list(ref_E)
}).sort_values(by="Polarity Index (P')", ascending=True)

df

# --- compact fit -> P'o + neighbors ---
col = "Polarity Index (P')"
E = df["E: Xw-Xo"].to_numpy(float); P = df[col].to_numpy(float); Pw = 10.2
C1, C2 = np.linalg.lstsq(np.c_[np.ones_like(P), -P], E, rcond=None)[0]
Po = 2*C1/C2 - Pw
alpha = C2 / (2*(Pw - Po))

d = df.sort_values(col).reset_index(drop=True)
Ps = d[col].to_numpy(float)
i = np.searchsorted(Ps, Po, side="left")
below = d.iloc[i-1] if i > 0 else None
above = d.iloc[i]   if i < len(d) else None

print(f"P'_o = {Po:.3f}   alpha = {alpha:.5f}")
print("below:", None if below is None else f"{below['Solvent']} @ P'={below[col]:.2f}")
print("above:", None if above is None else f"{above['Solvent']} @ P'={above[col]:.2f}")


# %% 2. Rapid control
fig1 = plt.figure(figsize=(8, 5))

# data points
plt.scatter(polarity_index, ref_E, label="Tuned data (8 solvents)")
for i, solvent in enumerate(solvents):
    plt.annotate(solvent, (polarity_index[i], ref_E[i]),
                 fontsize=8, xytext=(5,5), textcoords='offset points')

# regression line E = C1 - C2 * P'
Pgrid = np.linspace(min(P)-0.2, max(P)+0.2, 200)
Egrid = C1 - C2 * Pgrid
plt.plot(Pgrid, Egrid, linestyle='-', linewidth=1.5,
         label=f"Fit: E = {C1:.2f} - {C2:.2f}·P'")

# octanol (predicted) point + guide line
Eo = C1 - C2 * Po
plt.scatter([Po], [Eo], s=70, marker='o', edgecolor='k', facecolor='orange', label="Octanol (pred)")
plt.axvline(Po, linestyle='--', linewidth=1.0, color='orange', alpha=0.7)

# annotation for octanol
plt.annotate(f"Octanol\nP'={Po:.2f}",
             xy=(Po, Eo), xycoords='data',
             xytext=(10, 10), textcoords='offset points',
             fontsize=9, color='orange',
             arrowprops=dict(arrowstyle="->", color='orange', lw=1.0))

# cosmetics
plt.xlabel("Polarity Index (P')")
plt.ylabel(r"$\chi_{i,w}-\chi_{i,o}$")
plt.title("Tuned Dataset: E vs. P' with regression and octanol prediction")
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()

# save
plt.tight_layout()
fig1.savefig("images/polarityindexv2_fig1.png", dpi=300, bbox_inches='tight')
plt.show()

# %% 3. Tunning: rcritical = 1/s
# --- tune s in n(r)=s*r-1 (r>=1), then compute P'_o, alpha, neighbors ---

def S_with_s(Vi, Vj, s=0.2):
    r = np.asarray(Vi, float) / np.asarray(Vj, float)
    n = np.zeros_like(r)
    m = r >= 1.0/s
    n[m] = s*r[m] - 1.0
    return 1.0 - (r - n)

def fit_for_s(s):
    ref_S = S_with_s(vol_values, Vw, s) - S_with_s(vol_values, Vo, s)
    E = np.array(logP_values)*2.302585092994046 - ref_S
    P = np.array(polarity_index, float)
    C1, C2 = np.linalg.lstsq(np.c_[np.ones_like(P), -P], E, rcond=None)[0]
    Po = 2*C1/C2 - 10.2
    alpha = C2 / (2*(10.2 - Po))
    Ehat = C1 - C2*P
    rmse = np.sqrt(np.mean((E - Ehat)**2))
    return rmse, Po, alpha, C1, C2

# quick grid search over s
s_grid = np.linspace(0.1, 0.7, 61)
best = min(((*fit_for_s(s), s) for s in s_grid), key=lambda t: t[0])
rmse_tunned, Po_tunned, alpha_tunned, C1_tunned, C2_tunned, s_best = best

d = df.sort_values("Polarity Index (P')").reset_index(drop=True)
Ps = d["Polarity Index (P')"].to_numpy(float)
i = np.searchsorted(Ps, Po, side="left")
below = d.iloc[i-1] if i>0 else None
above = d.iloc[i]   if i<len(d) else None

print(f"best s = {s_best:.3f}   P'_o = {Po:.3f}   alpha = {alpha_tunned:.5f}   RMSE={rmse_tunned:.3f}")
col = "Polarity Index (P')"
print("below:", None if below is None else f"{below['Solvent']} @ P'={below[col]:.2f}")
print("above:", None if above is None else f"{above['Solvent']} @ P'={above[col]:.2f}")


# %% 4. Extended Validation Dataset (~35 solvents)

ext_solvents = [
    "Pentane", "1,1,2-Trichlorotrifluoroethane", "Cyclopentane", "Heptane", "Hexane",
    "Iso-Octane", "Petroleum Ether", "Cyclohexane", "n-Butyl Chloride", "Toluene",
    "Methyl t-Butyl Ether", "o-Xylene", "Chlorobenzene", "o-Dichlorobenzene", "Ethyl Ether",
    "Dichloromethane", "Ethylene Dichloride", "n-Butyl Alcohol", "Isopropyl Alcohol",
    "n-Butyl Acetate", "Isobutyl Alcohol", "Methyl Isoamyl Ketone", "n-Propyl Alcohol",
    "Tetrahydrofuran", "Chloroform", "Methyl Isobutyl Ketone", "Ethyl Acetate",
    "Methyl n-Propyl Ketone", "Methyl Ethyl Ketone", "1,4-Dioxane", "Acetone", "Methanol",
    "Pyridine", "2-Methoxyethanol", "Acetonitrile", "Propylene Carbonate", "N,N-Dimethylformamide",
    "Dimethyl Acetamide", "N-Methylpyrrolidone", "Dimethyl Sulfoxide", "Water"
]

ext_polarity_index = [
    0.0, 0.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 1.0, 2.4, 2.5, 2.5, 2.7, 2.7, 2.8,
    3.1, 3.5, 3.9, 3.9, 4.0, 4.0, 4.0, 4.0, 4.0, 4.1, 4.2, 4.4, 4.5, 4.7, 4.8,
    5.1, 5.1, 5.3, 5.5, 5.8, 6.1, 6.4, 6.5, 6.7, 7.2, 10.2
]

ext_logP_values = [
    3.39, 4.30, 3.20, 4.66, 3.90, 4.50, 3.50, 3.44, 2.70, 2.73, 1.20, 3.12, 2.84, 3.38, 0.83,
    1.25, 1.48, 0.88, 0.05, 1.82, 0.79, 1.98, 0.25, 0.46, 1.97, 1.31, 0.73, 1.50, 0.29, -0.27,
    -0.24, -0.77, 0.65, -0.77, -0.22, -0.41, -1.01, -0.77, -0.38, -1.35, -1.38
]

# molar volume values
'''
replacements = {
    "Petroleum Ether": "n-Hexane",  # or "n-Pentane" or "C5–C7 Alkanes mixture"
}
ext_vol_values = [float(migrant(replacements.get(s,s)).molarvolumeMiller) for s in ext_solvents]
'''
ext_vol_values =[
 81.78592761446717, 218.56278542419872, 79.42845452645774,
 114.7067332200619, 98.21183889765543, 131.28307497280687,
 98.21183889765543, 95.84160035654774, 105.72057826620575,
 105.21479478698275, 100.52501450631739, 121.74044796002656,
 129.29478926007326, 170.22839840459193, 84.08695604709226,
 96.7449080359383, 113.24489554606926, 84.08695604709226,
 67.7541442592405, 133.5683188450608, 84.08695604709226,
 131.2357245895863, 67.7541442592405, 81.73922557439461,
 137.37167951549557, 114.6595686581247, 100.47803092828921,
 98.15314934465678, 81.73922557439461, 100.47803092828921,
 65.40976129570748, 35.44756135405305, 89.91185054664918,
 86.38982016238275, 45.75175043917132, 116.93589883743107,
 82.88364826856238, 99.31539198895744, 113.44527642606757,
 88.78810100209161, 19.588376948550433
 ]

# Entropy correction
ext_S = S(ext_vol_values,Vw) - S(ext_vol_values,Vo) #- (1/Vw - 1/Vo) * np.array(ext_vol_values)
ext_E = np.array(ext_logP_values) * 2.302585092994046 - ext_S


# Filter out the ones already in the "tuned" set
validation_solvents,validation_polarity_index,validation_logP_values, \
validation_vol_values, validation_E_values = [],[],[],[],[]

for i, solvent in enumerate(ext_solvents):
    if solvent not in solvents:
        validation_solvents.append(solvent)
        validation_polarity_index.append(ext_polarity_index[i])
        validation_logP_values.append(ext_logP_values[i])
        validation_vol_values.append(ext_vol_values[i])
        validation_E_values.append(ext_E[i])

df_validation = pd.DataFrame({
    "Solvent": validation_solvents,
    "Polarity Index (P')": validation_polarity_index,
    "logP": validation_logP_values,
    "Miller's molar volume": validation_vol_values,
    'E: Xw-Xo': list(validation_E_values)
}).sort_values(by="Polarity Index (P')")

df_validation.head(12)


# %% 5. Fit a Quadratic Model: logP = a (P')^2 + b (P') + c
coefficients = np.polyfit(polarity_index, ref_E, 2)
quadratic_model = np.poly1d(coefficients)

a, b, c = coefficients
print(f"Fitted quadratic coefficients: a={a:.4f}, b={b:.4f}, c={c:.4f}")
print(f"Model =>  Xw-Xo = {a:.4f}*(P')^2 + {b:.4f}*P' + {c:.4f}")

# For plotting:
x_range = np.linspace(min(polarity_index), max(polarity_index), 100)
y_fitted = quadratic_model(x_range)

# %% 6. Implementated Prediction Model
# Prediction model for a solute i:
#
#   for Emin <= E=logP*ln10 - S <= Emax
#      P' = (-B - sqrt(B² - 4A(C - (logP*ln10 - S)))) / (2A)
#      with A,B,C are fitted constants
#   for E>Emin=C-B**2/(4*A), P'=10.2
#   for E>Emax=C, P'=0.0
#
# S is the entropic contribution associated with the difference of the natural
# logarithm of the solute activity coefficients between water and octanol :
#   S = - (1/Vw - 1/Vo) * Vi
#   E = (logP*ln10 - S) represents the difference of FH coefficient between
#   i+water and i+octonol: $\chi_{i,w} - $\chi_{i,o}$
#

Vw=19.588376948550433  # migrant("water").molarvolumeMiller
Vo=150.26143432234372  # migrant("octanol").molarvolumeMiller
A = 0.07485019080020634
B = -2.268683501033584
C = 13.079540672499757
Emin = C - B**2 / (4*A) # -3.677
Emax = C # 14.204

def predict_polarity_index(logP, Vi, A=A, B=B, C=C, Vw=Vw, Vo=Vo):
    """
    Predict the polarity index P' from logP and molar volume Vi,
    using the quadratic model with bounding behavior for extreme values.

    Parameters
    ----------
    logP : float or np.ndarray
        logP value of the solute.
    Vi : float or np.ndarray
        Molar volume of the solute in cm³/mol.
    A, B, C : float
        Quadratic model coefficients (Xw - Xo = A·P'^2 + B·P' + C)
    Vw, Vo : float
        Molar volumes of water and octanol in cm³/mol.

    Returns
    -------
    P_pred : np.ndarray
        Predicted polarity index P'
    """
    # Ensure arrays
    logP = np.asarray(logP)
    Vi = np.asarray(Vi)

    # Step 1: Compute entropy correction S and target E
    Scalc = S(Vi,Vw) - S(Vi,Vo) # - (1/Vw - 1/Vo) * Vi
    E = logP * np.log(10) - Scalc

    # Step 2: Define output array
    P_pred = np.zeros_like(E)

    # Step 3: Identify E regions
    Emin = C - B**2 / (4*A)
    Emax = C

    high_polar_mask = E <= Emin
    low_polar_mask = E >= Emax
    intermediate_mask = ~ (high_polar_mask | low_polar_mask)

    # Step 4: Apply the model only in the valid range
    Emid = E[intermediate_mask]
    D = B**2 - 4*A*(C - Emid)  # guaranteed D ≥ 0 here
    P_pred[intermediate_mask] = (-B - np.sqrt(D)) / (2*A)

    # Saturation bounds
    P_pred[high_polar_mask] = 10.2
    P_pred[low_polar_mask] = 0.0

    # Model capping if extended to ionic or extreme hydrogen-bonding systems
    P_pred = np.clip(P_pred, 0, 10.2)

    return P_pred


# Comparison of prediction models (direct and inverse)
from sklearn.metrics import mean_squared_error, r2_score

# True values
P_true = np.array(polarity_index)

# Prediction via inverse model
E = np.array(logP_values) * np.log(10) - (S(vol_values,Vw) - S(vol_values,Vo))
P_pred_inverse = predict_polarity_index(logP_values, vol_values)

# Direct regression: fit P' = poly(E)
direct_fit = np.polyfit(E, P_true, 2)
P_pred_direct = np.poly1d(direct_fit)(E)

# Compare
rmse_inv = np.sqrt(mean_squared_error(P_true, P_pred_inverse))
rmse_dir = np.sqrt(mean_squared_error(P_true, P_pred_direct))

print(f"RMSE Inverse Model: {rmse_inv:.3f}")
print(f"RMSE Direct Fit (P' vs E): {rmse_dir:.3f}")

# Linear regression: P' = a·E + b
from sklearn.linear_model import LinearRegression

# X and Y
E_array = np.array(logP_values) * np.log(10) - (S(vol_values,Vw) - S(vol_values,Vo)) # + (1/Vw - 1/Vo) * np.array(vol_values)
Pprime_array = np.array(polarity_index)

linreg = LinearRegression()
linreg.fit(E_array.reshape(-1, 1), Pprime_array)

a_lin = linreg.coef_[0]
b_lin = linreg.intercept_
print(f"Linear fit:  P' = {a_lin:.4f}·E + {b_lin:.4f}")

# Predict and evaluate
P_pred_lin = linreg.predict(E_array.reshape(-1, 1))
rmse_lin = np.sqrt(np.mean((P_pred_lin - Pprime_array) ** 2))
print(f"RMSE Linear Fit: {rmse_lin:.3f}")

# %% 7. Alpha recalibration on Phat
Phat = predict_polarity_index(logP_values, vol_values)
Ehat = quadratic_model(Phat)
Dhat = (Pw-Po)*(Pw+Po-2*Phat)
alpha2 = float(np.dot(Ehat, Dhat) / np.dot(Dhat, Dhat))

# refine Po by 1-D search
Po_grid = np.linspace(2.0, 6.0, 401)
alpha_po = []
rmse_po  = []
for Po_try in Po_grid:
    D_try = (Pw - Po_try) * (Pw + Po_try - 2.0*Phat)
    a_try = np.dot(E, D_try) / np.dot(D_try, D_try)
    rmse_po.append(np.sqrt(np.mean((E - a_try*D_try)**2)))
    alpha_po.append(a_try)
Po_recalibrated = Po_grid[int(np.argmin(rmse_po))]
alpha3 = alpha_po[int(np.argmin(rmse_po))]

print(f"alpha recalibration {alpha:3f} -> {alpha2:3f} -> {alpha3:3f}")
# %% 8. Compare with the Extended Validation Dataset

fig2 = plt.figure(figsize=(9, 6))

# Plot tuned set
plt.scatter(polarity_index, ref_E, color='Crimson', label="Tuned (8 solvents)")
for i, solvent in enumerate(solvents):
    plt.annotate(
        solvent, (polarity_index[i], ref_E[i]),
        fontsize=8, xytext=(5,5), textcoords='offset points'
    )

# Plot validation set
plt.scatter(validation_polarity_index, validation_E_values,
            color='DeepSkyBlue', label="Validation (~35 solvents)")
for i, solvent in enumerate(validation_solvents):
    plt.annotate(
        solvent, (validation_polarity_index[i], validation_E_values[i]),
        fontsize=6, xytext=(5,5), textcoords='offset points'
    )

# Plot fitted curve
E_range = np.linspace(-6, 23, 1000)
p_estimated = predict_polarity_index(E_range/np.log(10),0)
plt.plot(p_estimated, E_range, color='Crimson', linestyle='-.',
         label="Inverse Quadratic fit")
#plt.plot(x_range, y_fitted, color='Crimson', linestyle='--', label="Quadratic Fit")


# octanol (predicted) point + guide line
Eo = C1 - C2 * Po
plt.scatter([Po], [Eo], s=70, marker='o', edgecolor='k', facecolor='orange', label="Octanol (pred)")
# plt.axvline(Po, linestyle='--', linewidth=1.0, color='orange', alpha=0.7)

# annotation for octanol
plt.annotate(f"Octanol\nP'={Po:.2f}",
             xy=(Po, Eo), xycoords='data',
             xytext=(10, 10), textcoords='offset points',
             fontsize=9, color='orange',
             arrowprops=dict(arrowstyle="->", color='orange', lw=1.0))

plt.xlabel("Polarity Index (P')")
plt.ylabel("$\chi_{i,w}-\chi_{i,o}$")
plt.title("$\chi_{i,w}-\chi_{i,o} vs P$': Quadratic Model and Validation")
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.show()

fig2.print("images/polarityindexv2_fig2")

# --- fraction of validation points that underestimate model by >5% ---

# Build an interpolator E_hat(P') from the inverse curve you already computed
# (p_estimated vs E_range). Use a monotone-safe np.interp after sorting.
idx = np.argsort(p_estimated)
P_curve = p_estimated[idx]
E_curve = E_range[idx]

# Predicted model value at each validation P'
E_hat_val = np.interp(validation_polarity_index, P_curve, E_curve)

# Define "underestimate the model by more than 5%":
# measured E < (1 - 0.05) * model E
tol = 0.2
under_mask = np.array(validation_E_values) < (1.0 - tol) * E_hat_val

frac_under = under_mask.mean()
print(f"Underestimation > {tol*100:.1f}% (validation set): {frac_under*100:.1f}% "
      f"({under_mask.sum()}/{under_mask.size})")

# List the offending substances
offenders = [s for s, m in zip(validation_solvents, under_mask) if m]
print(f"Underestimating substances (> {tol*100:.1f}% below model):")
print(", ".join(offenders) if offenders else "None")


# %% 9. Comparison of predicted P' values vs logP
# comparison for the series of 8 solvents (training set)
# comparison for the series of 35 solvents (learning set)

# Apply the model to both training and validation sets
P_train_pred = predict_polarity_index(np.array(logP_values), np.array(vol_values))
P_valid_pred = predict_polarity_index(np.array(validation_logP_values), np.array(validation_vol_values))

# Construct comparison DataFrame
df_comparison = pd.DataFrame({
    "Solvent": solvents + validation_solvents,
    "Set": ["Tuned"] * len(solvents) + ["Validation"] * len(validation_solvents),
    "Measured P'": polarity_index + validation_polarity_index,
    "Predicted P'": np.concatenate([P_train_pred, P_valid_pred])
})
df_comparison["ΔP'"] = df_comparison["Predicted P'"] - df_comparison["Measured P'"]

# Display first rows of comparison
df_comparison.sort_values(by="Set").head(10)

# %% 10. Plot Predicted vs Measured Polarity Index

fig3 = plt.figure(figsize=(9, 6))

# Diagonal line
lims = [0, 11]
plt.plot(lims, lims, '--k', alpha=0.5, label='Perfect agreement')

# Scatter points
df_train = df_comparison[df_comparison["Set"] == "Tuned"]
df_valid = df_comparison[df_comparison["Set"] == "Validation"]

# Tuned set
plt.scatter(df_train["Measured P'"], df_train["Predicted P'"],
            color="Crimson", label="Tuned Set (8 solvents)")
for i, name in enumerate(solvents):
    plt.annotate(name, (polarity_index[i], P_train_pred[i]), fontsize=8, xytext=(4, 4), textcoords='offset points')

# Validation set
plt.scatter(validation_polarity_index, P_valid_pred, color="DeepSkyBlue", label="Validation (~35 solvents)")
for i, name in enumerate(validation_solvents):
    plt.annotate(name, (validation_polarity_index[i], P_valid_pred[i]), fontsize=6, xytext=(3, 3), textcoords='offset points')

# Labels and aesthetics
plt.xlabel("Measured Polarity Index (P')")
plt.ylabel("Predicted Polarity Index (P')")
plt.title("Prediction of Polarity Index from logP and Molar Volume")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.xlim(lims)
plt.ylim(lims)
plt.gca().set_aspect('equal', adjustable='box')
plt.show()

fig3.print("images/polarityindexv2_fig3")
# %% 11. Performance summary

from sklearn.metrics import mean_squared_error, r2_score

def print_model_stats(setname, measured, predicted):
    rmse = np.sqrt(mean_squared_error(measured, predicted))
    r2 = r2_score(measured, predicted)
    print(f"{setname} Set:")
    print(f"  RMSE  = {rmse:.3f}")
    print(f"  R²    = {r2:.3f}")
    print()

print_model_stats("Tuned", df_train["Measured P'"], df_train["Predicted P'"])
print_model_stats("Validation", df_valid["Measured P'"], df_valid["Predicted P'"])