Custom Types#
Often, the behavior for a field needs to be customized to support a particular shape or validation method that ParamTools does not support out of the box. In this case, you may use the register_custom_type
function to add your new type
to the ParamTools type registry. Each type
has a corresponding field
that is used for serialization and deserialization. ParamTools will then use this field
any time it is handling a value
, label
, or member
that is of this type
.
ParamTools is built on top of marshmallow
, a general purpose validation library. This means that you must implement a custom marshmallow
field to go along with your new type. Please refer to the marshmallow
docs if you have questions about the use of marshmallow
in the examples below.
32 Bit Integer Example#
ParamTools’s default integer field uses NumPy’s int64
type. This example shows you how to define an int32
type and reference it in your defaults
.
First, let’s define the Marshmallow class:
import marshmallow as ma
import numpy as np
class Int32(ma.fields.Field):
"""
A custom type for np.int32.
https://numpy.org/devdocs/reference/arrays.dtypes.html
"""
# minor detail that makes this play nice with array_first
np_type = np.int32
def _serialize(self, value, *args, **kwargs):
"""Convert np.int32 to basic, serializable Python int."""
return value.tolist()
def _deserialize(self, value, *args, **kwargs):
"""Cast value from JSON to NumPy Int32."""
converted = np.int32(value)
return converted
Now, reference it in our defaults JSON/dict object:
import paramtools as pt
# add int32 type to the paramtools type registry
pt.register_custom_type("int32", Int32())
class Params(pt.Parameters):
defaults = {
"small_int": {
"title": "Small integer",
"description": "Demonstrate how to define a custom type",
"type": "int32",
"value": 2
}
}
params = Params(array_first=True)
print(f"value: {params.small_int}, type: {type(params.small_int)}")
value: 2, type: <class 'numpy.int32'>
One problem with this is that we could run into some deserialization issues. Due to integer overflow, our deserialized result is not the number that we passed in–it’s negative!
params.adjust(dict(
# this number wasn't chosen randomly.
small_int=2147483647 + 1
))
/tmp/ipykernel_2025/800111478.py:18: DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 2147483648 to int32 will fail in the future.
For the old behavior, usually:
np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
converted = np.int32(value)
OrderedDict([('small_int', [OrderedDict([('value', -2147483648)])])])
Marshmallow Validator#
Fortunately, you can specify a custom validator with marshmallow
or ParamTools. Making this works requires modifying the _deserialize
method to check for overflow like this:
class Int32(ma.fields.Field):
"""
A custom type for np.int32.
https://numpy.org/devdocs/reference/arrays.dtypes.html
"""
# minor detail that makes this play nice with array_first
np_type = np.int32
def _serialize(self, value, *args, **kwargs):
"""Convert np.int32 to basic Python int."""
return value.tolist()
def _deserialize(self, value, *args, **kwargs):
"""Cast value from JSON to NumPy Int32."""
converted = np.int32(value)
# check for overflow and let range validator
# display the error message.
if converted != int(value):
return int(value)
return converted
Now, let’s see how to use marshmallow
to fix this problem:
import marshmallow as ma
import paramtools as pt
# get the minimum and maxium values for 32 bit integers.
min_int32 = -2147483648 # = np.iinfo(np.int32).min
max_int32 = 2147483647 # = np.iinfo(np.int32).max
# add int32 type to the paramtools type registry
pt.register_custom_type(
"int32",
Int32(validate=[
ma.validate.Range(min=min_int32, max=max_int32)
])
)
class Params(pt.Parameters):
defaults = {
"small_int": {
"title": "Small integer",
"description": "Demonstrate how to define a custom type",
"type": "int32",
"value": 2
}
}
params = Params(array_first=True)
params.adjust(dict(
small_int=np.int64(max_int32) + 1
))
---------------------------------------------------------------------------
ValidationError Traceback (most recent call last)
Cell In[5], line 31
19 defaults = {
20 "small_int": {
21 "title": "Small integer",
(...)
25 }
26 }
29 params = Params(array_first=True)
---> 31 params.adjust(dict(
32 small_int=np.int64(max_int32) + 1
33 ))
File ~/work/ParamTools/ParamTools/paramtools/parameters.py:257, in Parameters.adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)
210 def adjust(
211 self,
212 params_or_path: Union[str, Mapping[str, List[ValueObject]]],
(...)
216 clobber: bool = True,
217 ):
218 """
219 Deserialize and validate parameter adjustments. `params_or_path`
220 can be a file path or a `dict` that has not been fully deserialized.
(...)
255 least one existing value item's corresponding label values.
256 """
--> 257 return self._adjust(
258 params_or_path,
259 ignore_warnings=ignore_warnings,
260 raise_errors=raise_errors,
261 extend_adj=extend_adj,
262 clobber=clobber,
263 )
File ~/work/ParamTools/ParamTools/paramtools/parameters.py:375, in Parameters._adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, deserialized, validate, clobber)
371 # throw error if raise_errors is True or ignore_warnings is False
372 if (raise_errors and has_errors) or (
373 not ignore_warnings and has_warnings
374 ):
--> 375 raise self.validation_error
377 # Update attrs for params that were adjusted.
378 self._set_state(params=parsed_params.keys())
ValidationError: {
"errors": {
"small_int": [
"Must be greater than or equal to -2147483648 and less than or equal to 2147483647."
]
}
}
ParamTools Validator#
Finally, we will use ParamTools to solve this problem. We need to modify how we create our custom marshmallow
field so that it’s wrapped by ParamTools’s PartialField
. This makes it clear that your field still needs to be initialized, and that your custom field is able to receive validation information from the defaults
configuration:
import paramtools as pt
# add int32 type to the paramtools type registry
pt.register_custom_type(
"int32",
pt.PartialField(Int32)
)
class Params(pt.Parameters):
defaults = {
"small_int": {
"title": "Small integer",
"description": "Demonstrate how to define a custom type",
"type": "int32",
"value": 2,
"validators": {
"range": {"min": -2147483648, "max": 2147483647}
}
}
}
params = Params(array_first=True)
params.adjust(dict(
small_int=2147483647 + 1
))
/tmp/ipykernel_2025/1243571737.py:15: DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 2147483648 to int32 will fail in the future.
For the old behavior, usually:
np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
converted = np.int32(value)
---------------------------------------------------------------------------
ValidationError Traceback (most recent call last)
Cell In[6], line 27
12 defaults = {
13 "small_int": {
14 "title": "Small integer",
(...)
21 }
22 }
25 params = Params(array_first=True)
---> 27 params.adjust(dict(
28 small_int=2147483647 + 1
29 ))
File ~/work/ParamTools/ParamTools/paramtools/parameters.py:257, in Parameters.adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)
210 def adjust(
211 self,
212 params_or_path: Union[str, Mapping[str, List[ValueObject]]],
(...)
216 clobber: bool = True,
217 ):
218 """
219 Deserialize and validate parameter adjustments. `params_or_path`
220 can be a file path or a `dict` that has not been fully deserialized.
(...)
255 least one existing value item's corresponding label values.
256 """
--> 257 return self._adjust(
258 params_or_path,
259 ignore_warnings=ignore_warnings,
260 raise_errors=raise_errors,
261 extend_adj=extend_adj,
262 clobber=clobber,
263 )
File ~/work/ParamTools/ParamTools/paramtools/parameters.py:375, in Parameters._adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, deserialized, validate, clobber)
371 # throw error if raise_errors is True or ignore_warnings is False
372 if (raise_errors and has_errors) or (
373 not ignore_warnings and has_warnings
374 ):
--> 375 raise self.validation_error
377 # Update attrs for params that were adjusted.
378 self._set_state(params=parsed_params.keys())
ValidationError: {
"errors": {
"small_int": [
"small_int 2147483648 > max 2147483647 "
]
}
}