From beee98ab4a28c9224b7e40a7877dd822763e2450 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 20 Dec 2024 12:45:24 -0800 Subject: [PATCH] Add int4/uint4 support to bitcast_convert_type --- jax/_src/dtypes.py | 13 ++++++++++++ jax/_src/lax/lax.py | 17 +++++++++------- jax/_src/lax_reference.py | 34 ++++++++++++++++++++++++++++++- tests/lax_test.py | 42 ++++++++++++++++++++++----------------- 4 files changed, 80 insertions(+), 26 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 95ef5af34e0c..52cb3d87bbda 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -205,6 +205,19 @@ def supports_inf(dtype: DTypeLike) -> bool: 'c': complex_, } +def bit_width(dtype: DTypeLike) -> int: + """Number of bits per element for the dtype.""" + # Note: we cannot use dtype.itemsize here because this is + # incorrect for sub-byte integer types. + if dtype == bool: + return 8 # physical bit layout for boolean dtype + elif issubdtype(dtype, np.integer): + return iinfo(dtype).bits + elif issubdtype(dtype, np.inexact): + return finfo(dtype).bits + else: + raise ValueError("unexpected input: {dtype=}") + # Trivial vectorspace datatype needed for tangent values of int/bool primals float0: np.dtype = np.dtype([('float0', np.void, 0)]) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 06a612ff4c93..a6f74c40095f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3367,18 +3367,21 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype): old_dtype = dtypes.canonicalize_dtype(operand.dtype) new_dtype = dtypes.canonicalize_dtype(new_dtype) - if old_dtype.itemsize == new_dtype.itemsize: + old_nbits = dtypes.bit_width(old_dtype) + new_nbits = dtypes.bit_width(new_dtype) + + if old_nbits == new_nbits: return operand.shape - elif old_dtype.itemsize > new_dtype.itemsize: - return (*operand.shape, old_dtype.itemsize // new_dtype.itemsize) + elif old_nbits > new_nbits: + return (*operand.shape, old_nbits // new_nbits) else: dim_size = operand.shape[-1] if operand.shape else 1 - if dim_size * old_dtype.itemsize != new_dtype.itemsize: + if dim_size * old_nbits != new_nbits: raise ValueError( f"Attempting to convert array of shape {operand.shape} " - f"from {old_dtype} of size {old_dtype.itemsize} " - f"to {new_dtype} of size {new_dtype.itemsize}, " - f"but {dim_size} * {old_dtype.itemsize} != {new_dtype.itemsize}") + f"from {old_dtype} of size {old_nbits} bits " + f"to {new_dtype} of size {new_nbits}, bits " + f"but {dim_size} * {old_nbits} != {new_nbits}") return operand.shape[:-1] def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index 81209f6ed34a..4d4c24b0500e 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -173,8 +173,40 @@ def clz(x): def convert_element_type(operand, dtype): return np.asarray(operand, dtype=dtype) +def _bitcast_uint4_to_uint8(operand): + # Note: assumes little-endian byte order. + assert operand.dtype == 'uint4' + operand = operand.astype('uint8') + return operand[..., ::2] + (operand[..., 1::2] << 4) + +def _bitcast_uint8_to_uint4(operand): + # Note: assumes little-endian byte order. + assert operand.dtype == 'uint8' + result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4') + result[..., ::2] = (operand & 0b00001111).astype('uint4') + result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4') + return result + def bitcast_convert_type(operand, dtype): - return np.asarray(operand).view(dtype) + operand = np.asarray(operand) + nbits_in = dtypes.bit_width(operand.dtype) + nbits_out = dtypes.bit_width(dtype) + + if nbits_out > nbits_in: + assert operand.shape[-1] == nbits_out // nbits_in + out_shape = operand.shape[:-1] + elif nbits_out == nbits_in: + out_shape = operand.shape + else: + out_shape = (*operand.shape, nbits_in // nbits_out) + + # Special handling for 4-bit integers. + if nbits_in == 4: + operand = _bitcast_uint4_to_uint8(operand.view('uint4')) + if nbits_out == 4: + operand = _bitcast_uint8_to_uint4(operand.view('uint8')) + + return operand.view(dtype).reshape(out_shape) def clamp(min, operand, max): return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype) diff --git a/tests/lax_test.py b/tests/lax_test.py index 3b71a8c3698b..88e712503512 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -170,43 +170,49 @@ def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype): self._CheckAgainstNumpy(numpy_op, op, args_maker) @jtu.sample_product( - from_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned, - to_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned, + from_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned, + to_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned, shape = [(), (2,), (2, 3)] ) def testBitcastConvertType(self, from_dtype, to_dtype, shape): rng = jtu.rand_default(self.rng()) - itemsize_in = np.dtype(from_dtype).itemsize - itemsize_out = np.dtype(to_dtype).itemsize - if itemsize_in < itemsize_out: - shape = (*shape, itemsize_out // itemsize_in) + nbits_in = dtypes.bit_width(from_dtype) + nbits_out = dtypes.bit_width(to_dtype) + if nbits_in < nbits_out: + shape = (*shape, nbits_out // nbits_in) args_maker = lambda: [rng(shape, from_dtype)] - op = lambda x: lax.bitcast_convert_type(x, to_dtype) - self._CompileAndCheck(op, args_maker) + jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype) + self._CompileAndCheck(jnp_op, args_maker) # Test the shape and dtype of the output. We avoid testing the values here # because the bitwise representation may vary from platform to platform. - out = op(*args_maker()) - if itemsize_in == itemsize_out: + out = jnp_op(*args_maker()) + if nbits_in == nbits_out: expected_shape = shape - elif itemsize_in < itemsize_out: + elif nbits_in < nbits_out: expected_shape = shape[:-1] else: - expected_shape = (*shape, itemsize_in // itemsize_out) + expected_shape = (*shape, nbits_in // nbits_out) self.assertEqual(out.dtype, to_dtype) self.assertEqual(out.shape, expected_shape) @jtu.sample_product( [dict(from_dtype=from_dtype, to_dtype=to_dtype) for from_dtype, to_dtype in itertools.product( - [np.float32, np.int32, "float32", "int32"], repeat=2)], + ['int4', 'uint4', np.int8, np.uint8, np.int32, np.float16, np.float32], + repeat=2)], + shape=[(4,), (2, 4), (2, 3, 4)] ) - def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype): + def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape): + nbits_in = dtypes.bit_width(from_dtype) + nbits_out = dtypes.bit_width(to_dtype) + if nbits_in < nbits_out: + shape = (*shape, nbits_out // nbits_in) rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng((2, 3), from_dtype)] - op = lambda x: lax.bitcast_convert_type(x, to_dtype) - numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype) - self._CheckAgainstNumpy(numpy_op, op, args_maker) + args_maker = lambda: [rng(shape, from_dtype)] + jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype) + np_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) @jtu.sample_product( [dict(from_dtype=from_dtype, to_dtype=to_dtype)