Source code for oct2py.io

"""io handling."""
# Copyright (c) oct2py developers.
# Distributed under the terms of the MIT License.


import dis
import inspect
import threading

import numpy as np

try:
    from scipy.io import loadmat, savemat  # type:ignore[import-untyped]
    from scipy.io.matlab import MatlabFunction, MatlabObject  # type:ignore[import-untyped]
    from scipy.sparse import spmatrix  # type:ignore[import-untyped]
except ImportError:
    try:  # noqa
        from scipy.io.matlab.mio5 import MatlabFunction, MatlabObject  # type:ignore[import-untyped]
    except ImportError:
        pass


try:
    from pandas import DataFrame, Series
except Exception:

    class Series:  # type:ignore[no-redef]
        """placeholder."""

        pass

    class DataFrame:  # type:ignore[no-redef]
        """placeholder."""

        pass


from .dynamic import OctaveFunctionPtr, OctaveUserClass, OctaveVariablePtr
from .utils import Oct2PyError

_WRITE_LOCK = threading.Lock()


def read_file(path, session=None):
    """Read the data from the given file path."""
    try:
        data = loadmat(path, struct_as_record=True)
    except UnicodeDecodeError as e:
        raise Oct2PyError(str(e)) from None
    out = {}
    for key, value in data.items():
        out[key] = _extract(value, session)
    return out


def write_file(obj, path, oned_as="row", convert_to_float=True):
    """Save a Python object to an Octave file on the given path."""
    data = _encode(obj, convert_to_float)
    try:
        # scipy.io.savemat is not thread-save.
        # See https://github.com/scipy/scipy/issues/7260
        with _WRITE_LOCK:
            savemat(path, data, appendmat=False, oned_as=oned_as, long_field_names=True)
    except KeyError:  # pragma: no cover
        msg = "could not save mat file"
        raise Exception(msg) from None


