Skip to content

Jax float64 precision issues do not play ball with hypothesis #368

@ev-br

Description

@ev-br

A typical example is (test_diff):

self = <hypothesis.extra.array_api.ArrayStrategy object at 0x7e6a6cf6c990>, val = 2.112233982580733, val_0d = Array(2.1122339, dtype=float32)
strategy = FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308)

    def check_set_value(self, val, val_0d, strategy):
        if val == val and self.builtin(val_0d) != val:
            if self.builtin is float:
                assert self.finfo is not None  # for mypy
                try:
                    is_subnormal = 0 < abs(val) < self.finfo.smallest_normal
                except Exception:
                    # val may be a non-float that does not support the
                    # operations __lt__ and __abs__
                    is_subnormal = False
                if is_subnormal:
                    raise InvalidArgument(
                        f"Generated subnormal float {val} from strategy "
                        f"{strategy} resulted in {val_0d!r}, probably "
                        f"as a result of array module {self.xp.__name__} "
                        "being built with flush-to-zero compiler options. "
                        "Consider passing allow_subnormal=False."
                    )
>           raise InvalidArgument(
                f"Generated array element {val!r} from strategy {strategy} "
                f"cannot be represented with dtype {self.dtype}. "
                f"Array module {self.xp.__name__} instead "
                f"represents the element as {val_0d}. "
                "Consider using a more precise elements strategy, "
                "for example passing the width argument to floats()."
            )
E           hypothesis.errors.InvalidArgument: Generated array element 2.112233982580733 from strategy FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308) cannot be represented with dtype <class 'jax.numpy.float64'>. Array module jax.numpy instead represents the element as 2.112233877182007. Consider using a more precise elements strategy, for example passing the width argument to floats().
E           while generating 'x' from sampled_from((<class 'jax.numpy.uint8'>, <class 'jax.numpy.int8'>, <class 'jax.numpy.int16'>, <class 'jax.numpy.int32'>, <class 'jax.numpy.float32'>, <class 'jax.numpy.float64'>, <class 'jax.numpy.complex64'>, <class 'jax.numpy.complex128'>)).flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
E           Explanation:
E               These lines were always and only run by failing examples:
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:328
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:651
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/numpy/_core/getlimits.py:609

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions