Custom Types#

Often, the behavior for a field needs to be customized to support a particular shape or validation method that ParamTools does not support out of the box. In this case, you may use the register_custom_type function to add your new type to the ParamTools type registry. Each type has a corresponding field that is used for serialization and deserialization. ParamTools will then use this field any time it is handling a value, label, or member that is of this type.

ParamTools is built on top of marshmallow, a general purpose validation library. This means that you must implement a custom marshmallow field to go along with your new type. Please refer to the marshmallow docs if you have questions about the use of marshmallow in the examples below.

32 Bit Integer Example#

ParamTools’s default integer field uses NumPy’s int64 type. This example shows you how to define an int32 type and reference it in your defaults.

First, let’s define the Marshmallow class:

import marshmallow as ma
import numpy as np

class Int32(ma.fields.Field):
    """
    A custom type for np.int32.
    https://numpy.org/devdocs/reference/arrays.dtypes.html
    """
    # minor detail that makes this play nice with array_first
    np_type = np.int32

    def _serialize(self, value, *args, **kwargs):
        """Convert np.int32 to basic, serializable Python int."""
        return value.tolist()

    def _deserialize(self, value, *args, **kwargs):
        """Cast value from JSON to NumPy Int32."""
        converted = np.int32(value)
        return converted

Now, reference it in our defaults JSON/dict object:

import paramtools as pt


# add int32 type to the paramtools type registry
pt.register_custom_type("int32", Int32())


class Params(pt.Parameters):
    defaults = {
        "small_int": {
            "title": "Small integer",
            "description": "Demonstrate how to define a custom type",
            "type": "int32",
            "value": 2
        }
    }


params = Params(array_first=True)


print(f"value: {params.small_int}, type: {type(params.small_int)}")
value: 2, type: <class 'numpy.int32'>

One problem with this is that we could run into some deserialization issues. Due to integer overflow, our deserialized result is not the number that we passed in–it’s negative!

params.adjust(dict(
    # this number wasn't chosen randomly.
    small_int=2147483647 + 1
))
/tmp/ipykernel_2025/800111478.py:18: DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays.  The conversion of 2147483648 to int32 will fail in the future.
For the old behavior, usually:
    np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
  converted = np.int32(value)
OrderedDict([('small_int', [OrderedDict([('value', -2147483648)])])])

Marshmallow Validator#

Fortunately, you can specify a custom validator with marshmallow or ParamTools. Making this works requires modifying the _deserialize method to check for overflow like this:

class Int32(ma.fields.Field):
    """
    A custom type for np.int32.
    https://numpy.org/devdocs/reference/arrays.dtypes.html
    """
    # minor detail that makes this play nice with array_first
    np_type = np.int32

    def _serialize(self, value, *args, **kwargs):
        """Convert np.int32 to basic Python int."""
        return value.tolist()

    def _deserialize(self, value, *args, **kwargs):
        """Cast value from JSON to NumPy Int32."""
        converted = np.int32(value)

        # check for overflow and let range validator
        # display the error message.
        if converted != int(value):
            return int(value)

        return converted

Now, let’s see how to use marshmallow to fix this problem:

import marshmallow as ma
import paramtools as pt


# get the minimum and maxium values for 32 bit integers.
min_int32 = -2147483648  # = np.iinfo(np.int32).min
max_int32 = 2147483647  # = np.iinfo(np.int32).max

# add int32 type to the paramtools type registry
pt.register_custom_type(
    "int32",
    Int32(validate=[
        ma.validate.Range(min=min_int32, max=max_int32)
    ])
)


class Params(pt.Parameters):
    defaults = {
        "small_int": {
            "title": "Small integer",
            "description": "Demonstrate how to define a custom type",
            "type": "int32",
            "value": 2
        }
    }


params = Params(array_first=True)

params.adjust(dict(
    small_int=np.int64(max_int32) + 1
))
---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
Cell In[5], line 31
     19     defaults = {
     20         "small_int": {
     21             "title": "Small integer",
   (...)
     25         }
     26     }
     29 params = Params(array_first=True)