[docs] class Struct(dict): # type:ignore[type-arg] """ Octave style struct, enhanced. Notes ===== Supports dictionary and attribute style access. Can be pickled, and supports code completion in a REPL. Examples ======== >>> from pprint import pprint >>> from oct2py import Struct >>> a = Struct() >>> a.b = 'spam' # a["b"] == 'spam' >>> a.c["d"] = 'eggs' # a.c.d == 'eggs' >>> pprint(a) {'b': 'spam', 'c': {'d': 'eggs'}} """ def __getattr__(self, attr): """Access the dictionary keys for unknown attributes.""" try: return self[attr] except KeyError: msg = "'Struct' object has no attribute %s" % attr raise AttributeError(msg) from None def __getitem__(self, attr): """Get a dict value; create a Struct if requesting a Struct member.""" # Do not create a key if the attribute starts with an underscore. if attr in self or attr.startswith("_"): return dict.__getitem__(self, attr) frame = inspect.currentframe() if frame is None or frame.f_back is None: return None # step into the function that called us if frame.f_back.f_back and self._is_allowed(frame.f_back.f_back): # noqa dict.__setitem__(self, attr, Struct()) elif self._is_allowed(frame.f_back): dict.__setitem__(self, attr, Struct()) return dict.__getitem__(self, attr) def _is_allowed(self, frame): # Check for allowed op code in the calling frame. allowed = [dis.opmap["STORE_ATTR"], dis.opmap["LOAD_CONST"], dis.opmap.get("STOP_CODE", 0)] bytecode = frame.f_code.co_code instruction = bytecode[frame.f_lasti + 3] return instruction in allowed __setattr__ = dict.__setitem__ # type:ignore[assignment] __delattr__ = dict.__delitem__ # type:ignore[assignment] @property def __dict__(self): # Allow for code completion in a REPL. return self.copy()
[docs] class StructArray(np.recarray): # type:ignore[type-arg] """A Python representation of an Octave structure array. Notes ===== Accessing a record returns a Cell containing the values. This class is not meant to be directly created by the user. It is created automatically for structure array values received from Octave. The last axis is squeezed if it is of size 1 to simplify element access. Examples ======== >>> from oct2py import octave >>> # generate the struct array >>> octave.eval('x = struct("y", {1, 2}, "z", {3, 4});') >>> x = octave.pull('x') >>> x.y # attribute access -> oct2py Cell Cell([[1.0, 2.0]]) >>> x['z'] # item access -> oct2py Cell Cell([[3.0, 4.0]]) >>> x[0, 0] # index access -> numpy record (1.0, 3.0) >>> x[0, 1].z 4.0 """ def __new__(cls, value, session=None): """Create a struct array from a value and optional Octave session.""" value = np.asarray(value) # Squeeze the last element if it is 1 if value.shape[value.ndim - 1] == 1: value = value.squeeze(axis=value.ndim - 1) value = np.atleast_1d(value) if not session: return value.view(cls) # Extract the values. obj = np.empty(value.size, dtype=value.dtype).view(cls) for i, item in enumerate(value.ravel()): for name in value.dtype.names: obj[i][name] = _extract(item[name], session) return obj.reshape(value.shape) @property def fieldnames(self): """The field names of the struct array.""" return self.dtype.names def __getattribute__(self, attr): """Return object arrays as cells and all other values unchanged.""" attr = np.recarray.__getattribute__(self, attr) if isinstance(attr, np.ndarray) and attr.dtype.kind == "O": return Cell(attr) return attr def __getitem__(self, item): """Return object arrays as cells and all other values unchanged.""" item = np.recarray.__getitem__(self, item) if isinstance(item, np.ndarray) and item.dtype.kind == "O": return Cell(item) return item def __repr__(self): """A str repr for the struct array.""" shape = self.shape if len(shape) == 1: shape = (shape[0], 1) msg = "x".join(str(i) for i in shape) msg += " StructArray containing the fields:" for key in self.fieldnames: msg += "\n %s" % key return msg
[docs] class Cell(np.ndarray): # type:ignore[type-arg] """A Python representation of an Octave cell array. Notes ===== This class is not meant to be directly created by the user. It is created automatically for cell array values received from Octave. The last axis is squeezed if it is of size 1 to simplify element access. Examples ======== >>> from oct2py import octave >>> # generate the struct array >>> octave.eval("x = cell(2,2); x(:) = 1.0;") >>> x = octave.pull('x') >>> x Cell([[1.0, 1.0], [1.0, 1.0]]) >>> x[0] Cell([1.0, 1.0]) >>> x[0].tolist() [1.0, 1.0] """ def __new__(cls, value, session=None): """Create a cell array from a value and optional Octave session.""" # Use atleast_2d to preserve Octave size() value = np.atleast_2d(np.asarray(value, dtype=object)) if not session: return value.view(cls) # Extract the values. obj = np.empty(value.size, dtype=object).view(cls) for i, item in enumerate(value.ravel()): obj[i] = _extract(item, session) obj = obj.reshape(value.shape) # type:ignore[assignment] return obj def __repr__(self): """A string repr for the cell array.""" shape = self.shape if len(shape) == 1: shape = (shape[0], 1) msg = self.view(np.ndarray).__repr__() msg = msg.replace("array", "Cell", 1) return msg.replace(", dtype=object", "", 1) def __getitem__(self, key): """Get an element of the array.""" if key == 0 and self.size == 1: # Note: # Can't use `return super().ravel()[0]` here key = tuple([0] * self.ndim) return super().__getitem__(key)
def _extract(data, session=None): # noqa """Convert the Octave values to values suitable for Python.""" # Extract each item of a list. if isinstance(data, list): return [_extract(d, session) for d in data] # Ignore leaf objects. if not isinstance(data, np.ndarray): return data # Extract user defined classes. if isinstance(data, MatlabObject): cls = session._get_user_class(data.classname) return cls.from_value(data) # Extract struct data. if data.dtype.names: # Singular struct if data.size == 1: return _create_struct(data, session) # Struct array return StructArray(data, session) # Extract cells. if data.dtype.kind == "O": return Cell(data, session) # Compress singleton values. if data.size == 1: return data.item() # Compress empty values. if data.size == 0: if data.dtype.kind in "US": return "" return [] # Return standard array. return data def _create_struct(data, session): """Create a struct from session data.""" out = Struct() for name in data.dtype.names: item = data[name] # Extract values that are cells (they are doubly wrapped). if isinstance(item, np.ndarray) and item.dtype.kind == "O": item = item.squeeze().tolist() out[name] = _extract(item, session) return out def _encode(data, convert_to_float): # noqa """Convert the Python values to values suitable to send to Octave.""" ctf = convert_to_float # Handle variable pointer. if isinstance(data, (OctaveVariablePtr)): return _encode(data.value, ctf) # Handle a user defined object. if isinstance(data, OctaveUserClass): return _encode(OctaveUserClass.to_value(data), ctf) # Handle a function pointer. if isinstance(data, (OctaveFunctionPtr, MatlabFunction)): msg = "Cannot write Octave functions" raise Oct2PyError(msg) # Handle matlab objects. if isinstance(data, MatlabObject): # type:ignore[unreachable] view = data.view(np.ndarray) out = MatlabObject(data, data.classname) for name in out.dtype.names: out[name] = _encode(view[name], ctf) return out # Integer objects should be converted to floats if isinstance(data, int): return float(data) # Handle pandas series and dataframes if isinstance(data, (DataFrame, Series)): return _encode(data.values, ctf) # Extract and encode values from dict-like objects. if isinstance(data, dict): out = {} for key, value in data.items(): out[key] = _encode(value, ctf) return out # Send None as nan. if data is None: return np.NaN # Sets are treated like lists. if isinstance(data, set): return _encode(list(data), ctf) # Lists can be interpreted as numeric arrays or cell arrays. if isinstance(data, list): if _is_simple_numeric(data): return _encode(np.array(data), ctf) return _encode(tuple(data), ctf) # Tuples are handled as cells. if isinstance(data, tuple): obj = np.empty(len(data), dtype=object) for i, item in enumerate(data): obj[i] = _encode(item, ctf) return obj # Sparse data must be floating type. if isinstance(data, spmatrix): return data.astype(np.float64) # Return other data types unchanged. if not isinstance(data, np.ndarray): return data # Extract and encode data from object-like arrays. if data.dtype.kind in "OV": out = np.empty(data.size, dtype=data.dtype) for i, item in enumerate(data.ravel()): if data.dtype.names: for name in data.dtype.names: out[i][name] = _encode(item[name], ctf) else: out[i] = _encode(item, ctf) return out.reshape(data.shape) # Complex 128 is the highest supported by savemat. if data.dtype.name == "complex256": return data.astype(np.complex128) # Convert to float if applicable. if ctf and data.dtype.kind in "ui": return data.astype(np.float64) # Return standard array. return data def _is_simple_numeric(data): """Test if a list contains simple numeric data.""" item_len = None for item in data: if isinstance(item, set): item = list(item) # noqa if isinstance(item, list): if not _is_simple_numeric(item): return False # Numpy does not support creating an ndarray from # ragged nested sequences # (which is a list-or-tuple of lists-or-tuples-or ndarrays with # different lengths or shapes if item_len is None: item_len = len(item) if len(item) != item_len: return False elif not isinstance(item, (int, float, complex)): return False return True