Source code for nanite.model.core

from collections import OrderedDict
import inspect
import warnings

import numpy as np

from . import residuals


[docs] class ModelError(BaseException): pass
[docs] class ModelIncompleteError(ModelError): pass
[docs] class ModelImplementationError(ModelError): pass
[docs] class ModelImportError(ModelError): pass
[docs] class ModelImplementationWarning(UserWarning): pass
[docs] class NaniteFitModel: def __init__(self, model_module): """Initialize the model with an imported Python module""" self.module = model_module self._module_check() self._module_autocomplete() # propagate all module attributes to this instance # standard parameters self.get_parameter_defaults = self.module.get_parameter_defaults self.model_doc = self.module.model_doc self.model_key = self.module.model_key self.model_name = self.module.model_name self.parameter_keys = self.module.parameter_keys self.parameter_names = self.module.parameter_names self.parameter_units = self.module.parameter_units self.valid_axes_x = self.module.valid_axes_x self.valid_axes_y = self.module.valid_axes_y # optional ancillary parameters if hasattr(self.module, "compute_ancillaries"): self.has_module_ancillaries = True self.parameter_anc_keys = self.module.parameter_anc_keys self.parameter_anc_names = self.module.parameter_anc_names self.parameter_anc_units = self.module.parameter_anc_units else: self.has_module_ancillaries = False # model function self.model = self.module.model # residuals self.residual = self.module.residual def __str__(self): return f"NaniteFitModel '{self.model_key}'" def __repr__(self): return f"<NaniteFitModel '{self.model_key}' at {hex(id(self))}" def _module_autocomplete(self): """Add any missing attributes to the underlying model module""" # check for residuals function if not hasattr(self.module, "residual"): # use the default residual function self.module.residual = residuals.get_default_residuals_wrapper( model_function=self.module.model_func) # check for modeling function if not hasattr(self.module, "model"): # use the default residual function self.module.model = residuals.get_default_modeling_wrapper( model_function=self.module.model_func) def _module_check(self): """Checks whether the model's module is set up correctly""" # sanity checks missing = [] for attr in [ "get_parameter_defaults", "model_doc", "model_key", "model_name", "parameter_keys", "parameter_names", "parameter_units", "valid_axes_x", "valid_axes_y", ]: if not hasattr(self.module, attr): missing.append(attr) if missing: raise ModelIncompleteError( f"Model `{self.module}` is missing the following " + f"attributes: {', '.join(missing)}") model_key = self.module.model_key # check for completeness of ancillary parameter recipe if hasattr(self.module, "compute_ancillaries"): missing_anc = [] for attr in ["parameter_anc_keys", "parameter_anc_names", "parameter_anc_units", ]: if not hasattr(self.module, attr): missing_anc.append(attr) if missing_anc: raise ModelIncompleteError( f"Model `{model_key}` is missing the following " + f"attributes: {', '.join(missing_anc)}") # check length of modeling lists if len(self.module.parameter_keys) != len(self.module.parameter_names): raise ModelImplementationError( "'parameter_keys' and 'parameter_names' have different " + f"lengths for model '{model_key}'!") if len(self.module.parameter_keys) != len(self.module.parameter_units): raise ModelImplementationError( "'parameter_keys' and 'parameter_units' have different " + f"lengths for model '{model_key}'!") # check for spaces in units if [u.strip() for u in self.module.parameter_units] \ != self.module.parameter_units: warnings.warn("The `parameter_units` should not contain leading " + f"or trailing spaces. Please check {model_key}!", ModelImplementationWarning) if hasattr(self.module, "parameter_anc_units"): if [u.strip() for u in self.module.parameter_anc_units] \ != self.module.parameter_anc_units: warnings.warn( "The `parameter_anc_units` should not contain leading " + f"or trailing spaces. Please check {model_key}!", ModelImplementationWarning) # check for label uniqueness if len(self.module.parameter_names) \ != len(set(self.module.parameter_names)): raise ModelImplementationError( f"'parameter_names' should be unique for '{model_key}'!") # checks for model parameters p_def = list(self.module.get_parameter_defaults().keys()) p_arg = list(inspect.signature( self.module.model_func).parameters.keys()) for ii, key in enumerate(self.module.parameter_keys): if key != p_def[ii]: raise ModelImplementationError( "Please check 'parameter_keys' and " + f"'get_parameter_defaults' of the model '{model_key}'. " + f"Keys {key} and {p_def[ii]} are not in order!") if key != p_arg[ii+1]: warnings.warn( "Please make sure that the parameters of the model " + "function are in the same order as in 'parameter_keys' " + f"for the model '{model_key}'! " + "The abscissa (usually `delta`) should come first. " + "This warning may become an Exception in the future!", ModelImplementationWarning)
[docs] def compute_ancillaries(self, fd): """Compute ancillary parameters for a force-distance dataset Ancillary parameters include parameters that: - are unrelated to fitting: They may just be important parameters to the user. - require the entire dataset: They cannot be extracted during fitting, because they require more than just the approach xor retract curve to compute (e.g. hysteresis, jump of retract curve at maximum indentation). They may, additionally, depend on initial fit parameters set by the user. - require a fit: They are dependent on fitting parameters but are not required during fitting. Notes ----- If an ancillary parameter name matches that of a fitting parameter, then it is assumed that it can be used for fitting. Please see :func:`nanite.indent.Indentation.get_initial_fit_parameters` and :func:`nanite.fit.guess_initial_parameters`. Ancillary parameters are set to `np.nan` if they cannot be computed. Parameters ---------- fd: nanite.indent.Indentation The force-distance data for which to compute the ancillary parameters Returns ------- ancillaries: collections.OrderedDict key-value dictionary of ancillary parameters """ # TODO: # - ancillaries are not cached yet (some ancillaries might depend on # fitting interval or other initial parameters - take that into # account) # - "max_indent" actually belongs to "common_ancillaries" (see fit.py) anc_ord = OrderedDict() # general for key in ANCILLARY_COMMON: gmeth = ANCILLARY_COMMON[key][2] anc_ord[key] = gmeth(fd) # from module if self.has_module_ancillaries: anc_md = self.module.compute_ancillaries(fd) for kk in self.parameter_anc_keys: anc_ord[kk] = anc_md[kk] return anc_ord
[docs] def get_anc_parm_keys(self): """Return the key names of a model's ancillary parameters""" akeys = list(ANCILLARY_COMMON.keys()) if self.has_module_ancillaries: akeys += self.parameter_anc_keys return akeys
[docs] def get_parm_name(self, key): """Return parameter label Parameters ---------- key: str The parameter key (e.g. "E") Returns ------- parm_name: str The parameter label (e.g. "Young's Modulus") """ if key in self.parameter_keys: idx = self.parameter_keys.index(key) return self.parameter_names[idx] elif (self.has_module_ancillaries and key in self.parameter_anc_keys): idx = self.parameter_anc_keys.index(key) return self.parameter_anc_names[idx] elif key in ANCILLARY_COMMON: return ANCILLARY_COMMON[key][0] else: raise KeyError( f"Could not find parameter name for '{key}' in '{self}'!")
[docs] def get_parm_unit(self, key): """Return parameter unit Parameters ---------- key: str The parameter key (e.g. "E") Returns ------- parm_unit: str The parameter unit (e.g. "Pa") """ if key in self.parameter_keys: idx = self.parameter_keys.index(key) return self.parameter_units[idx] elif (self.has_module_ancillaries and key in self.parameter_anc_keys): idx = self.parameter_anc_keys.index(key) return self.parameter_anc_units[idx] elif key in ANCILLARY_COMMON: return ANCILLARY_COMMON[key][1] else: raise KeyError( f"Could not find parameter unit for '{key}' in '{self}'!")
[docs] def compute_anc_max_indent(fd): """Compute ancillary parameter 'Maximum indentation'""" # compute maximal indentation if ("tip position" in fd and "fit" in fd and "params_fitted" in fd.fit_properties and "contact_point" in fd.fit_properties["params_fitted"]): cp = fd.fit_properties["params_fitted"]["contact_point"].value idmax = fd.appr["fit"].argmax() mi = fd.appr["tip position"][idmax] mival = (cp-mi) else: mival = np.nan return mival
#: Common ancillary parameters ANCILLARY_COMMON = OrderedDict() ANCILLARY_COMMON["max_indent"] = ("Maximum indentation", "m", compute_anc_max_indent)