import copy
from collections import defaultdict
from typing import List, Dict, Any, Union, Generator
from paramtools.sorted_key_list import SortedKeyList
from paramtools.typing import ValueObject
def default_cmp_func(x):
return x
class ValueItem:
"""
Handles index-based look-ups on the Values class.
"""
def __init__(self, values: "Values", index: List[int] = None):
self.values = values
self.index = list(index) if index is not None else index
def __getitem__(self, item):
if isinstance(item, slice):
if self.index is not None:
indices = item.indices(len(self.index))
return [
dict(self.values.values[self.index[ix]])
for ix in range(*indices)
]
else:
indices = item.indices(len(self.values))
return [dict(self.values.values[ix]) for ix in range(*indices)]
elif self.index is not None:
return dict(self.values.values[self.index[item]])
else:
return dict(self.values.values[item])
class ValueBase:
@property
def cmp_attr(self):
raise NotImplementedError()
def __eq__(self, value=None, **labels):
return self.cmp_attr.eq(**{self.label: value})
def __ne__(self, value):
return self.cmp_attr.ne(**{self.label: value})
def __gt__(self, value):
return self.cmp_attr.gt(**{self.label: value})
def __ge__(self, value):
return self.cmp_attr.gte(**{self.label: value})
def __lt__(self, value):
return self.cmp_attr.lt(**{self.label: value})
def __le__(self, value):
return self.cmp_attr.lte(**{self.label: value})
def __len__(self):
return len([item for item in iter(self.cmp_attr)])
def __iter__(self):
return iter(self.cmp_attr)
def __getitem__(self, item):
return self.cmp_attr[item]
def eq(self, value, strict=True):
return self.cmp_attr.eq(strict, **{self.label: value})
def ne(self, value, strict=True):
return self.cmp_attr.ne(strict, **{self.label: value})
def gt(self, value, strict=True):
return self.cmp_attr.gt(strict, **{self.label: value})
def gte(self, value, strict=True):
return self.cmp_attr.gte(strict, **{self.label: value})
def lt(self, value, strict=True):
return self.cmp_attr.lt(strict, **{self.label: value})
def lte(self, value, strict=True):
return self.cmp_attr.lte(strict, **{self.label: value})
def isin(self, value, strict=True):
return self.cmp_attr.isin(strict, **{self.label: value})
class QueryResult(ValueBase):
def __init__(self, values: "Values", index: List[Any]):
self.values = values
self.index = index
def __and__(self, queryresult: "QueryResult"):
res = set(self.index) & set(queryresult.index)
return QueryResult(self.values, res)
def __or__(self, queryresult: "QueryResult"):
res = set(self.index) | set(queryresult.index)
return QueryResult(self.values, res)
def __repr__(self):
vo_repr = "\n ".join(
str(dict(self.values.values[i])) for i in (self.index or [])
)
return f"QueryResult([\n {vo_repr}\n])"
def __iter__(self):
for i in self.index:
yield self.values.values[i]
def __getitem__(self, item):
raise NotImplementedError(
"Use .isel to do index-based look ups or as_values to chain queries."
)
@property
def isel(self):
return ValueItem(self.values, self.index)
def tolist(self):
return [self.values.values[i] for i in self.index]
def eq(self, strict=True, **labels):
return self.cmp_attr.eq(strict, **labels)
def ne(self, strict=True, **labels):
return self.cmp_attr.ne(strict, **labels)
def gt(self, strict=True, **labels):
return self.cmp_attr.gt(strict, **labels)
def gte(self, strict=True, **labels):
return self.cmp_attr.gte(strict, **labels)
def lt(self, strict=True, **labels):
return self.cmp_attr.lt(strict, **labels)
def lte(self, strict=True, **labels):
return self.cmp_attr.lte(strict, **labels)
def isin(self, strict=True, **labels):
return self.cmp_attr.isin(strict, **labels)
def __eq__(self, *args, **kwargs):
raise NotImplementedError()
def __ne__(self, *args, **kwargs):
raise NotImplementedError()
def __gt__(self, *args, **kwargs):
raise NotImplementedError()
def __ge__(self, *args, **kwargs):
raise NotImplementedError()
def __lt__(self, *args, **kwargs):
raise NotImplementedError()
def __le__(self, *args, **kwargs):
raise NotImplementedError()
def as_values(self):
return Values(
values=list(self), index=self.index, keyfuncs=self.values.keyfuncs
)
def delete(self):
self.values.delete(*self.index, inplace=True)
@property
def cmp_attr(self):
return self.values
class Slice(ValueBase):
def __init__(self, values: "Values", label: str):
self.values = values
self.label = label
@property
def cmp_attr(self):
return self.values
def __getitem__(self, item):
if isinstance(item, slice):
indices = item.indices(len(self))
return [
self.values.values[ix].get(self.label, None)
for ix in range(*indices)
]
else:
return self.values.values[item][self.label]
@property
def isel(self):
raise NotImplementedError(
"Access values of a Slice object directly: parameters['label'][1]"
)
def __repr__(self):
vo_repr = "\n ".join(
str(dict(self.values.values[i])) for i in self.values.values
)
return f"Slice([\n {vo_repr}\n], \nlabel={self.label})"
[docs]class Values(ValueBase):
"""
The Values class is used to query and update parameter values.
For more information, checkout the
`Viewing Data <https://paramtools.dev/api/viewing-data.html>`_ docs.
"""
def __init__(
self,
values: List[ValueObject],
keyfuncs: Dict[str, Any] = None,
skls: Dict[str, SortedKeyList] = None,
index: List[Any] = None,
):
self.index = index or list(range(len(values)))
self.values = {ix: value for ix, value in zip(self.index, values)}
self.keyfuncs = keyfuncs
self.label = "value"
if skls is not None:
self.skls = skls
else:
self.skls = self.build_skls(self.values, keyfuncs or {})
def build_skls(self, values, keyfuncs):
label_values = defaultdict(list)
label_index = defaultdict(list)
for ix, vo in values.items():
for label, value in vo.items():
label_values[label].append(value)
label_index[label].append(ix)
skls = {}
for label in label_values:
keyfunc = self.get_keyfunc(label, keyfuncs)
skls[label] = SortedKeyList(
label_values[label], keyfunc, label_index[label]
)
return skls
def update_skls(self, values):
# TODO: remove existing values with clashing index
for ix, vo in values.items():
for label, value in vo.items():
if self.skls.get(label, None) is not None:
self.skls[label].add(value, index=ix)
else:
self.skls[label] = SortedKeyList(
[value],
keyfunc=self.get_keyfunc(label, self.keyfuncs),
index=[ix],
)
def get_keyfunc(self, label, keyfuncs):
keyfunc = keyfuncs.get(label)
return keyfunc or default_cmp_func
def _cmp(self, op, strict, **labels):
label, value = list(labels.items())[0]
skl = self.skls.get(label, None)
if skl is None and strict:
raise KeyError(f"Unknown label: {label}.")
elif skl is None and not strict:
return QueryResult(self, list(self.index))
skl_result = getattr(self.skls[label], op)(value)
if not strict:
match_index = skl_result.index if skl_result else []
missing = self.missing(label)
match_index = set(match_index + missing.index)
elif skl_result is None:
match_index = []
else:
match_index = skl_result.index
return QueryResult(self, match_index)
def __getitem__(self, label):
if label not in self.skls:
raise KeyError(f"Unknown label: {label}")
return Slice(self, label)
def missing(self, label: str):
index = list(set(self.index) - self.skls[label].index)
return QueryResult(self, index)
[docs] def eq(self, strict=True, **labels):
"""
Returns values that match the given label:
.. code-block:: Python
params.sel["my_param"].eq(my_label=5)
params.sel["my_param"]["my_label"] == 5
"""
return self._cmp("eq", strict, **labels)
[docs] def ne(self, strict=True, **labels):
"""
Returns values that do match the given label:
.. code-block:: Python
params.sel["my_param"].ne(my_label=5)
params.sel["my_param"]["my_label"] != 5
"""
return self._cmp("ne", strict, **labels)
[docs] def gt(self, strict=True, **labels):
"""
Returns values that have label values greater than the label value:
.. code-block:: Python
params.sel["my_param"].gt(my_label=5)
params.sel["my_param"]["my_label"] > 5
"""
return self._cmp("gt", strict, **labels)
def gte(self, strict=True, **labels):
"""
Returns values that have label values greater than or equal to the label value:
.. code-block:: Python
params.sel["my_param"].gte(my_label=5)
params.sel["my_param"]["my_label"] >= 5
"""
return self._cmp("gte", strict, **labels)
[docs] def lt(self, strict=True, **labels):
"""
Returns values that have label values less than the label value:
.. code-block:: Python
params.sel["my_param"].lt(my_label=5)
params.sel["my_param"]["my_label"] < 5
"""
return self._cmp("lt", strict, **labels)
[docs] def lte(self, strict=True, **labels):
"""
Returns values that have label values less than or equal to the label value:
.. code-block:: Python
params.sel["my_param"].lte(my_label=5)
params.sel["my_param"]["my_label"] <= 5
"""
return self._cmp("lte", strict, **labels)
[docs] def isin(self, strict=True, **labels):
"""
Returns values that have label values less than or equal to the label value:
.. code-block:: Python
params.sel["my_param"].isin(my_label=[5, 6])
"""
label, values = list(labels.items())[0]
return union(
self.eq(strict=strict, **{label: value}) for value in values
)
def add(
self, values: List[ValueObject], index: List[Any] = None, inplace=False
):
if index is not None:
assert len(index) == len(values)
new_index = index
else:
max_index = max(self.index) if self.index else 0
new_index = [max_index + i + 1 for i in range(len(values))]
new_values = {ix: value for ix, value in zip(new_index, values)}
if inplace:
self.update_skls(new_values)
self.values.update(new_values)
self.index += new_index
else:
current_index = list(self.index)
updated_values = dict(self.values)
updated_values.update(new_values)
return Values(
[value for value in updated_values.values()],
skls=self.build_skls(updated_values, self.keyfuncs),
index=current_index + new_index,
)
def delete(self, *index, inplace=False):
if not index:
index = list(self.index)
if inplace:
for ix in index:
self.values.pop(ix)
self.index.remove(ix)
self.skls = self.build_skls(self.values, self.keyfuncs)
else:
new_index = list(self.index)
new_values = copy.deepcopy(self.values)
for ix in index:
new_values.pop(ix)
new_index.remove(ix)
return Values(
[value for value in new_values.values()],
keyfuncs=self.keyfuncs,
index=new_index,
)
@property
def cmp_attr(self):
return self
@property
def isel(self):
"""
Select values by their index:
.. code-block:: Python
params.sel["my_param"].isel[0]
params.sel["my_param"].isel[:5]
"""
return ValueItem(self, self.index)
@property
def labels(self):
return list(self.skls.keys())
def __eq__(self, other):
if isinstance(other, ValueBase):
return list(self) == list(other)
elif isinstance(other, list):
return list(self) == other
else:
raise TypeError(f"Unable to compare Values against {type(other)}")
[docs] def __and__(self, queryresult: "QueryResult"):
"""
Combine queries with logical 'and':
.. code-block:: Python
my_param = params.sel["my_param]
(my_param["my_label"] == 5) & (my_param["oth_label"] == "hello")
"""
res = set(self.index) & set(queryresult.index)
return QueryResult(self, res)
[docs] def __or__(self, queryresult: "QueryResult"):
"""
Combine queries with logical 'or':
.. code-block:: Python
my_param = params.sel["my_param]
(my_param["my_label"] == 5) | (my_param["oth_label"] == "hello")
"""
res = set(self.index) | set(queryresult.index)
return QueryResult(self, res)
def __iter__(self):
for value in self.values.values():
yield value
def __repr__(self):
vo_repr = (
",\n ".join(str(dict(self.values[i])) for i in self.index) + ","
)
return f"Values([\n {vo_repr}\n])"
def union(
queryresults: Union[List[ValueBase], Generator[ValueBase, None, None]]
):
result = None
for queryresult in queryresults:
if result is None:
result = queryresult
else:
result |= queryresult
return result or QueryResult(None, [])
def intersection(
queryresults: Union[List[ValueBase], Generator[ValueBase, None, None]]
):
result = None
for queryresult in queryresults:
if result is None:
result = queryresult
else:
result &= queryresult
return result or QueryResult(None, [])