"""
Functionality for logging function evaluations during optimization.
This is all very ugly, but I couldn't find a better way to track function
evaluations that works with multiple levels of parallelization,
deep-copying, and pickling.
"""
from __future__ import annotations
import functools
import heapq
import logging
import os
import tempfile
import threading
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from multiprocessing import Manager
from multiprocessing.managers import ListProxy
from pathlib import Path
from typing import Any
import numpy as np
import pypesto.ensemble
__all__ = ["EvalLogger", "TopKSelector", "ThresholdSelector"]
logger = logging.getLogger(__name__)
def _reconstruct_method_interceptor(
obj: Any, method_name: str, handler
) -> MethodInterceptorProxy:
"""
Module-level factory used by pickle to rebuild the MethodInterceptorProxy.
"""
return MethodInterceptorProxy(obj, method_name, handler)
class MethodInterceptorProxy:
own_attrs = ("_obj", "_method_name", "_handler")
def __init__(self, obj: Any, method_name: str, handler: Callable):
"""
Wrap `obj` and intercept calls to `method_name`.
`handler` is called as handler(orig_callable, *args, **kwargs).
"""
for attr in self.own_attrs:
if hasattr(obj, attr):
raise ValueError(
f"Cannot proxy object with attribute {attr!r}"
)
self._obj = obj
self._method_name = method_name
self._handler = handler
@property
def __class__(self):
"""
Make the proxy appear as the wrapped object's class for
isinstance() checks and similar uses.
"""
return self._obj.__class__
def __repr__(self):
# We might want to forward __repr__ as well, but for debugging it's
# useful to see that this is a proxy.
return f"<Proxy of {repr(self._obj)}>"
def __getattr__(self, name):
# forward everything except the intercepted method
attr = getattr(self._obj, name)
if name == self._method_name and callable(attr):
@functools.wraps(attr)
def wrapped(*args, **kwargs):
return self._handler(attr, *args, **kwargs)
return wrapped
return attr
def __setattr__(self, name, value):
# keep proxy internals on the proxy;
# forward other sets to wrapped object
if name in self.own_attrs:
object.__setattr__(self, name, value)
else:
setattr(self._obj, name, value)
def __call__(self, *args, **kwargs):
# support calling the proxy itself (intercept __call__)
if self._method_name == "__call__":
orig = getattr(self._obj, "__call__", self._obj) # noqa B004
return self._handler(orig, *args, **kwargs)
# otherwise, if underlying object is callable and not intercepted,
# call it
orig = getattr(self._obj, "__call__", None) # noqa B004
if callable(orig):
return orig(*args, **kwargs)
raise TypeError(f"{type(self).__name__} object is not callable")
def __reduce__(self):
"""
Control pickling: return a top-level callable and args so pickle
doesn't try to reconstruct the object based on the (overridden)
`__class__`.
"""
return (
_reconstruct_method_interceptor,
(self._obj, self._method_name, self._handler),
)
def __deepcopy__(self, memo):
"""
Deepcopy the proxy but keep the *original* handler reference.
"""
import copy
if id(self) in memo:
return memo[id(self)]
obj_copy = copy.deepcopy(self._obj, memo)
proxy_copy = MethodInterceptorProxy(
obj_copy, self._method_name, self._handler
)
memo[id(self)] = proxy_copy
return proxy_copy
@contextmanager
def temp_swap_attr(container: Any, attr_name: str, proxy: Any):
"""
Temporarily replace `container.attr_name` with `proxy`
and restore afterwards.
"""
had_attr = hasattr(container, attr_name)
old = getattr(container, attr_name, None)
setattr(container, attr_name, proxy)
yield
if had_attr:
setattr(container, attr_name, old)
else:
try:
delattr(container, attr_name)
except Exception:
logger.exception("delattr failed in temp_swap_attr cleanup")
[docs]
class EvalLogger:
"""
Log function evaluations.
**This is experimental and is likely to change in future releases.**
"""
[docs]
def __init__(
self,
selector: EvalSelectorBase = None,
_shared_evals: ListProxy = None,
_shared_lock: threading.Lock = None,
):
"""
Initialize.
:param selector:
Optional selector / handler for evaluations.
See :class:`TopKSelector` and :class:`ThresholdSelector`.
If not provided, all objective evaluations will be kept in memory.
"""
if (
_shared_evals is None
and _shared_lock is not None
or _shared_evals is not None
and _shared_lock is None
):
raise ValueError(
"`_shared_evals` and `_shared_lock` must both be provided "
"or both be None."
)
# If no `shared_evals` is provided, create a private Manager
# and a manager.list() that will be used to store evaluations.
# The Manager object is kept on the instance for the process
# lifetime but is excluded from pickling. Same for the selector.
if _shared_evals is None:
self._manager = Manager()
self._shared_evals = self._manager.list()
self._shared_lock = self._manager.Lock()
else:
# when passed an existing manager list proxy (e.g. on unpickle)
self._manager = None
self._shared_evals = _shared_evals
self._shared_lock = _shared_lock
self.selector = selector
@property
def evals(self) -> list:
return list(self._shared_evals)
[docs]
def log(self, x: np.ndarray, fx: float) -> None:
"""Log an evaluation.
:param x: Parameter vector
:param fx: Function value
"""
if self.selector is not None and not self.selector.is_running:
# start as late as possible to avoid issues with forking processes
# elsewhere
self.selector.start_background_ingest(
self._shared_evals, self._shared_lock
)
with self._shared_lock:
self._shared_evals.append((x, fx))
def _obj_call_wrapper(self, orig_bound, x, *args, **kwargs):
"""
Instance method wrapper for `Objective.__call__`.
Must be pickleable, so cannot be a closure.
"""
# We might have to support different return types here in the future
# (e.g., residuals). For now, assume scalar objective value.
fx = orig_bound(x, *args, **kwargs)
self.log(x, fx)
return fx
[docs]
@contextmanager
def attach(self, problem: pypesto.Problem):
"""Context manager to attach the `EvalLogger` to an `Objective`.
:param problem: The problem that contains the objective whose
evaluations are to be logged.
"""
proxy = MethodInterceptorProxy(
problem.objective, "__call__", self._obj_call_wrapper
)
with temp_swap_attr(problem, "objective", proxy):
yield
if self.selector is not None:
self.selector.stop_background_ingest()
# ingest any remaining entries
self.selector.ingest_from_shared(
self._shared_evals, self._shared_lock
)
def __getstate__(self):
"""
Pickle everything except the Manager and selector.
"""
return {
"_shared_evals": self._shared_evals,
"_shared_lock": self._shared_lock,
}
def __setstate__(self, state):
"""Restore."""
self._shared_evals = state.get("_shared_evals")
self._shared_lock = state.get("_shared_lock")
self._manager = None
self.selector = None
class EvalSelectorBase:
"""
Base class for objective evaluation selectors.
This is intended for filtering and potentially checkpointing objective
function evaluations in combination with :class:`EvalLogger`.
Subclasses must implement :meth:`process` to process the incoming function
evaluation records. ``EvalSelectorBase`` consumes entries from an
:class:`EvalLogger` and pass them to `process`.
:param dim: Problem dimension (length of parameter vector).
:param path: Optional filesystem path used by subclasses.
:param dtype: Numpy dtype used for internal numeric storage.
"""
def __init__(self, dim: int, path: Path | str | None = None, dtype=float):
"""
Initialize the selector.
:param dim: Problem dimension (length of parameter vector).
:param path: Optional filesystem path used by subclasses.
:param dtype: Numpy dtype used for internal numeric storage.
"""
# TODO: consider making that easier to use and inferring dim from first
# record? defer initialization until then.
self.dim = int(dim)
self.path: Path | None = Path(path) if path is not None else None
self.dtype = dtype
#: Threading lock protecting internal mutable state.
self._lock = threading.Lock()
#: Background ingestion thread.
self._bg_thread: threading.Thread | None = None
#: Event used to request background thread stop.
self._stop_bg = threading.Event()
def _normalize_x(self, x: Any) -> np.ndarray:
"""
Accept list/tuple/np.ndarray and return a contiguous np.ndarray with
dtype.
"""
arr = np.asarray(x, dtype=self.dtype)
if arr.shape != (self.dim,):
raise ValueError(
f"x must have shape ({self.dim},), got {arr.shape}"
)
return np.ascontiguousarray(arr)
@staticmethod
def _key_from_array(arr: np.ndarray) -> tuple:
"""
Create a hashable key from a parameter vector array.
"""
return tuple(float(v) for v in arr)
@property
def is_running(self) -> bool:
"""Return True if background ingestion thread is running.
False otherwise.
"""
return self._bg_thread is not None
def ingest_from_shared(
self, shared_list: ListProxy, shared_lock: threading.Lock | None
) -> int:
"""
Consume entries from a shared list.
:param shared_list: ``multiprocessing.Manager().list()``
containing records to ingest.
:param shared_lock: Lock protecting access to `shared_list`.
:returns: Number of processed entries.
"""
with shared_lock:
snap = list(shared_list)
n = len(snap)
if n == 0:
return 0
shared_list[:] = []
consumed = 0
for item in snap:
if item is None:
continue
try:
x, fx = item
self.process(x, float(fx))
except Exception:
logger.exception(f"Failed to process {item!r}")
consumed += 1
return consumed
def start_background_ingest(
self,
shared_list: ListProxy,
shared_lock: threading.Lock,
interval: float = 1.0,
):
"""
Start a daemon thread that periodically ingests from the shared list.
:param shared_list: ``multiprocessing.Manager().list()``
containing records to ingest.
:param shared_lock: Lock protecting access to `shared_list`.
:param interval: Time interval between ingests in seconds.
"""
if self._bg_thread is not None and self._bg_thread.is_alive():
return
self._stop_bg.clear()
def worker():
while not self._stop_bg.is_set():
self.ingest_from_shared(shared_list, shared_lock)
self._stop_bg.wait(interval)
t = threading.Thread(target=worker, daemon=True)
self._bg_thread = t
t.start()
def stop_background_ingest(self):
"""
Stop the background ingestion thread if running.
"""
if self._bg_thread is None:
return
self._stop_bg.set()
self._bg_thread.join(timeout=2.0)
self._bg_thread = None
def process(self, x: Sequence[float], fx: float) -> None:
"""
Process an objective evaluation.
"""
# TODO: consider changing (x, fx) to a more extendable record type
raise NotImplementedError()
[docs]
class TopKSelector(EvalSelectorBase):
"""
Maintain the K-best unique parameter vectors seen so far.
"""
[docs]
def __init__(
self, *, k: int, dim: int, path: Path | str | None = None, dtype=float
):
"""
Initialize.
:param k: The number of best entries to keep.
(Best means lowest function value.)
:param dim: Problem dimension (length of parameter vector).
:param path: Optional filesystem path used by subclasses.
:param dtype: Numpy dtype used for internal numeric storage.
"""
super().__init__(dim=dim, path=path, dtype=dtype)
self._k = int(k)
#: Buffer for parameter vectors.
self._x = np.empty((self._k, self.dim), dtype=self.dtype)
#: Buffer for function values.
self._fx = np.empty(self._k, dtype=self.dtype)
#: Buffer for pseudo-timestamps. Currently: insertion order.
self._ts = np.empty(self._k, dtype=np.int64)
#: Validity mask for slots. (Not all slots may be used yet.)
self._valid = np.zeros(self._k, dtype=bool)
#: Min-heap of stored entries to avoid full sorting on insert.
# Entries are tuples (-fx, ts, slot)
self._heap: list[tuple] = []
#: Pseudo-timestamp counter for entries. Used for stable sorting
# of heap entries with equal fx.
self._counter = 0
#: Next free slot index in _x, _fx, _ts buffers.
self._next_slot = 0
#: stored parameter vectors as hashable keys to enforce uniqueness
self._seen: set[tuple] = set()
[docs]
def process(self, x: Sequence[float], fx: float) -> None:
"""
Process an objective evaluation.
Add (x, fx) to top-K if x is not already stored.
"""
x_arr = self._normalize_x(x)
key = self._key_from_array(x_arr)
with self._lock:
# reject duplicate parameter vectors
if key in self._seen:
return
# determine slot to use
if self._next_slot < self._k:
# less than k entries stored: use next free slot
slot = self._next_slot
self._next_slot += 1
else:
# k entries already stored: evict the worst if new is better.
# worst is at top of min-heap (as -fx)
worst_fx = -self._heap[0][0]
if fx >= worst_fx:
# worse than or equal to worst stored: reject
return
# remove worst from heap and seen-set
_, _, slot = heapq.heappop(self._heap)
old_key = self._key_from_array(self._x[slot])
self._seen.discard(old_key)
# store new entry
ts = self._counter
self._counter += 1
self._x[slot] = x_arr
self._fx[slot] = fx
self._ts[slot] = ts
self._valid[slot] = True
heapq.heappush(self._heap, (-float(fx), int(ts), int(slot)))
self._seen.add(key)
return
[docs]
def snapshot(self) -> dict[str, np.ndarray]:
"""Create a snapshot of the stored entries."""
with self._lock:
mask = self._valid
order = [
slot
for _, _, slot in sorted(
self._heap, key=lambda t: (-t[0], t[1])
)
if mask[slot]
]
return {
"x": np.ascontiguousarray(self._x[mask][order]),
"fx": np.ascontiguousarray(self._fx[mask][order]),
"ts": np.ascontiguousarray(self._ts[mask][order]),
}
# TODO: snapshotting frequency -- who controls?
# after how many ingests? seconds? ...?
[docs]
def save(self):
"""Save the stored entries as numpy ``.npz`` file."""
if self.path is None:
return
snapshot = self.snapshot()
# first save to a temp file to avoid corrupted files on crashes,
# then atomically rename
dirpath = self.path.parent
fd, tmpname = tempfile.mkstemp(
prefix="._topk_", suffix=".npz", dir=dirpath
)
os.close(fd)
try:
np.savez(tmpname, **snapshot)
os.replace(tmpname, self.path)
except Exception:
logger.exception("Failed to save snapshot to %s", self.path)
[docs]
def to_ensemble(self) -> pypesto.ensemble.Ensemble:
"""Create a :class:`pypesto.Ensemble` from the stored entries."""
snapshot = self.snapshot()
x = snapshot["x"]
ensemble = pypesto.ensemble.Ensemble(x_vectors=x)
return ensemble
[docs]
class ThresholdSelector(EvalSelectorBase):
"""
Maintain all unique parameter vectors below a certain threshold based
on the best function value seen so far.
"""
[docs]
def __init__(
self,
*,
dim: int,
mode: str,
threshold: float,
path: Path | str | None = None,
dtype=float,
_k: int = 100,
_chunk_size: int = 64,
):
"""
Initialize.
:param dim: Problem dimension (length of parameter vector).
:param path: Optional filesystem path for snapshots.
:param threshold: Threshold for accepting new entries.
Interpreted according to `mode`.
:param mode: 'abs' or 'rel' mode for thresholding.
If 'abs', new entries are accepted if
``fx - best_fx <= threshold``.
If 'rel', new entries are accepted if
``|(fx - best_fx) / best_fx | <= threshold``.
:param dtype: Numpy dtype used for internal numeric storage.
:param _k: Initial capacity.
:param _chunk_size: Minimum grow chunk size.
"""
super().__init__(dim=dim, path=path, dtype=dtype)
if mode not in ("abs", "rel"):
raise ValueError(f"Unknown threshold mode {mode!r}")
self._threshold = float(threshold)
self._mode = mode
self._chunk_size = max(1, int(_chunk_size))
#: Capacity of internal buffers (will grow as needed).
self._cap = max(self._chunk_size, int(_k))
#: Current number of stored entries.
self._len = 0
self._x = np.empty((self._cap, self.dim), dtype=self.dtype)
self._fx = np.empty(self._cap, dtype=self.dtype)
self._ts = np.empty(self._cap, dtype=np.int64)
self._valid = np.zeros(self._cap, dtype=bool)
#: Stored parameter vectors as hashable keys to enforce uniqueness.
# Maps parameter tuple -> slot index.
self._seen: dict[tuple, int] = {}
#: List of free slot indices.
self._free_slots: list[int] = []
#: Pseudo-timestamp counter for entries.
self._counter = 0
#: Best function value seen so far.
self._best_fx: float | None = None
def _grow_to(self, min_cap: int):
"""Grow internal buffers to at least min_cap."""
# allocate new buffers
new_cap = max(self._cap * 2, min_cap, self._cap + self._chunk_size)
x_new = np.empty((new_cap, self.dim), dtype=self.dtype)
fx_new = np.empty(new_cap, dtype=self.dtype)
ts_new = np.empty(new_cap, dtype=np.int64)
valid_new = np.zeros(new_cap, dtype=bool)
# initialize with old data
x_new[: self._cap] = self._x
fx_new[: self._cap] = self._fx
ts_new[: self._cap] = self._ts
valid_new[: self._cap] = self._valid
self._x, self._fx, self._ts = x_new, fx_new, ts_new
self._valid = valid_new
self._cap = new_cap
def _meets_threshold(self, fx: float) -> bool:
"""Check if `fx` meets the threshold criterion."""
best_fx = self._best_fx
if best_fx is None:
return True
match self._mode:
case "abs":
return fx - best_fx <= self._threshold
case "rel":
# TODO: handle best_fx == 0 case?
# let's see if relative thresholding is actually useful
return abs((fx - best_fx) / best_fx) <= self._threshold
case _:
raise ValueError(f"Unknown threshold mode {self._mode!r}")
def _store(self, slot: int, x: np.ndarray, fx: float):
"""Store (x, fx) in the slot with the given index."""
ts = self._counter
self._counter += 1
self._x[slot] = x
self._fx[slot] = fx
self._ts[slot] = ts
self._valid[slot] = True
self._len += 1
[docs]
def process(self, x: Sequence[float], fx: float) -> None:
"""
Process an objective evaluation.
Add (x, fx) if it meets the threshold and is not a duplicate.
"""
x_arr = self._normalize_x(x)
key = self._key_from_array(x_arr)
with self._lock:
if key in self._seen:
return
if not self._meets_threshold(fx):
return
# pick a slot
if self._free_slots:
# reuse a slot freed during pruning
slot = self._free_slots.pop()
else:
if self._len >= self._cap:
# capacity exhausted -- grow and use next slot
old_cap = self._cap
self._grow_to(self._cap + 1)
slot = old_cap
else:
# find next unused slot among existing capacity
slot = 0
while slot < self._cap and self._valid[slot]:
slot += 1
if slot >= self._cap:
raise AssertionError(
f"Slot {slot} exceeds capacity {self._cap}"
)
# store
self._store(slot, x_arr, fx)
self._seen[key] = slot
# update best and prune if necessary
if self._best_fx is None or fx < self._best_fx:
self._best_fx = fx
self._prune()
return
def _prune(self):
"""
Prune after fx_best changed.
Remove stored entries that no longer meet the threshold .
"""
# find entries to remove
to_remove: list[tuple[tuple, int]] = []
for key, slot in self._seen.items():
if not self._meets_threshold(self._fx[slot]):
to_remove.append((key, slot))
# remove them
for key, slot in to_remove:
# mark slot free
self._valid[slot] = False
del self._seen[key]
self._free_slots.append(slot)
self._len -= 1
[docs]
def snapshot(self):
"""Create a snapshot of the stored entries."""
with self._lock:
mask = self._valid
order = np.argsort(self._fx[mask])
return {
"x": np.ascontiguousarray(self._x[mask][order]),
"fx": np.ascontiguousarray(self._fx[mask][order]),
"ts": np.ascontiguousarray(self._ts[mask][order]),
}
[docs]
def save(self):
"""Save the stored entries as numpy ``.npz`` file."""
if self.path is None:
return
snapshot = self.snapshot()
# first save to a temp file to avoid corrupted files on crashes,
# then atomically rename
dirpath = self.path.parent
fd, tmpname = tempfile.mkstemp(
prefix="._thresh_", suffix=".npz", dir=dirpath
)
os.close(fd)
try:
np.savez(tmpname, **snapshot)
os.replace(tmpname, self.path)
except Exception:
logger.exception("Failed to save snapshot to %s", self.path)
[docs]
def to_ensemble(self) -> pypesto.ensemble.Ensemble:
"""Create a :class:`pypesto.Ensemble` from the stored entries."""
snapshot = self.snapshot()
x = snapshot["x"]
ensemble = pypesto.ensemble.Ensemble(x_vectors=x)
return ensemble