Module trainlog.io

Core IO abstractions - reading & writing JSON Lines (https://jsonlines.org/).

This module provides JsonLinesIO, a simplified stream reader/writer for JSON lines, which also adds support for serializing/deserializing numpy arrays.

Expand source code
"""Core IO abstractions - reading & writing JSON Lines (https://jsonlines.org/).

This module provides `JsonLinesIO`, a simplified stream reader/writer for JSON lines,
which also adds support for serializing/deserializing numpy arrays.
"""

from __future__ import annotations

import functools as ft
import gzip as gzip_
import json
import os
import typing
from types import TracebackType
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    Iterable,
    Iterator,
    Optional,
    TextIO,
    Tuple,
    Type,
    TypeVar,
)

T = TypeVar("T")

NUMPY_DICT_KEY = "__numpy_dict"


def numpy_to_dict(array: Any) -> Dict[str, Any]:
    """Convert a numpy array to a JSON-able dictionary, `numpy_from_dict` restores."""
    return {
        NUMPY_DICT_KEY: 0,
        "shape": list(array.shape),
        "dtype": array.dtype.base.name,
        "items": array.flatten().tolist(),
    }


def numpy_from_dict(dict_: Dict[str, Any]) -> Any:
    """Convert a JSON-able dictionary from `numpy_to_dict` back to a numpy array.

    This test identifies valid `dict`:

        if trainlog.io.NUMPY_DICT_KEY in dict_:
            array = numpy_from_dict(dict_)
    """
    import numpy as np  # type: ignore  # pylint: disable=import-outside-toplevel

    assert dict_[NUMPY_DICT_KEY] == 0
    shape = tuple(dict_["shape"])
    dtype = np.dtype(dict_["dtype"])
    items = dict_["items"]
    return np.array(items, dtype=dtype).reshape(shape)


class JSONEncoderWithNumpy(json.JSONEncoder):
    """Add numpy array support using `numpy_to_dict`."""

    def default(self, o: Any) -> Any:
        if type(o).__name__ == "ndarray":
            return numpy_to_dict(o)
        return super().default(o)


ObjectHook = Callable[[Dict[str, Any]], Any]
ObjectPairsHook = Callable[[Iterable[Tuple[str, Any]]], Any]


class JSONDecoderWithNumpy(json.JSONDecoder):
    """Add numpy array support using `numpy_from_dict`."""

    @staticmethod
    def create_object_hook(
        next_hook: Optional[ObjectHook], dict_: Dict[str, Any]
    ) -> Any:
        """Create a chained object_hook, which tries `numpy_from_dict` first."""
        if NUMPY_DICT_KEY in dict_:
            return numpy_from_dict(dict_)
        if next_hook is not None:
            return next_hook(dict_)
        return dict_

    @staticmethod
    def create_object_pairs_hook(
        next_hook: ObjectPairsHook, obj_pairs: Iterable[Tuple[str, Any]]
    ) -> Any:
        """Create a chained object_pairs_hook, which tries `numpy_from_dict` first."""
        dict_ = dict(obj_pairs)
        if NUMPY_DICT_KEY in dict_:
            return numpy_from_dict(dict_)
        return next_hook(obj_pairs)

    def __init__(
        self,
        *,
        object_hook: Optional[ObjectHook] = None,
        object_pairs_hook: Optional[ObjectPairsHook] = None,
        **args: Any
    ):
        object_pairs_hook = (
            ft.partial(self.create_object_pairs_hook, object_pairs_hook)
            if object_pairs_hook is not None
            else None
        )
        super().__init__(
            object_hook=ft.partial(self.create_object_hook, object_hook),
            object_pairs_hook=object_pairs_hook,
            **args
        )


