Source code for pmrf._features

from typing import Sequence
import re

import skrf
import jax.numpy as jnp

from pmrf.constants import FeatureT, FeatureInputT
from pmrf.models.model import Model
from pmrf.frequency import Frequency
from pmrf.network_collection import NetworkCollection

[docs] def extract_features( source: Model | skrf.Network | NetworkCollection, frequency: Frequency | None, features: FeatureInputT, sparam_kind: str = 'all', dtype: jnp.dtype = jnp.complex128, ) -> jnp.ndarray: """ Extract frequency-dependent features from a model or network into a unified matrix. Features are combined column-by-column, resulting in a matrix where rows correspond to frequency points and columns to the requested features. This supports extracting scalars (e.g., 's11_mag'), complex parameters, or custom functions from both simulation models and measured networks. Parameters ---------- source : Model, skrf.Network, or NetworkCollection The source object. If a list of networks is provided, it is treated as a single stacked network with isolated ports. Missing networks in a collection should be specified using integers representing the port count at that index. frequency : pmrf.Frequency or skrf.Frequency, optional The frequency points for extraction. * **Required** if `source` is a Model. * **Optional** if `source` is a Network; defaults to the Network's native frequency points. If provided, the Network is interpolated to these points. features : str, list, dict, or tuple The features to extract. See Notes for formatting syntax. sparam_kind : {'transmission', 'reflection', 'all'} or None, optional Filters port expansion when generic features are requested (e.g., 's_re' without specific indices). dtype : data-type, optional The desired data-type for the output feature matrix. Returns ------- feature_matrix : ndarray The extracted feature matrix of shape (M, N), where M is the number of frequency points and N is the total number of features. Notes ----- **Feature Syntax** Features can be specified using string aliases, explicit tuples, or dictionaries for sub-components: * **String Aliases:** Convenient shorthands for standard parameters. * *Example:* ``'s11_db'``, ``'a21_deg'``, ``'s11_mag'``. * *Custom:* Any model attribute dependent on frequency can be accessed by name (e.g., ``'my_custom_gain'``). * **Explicit Tuples:** Defined as ``(path, parameter, index)``. * *Example:* ``('', 's_db', (0, 0))`` is equivalent to ``'s11_db'``. * The empty string ``''`` denotes the base model. * **Dictionaries:** Used to target submodels (via ``getattr``) or specific networks. * *Example:* ``{'lna': 's21_db'}`` extracts S21 from the 'lna' submodel. * *Nested:* ``{'front_end.lna': 's11'}`` accesses deep submodels. * *Measurement:* Maps to the label of a specific network in a collection. * **Lists:** A list containing any combination of the above will be flattened into columns. """ # We format the features to be flat (and parse them in the process) features = _format_features(features, source=source, sparam_kind=sparam_kind) # Get the frequency and format the sources if isinstance(source, skrf.Network): source_cpy = source.copy() source_cpy.name = '' source = NetworkCollection([source_cpy]) frequency = source.common_frequency() elif isinstance(source, NetworkCollection): # Currently only support a single frequency across networks frequency = source.common_frequency() elif isinstance(source, Model): if frequency is None: raise Exception("Frequency must be passed when extracting features from a model") if isinstance(frequency, skrf.Frequency): frequency = Frequency.from_skrf(frequency) else: raise TypeError("Invalid type to extract_features") # Return the extracted features if isinstance(source, Model): return _extract_model_features(source, features, frequency, dtype=dtype) else: return _extract_measured_features(source, features, frequency, dtype=dtype)
def _format_features(features: FeatureInputT, *, base_label='', source: Model | skrf.Network | NetworkCollection | None = None, sparam_kind: str = 'all') -> list[FeatureT]: # For the dict case, we just recursively call format features for each attribute and return early if isinstance(features, dict): features_out = [] for attr, attr_features in features.items(): if isinstance(source, NetworkCollection): src_obj = source[attr] else: src_obj = getattr(source, attr) features_out.extend(_format_features(attr_features, base_label=attr, source=src_obj, sparam_kind=sparam_kind)) return features_out # For the general case we get the features into a sequence of aliases or feature tuples and then expand if isinstance(features, str): raw_features = [features] elif not isinstance(features, Sequence): raw_features = [features] else: raw_features = features features_out = [] for raw_feature in raw_features: # Allowed options for each raw feature are: # 1) 'alias' # 2) ('label', 'alias') # 3) ('feature', ports) # 4) ('label', 'feature', ports) # Option 1: if isinstance(raw_feature, str): raw_feature_split = raw_feature.split('.') if len(raw_feature_split) > 1: label = ''.join(raw_feature_split[0:-1]) raw_feature = raw_feature_split[-1] else: label = base_label feature, ports = _parse_feature_alias(raw_feature) # Option 2 elif len(raw_feature) == 2 and isinstance(raw_feature[1], str): label = raw_feature[0] feature, ports = _parse_feature_alias(raw_feature[1]) # Option 3 elif len(raw_feature) == 2: label = base_label feature, ports = raw_feature # Option 4 else: label, feature, ports = raw_feature features_out.append((label, feature, ports)) if source is not None: no_ports_passed = all([feature_out[2] == (-1, -1) for feature_out in features_out]) all_labels_equal = all([feature_out[0] == features_out[0][0] for feature_out in features_out]) # TODO fix defining_class (not working for e.g. s_mag currently) # base_features_only = all([defining_class(source, feature_out[1]) is Model for feature_out in features_out]) if no_ports_passed and all_labels_equal: if sparam_kind == 'all': port_tuples = source.port_tuples elif sparam_kind == 'reflection': port_tuples = [pt for pt in source.port_tuples if pt[0] == pt[1]] elif sparam_kind == 'transmission': port_tuples = [pt for pt in source.port_tuples if pt[0] != pt[1]] else: raise Exception('Unknown S-parameter type for port expansion') label = features_out[0][0] features_out = [(label, feature_out[1], ports) for ports in port_tuples for feature_out in features_out] return features_out def _parse_feature_alias(alias: str) -> tuple[FeatureT, tuple[int, int]]: """ Convert a feature alias like 's11_mag' to a feature tuple like ('s_mag', (0, 0)). Supports: - Feature names with or without two-digit port numbers. - Optional suffixes (e.g., '_mag', '_db') - Returns (-1, -1) for features without ports. """ match = re.match(r'^([a-zA-Z]+)(\d)?(\d)?(.*)$', alias) if not match: raise ValueError(f"Invalid feature alias format: '{alias}'") prefix = match.group(1) port1 = match.group(2) port2 = match.group(3) suffix = match.group(4) if port1 is not None and port2 is not None: ports = (int(port1) - 1, int(port2) - 1) else: ports = (-1, -1) return (prefix + suffix, ports) def _extract_model_features(model: Model, features: list[FeatureT], freq: Frequency, dtype: jnp.dtype) -> jnp.ndarray: n_frequencies = len(freq) n_features = len(features) X = jnp.zeros((n_frequencies, n_features), dtype=dtype) for d, feature in enumerate(features): label, prop, (m, n) = feature[0], feature[1], feature[2] if m >= model.nports or n >= model.nports: raise Exception(f'Property {prop}{m+1}{n+1} requested but network is a {model.nports}-port') feature_model = model if label != '': sublabels = label.split('.') for sublabel in sublabels: feature_model = getattr(feature_model, sublabel) if prop[2:4] == 'mn': xfn = getattr(feature_model, prop) x = xfn(freq,m,n) elif m != -1 and n != -1: xfn = getattr(feature_model, prop) x = xfn(freq)[:,m,n] else: xfn = getattr(feature_model, prop) x = xfn(freq) X = X.at[:, d].set(x) return X def _extract_measured_features(networks: NetworkCollection, features: list[FeatureT], freq: Frequency | skrf.Frequency, dtype: jnp.dtype) -> jnp.ndarray: n_frequencies = len(freq) n_features = len(features) if isinstance(freq, Frequency): freq = freq.to_skrf() X = jnp.zeros((n_frequencies, n_features), dtype=dtype) for d, feature in enumerate(features): label, prop, (m, n) = feature[0], feature[1], feature[2] x = None ntwk = networks[label] if ntwk.frequency != freq: ntwk = ntwk.interpolate(freq) if m >= ntwk.nports or n >= ntwk.nports: raise Exception(f'Property {prop} for ports {m+1}{n+1} requested but network is a {ntwk.nports}-port') if prop[2:4] == 'mn': i = prop.index('_') prop_new = prop[0:i] if len(prop) > i + 3: prop_new += prop[i+3:] prop = prop_new x = getattr(ntwk, prop)[:,m,n] X = X.at[:, d].set(x) return X