import copy
import itertools
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
import functools
from typing import Optional, Dict, List, Any, Union, Mapping
import warnings
import numpy as np
from marshmallow import ValidationError as MarshmallowValidationError
from paramtools import utils
from paramtools import contrib
from paramtools.schema import ParamToolsSchema
from paramtools.schema_factory import SchemaFactory
from paramtools.sorted_key_list import SortedKeyList
from paramtools.typing import ValueObject, FileDictStringLike
from paramtools.exceptions import (
ParamToolsError,
SparseValueObjectsException,
ValidationError,
InconsistentLabelsException,
collision_list,
ParameterNameCollisionException,
)
from paramtools.values import Values, union, intersection
class ParameterSlice:
__slots__ = ("parameters", "_cache", "_key_cache")
def __init__(self, parameters):
self.parameters = parameters
self._cache = {}
self._key_cache = {}
def __getitem__(self, parameter_or_values):
keyfuncs = dict(self.parameters.keyfuncs)
if (
isinstance(parameter_or_values, str)
and parameter_or_values in self._cache
):
return self._cache[parameter_or_values]
elif isinstance(parameter_or_values, str):
data = self.parameters._data.get(parameter_or_values)
if data is None:
raise ValueError(f"Unknown parameter: {parameter_or_values}.")
try:
keyfunc = self._key_cache.get(parameter_or_values, None)
if keyfunc is None:
keyfunc = self.parameters._validator_schema.field_keyfunc(
parameter_or_values
)
self._key_cache[parameter_or_values] = keyfunc
self._cache[parameter_or_values] = keyfunc
keyfuncs["value"] = keyfunc
values = Values(data["value"], keyfuncs=keyfuncs)
self._cache[parameter_or_values] = values
return values
except contrib.validate.ValidationError as ve:
raise ParamToolsError(
f"There was an error retrieving the field for {parameter_or_values}",
{},
) from ve
else:
return Values(parameter_or_values, keyfuncs=keyfuncs)
[docs]class Parameters:
defaults = None
array_first: bool = False
label_to_extend: str = None
uses_extend_func: bool = False
index_rates: Dict = {}
def __init__(
self,
initial_state: Optional[dict] = None,
index_rates: Optional[dict] = None,
sort_values: bool = True,
**ops,
):
schemafactory = SchemaFactory(self.defaults)
(
self._defaults_schema,
self._validator_schema,
self._schema,
self._data,
) = schemafactory.schemas()
self.label_validators = schemafactory.label_validators
self.keyfuncs = {}
for label, lv in self.label_validators.items():
cmp_funcs = getattr(lv, "cmp_funcs", None)
if cmp_funcs is not None:
self.keyfuncs[label] = cmp_funcs()["key"]
self._stateless_label_grid = OrderedDict()
for name, v in self.label_validators.items():
if hasattr(v, "grid"):
self._stateless_label_grid[name] = v.grid()
else:
self._stateless_label_grid[name] = []
self.label_grid = copy.deepcopy(self._stateless_label_grid)
self._validator_schema.context["spec"] = self
self._warnings = {}
self._errors = {}
self._defer_validation = False
self._state = self.parse_labels(**(initial_state or {}))
self.index_rates = index_rates or self.index_rates
self.sel = ParameterSlice(self)
# set operators in order of importance:
# __init__ arg: most important
# class attribute: middle importance
# schema action: least important
# default value if three above are not specified.
default_ops = [
("array_first", False),
("label_to_extend", None),
("uses_extend_func", False),
]
schema_ops = self._schema.get("operators", {})
for name, default in default_ops:
if name in ops:
setattr(self, name, ops.get(name))
elif getattr(self, name, None) != default:
setattr(self, name, getattr(self, name))
elif name in schema_ops:
setattr(self, name, schema_ops[name])
else:
setattr(self, name, default)
if self.label_to_extend:
prev_array_first = self.array_first
self.array_first = False
self.set_state()
self.extend()
if prev_array_first:
self.array_first = True
self.set_state()
else:
self.set_state()
if "operators" not in self._schema:
self._schema["operators"] = {}
self._schema["operators"].update(self.operators)
if sort_values:
self.sort_values()
def __getitem__(self, parameter):
raise AttributeError(
f'Use params.sel["{parameter}"] instead of params["{parameter}"].'
)
[docs] def set_state(self, **labels):
"""
Sets state for the Parameters instance. The `_state`, `label_grid`, and
parameter attributes are all updated with the new state.
Use the `view_state` method to inspect the current state of the instance,
and use the `clear_state` method to revert to the default state.
**Raises**
- `ValidationError` if the labels kwargs contain labels that are not
specified in schema.json or if the label values fail the
validator set for the corresponding label in schema.json.
"""
self._set_state(**labels)
[docs] def clear_state(self):
"""
Reset the state of the `Parameters` instance.
"""
self._state = {}
self.label_grid = copy.deepcopy(self._stateless_label_grid)
self.set_state()
[docs] def view_state(self):
"""
Access the label state of the ``Parameters`` instance.
"""
return {label: value for label, value in self._state.items()}
[docs] def read_params(
self,
params_or_path: FileDictStringLike,
storage_options: Optional[Dict[str, Any]] = None,
):
"""
Read JSON data of the form:
- Python `dict`.
- JSON string.
- Local file path.
- Any URL readable by fsspec. For example:
- s3: `s3://paramtools-test/defaults.json`
- gcs: `gs://paramtools-dev/defaults.json`
- http: `https://somedomain.com/defaults.json`
- github: `github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json`
**Returns**
- `params`: Python Dict created from JSON file.
"""
return utils.read_json(params_or_path, storage_options)
[docs] def adjust(
self,
params_or_path: Union[str, Mapping[str, List[ValueObject]]],
ignore_warnings: bool = False,
raise_errors: bool = True,
extend_adj: bool = True,
clobber: bool = True,
):
"""
Deserialize and validate parameter adjustments. `params_or_path`
can be a file path or a `dict` that has not been fully deserialized.
The adjusted values replace the current values stored in the
corresponding parameter attributes.
If `clobber` is `True` and extend mode is on, then all future values
for a given parameter be replaced by the values in the adjustment.
If `clobber` is `False` and extend mode is on, then user-defined values
will not be replaced by values in this adjustment. Only values that
were added automatically via the extend method will be updated.
This simply calls a private method `_adjust` to do the upate. Creating
this layer on top of `_adjust` makes it easy to subclass `Parameters` and
implement custom `adjust` methods.
**Parameters**
- `params_or_path`: Adjustment that is either a `dict`, file path, or
JSON string.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
- `extend_adj`: If in extend mode, this is a flag indicating whether to
extend the adjustment values or not.
- `clobber`: If in extend mode, this is a flag indicating whether to
override all values, including user-defined values, or to only
override automatically created values.
**Returns**
- `params`: Parsed, validated parameters.
**Raises**
- `marshmallow.exceptions.ValidationError` if data is not valid.
- `ParameterUpdateException` if label values do not match at
least one existing value item's corresponding label values.
"""
return self._adjust(
params_or_path,
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
extend_adj=extend_adj,
clobber=clobber,
)
def _adjust(
self,
params_or_path,
ignore_warnings=False,
raise_errors=True,
extend_adj=True,
deserialized=False,
validate=True,
clobber=True,
):
"""
Internal method for performing adjustments.
"""
# Validate user adjustments.
if deserialized:
parsed_params = {}
try:
parsed_params = self._validator_schema.load(
params_or_path, ignore_warnings, deserialized=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params_or_path)
else:
params = self.read_params(params_or_path)
parsed_params = {}
try:
parsed_params = self._validator_schema.load(
params, ignore_warnings
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params)
if not self._errors:
if self.label_to_extend is not None and extend_adj:
extend_grid = self._stateless_label_grid[self.label_to_extend]
to_delete = defaultdict(list)
backup = {}
for param, vos in parsed_params.items():
for vo in utils.grid_sort(
vos, self.label_to_extend, extend_grid
):
if self.label_to_extend in vo:
if clobber:
queryset = self.sel[param]
else:
queryset = self.sel[param]["_auto"] == True
queryset &= queryset.gt(
strict=False,
**{
self.label_to_extend: vo[
self.label_to_extend
]
},
)
other_labels = utils.filter_labels(
vo,
drop=[self.label_to_extend, "value", "_auto"],
)
if other_labels:
queryset &= intersection(
queryset.eq(strict=False, **{label: value})
for label, value in other_labels.items()
)
to_delete[param] += list(queryset)
# make copy of value objects since they
# are about to be modified
backup[param] = copy.deepcopy(self._data[param]["value"])
try:
array_first = self.array_first
self.array_first = False
# delete params that will be overwritten out by extend.
self.delete(
to_delete,
extend_adj=False,
raise_errors=True,
ignore_warnings=ignore_warnings,
)
# set user adjustments.
self._adjust(
parsed_params,
extend_adj=False,
raise_errors=True,
ignore_warnings=ignore_warnings,
)
self.extend(
params=parsed_params.keys(),
ignore_warnings=ignore_warnings,
raise_errors=True,
)
except ValidationError:
for param in backup:
self._data[param]["value"] = backup[param]
finally:
self.array_first = array_first
else:
for param, value in parsed_params.items():
self._update_param(param, value)
self._validator_schema.context["spec"] = self
has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
# throw error if raise_errors is True or ignore_warnings is False
if (raise_errors and has_errors) or (
not ignore_warnings and has_warnings
):
raise self.validation_error
# Update attrs for params that were adjusted.
self._set_state(params=parsed_params.keys())
return parsed_params
[docs] @contextmanager
def transaction(
self, defer_validation=True, raise_errors=False, ignore_warnings=False
):
"""
Rollback any changes to parameter state after the context block closes.
.. code-block:: Python
import paramtools
class Params(paramtools.Parameters):
defaults = {
"min_param": {
"title": "Min param",
"description": "Must be less than 'max_param'",
"type": "int",
"value": 2,
"validators": {
"range": {"max": "max_param"}
}
},
"max_param": {
"title": "Max param",
"type": "int",
"value": 3
}
}
params = Params()
with params.transaction():
params.adjust({"min_param": 4})
params.adjust({"max_param": 5})
**Parameters:**
- `defer_validation`: Defer schema-level validation until the end of the block.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
"""
_data = copy.deepcopy(self._data)
_ops = dict(self.operators)
_state = dict(self.view_state())
try:
self._defer_validation = defer_validation
yield self
except Exception as e:
self._data = _data
raise e
finally:
self._state = _state
self._ops = _ops
self._defer_validation = False
if defer_validation:
self.validate(
self.specification(use_state=False, meta_data=False),
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
)
[docs] def validate(self, params, raise_errors=True, ignore_warnings=False):
"""
Validate parameter adjustment without modifying existing values.
For example, validate the current parameter values:
.. code-block:: Python
params.validate(
params.specification(use_state=False)
)
**Parameters:**
- `params`: Parameters to validate.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
"""
try:
self._validator_schema.load(
params, ignore_warnings, deserialized=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params)
has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
if (raise_errors and has_errors) or (
not ignore_warnings and has_warnings
):
raise self.validation_error
def delete(
self,
params_or_path,
ignore_warnings=False,
raise_errors=True,
extend_adj=True,
):
"""
Delete value objects in params_or_path.
Returns: adjustment for deleting parameters.
Raises:
marshmallow.exceptions.ValidationError if data is not valid.
ParameterUpdateException if label values do not match at
least one existing value item's corresponding label values.
"""
return self._delete(
params_or_path,
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
extend_adj=extend_adj,
)
def _delete(
self,
params_or_path,
ignore_warnings=False,
raise_errors=True,
extend_adj=True,
):
"""
Internal method that sets the 'value' member for all value objects
to None. Value objects with 'value' set to None are deleted.
"""
params = self.read_params(params_or_path)
# Validate user adjustments.
parsed_params = {}
try:
parsed_params = self._validator_schema.load(
params, ignore_warnings=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params)
to_delete = {}
for param, vos in parsed_params.items():
to_delete[param] = [dict(vo, **{"value": None}) for vo in vos]
self._update_param(param, to_delete[param])
if self.label_to_extend is not None and extend_adj:
self.extend()
self._validator_schema.context["spec"] = self
has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
# throw error if raise_errors is True or ignore_warnings is False
if (raise_errors and has_errors) or (
not ignore_warnings and has_warnings
):
raise self.validation_error
# Update attrs for params that were adjusted.
self._set_state(params=to_delete.keys())
return to_delete
@property
def errors(self):
if not self._errors:
return {}
return {
param: utils.ravel(messages)
for param, messages in self._errors["messages"].items()
}
@property
def warnings(self):
if not self._warnings:
return {}
return {
param: utils.ravel(messages)
for param, messages in self._warnings["messages"].items()
}
@property
def validation_error(self):
messages = {
"errors": self._errors.get("messages", {}),
"warnings": self._warnings.get("messages", {}),
}
labels = {
"errors": self._errors.get("labels", {}),
"warnings": self._warnings.get("labels", {}),
}
return ValidationError(messages=messages, labels=labels)
@property
def schema(self):
pre = dict(self._schema)
pre["operators"] = self.operators
return ParamToolsSchema().dump(pre)
@property
def operators(self):
return {
"array_first": self.array_first,
"label_to_extend": self.label_to_extend,
"uses_extend_func": self.uses_extend_func,
}
def dump(self, sort_values: bool = True, use_state: bool = True):
"""
Dump a representation of this instance to JSON. This makes it
possible to load this instance's data after sending the data
across the wire or from another programming language. The
dumped values will be queried using this instance's state.
"""
spec = self.specification(
meta_data=True,
include_empty=True,
serializable=True,
sort_values=sort_values,
use_state=use_state,
)
result = {"schema": self.schema}
result.update(spec)
return result
[docs] def specification(
self,
use_state: bool = True,
meta_data: bool = False,
include_empty: bool = False,
serializable: bool = False,
sort_values: bool = False,
**labels,
):
"""
Query value(s) of all parameters along labels specified in
`labels`.
**Parameters**
- `use_state`: Use the instance's state for the select operation.
- `meta_data`: Include information like the parameter
`description` and title.
- `include_empty`: Include parameters that do not meet the label query.
- `serializable`: Return data that is compatible with `json.dumps`.
- `sort_values`: Sort values by the `label` order.
**Returns**
- `dict` of parameter names and data.
"""
if use_state:
labels.update(self._state)
all_params = OrderedDict()
for param in self._validator_schema.fields:
result = self.select_eq(param, False, **labels)
if sort_values and result:
result = self.sort_values(
data={param: result}, has_meta_data=False
)[param]
if result or include_empty:
if meta_data:
param_data = self._data[param]
result = dict(param_data, **{"value": result})
# Add "value" key to match marshmallow schema format.
elif serializable:
result = {"value": result}
all_params[param] = result
if serializable:
ser = self._defaults_schema.dump(all_params)
# Unpack the values after serialization if meta_data not specified.
if not meta_data:
ser = {param: value["value"] for param, value in ser.items()}
return ser
else:
return all_params
[docs] def to_array(self, param, **labels):
"""
Convert a Value object to an n-labelal array. The list of Value
objects must span the specified parameter space. The parameter space
is defined by inspecting the label validators in schema.json
and the state attribute of the Parameters instance.
**Parameters**
- `param`: Name of parameter that will be used to create array.
- `labels`: Optionally, override instance state.
**Returns**
- `arr`: NumPy array created from list of value objects.
**Raises**
- `InconsistentLabelsException`: Value objects do not have consistent
labels.
- `SparseValueObjectsException`: Value object does not span the
entire space specified by the Order object.
- `ParamToolsError`: Parameter is an array type and has labels.
This is not supported by ParamTools when using array_first.
"""
label_grid = copy.deepcopy(self.label_grid)
state = copy.deepcopy(self._state)
if labels:
parsed_labels = self.parse_labels(**labels)
label_grid.update(parsed_labels)
state.update(parsed_labels)
if state:
value_items = list(
intersection(
self.sel[param].isin(strict=False, **{label: values})
for label, values in state.items()
)
)
else:
value_items = list(self.sel[param])
if not value_items:
return np.array([])
label_order, value_order = self._resolve_order(
param, value_items, label_grid
)
shape = []
for label in label_order:
shape.append(len(value_order[label]))
shape = tuple(shape)
# Compare len value items with the expected length if they are full.
# In the futute, sparse objects should be supported by filling in the
# unspecified labels.
number_dims = self._data[param].get("number_dims", 0)
if not shape and number_dims > 0:
return np.array(
value_items[0]["value"], dtype=self._numpy_type(param)
)
elif shape and number_dims > 0:
raise ParamToolsError(
f"\nParameter '{param}' is an array parameter with {number_dims} dimension(s) and "
f"has labels: {', '.join(label_order)}.\n\nParamTools does not "
f"support the use of 'array_first' with array parameters that use labels. "
f"\nYou may be able to describe this parameter's values with additional "
f"labels\nand the 'label_to_extend' operator."
)
elif not shape and number_dims == 0:
data_type = self._numpy_type(param)
value = value_items[0]["value"]
if data_type == object:
return value
else:
return data_type(value)
exp_full_shape = functools.reduce(lambda x, y: x * y, shape)
act_full_shape = len(value_items)
if act_full_shape != exp_full_shape:
# maintains label value order over value objects.
exp_grid = list(itertools.product(*value_order.values()))
# preserve label value order for each value object by
# iterating over label_order.
actual = list(
[tuple(vo[d] for d in label_order) for vo in value_items]
)
missing = "\n\t".join(
[str(d) for d in exp_grid if d not in actual]
)
counter = defaultdict(int)
extra = []
duplicates = []
for comb in actual:
counter[comb] += 1
if counter[comb] > 1:
duplicates.append((comb, counter[comb]))
if comb not in exp_grid:
extra.append(comb)
msg = ""
if missing:
msg += f"Missing combinations:\n\t{missing}"
if extra:
msg += f"Extra combinations:\n\t{extra}"
if duplicates:
msg += f"Duplicate combinations:\n\t{duplicates}"
raise SparseValueObjectsException(
f"The Value objects for {param} do not span the specified "
f"parameter space. {msg}"
)
def list_2_tuple(x):
return tuple(x) if isinstance(x, list) else x
arr = np.empty(shape, dtype=self._numpy_type(param))
for vi in value_items:
# ix stores the indices of `arr` that need to be filled in.
ix = [[] for i in range(len(label_order))]
for label_pos, label_name in enumerate(label_order):
# assume value_items is dense in the sense that it spans
# the label space.
ix[label_pos].append(
value_order[label_name].index(vi[label_name])
)
ix = tuple(map(list_2_tuple, ix))
arr[ix] = vi["value"]
return arr
[docs] def from_array(self, param, array=None, **labels):
"""
Convert NumPy array to a Value object.
**Parameters**
- `param`: Name of parameter to convert to a list of value objects.
- `array`: Optionally, provide a NumPy array to convert into a list
of value objects. If not specified, the value at `self.param` will
be used.
- `labels`: Optionally, override instance state.
**Returns**
- List of `ValueObjects`
**Raises**
- `InconsistentLabelsException`: Value objects do not have consistent
labels.
"""
if array is None:
array = getattr(self, param)
if not isinstance(array, np.ndarray):
raise TypeError(
"A NumPy Ndarray should be passed to this method "
"or the instance attribute should be an array."
)
label_grid = copy.deepcopy(self.label_grid)
state = copy.deepcopy(self._state)
if labels:
parsed_labels = self.parse_labels(**labels)
label_grid.update(parsed_labels)
state.update(parsed_labels)
if state:
value_items = list(
intersection(
self.sel[param].isin(strict=False, **{label: value})
for label, value in state.items()
)
)
else:
value_items = list(self.sel[param])
label_order, value_order = self._resolve_order(
param, value_items, label_grid
)
label_values = itertools.product(*value_order.values())
label_indices = itertools.product(
*map(lambda x: range(len(x)), value_order.values())
)
value_items = []
for dv, di in zip(label_values, label_indices):
vi = {label_order[j]: dv[j] for j in range(len(dv))}
vi["value"] = array[di]
value_items.append(vi)
return value_items
[docs] def extend(
self,
label: Optional[str] = None,
label_values: Optional[List[Any]] = None,
params: Optional[List[str]] = None,
raise_errors: bool = True,
ignore_warnings: bool = False,
):
"""
Extend parameters along `label`.
**Parameters**
- `label`: Label to extend values along. By default, `label_to_extend`
is used.
- `label_values`: values of `label` to extend. By default, this is a grid
created from the valid values of `label_to_extend`.
- `params`: Parameters to extend. By default, all parameters are extended.
- `raise_errors`: Whether `adjust` should raise or store errors.
- `ignore_warnings`: Whether `adjust` should raise or ignore warnings.
**Raises**
- `InconsistentLabelsException`: Value objects do not have consistent
labels.
"""
if label is None:
label = self.label_to_extend
else:
label = label
spec = self.specification(meta_data=True)
if params is not None:
spec = {
param: self._data[param]
for param, data in spec.items()
if param in params
}
full_extend_grid = self._stateless_label_grid[label]
if label_values is not None:
labels = self.parse_labels(**{label: label_values})
extend_grid = labels[label]
else:
extend_grid = self._stateless_label_grid[label]
cmp_funcs = self.label_validators[label].cmp_funcs(choices=extend_grid)
adjustment = defaultdict(list)
for param, data in spec.items():
if not any(label in vo for vo in data["value"]):
continue
extended_vos = set()
for vo in sorted(
data["value"], key=lambda val: cmp_funcs["key"](val[label])
):
hashable_vo = utils.hashable_value_object(vo)
if hashable_vo in extended_vos:
continue
else:
extended_vos.add(hashable_vo)
queryset = self.sel[param].gt(
strict=False, **{label: vo[label]}
)
other_labels = utils.filter_labels(
vo, drop=["value", label, "_auto"]
)
if other_labels:
queryset &= intersection(
queryset.eq(strict=False, **{oth_label: value})
for oth_label, value in other_labels.items()
)
extended_vos.update(
map(utils.hashable_value_object, list(queryset))
)
values = queryset.as_values().add(values=[vo])
defined_vals = {eq_vo[label] for eq_vo in queryset}
missing_vals = sorted(
set(extend_grid) - defined_vals, key=cmp_funcs["key"]
)
if not missing_vals:
continue
extended = defaultdict(list)
for vo in values:
extended[vo[label]].append(vo)
skl = SortedKeyList(extended.keys(), cmp_funcs["key"])
for val in missing_vals:
lte_val = skl.lte(val)
if lte_val is not None:
closest_val = lte_val.values[-1]
else:
closest_val = skl.gte(val).values[0]
if closest_val in extended:
value_objects = extended.pop(closest_val)
else:
value_objects = values.eq(
strict=False, **{label: closest_val}
)
# In practice, value_objects has length one.
# Theoretically, there could be multiple if the inital value
# object had less labels than later value objects and thus
# matched multiple value objects.
for value_object in value_objects:
ext = dict(value_object, **{label: val})
ext = self.extend_func(
param, ext, value_object, full_extend_grid, label
)
extended_vos.add(
utils.hashable_value_object(value_object)
)
extended[val].append(ext)
skl.add(val)
adjustment[param].append(OrderedDict(ext, _auto=True))
# Ensure that the adjust method of paramtools.Parameters is used
# in case the child class also implements adjust.
return self._adjust(
adjustment,
extend_adj=False,
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
deserialized=True,
)
[docs] def extend_func(
self,
param: str,
extend_vo: ValueObject,
known_vo: ValueObject,
extend_grid: List,
label: str,
):
"""
Function for applying indexing rates to parameter values as they
are extended. Projects may implement their own `extend_func` by
overriding this one. Projects need to write their own `indexing_rate`
method for returning the correct indexing rate for a given parameter
and value of `label`.
**Returns**
- `extend_vo`: New `ValueObject`.
"""
if not self.uses_extend_func or not self._data[param].get(
"indexed", False
):
return extend_vo
known_val = known_vo[label]
known_ix = extend_grid.index(known_val)
toext_val = extend_vo[label]
toext_ix = extend_grid.index(toext_val)
if toext_ix > known_ix:
# grow value according to the index rate supplied by the user defined
# self.indexing_rate method.
for ix in range(known_ix, toext_ix):
v = extend_vo["value"] * (
1 + self.get_index_rate(param, extend_grid[ix])
)
extend_vo["value"] = np.round(v, 2) if v < 9e99 else 9e99
else:
# shrink value according to the index rate supplied by the user defined
# self.indexing_rate method.
for ix in reversed(range(toext_ix, known_ix)):
v = (
extend_vo["value"]
* (1 + self.get_index_rate(param, extend_grid[ix])) ** -1
)
extend_vo["value"] = np.round(v, 2) if v < 9e99 else 9e99
return extend_vo
def get_index_rate(self, param: str, lte_val: Any):
"""
Return the value of the index_rates dictionary matching the
label to extend value, `lte_val`.
Projects may find it convenient to override this method with their own
`index_rate` method.
"""
return self.index_rates[lte_val]
[docs] def parse_labels(self, **labels):
"""
Parse and validate labels.
**Returns**
- Parsed and validated labels.
"""
parsed = defaultdict(list)
messages = {}
for name, values in labels.items():
if name not in self.label_validators:
messages[name] = f"{name} is not a valid label."
continue
if not isinstance(values, list):
list_values = [values]
else:
list_values = values
assert isinstance(list_values, list)
for value in list_values:
try:
parsed[name].append(
self.label_validators[name].deserialize(value)
)
except MarshmallowValidationError as ve:
messages[name] = str(ve)
if messages:
raise ValidationError({"errors": messages}, labels=None)
return parsed
def _set_state(self, params=None, **labels):
"""
Private method for setting the state on a Parameters instance. Internal
methods can set which params will be updated. This is helpful when a set
of parameters are adjusted and only their attributes need to be updated.
"""
labels = self.parse_labels(**labels)
self._state.update(labels)
for label_name, label_value in self._state.items():
assert isinstance(label_value, list)
self.label_grid[label_name] = label_value
spec = self.specification(include_empty=True, **self._state)
if params is not None:
spec = {param: spec[param] for param in params}
for name, value in spec.items():
self.sel._cache.pop(name, None)
if name in collision_list:
raise ParameterNameCollisionException(
f"The paramter name, '{name}', is already used by the Parameters object."
)
if self.array_first:
setattr(self, name, self.to_array(name))
else:
setattr(self, name, value)
def _resolve_order(self, param, value_items, label_grid):
"""
Resolve the order of the labels and their values by
inspecting data in the label grid values.
The labels to be used are the ones that are specified
for each value object. Note that the labels must be specified
_consistently_ for all value objects, i.e. none can be added or omitted
for any value object in the list.
**Returns**
- `label_order`: The label order.
- `value_order`: The values, in order, for each label.
**Raises**
- `InconsistentLabelsException`: Value objects do not have consistent
labels.
"""
used = utils.consistent_labels(value_items)
if used is None:
raise InconsistentLabelsException(
"Labels were added or omitted for some value object(s)."
)
label_order, value_order = [], {}
for label_name, label_values in label_grid.items():
if label_name in used:
label_order.append(label_name)
value_order[label_name] = label_values
return label_order, value_order
def _numpy_type(self, param):
"""
Get the numpy type for a given parameter.
"""
return (
self._validator_schema.fields[param].schema.fields["value"].np_type
)
def _select(self, param, op, strict, **labels):
if "exact_match" in labels:
warnings.warn(
"'exact_match' has been deprecated in favor of 'strict'."
)
strict = labels.pop("exact_match")
res = self.sel[param]
for label, value in labels.items():
if isinstance(value, list):
res &= union(
self.sel[param]._cmp(op, strict, **{label: element})
for element in value
)
else:
res &= self.sel[param]._cmp(op, strict, **{label: value})
return list(res)
def select_eq(self, param, strict=True, **labels):
return self._select(param, "eq", strict, **labels)
def select_ne(self, param, strict=True, **labels):
return self._select(param, "ne", strict, **labels)
def select_gt(self, param, strict=True, **labels):
return self._select(param, "gt", strict, **labels)
def select_gte(self, param, strict=True, **labels):
return self._select(param, "gte", strict, **labels)
def select_lt(self, param, strict=True, **labels):
return self._select(param, "lt", strict, **labels)
def select_lte(self, param, strict=True, **labels):
return self._select(param, "lte", strict, **labels)
def _update_param(self, param, new_values):
"""
Update the current parameter values with those specified by
the adjustment. The values that need to be updated are chosen
by finding all value items with label values matching the
label values specified in the adjustment. If the value is
set to None, then that value object will be removed.
Note: _update_param used to raise a ParameterUpdateException if one of the new
values did not match at least one of the current value objects. However,
this was dropped to better support the case where the parameters are being
extended along some label to fill the parameter space. An exception could
be raised if a new value object contains a label that is not used in the
current value objects for the parameter. However, it seems like it could be
expensive to check this case, especially when a project is extending parameters.
For now, no exceptions are raised by this method.
"""
param_values = self.sel[param]
if len(list(param_values)) == 0:
self._data[param]["value"] = new_values
return
for new_vo in new_values:
labels = utils.filter_labels(new_vo, drop=["value"])
if not labels:
if new_vo["value"] is not None:
for curr_vo in self._data[param]["value"]:
curr_vo["value"] = new_vo["value"]
else:
param_values.delete(inplace=True)
continue
to_update = intersection(
param_values.eq(strict=True, **{label: value})
for label, value in labels.items()
if label in param_values.labels and label != "_auto"
)
if len(list(to_update)) > 0:
if new_vo["value"] is None:
to_update.delete()
else:
for curr_vo in to_update:
curr_vo["value"] = new_vo["value"]
if new_vo.get("_auto") is None:
curr_vo.pop("_auto", None)
else:
if new_vo["value"] is not None:
param_values.add([new_vo], inplace=True)
self.sel._cache[param] = param_values
self._data[param]["value"][:] = list(param_values)
def _parse_validation_messages(self, messages, params):
"""Parse validation messages from marshmallow"""
if messages.get("warnings"):
self._warnings.update(
self._parse_errors(messages.pop("warnings"), params)
)
self._errors.update(self._parse_errors(messages, params))
def _parse_errors(self, messages, params):
"""
Parse the error messages given by marshmallow.
Marshamllow error structure:
{
"list_param": {
0: {
"value": {
0: [err message for first item in value list]
i: [err message for i-th item in value list]
}
},
i-th value object: {
"value": {
0: [...],
...
}
},
}
"nonlist_param": {
0: {
"value": [err message]
},
...
}
}
self._errors structure:
{
"messages": {
"param": [
["value": {0: [msg0, msg1, ...], other_bad_ix: ...},
"label0": {0: msg, ...} // if errors on label values.
],
...
},
"label": {
"param": [
{label_name: label_value, other_label_name: other_label_value},
...
// list indices correspond to the error messages' indices
// of the error messages caused by the value of this value
// object.
]
}
}
"""
error_info = {
"messages": defaultdict(dict),
"labels": defaultdict(dict),
}
for pname, data in messages.items():
if pname == "_schema":
error_info["messages"]["schema"] = [
f"Data format error: {data}"
]
continue
if data == ["Unknown field."]:
error_info["messages"]["schema"] = [f"Unknown field: {pname}"]
continue
param_data = utils.ensure_value_object(params[pname])
error_labels = []
formatted_errors = []
for ix, marshmessages in data.items():
error_labels.append(
utils.filter_labels(param_data[ix], drop=["value"])
)
formatted_errors_ix = []
for _, messages in marshmessages.items():
if messages:
if isinstance(messages, list):
formatted_errors_ix += messages
else:
for _, messagelist in messages.items():
formatted_errors_ix += messagelist
formatted_errors.append(formatted_errors_ix)
error_info["messages"][pname] = formatted_errors
error_info["labels"][pname] = error_labels
return error_info
def __iter__(self):
return iter(self._data)
def keys(self):
"""
Return parameter names.
"""
return self._data.keys()
def items(self):
"""
Iterate using python dictionary .items() syntax.
"""
for param in self:
yield param, getattr(self, param)
return
def to_dict(self):
"""
Return instance as python dictionary.
"""
return dict(self.items())
[docs] def sort_values(self, data=None, has_meta_data=True):
"""
Sort value objects for all parameters in `data` according
to the order specified in `schema`.
**Parameters**
- `data`: Parameter data to be sorted. This should be a
`dict` of parameter names and values. If `data` is `None`,
the current values will be sorted.
- `has_meta_data`: Whether parameter values should be accessed
directly or through the "value" attribute.
**Returns**
- Sorted data.
"""
def keyfunc(vo, label, label_values):
if label in vo and label_values:
return label_values.index(vo[label])
else:
return -1
if data is None:
data = self._data
update_attrs = True
if not has_meta_data:
raise ParamToolsError(
"has_meta_data must be True if data is not specified."
)
else:
update_attrs = False
# nothing to do if labels aren't specified
if not self._stateless_label_grid:
return data
# iterate over labels so that the first label's order
# takes precedence.
label_grid = self._stateless_label_grid
for param in data:
for label in reversed(label_grid):
label_values = label_grid[label]
pfunc = functools.partial(
keyfunc, label=label, label_values=label_values
)
if has_meta_data:
data[param]["value"] = sorted(
data[param]["value"], key=pfunc
)
else:
data[param] = sorted(data[param], key=pfunc)
# Only update attributes when array first is off, since
# value order will not affect how arrays are constructed.
if update_attrs and not self.array_first:
self.sel._cache.pop(param, None)
if self._state:
attr_vals = self.sel[param]
active = intersection(
attr_vals[label].isin(value)
for label, value in self._state.items()
if label in attr_vals.labels
)
else:
active = data[param]["value"]
sorted_values = self.sort_values(
{param: list(active)}, has_meta_data=False
)[param]
setattr(self, param, sorted_values)
return data