class JsonLinesIO(Generic[T]):
    """Reader/writer for JSON Lines files.

    See https://jsonlines.org/.

    Similar to `TextIO`, but writes "JSON-able" objects rather than strings.
    """

    stream: TextIO
    dump_args: Dict[str, Any]
    load_args: Dict[str, Any]

    def __init__(
        self,
        stream: TextIO,
        dump_args: Optional[Dict[str, Any]] = None,
        load_args: Optional[Dict[str, Any]] = None,
    ):
        self.stream = stream
        self.dump_args = dump_args.copy() if dump_args else {}
        self.dump_args.setdefault("separators", (",", ":"))
        self.dump_args.setdefault("cls", JSONEncoderWithNumpy)
        self.load_args = load_args.copy() if load_args else {}
        self.load_args.setdefault("cls", JSONDecoderWithNumpy)

    def __enter__(self) -> JsonLinesIO[T]:
        return self

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_value: Optional[BaseException],
        traceback: Optional[TracebackType],
    ) -> None:
        self.close()

    def __iter__(self) -> Iterator[T]:
        return self.objects()

    def close(self) -> None:
        """Close the underlying text stream."""
        self.stream.close()

    def flush(self) -> None:
        """Flush the underlying text stream."""
        self.stream.flush()

    def write(self, obj: T) -> None:
        """Write an object to the file, as a JSON entry on a single line."""
        json.dump(obj, self.stream, **self.dump_args)
        self.stream.write("\n")

    def read(self) -> T:
        """Read a single object from the file.

        Throws EOFError if there are no more JSON objects in the file.
        """
        line = self.stream.readline()
        if not line:
            raise EOFError(
                "Attempting to read JSON data past the end of stream", self.stream
            )
        return typing.cast(T, json.loads(line, **self.load_args))

    def objects(self) -> Iterator[T]:
        """An iterator over objects in the file."""
        try:
            while True:
                yield self.read()
        except EOFError:
            pass


def open_maybe_gzip(
    path: str,
    mode: str = "r",
    gzip: Optional[bool] = None,  # pylint: disable=redefined-outer-name
) -> TextIO:
    """Open a file, but use gzip.open if appropriate.

    gzip -- Treat the file as GZIP? If `None`, autodetect based on path extension.
    """
    if gzip or (gzip is None and os.path.splitext(path)[-1] in (".gz", ".gzip")):
        # Mode should default to text, for consistency with `open()`
        gzip_mode = mode if "b" in mode or "t" in mode else mode + "t"
        return typing.cast(TextIO, gzip_.open(path, gzip_mode))
    return typing.cast(TextIO, open(path, mode))


def read_jsonlines(
    path: str, load_args: Optional[Dict[str, Any]] = None
) -> Iterator[T]:
    """Read JSON Lines from a local filesystem path."""
    with JsonLinesIO[T](open_maybe_gzip(path), load_args=load_args) as reader:
        yield from reader


def write_jsonlines(
    path: str, objects: Iterable[T], dump_args: Optional[Dict[str, Any]] = None
) -> None:
    """Write JSON Lines to a local filesystem path."""
    with JsonLinesIO[T](open_maybe_gzip(path, "w"), dump_args=dump_args) as writer:
        for obj in objects:
            writer.write(obj)


def gzip(
    path: str, extension: str = ".gz", delete: bool = True, chunk_size: int = 1024
) -> None:
    """Gzip a local file (by default deleting the original afterwards)."""
    assert extension, "cannot write the gzip to the file being read"
    with open(path, "rb") as srcf, gzip_.open(str(path) + extension, "wb") as destf:
        buffer = bytearray(chunk_size)
        while True:
            count = srcf.readinto(buffer)  # type: ignore
            if not count:
                break
            destf.write(buffer[:count])
    if delete:
        os.remove(path)

Functions

def gzip(path: str, extension: str = '.gz', delete: bool = True, chunk_size: int = 1024) ‑> NoneType

Gzip a local file (by default deleting the original afterwards).

Expand source code
def gzip(
    path: str, extension: str = ".gz", delete: bool = True, chunk_size: int = 1024
) -> None:
    """Gzip a local file (by default deleting the original afterwards)."""
    assert extension, "cannot write the gzip to the file being read"
    with open(path, "rb") as srcf, gzip_.open(str(path) + extension, "wb") as destf:
        buffer = bytearray(chunk_size)
        while True:
            count = srcf.readinto(buffer)  # type: ignore
            if not count:
                break
            destf.write(buffer[:count])
    if delete:
        os.remove(path)
def numpy_from_dict(dict_: Dict[str, Any]) ‑> Any

Convert a JSON-able dictionary from numpy_to_dict() back to a numpy array.

This test identifies valid dict:

if trainlog.io.NUMPY_DICT_KEY in dict_:
    array = numpy_from_dict(dict_)
