pmrf.fitting.FrequentistContext

class pmrf.fitting.FrequentistContext(model, measured, frequency, features, measured_features, output_path=None, output_root=None, sparam_kind=None, logger=None, cost_function=None)[source]

Bases: FitContext

Context object for Frequentist fitting, containing the cost function.

Parameters:
  • model (Model)

  • measured (Network | NetworkCollection)

  • frequency (Frequency)

  • features (list[tuple[str, str, tuple[int, int]]])

  • measured_features (ndarray)

  • output_path (str | None)

  • output_root (str | None)

  • sparam_kind (str | None)

  • logger (Logger | None)

  • cost_function (Module | None)

cost_function

The sequence of functions defining the error metric.

Type:

eqx.Module or None

__init__(model, measured, frequency, features, measured_features, output_path=None, output_root=None, sparam_kind=None, logger=None, cost_function=None)
Parameters:
  • model (Model)

  • measured (Network | NetworkCollection)

  • frequency (Frequency)

  • features (list[tuple[str, str, tuple[int, int]]])

  • measured_features (ndarray)

  • output_path (str | None)

  • output_root (str | None)

  • sparam_kind (str | None)

  • logger (Logger | None)

  • cost_function (Module | None)

Return type:

None

Methods

__delattr__(name, /)

Implement delattr(self, name).

__dir__()

Default dir() implementation.

__eq__(other)

Return self==value.

__format__(format_spec, /)

Default object formatter.

__ge__(value, /)

Return self>=value.

__getattribute__(name, /)

Return getattr(self, name).

__gt__(value, /)

Return self>value.

__init__(model, measured, frequency, ...[, ...])

__init_subclass__

This method is called when a class is subclassed.

__le__(value, /)

Return self<=value.

__lt__(value, /)

Return self<value.

__ne__(value, /)

Return self!=value.

__new__(**kwargs)

__reduce__()

Helper for pickle.

__reduce_ex__(protocol, /)

Helper for pickle.

__repr__()

Return repr(self).

__setattr__(name, value, /)

Implement setattr(self, name, value).

__sizeof__()

Size of object in memory, in bytes.

__str__()

Return str(self).

__subclasshook__

Abstract classes can override this to customize issubclass().

bounds()

Retrieve the lower and upper bounds for all model parameters.

make_cost_function([as_numpy])

Create the cost function to be minimized.

make_feature_function([as_numpy])

Create a JIT-compiled function to extract features from model parameters.

model_param_names()

Get the names of the flat parameters of the model.

settings([solver_kwargs, fitter_kwargs])

Create a FitSettings object from the current context.

Attributes

__annotations__

__dataclass_fields__

__dataclass_params__

__dict__

__doc__

__hash__

__match_args__

__module__

__weakref__

list of weak references to the object (if defined)

cost_function

logger

output_path

output_root

sparam_kind

model

measured

frequency

features

measured_features

cost_function: Module | None = None
make_cost_function(as_numpy=False)[source]

Create the cost function to be minimized.

The cost function calculates the error between measured and model features, applies the defined cost_function transformation, and returns a scalar value.

Parameters:

as_numpy (bool, optional, default=False) – If True, returns a function compatible with NumPy arrays; otherwise JAX arrays.

Returns:

The JIT-compiled cost function taking flat parameters and returning a scalar cost.

Return type:

callable

bounds()[source]

Retrieve the lower and upper bounds for all model parameters.

Returns:

A tuple containing (lower_bounds, upper_bounds).

Return type:

tuple of jnp.ndarray

Raises:

Exception – If any parameter is not associated with a parameter group.

__init__(model, measured, frequency, features, measured_features, output_path=None, output_root=None, sparam_kind=None, logger=None, cost_function=None)
Parameters:
  • model (Model)

  • measured (Network | NetworkCollection)

  • frequency (Frequency)

  • features (list[tuple[str, str, tuple[int, int]]])

  • measured_features (ndarray)

  • output_path (str | None)

  • output_root (str | None)

  • sparam_kind (str | None)

  • logger (Logger | None)

  • cost_function (Module | None)

Return type:

None