---> 31 params.adjust(dict(
     32     small_int=np.int64(max_int32) + 1
     33 ))

File ~/work/ParamTools/ParamTools/paramtools/parameters.py:257, in Parameters.adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)
    210 def adjust(
    211     self,
    212     params_or_path: Union[str, Mapping[str, List[ValueObject]]],
   (...)
    216     clobber: bool = True,
    217 ):
    218     """
    219     Deserialize and validate parameter adjustments. `params_or_path`
    220     can be a file path or a `dict` that has not been fully deserialized.
   (...)
    255         least one existing value item's corresponding label values.
    256     """
--> 257     return self._adjust(
    258         params_or_path,
    259         ignore_warnings=ignore_warnings,
    260         raise_errors=raise_errors,
    261         extend_adj=extend_adj,
    262         clobber=clobber,
    263     )

File ~/work/ParamTools/ParamTools/paramtools/parameters.py:375, in Parameters._adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, deserialized, validate, clobber)
    371 # throw error if raise_errors is True or ignore_warnings is False
    372 if (raise_errors and has_errors) or (
    373     not ignore_warnings and has_warnings
    374 ):
--> 375     raise self.validation_error
    377 # Update attrs for params that were adjusted.
    378 self._set_state(params=parsed_params.keys())

ValidationError: {
    "errors": {
        "small_int": [
            "Must be greater than or equal to -2147483648 and less than or equal to 2147483647."
        ]
    }
}

ParamTools Validator#

Finally, we will use ParamTools to solve this problem. We need to modify how we create our custom marshmallow field so that it’s wrapped by ParamTools’s PartialField. This makes it clear that your field still needs to be initialized, and that your custom field is able to receive validation information from the defaults configuration:

import paramtools as pt


# add int32 type to the paramtools type registry
pt.register_custom_type(
    "int32",
    pt.PartialField(Int32)
)


class Params(pt.Parameters):
    defaults = {
        "small_int": {
            "title": "Small integer",
            "description": "Demonstrate how to define a custom type",
            "type": "int32",
            "value": 2,
            "validators": {
                "range": {"min": -2147483648, "max": 2147483647}
            }
        }
    }


params = Params(array_first=True)

params.adjust(dict(
    small_int=2147483647 + 1
))
/tmp/ipykernel_2025/1243571737.py:15: DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays.  The conversion of 2147483648 to int32 will fail in the future.
For the old behavior, usually:
    np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
  converted = np.int32(value)
---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
Cell In[6], line 27
     12     defaults = {
     13         "small_int": {
     14             "title": "Small integer",
   (...)
     21         }
     22     }
     25 params = Params(array_first=True)
---> 27 params.adjust(dict(
     28     small_int=2147483647 + 1
     29 ))

File ~/work/ParamTools/ParamTools/paramtools/parameters.py:257, in Parameters.adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)
    210 def adjust(
    211     self,
    212     params_or_path: Union[str, Mapping[str, List[ValueObject]]],
   (...)
    216     clobber: bool = True,
    217 ):
    218     """
    219     Deserialize and validate parameter adjustments. `params_or_path`
    220     can be a file path or a `dict` that has not been fully deserialized.
   (...)
    255         least one existing value item's corresponding label values.
    256     """
--> 257     return self._adjust(
    258         params_or_path,
    259         ignore_warnings=ignore_warnings,
    260         raise_errors=raise_errors,
    261         extend_adj=extend_adj,
    262         clobber=clobber,
    263     )

File ~/work/ParamTools/ParamTools/paramtools/parameters.py:375, in Parameters._adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, deserialized, validate, clobber)
    371 # throw error if raise_errors is True or ignore_warnings is False
    372 if (raise_errors and has_errors) or (
    373     not ignore_warnings and has_warnings
    374 ):
--> 375     raise self.validation_error
    377 # Update attrs for params that were adjusted.
    378 self._set_state(params=parsed_params.keys())

ValidationError: {
    "errors": {
        "small_int": [
            "small_int 2147483648 > max 2147483647 "
        ]
    }
}