import inspect
import multiprocessing as mp
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import List, Sequence, Tuple
import numpy as np
from scipy import optimize as sop
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from dosma import defaults
from dosma.data_io.med_volume import MedicalVolume
from dosma.defaults import preferences
__all__ = ["MonoExponentialFit", "curve_fit", "monoexponential", "biexponential"]
class _Fit(ABC):
"""Abstract class for fitting quantitative values.
"""
@abstractmethod
def fit(self) -> Tuple[MedicalVolume, MedicalVolume]:
"""Fit quantitative values per pixel across multiple volumes.
Pixels with errors in fitting are set to np.nan.
Returns:
tuple[MedicalVolume, MedicalVolume]: Quantitative value volume and
r-squared goodness of fit volume.
"""
pass
[docs]class MonoExponentialFit(_Fit):
"""Fit quantitative values using mono-exponential fit of model :math:`a*exp(t/tc)`.
Args:
ts (:obj:`array-like`): 1D array of times in milliseconds (typically echo times)
corresponding to different volumes.
subvolumes (list[MedicalVolumes]): Volumes (in order) corresponding to times in `ts`.
mask (:obj:`MedicalVolume`, optional): Mask of pixels to fit.
If specified, pixels outside of mask region are ignored and set to ``np.nan``.
Speeds fitting time as fewer fits are required.
bounds (:obj:`tuple[float, float]`, optional): Upper and lower bound for quantitative
values. Values outside those bounds will be set to ``np.nan``.
tc0 (:obj:`float`, optional): Initial time constant guess (in milliseconds).
decimal_precision (:obj:`int`, optional): Rounding precision after the decimal point.
"""
def __init__(
self,
ts: Sequence[float],
subvolumes: List[MedicalVolume],
mask: MedicalVolume = None,
bounds: Tuple[float] = (0, 100.0),
tc0: float = 30.0,
decimal_precision: int = 1,
verbose: bool = False,
num_workers: int = 0,
):
if (not isinstance(subvolumes, list)) or (
not all([isinstance(sv, MedicalVolume) for sv in subvolumes])
):
raise TypeError("`subvolumes` must be list of MedicalVolumes.")
if len(ts) != len(subvolumes):
raise ValueError(
"`len(ts)`={:d}, but `len(subvolumes)`={:d}".format(len(ts), len(subvolumes))
)
self.ts = ts
orientation = subvolumes[0].orientation
subvolumes = [sv.reformat(orientation) for sv in subvolumes]
self.subvolumes = subvolumes
if mask and not isinstance(mask, MedicalVolume):
raise TypeError("`mask` must be a MedicalVolume")
self.mask = mask.reformat(orientation)
self.verbose = verbose
self.num_workers = num_workers
if len(bounds) != 2:
raise ValueError("`bounds` should provide lower/upper bound in format (lb, ub)")
self.bounds = bounds
self.tc0 = tc0
self.decimal_precision = decimal_precision
[docs] def fit(self):
svs = []
msk = None
subvolumes = self.subvolumes
for sv in subvolumes[1:]:
assert subvolumes[0].is_same_dimensions(sv), "Dimension mismatch within subvolumes"
if self.mask:
assert subvolumes[0].is_same_dimensions(
self.mask, defaults.AFFINE_DECIMAL_PRECISION
), "Mask dimension mismatch"
msk = self.mask.volume
msk = msk.reshape(1, -1)
original_shape = subvolumes[0].volume.shape
affine = np.array(self.subvolumes[0].affine)
for i in range(len(self.ts)):
sv = subvolumes[i].volume
svr = sv.reshape((1, -1))
if msk is not None:
svr = svr * msk
svs.append(svr)
svs = np.concatenate(svs)
p0 = (1.0, -1 / self.tc0)
popt, r_squared = curve_fit(
monoexponential,
self.ts,
svs,
self.bounds,
p0=p0,
show_pbar=self.verbose,
num_workers=self.num_workers,
)
vals = 1 / np.abs(popt[:, 1])
map_unfiltered = vals.reshape(original_shape)
r_squared = r_squared.reshape(original_shape)
# All accepted values must meet an r-squared threshold of `DEFAULT_R2_THRESHOLD`.
tc_map = map_unfiltered * (r_squared >= preferences.fitting_r2_threshold)
# Filter calculated values that are below limit bounds.
tc_map[tc_map < self.bounds[0]] = np.nan
tc_map = np.nan_to_num(tc_map)
tc_map[tc_map > self.bounds[1]] = np.nan
tc_map = np.nan_to_num(tc_map)
tc_map = np.around(tc_map, self.decimal_precision)
time_constant_volume = MedicalVolume(tc_map, affine=affine)
rsquared_volume = MedicalVolume(r_squared, affine=affine)
return time_constant_volume, rsquared_volume
__EPSILON__ = 1e-8
[docs]def curve_fit(
func,
x,
y,
y_bounds=None,
p0=None,
maxfev=100,
ftol=1e-5,
eps=1e-8,
show_pbar=False,
num_workers=0,
**kwargs,
):
"""Use non-linear least squares to fit a function ``func`` to data.
Uses :func:`scipy.optimize.curve_fit` backbone.
Args:
func (callable): The model function, f(x, ...). It must take the independent variable
as the first argument and the parameters to fit as separate remaining arguments.
x (ndarray): The independent variable(s) where the data is measured.
Should usually be an M-length sequence or an (k,M)-shaped array for functions
with k predictors, but can actually be any object.
y (ndarray): The dependent data, a length M array - nominally func(xdata, ...) - or
an (M,N)-shaped array for N different sequences.
y_bounds (tuple, optional): Lower and upper bound on y values. Defaults to no bounds.
Sequences with observations out of this range will not be processed.
p0 (Sequence, optional): Initial guess for the parameters (length N). If None, then
the initial values will all be 1 (if the number of parameters for the
function can be determined using introspection, otherwise a ValueError is raised).
maxfev (int, optional): Maximum number of function evaluations before the termination.
If `bounds` argument for `scipy.optimize.curve_fit` is specified, this corresponds
to the `max_nfev` in the least squares algorithm
ftol (float): Tolerance for termination by the change of the cost function.
See `scipy.optimize.least_squares` for more details.
eps (float, optional): Epsilon for computing r-squared.
show_pbar (bool, optional): If `True`, show progress bar. Note this can increase runtime
slightly when using multiple workers.
kwargs: Keyword args for `scipy.optimize.curve_fit`.
"""
x = np.asarray(x)
y = np.asarray(y)
if y.ndim == 1:
y = y.view(y.shape + (1,))
N = y.shape[-1]
func_args = inspect.getargspec(func).args
nparams = len(func_args) - 2 if "self" in func_args else len(func_args) - 1
if "bounds" not in kwargs:
kwargs["maxfev"] = maxfev
elif "max_nfev" not in kwargs:
kwargs["max_nfev"] = maxfev
num_workers = min(num_workers, N)
fitter = partial(
_curve_fit,
x=x,
func=func,
y_bounds=y_bounds,
p0=p0,
ftol=ftol,
eps=eps,
show_pbar=show_pbar,
nparams=nparams,
**kwargs,
)
oob = y_bounds is not None and ((y < y_bounds[0]).any() or (y > y_bounds[1]).any())
if oob:
warnings.warn("Out of bounds values found. Failure in fit will result in np.nan")
popts = []
r_squared = []
if not num_workers:
for i in tqdm(range(N), disable=not show_pbar):
popt_, r2_ = fitter(y[:, i])
popts.append(popt_)
r_squared.append(r2_)
else:
if show_pbar:
data = process_map(fitter, y.T, max_workers=num_workers, tqdm_class=tqdm)
else:
with mp.Pool(num_workers) as p:
data = p.map(fitter, y.T)
popts, r_squared = [x[0] for x in data], [x[1] for x in data]
return np.stack(popts, axis=0), np.asarray(r_squared)
def _curve_fit(
y,
x,
func,
y_bounds=None,
p0=None,
maxfev=100,
ftol=1e-5,
eps=1e-8,
show_pbar=False,
nparams=None,
**kwargs,
):
def _fit_internal(_x, _y):
popt, _ = sop.curve_fit(func, _x, _y, p0=p0, maxfev=maxfev, ftol=ftol, **kwargs)
residuals = _y - func(_x, *popt)
ss_res = np.sum(residuals ** 2)
ss_tot = np.sum((_y - np.mean(_y)) ** 2)
r_squared = 1 - (ss_res / (ss_tot + eps))
return popt, r_squared
if nparams is None:
func_args = inspect.getargspec(func).args
nparams = len(func_args) - 2 if "self" in func_args else len(func_args) - 1
# import pdb; pdb.set_trace()
oob = y_bounds is not None and ((y < y_bounds[0]).any() or (y > y_bounds[1]).any())
if oob or (y == 0).all():
return (np.nan,) * nparams, 0
try:
popt_, r2_ = _fit_internal(x, y)
except RuntimeError:
popt_, r2_ = (np.nan,) * nparams, 0
return popt_, r2_
[docs]def monoexponential(x, a, b):
"""Function: :math:`f(x) = a * e^{b*x}`."""
return a * np.exp(b * x)
[docs]def biexponential(x, a1, b1, a2, b2):
"""Function: :math:`f(x) = a1*e^{b1*x} + a2*e^{b2*x}`."""
return a1 * np.exp(b1 * x) + a2 * np.exp(b2 * x)
def __fit_mono_exp__(x, y, p0=None):
def func(t, a, b):
exp = np.exp(b * t)
return a * exp
warnings.warn(
"__fit_mono_exp__ is deprecated since v0.12 and will no longer be "
"supported in v0.13. Use `curve_fit` instead.",
DeprecationWarning,
)
x = np.asarray(x)
y = np.asarray(y)
popt, _ = sop.curve_fit(func, x, y, p0=p0, maxfev=100, ftol=1e-5)
residuals = y - func(x, popt[0], popt[1])
ss_res = np.sum(residuals ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r_squared = 1 - (ss_res / (ss_tot + __EPSILON__))
return popt, r_squared
def __fit_monoexp_tc__(x, ys, tc0, show_pbar=False):
warnings.warn(
"__fit_monoexp_tc__ is deprecated since v0.12 and will no longer be "
"supported in v0.13. Use `curve_fit` instead.",
DeprecationWarning,
)
p0 = (1.0, -1 / tc0)
time_constants = np.zeros([1, ys.shape[-1]])
r_squared = np.zeros([1, ys.shape[-1]])
warned_negative = False
for i in tqdm(range(ys.shape[1]), disable=not show_pbar):
y = ys[..., i]
if (y < 0).any() and not warned_negative:
warned_negative = True
warnings.warn(
"Negative values found. Failure in monoexponential fit will result in np.nan"
)
# Skip any negative values or all values that are 0s
if (y < 0).any() or (y == 0).all():
continue
try:
params, r2 = __fit_mono_exp__(x, y, p0=p0)
tc = 1 / abs(params[-1])
except RuntimeError:
tc, r2 = (np.nan, 0.0)
time_constants[..., i] = tc
r_squared[..., i] = r2
return time_constants, r_squared