Expand source code
def numpy_from_dict(dict_: Dict[str, Any]) -> Any:
    """Convert a JSON-able dictionary from `numpy_to_dict` back to a numpy array.

    This test identifies valid `dict`:

        if trainlog.io.NUMPY_DICT_KEY in dict_:
            array = numpy_from_dict(dict_)
    """
    import numpy as np  # type: ignore  # pylint: disable=import-outside-toplevel

    assert dict_[NUMPY_DICT_KEY] == 0
    shape = tuple(dict_["shape"])
    dtype = np.dtype(dict_["dtype"])
    items = dict_["items"]
    return np.array(items, dtype=dtype).reshape(shape)
def numpy_to_dict(array: Any) ‑> Dict[str, Any]

Convert a numpy array to a JSON-able dictionary, numpy_from_dict() restores.

Expand source code
def numpy_to_dict(array: Any) -> Dict[str, Any]:
    """Convert a numpy array to a JSON-able dictionary, `numpy_from_dict` restores."""
    return {
        NUMPY_DICT_KEY: 0,
        "shape": list(array.shape),
        "dtype": array.dtype.base.name,
        "items": array.flatten().tolist(),
    }
def open_maybe_gzip(path: str, mode: str = 'r', gzip: Optional[bool] = None) ‑> 

Open a file, but use gzip.open if appropriate.

gzip – Treat the file as GZIP? If None, autodetect based on path extension.

Expand source code
def open_maybe_gzip(
    path: str,
    mode: str = "r",
    gzip: Optional[bool] = None,  # pylint: disable=redefined-outer-name
) -> TextIO:
    """Open a file, but use gzip.open if appropriate.

    gzip -- Treat the file as GZIP? If `None`, autodetect based on path extension.
    """
    if gzip or (gzip is None and os.path.splitext(path)[-1] in (".gz", ".gzip")):
        # Mode should default to text, for consistency with `open()`
        gzip_mode = mode if "b" in mode or "t" in mode else mode + "t"
        return typing.cast(TextIO, gzip_.open(path, gzip_mode))
    return typing.cast(TextIO, open(path, mode))
def read_jsonlines(path: str, load_args: Optional[Dict[str, Any]] = None) ‑> Iterator[~T]

Read JSON Lines from a local filesystem path.

Expand source code
def read_jsonlines(
    path: str, load_args: Optional[Dict[str, Any]] = None
) -> Iterator[T]:
    """Read JSON Lines from a local filesystem path."""
    with JsonLinesIO[T](open_maybe_gzip(path), load_args=load_args) as reader:
        yield from reader
def write_jsonlines(path: str, objects: Iterable[T], dump_args: Optional[Dict[str, Any]] = None) ‑> NoneType

Write JSON Lines to a local filesystem path.

Expand source code
def write_jsonlines(
    path: str, objects: Iterable[T], dump_args: Optional[Dict[str, Any]] = None
) -> None:
    """Write JSON Lines to a local filesystem path."""
    with JsonLinesIO[T](open_maybe_gzip(path, "w"), dump_args=dump_args) as writer:
        for obj in objects:
            writer.write(obj)

Classes

class JSONDecoderWithNumpy (*, object_hook: Optional[ObjectHook] = None, object_pairs_hook: Optional[ObjectPairsHook] = None, **args: Any)

Add numpy array support using numpy_from_dict().

object_hook, if specified, will be called with the result of every JSON object decoded and its return value will be used in place of the given dict. This can be used to provide custom deserializations (e.g. to support JSON-RPC class hinting).

object_pairs_hook, if specified will be called with the result of every JSON object decoded with an ordered list of pairs. The return value of object_pairs_hook will be used instead of the dict. This feature can be used to implement custom decoders. If object_hook is also defined, the object_pairs_hook takes priority.

parse_float, if specified, will be called with the string of every JSON float to be decoded. By default this is equivalent to float(num_str). This can be used to use another datatype or parser for JSON floats (e.g. decimal.Decimal).

parse_int, if specified, will be called with the string of every JSON int to be decoded. By default this is equivalent to int(num_str). This can be used to use another datatype or parser for JSON integers (e.g. float).

parse_constant, if specified, will be called with one of the following strings: -Infinity, Infinity, NaN. This can be used to raise an exception if invalid JSON numbers are encountered.

If strict is false (true is the default), then control characters will be allowed inside strings. Control characters in this context are those with character codes in the 0-31 range, including '\t' (tab), '\n', '\r' and '\0'.

