"""
IO helpers e.g. for model loading and saving.
"""
import os
import jsonpickle
from typing import BinaryIO
from pmrf.models.model import Model
[docs]
def load(source: str | BinaryIO) -> Model:
"""
Load a ParamRF model from a file or file-like object.
Parameters
----------
source : str or BinaryIO
The path to the saved model file or an open file-like object
containing the model data.
Returns
-------
Model
The deserialized ParamRF model instance.
Raises
------
TypeError
If the root model or any nested submodels fail to load and
silently degrade into dictionaries (usually due to moved classes).
"""
if isinstance(source, (str, os.PathLike)):
with open(source, "r", encoding="utf8") as f:
data = f.read()
else:
data = source.read()
decoded = jsonpickle.decode(data)
# 1. Check if the root object degraded
if not isinstance(decoded, Model):
raise TypeError(
f"Failed to load model. Expected a ParamRF Model, but got '{type(decoded).__name__}'. "
"This almost always happens because the model's class was moved to a different module, "
"renamed, or deleted. jsonpickle cannot find the import path and silently degraded it to a dict."
)
# 2. Recursively check for nested degraded models
def _verify_no_degraded_models(obj, current_path="root"):
if isinstance(obj, dict):
# ParamRF models have tell-tale internal fields. If a dict has these, it's a dead model.
if '_param_groups' in obj or '_separator' in obj or 'z0' in obj:
raise TypeError(
f"Degraded submodel found at path '{current_path}'. "
"A nested model failed to instantiate and became a dictionary. "
"Did you move or rename a component class in your codebase?"
)
for k, v in obj.items():
_verify_no_degraded_models(v, f"{current_path}.{k}")
elif isinstance(obj, (list, tuple)):
for i, v in enumerate(obj):
_verify_no_degraded_models(v, f"{current_path}[{i}]")
elif hasattr(obj, '__dataclass_fields__'):
# Safely traverse Equinox modules and dataclasses
for f in obj.__dataclass_fields__:
_verify_no_degraded_models(getattr(obj, f), f"{current_path}.{f}")
# _verify_no_degraded_models(decoded)
return decoded
[docs]
def save(target: str | BinaryIO, model: Model):
"""
Save a ParamRF model to a file or file-like object.
...
"""
model_save = model._saveable()
data = jsonpickle.encode(model_save)
if isinstance(target, (str, os.PathLike)):
with open(target, "w", encoding="utf8") as f:
f.write(data)
else:
target.write(data)