Skip to content

Commit

Permalink
Merge pull request #25646 from jakevdp:int4-casting
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708404847
  • Loading branch information
Google-ML-Automation committed Dec 20, 2024
2 parents 7ecc947 + beee98a commit 043c260
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 26 deletions.
13 changes: 13 additions & 0 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,19 @@ def supports_inf(dtype: DTypeLike) -> bool:
'c': complex_,
}

def bit_width(dtype: DTypeLike) -> int:
"""Number of bits per element for the dtype."""
# Note: we cannot use dtype.itemsize here because this is
# incorrect for sub-byte integer types.
if dtype == bool:
return 8 # physical bit layout for boolean dtype
elif issubdtype(dtype, np.integer):
return iinfo(dtype).bits
elif issubdtype(dtype, np.inexact):
return finfo(dtype).bits
else:
raise ValueError("unexpected input: {dtype=}")

# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0: np.dtype = np.dtype([('float0', np.void, 0)])

Expand Down
17 changes: 10 additions & 7 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3367,18 +3367,21 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
new_dtype = dtypes.canonicalize_dtype(new_dtype)

if old_dtype.itemsize == new_dtype.itemsize:
old_nbits = dtypes.bit_width(old_dtype)
new_nbits = dtypes.bit_width(new_dtype)

if old_nbits == new_nbits:
return operand.shape
elif old_dtype.itemsize > new_dtype.itemsize:
return (*operand.shape, old_dtype.itemsize // new_dtype.itemsize)
elif old_nbits > new_nbits:
return (*operand.shape, old_nbits // new_nbits)
else:
dim_size = operand.shape[-1] if operand.shape else 1
if dim_size * old_dtype.itemsize != new_dtype.itemsize:
if dim_size * old_nbits != new_nbits:
raise ValueError(
f"Attempting to convert array of shape {operand.shape} "
f"from {old_dtype} of size {old_dtype.itemsize} "
f"to {new_dtype} of size {new_dtype.itemsize}, "
f"but {dim_size} * {old_dtype.itemsize} != {new_dtype.itemsize}")
f"from {old_dtype} of size {old_nbits} bits "
f"to {new_dtype} of size {new_nbits}, bits "
f"but {dim_size} * {old_nbits} != {new_nbits}")
return operand.shape[:-1]

def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
Expand Down
34 changes: 33 additions & 1 deletion jax/_src/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,40 @@ def clz(x):
def convert_element_type(operand, dtype):
return np.asarray(operand, dtype=dtype)

def _bitcast_uint4_to_uint8(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint4'
operand = operand.astype('uint8')
return operand[..., ::2] + (operand[..., 1::2] << 4)

def _bitcast_uint8_to_uint4(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint8'
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
result[..., ::2] = (operand & 0b00001111).astype('uint4')
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
return result

def bitcast_convert_type(operand, dtype):
return np.asarray(operand).view(dtype)
operand = np.asarray(operand)
nbits_in = dtypes.bit_width(operand.dtype)
nbits_out = dtypes.bit_width(dtype)

if nbits_out > nbits_in:
assert operand.shape[-1] == nbits_out // nbits_in
out_shape = operand.shape[:-1]
elif nbits_out == nbits_in:
out_shape = operand.shape
else:
out_shape = (*operand.shape, nbits_in // nbits_out)

# Special handling for 4-bit integers.
if nbits_in == 4:
operand = _bitcast_uint4_to_uint8(operand.view('uint4'))
if nbits_out == 4:
operand = _bitcast_uint8_to_uint4(operand.view('uint8'))

return operand.view(dtype).reshape(out_shape)

def clamp(min, operand, max):
return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype)
Expand Down
42 changes: 24 additions & 18 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,43 +170,49 @@ def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype):
self._CheckAgainstNumpy(numpy_op, op, args_maker)

@jtu.sample_product(
from_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
to_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
from_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
to_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
shape = [(), (2,), (2, 3)]
)
def testBitcastConvertType(self, from_dtype, to_dtype, shape):
rng = jtu.rand_default(self.rng())
itemsize_in = np.dtype(from_dtype).itemsize
itemsize_out = np.dtype(to_dtype).itemsize
if itemsize_in < itemsize_out:
shape = (*shape, itemsize_out // itemsize_in)
nbits_in = dtypes.bit_width(from_dtype)
nbits_out = dtypes.bit_width(to_dtype)
if nbits_in < nbits_out:
shape = (*shape, nbits_out // nbits_in)
args_maker = lambda: [rng(shape, from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
self._CompileAndCheck(op, args_maker)
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
self._CompileAndCheck(jnp_op, args_maker)

# Test the shape and dtype of the output. We avoid testing the values here
# because the bitwise representation may vary from platform to platform.
out = op(*args_maker())
if itemsize_in == itemsize_out:
out = jnp_op(*args_maker())
if nbits_in == nbits_out:
expected_shape = shape
elif itemsize_in < itemsize_out:
elif nbits_in < nbits_out:
expected_shape = shape[:-1]
else:
expected_shape = (*shape, itemsize_in // itemsize_out)
expected_shape = (*shape, nbits_in // nbits_out)
self.assertEqual(out.dtype, to_dtype)
self.assertEqual(out.shape, expected_shape)

@jtu.sample_product(
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
for from_dtype, to_dtype in itertools.product(
[np.float32, np.int32, "float32", "int32"], repeat=2)],
['int4', 'uint4', np.int8, np.uint8, np.int32, np.float16, np.float32],
repeat=2)],
shape=[(4,), (2, 4), (2, 3, 4)]
)
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype):
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape):
nbits_in = dtypes.bit_width(from_dtype)
nbits_out = dtypes.bit_width(to_dtype)
if nbits_in < nbits_out:
shape = (*shape, nbits_out // nbits_in)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
args_maker = lambda: [rng(shape, from_dtype)]
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
np_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)

@jtu.sample_product(
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
Expand Down

0 comments on commit 043c260

Please sign in to comment.