Expand source code
class JSONDecoderWithNumpy(json.JSONDecoder):
    """Add numpy array support using `numpy_from_dict`."""

    @staticmethod
    def create_object_hook(
        next_hook: Optional[ObjectHook], dict_: Dict[str, Any]
    ) -> Any:
        """Create a chained object_hook, which tries `numpy_from_dict` first."""
        if NUMPY_DICT_KEY in dict_:
            return numpy_from_dict(dict_)
        if next_hook is not None:
            return next_hook(dict_)
        return dict_

    @staticmethod
    def create_object_pairs_hook(
        next_hook: ObjectPairsHook, obj_pairs: Iterable[Tuple[str, Any]]
    ) -> Any:
        """Create a chained object_pairs_hook, which tries `numpy_from_dict` first."""
        dict_ = dict(obj_pairs)
        if NUMPY_DICT_KEY in dict_:
            return numpy_from_dict(dict_)
        return next_hook(obj_pairs)

    def __init__(
        self,
        *,
        object_hook: Optional[ObjectHook] = None,
        object_pairs_hook: Optional[ObjectPairsHook] = None,
        **args: Any
    ):
        object_pairs_hook = (
            ft.partial(self.create_object_pairs_hook, object_pairs_hook)
            if object_pairs_hook is not None
            else None
        )
        super().__init__(
            object_hook=ft.partial(self.create_object_hook, object_hook),
            object_pairs_hook=object_pairs_hook,
            **args
        )

Ancestors

  • json.decoder.JSONDecoder

Static methods

def create_object_hook(next_hook: Optional[ObjectHook], dict_: Dict[str, Any]) ‑> Any

Create a chained object_hook, which tries numpy_from_dict() first.

Expand source code
@staticmethod
def create_object_hook(
    next_hook: Optional[ObjectHook], dict_: Dict[str, Any]
) -> Any:
    """Create a chained object_hook, which tries `numpy_from_dict` first."""
    if NUMPY_DICT_KEY in dict_:
        return numpy_from_dict(dict_)
    if next_hook is not None:
        return next_hook(dict_)
    return dict_
def create_object_pairs_hook(next_hook: ObjectPairsHook, obj_pairs: Iterable[Tuple[str, Any]]) ‑> Any

Create a chained object_pairs_hook, which tries numpy_from_dict() first.

Expand source code
@staticmethod
def create_object_pairs_hook(
    next_hook: ObjectPairsHook, obj_pairs: Iterable[Tuple[str, Any]]
) -> Any:
    """Create a chained object_pairs_hook, which tries `numpy_from_dict` first."""
    dict_ = dict(obj_pairs)
    if NUMPY_DICT_KEY in dict_:
        return numpy_from_dict(dict_)
    return next_hook(obj_pairs)
class JSONEncoderWithNumpy (*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)

Add numpy array support using numpy_to_dict().

Constructor for JSONEncoder, with sensible defaults.

If skipkeys is false, then it is a TypeError to attempt encoding of keys that are not str, int, float or None. If skipkeys is True, such items are simply skipped.

If ensure_ascii is true, the output is guaranteed to be str objects with all incoming non-ASCII characters escaped. If ensure_ascii is false, the output can contain non-ASCII characters.

If check_circular is true, then lists, dicts, and custom encoded objects will be checked for circular references during encoding to prevent an infinite recursion (which would cause an OverflowError). Otherwise, no such check takes place.

If allow_nan is true, then NaN, Infinity, and -Infinity will be encoded as such. This behavior is not JSON specification compliant, but is consistent with most JavaScript based encoders and decoders. Otherwise, it will be a ValueError to encode such floats.

If sort_keys is true, then the output of dictionaries will be sorted by key; this is useful for regression tests to ensure that JSON serializations can be compared on a day-to-day basis.

If indent is a non-negative integer, then JSON array elements and object members will be pretty-printed with that indent level. An indent level of 0 will only insert newlines. None is the most compact representation.

If specified, separators should be an (item_separator, key_separator) tuple. The default is (', ', ': ') if indent is None and (',', ': ') otherwise. To get the most compact JSON representation, you should specify (',', ':') to eliminate whitespace.

If specified, default is a function that gets called for objects that can't otherwise be serialized. It should return a JSON encodable version of the object or raise a TypeError.

Expand source code
class JSONEncoderWithNumpy(json.JSONEncoder):
    """Add numpy array support using `numpy_to_dict`."""

    def default(self, o: Any) -> Any:
        if type(o).__name__ == "ndarray":
            return numpy_to_dict(o)
        return super().default(o)

Ancestors

  • json.encoder.JSONEncoder

Methods

def default(self, o: Any) ‑> Any

Implement this method in a subclass such that it returns a serializable object for o, or calls the base implementation (to raise a TypeError).

For example, to support arbitrary iterators, you could implement default like this::

def default(self, o):
    try:
        iterable = iter(o)
    except TypeError:
        pass
    else:
        return list(iterable)
    # Let the base class default method raise the TypeError
    return JSONEncoder.default(self, o)
Expand source code
def default(self, o: Any) -> Any:
    if type(o).__name__ == "ndarray":
        return numpy_to_dict(o)
    return super().default(o)
class JsonLinesIO (stream: TextIO, dump_args: Optional[Dict[str, Any]] = None, load_args: Optional[Dict[str, Any]] = None)

Reader/writer for JSON Lines files.

See https://jsonlines.org/.

Similar to TextIO, but writes "JSON-able" objects rather than strings.

Expand source code
class JsonLinesIO(Generic[T]):
    """Reader/writer for JSON Lines files.

    See https://jsonlines.org/.

    Similar to `TextIO`, but writes "JSON-able" objects rather than strings.
    """

    stream: TextIO
    dump_args: Dict[str, Any]
    load_args: Dict[str, Any]

    def __init__(
        self,
        stream: TextIO,
        dump_args: Optional[Dict[str, Any]] = None,
        load_args: Optional[Dict[str, Any]] = None,
    ):
        self.stream = stream
        self.dump_args = dump_args.copy() if dump_args else {}
        self.dump_args.setdefault("separators", (",", ":"))
        self.dump_args.setdefault("cls", JSONEncoderWithNumpy)
        self.load_args = load_args.copy() if load_args else {}
        self.load_args.setdefault("cls", JSONDecoderWithNumpy)

    def __enter__(self) -> JsonLinesIO[T]:
        return self

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_value: Optional[BaseException],
        traceback: Optional[TracebackType],
    ) -> None:
        self.close()

    def __iter__(self) -> Iterator[T]:
        return self.objects()

    def close(self) -> None:
        """Close the underlying text stream."""
        self.stream.close()

    def flush(self) -> None:
        """Flush the underlying text stream."""
        self.stream.flush()

    def write(self, obj: T) -> None:
        """Write an object to the file, as a JSON entry on a single line."""
        json.dump(obj, self.stream, **self.dump_args)
        self.stream.write("\n")

    def read(self) -> T:
        """Read a single object from the file.

        Throws EOFError if there are no more JSON objects in the file.
        """
        line = self.stream.readline()
        if not line:
            raise EOFError(
                "Attempting to read JSON data past the end of stream", self.stream
            )
        return typing.cast(T, json.loads(line, **self.load_args))

    def objects(self) -> Iterator[T]:
        """An iterator over objects in the file."""
        try:
            while True:
                yield self.read()
        except EOFError:
            pass

Ancestors

  • typing.Generic

Class variables

var dump_args : Dict[str, Any]
var load_args : Dict[str, Any]
var stream

Methods

def close(self) ‑> NoneType

Close the underlying text stream.

Expand source code
def close(self) -> None:
    """Close the underlying text stream."""
    self.stream.close()
def flush(self) ‑> NoneType

Flush the underlying text stream.

Expand source code
def flush(self) -> None:
    """Flush the underlying text stream."""
    self.stream.flush()
def objects(self) ‑> Iterator[~T]

An iterator over objects in the file.

Expand source code
def objects(self) -> Iterator[T]:
    """An iterator over objects in the file."""
    try:
        while True:
            yield self.read()
    except EOFError:
        pass
def read(self) ‑> ~T

Read a single object from the file.

Throws EOFError if there are no more JSON objects in the file.

Expand source code
def read(self) -> T:
    """Read a single object from the file.

    Throws EOFError if there are no more JSON objects in the file.
    """
    line = self.stream.readline()
    if not line:
        raise EOFError(
            "Attempting to read JSON data past the end of stream", self.stream
        )
    return typing.cast(T, json.loads(line, **self.load_args))
def write(self, obj: T) ‑> NoneType

Write an object to the file, as a JSON entry on a single line.

Expand source code
def write(self, obj: T) -> None:
    """Write an object to the file, as a JSON entry on a single line."""
    json.dump(obj, self.stream, **self.dump_args)
    self.stream.write("\n")