x-algorithm/phoenix/grok.py
2026-01-20 02:31:49 +00:00

587 lines
18 KiB
Python

# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
logger = logging.getLogger(__name__)
class TrainingState(NamedTuple):
"""Container for the training state."""
params: hk.Params
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
return _ffn_size
def make_recsys_attn_mask(
seq_len: int,
candidate_start_offset: int,
dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""Create attention mask for recommendation system inference.
Creates a mask where:
- Positions 0 to candidate_start_offset-1 (user+history): causal attention
- Positions candidate_start_offset onwards (candidates): can attend to user+history
and themselves (self-attention), but NOT to other candidates
This ensures each candidate is scored independently based on user+history context.
Args:
seq_len: Total sequence length (user + history + candidates)
candidate_start_offset: Position where candidates start in the sequence
dtype: Data type for the mask
Returns:
Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
"""
# Start with causal mask for the full sequence
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))
# Zero out candidate-to-candidate attention (bottom-right block)
attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)
# Add back self-attention for candidates (diagonal of the candidate block)
candidate_indices = jnp.arange(candidate_start_offset, seq_len)
attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)
return attn_mask
class MHAOutput(NamedTuple):
"""Outputs of the multi-head attention operation."""
embeddings: jax.Array
class DecoderOutput(NamedTuple):
embeddings: jax.Array
class TransformerOutput(NamedTuple):
embeddings: jax.Array
@dataclass
class TransformerConfig:
emb_size: int
key_size: int
num_q_heads: int
num_kv_heads: int
num_layers: int
widening_factor: float = 4.0
attn_output_multiplier: float = 1.0
name: Optional[str] = None
def make(self) -> "Transformer":
return Transformer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
widening_factor=self.widening_factor,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
num_layers=self.num_layers,
)
def hk_rms_norm(
x: jax.Array,
fixed_scale=False,
) -> jax.Array:
"""Applies a unique LayerNorm to x with default settings."""
ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
return ln(x)
class Linear(hk.Linear):
def __init__(
self,
output_size: int,
with_bias: bool = True,
name: Optional[str] = None,
):
super().__init__(
output_size=output_size,
with_bias=with_bias,
name=name,
)
def __call__( # type: ignore
self,
inputs: jax.Array,
) -> jax.Array:
"""Computes a linear transform of the input."""
fprop_dtype = inputs.dtype
if not inputs.shape:
raise ValueError("Input must not be scalar.")
input_size = inputs.shape[-1]
output_size = self.output_size
w = hk.get_parameter(
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
)
out = jnp.dot(inputs, w.astype(fprop_dtype))
if self.with_bias:
b = hk.get_parameter(
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
)
b = jnp.broadcast_to(b, out.shape)
out = out + b.astype(fprop_dtype)
return out
class RMSNorm(hk.RMSNorm):
def __init__(
self,
axis: Union[int, Sequence[int], slice],
eps: float = 1e-5,
name: Optional[str] = None,
create_scale: bool = True,
):
super().__init__(axis, eps, create_scale=create_scale, name=name)
def __call__(self, inputs: jax.Array):
fprop_dtype = inputs.dtype
param_shape = (inputs.shape[-1],)
if self.create_scale:
scale = hk.get_parameter(
"scale",
param_shape,
dtype=jnp.float32,
init=hk.initializers.Constant(0),
)
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
else:
scale = 1.0
inputs = inputs.astype(jnp.float32)
scale = jnp.float32(scale)
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
outputs = scale * normed_inputs
return outputs.astype(fprop_dtype)
def rotate_half(
x: jax.Array,
) -> jax.Array:
"""Obtain the rotated counterpart of each feature"""
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-x2, x1), axis=-1)
class RotaryEmbedding(hk.Module):
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
as described in https://arxiv.org/abs/2104.09864.
Attributes:
dim (int): Dimensionality of the feature vectors
base_exponent (int): Base exponent to compute embeddings from
"""
def __init__(
self,
dim: int,
name: Optional[str] = None,
base_exponent: int = 10000,
):
super().__init__(name)
self.dim = dim
self.base_exponent = base_exponent
assert self.dim % 2 == 0
def __call__(
self,
x: jax.Array,
seq_dim: int,
offset: jax.Array,
const_position: Optional[int] = None,
t: Optional[jax.Array] = None,
) -> jax.Array:
fprop_dtype = x.dtype
# Compute the per-dimension frequencies
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
inv_freq = jnp.asarray(
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
)
if jnp.shape(offset) == ():
# Offset can be a scalar or one offset per batch element.
offset = jnp.expand_dims(offset, 0)
# Compute the per element phase (to pass into sin and cos)
if const_position:
t = const_position * jnp.ones(
(
1,
x.shape[seq_dim],
),
dtype=jnp.float32,
)
elif t is None:
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
phase = jnp.einsum("bi,j->bij", t, inv_freq)
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
x = x.astype(fprop_dtype)
return x
class MultiHeadAttention(hk.Module):
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
key_size: int,
*,
with_bias: bool = True,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
attn_output_multiplier: float = 1.0,
name: Optional[str] = None,
):
super().__init__(name=name)
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.key_size = key_size
self.value_size = value_size or key_size
self.model_size = model_size or key_size * num_q_heads
self.attn_output_multiplier = attn_output_multiplier
self.with_bias = with_bias
def __call__(
self,
query: jax.Array,
key: jax.Array,
value: jax.Array,
mask: jax.Array,
) -> MHAOutput:
# In shape hints below, we suppress the leading dims [...] for brevity.
# Hence e.g. [A, B] should be read in every case as [..., A, B].
projection = self._linear_projection
# Check that the keys and values have consistent batch size and sequence length.
assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
if mask is not None:
assert mask.ndim == 4
assert mask.shape[0] in {
1,
query.shape[0],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert key.shape[0] in {
1,
query.shape[0],
}, f"key/query shape: {key.shape}/{query.shape}"
assert mask.shape[1] == 1
assert mask.shape[2] in {
1,
query.shape[1],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert mask.shape[3] in {
1,
key.shape[1],
}, f"mask/query shape: {mask.shape}/{key.shape}"
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
assert self.num_q_heads % self.num_kv_heads == 0
query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
key_heads = rotate(key_heads, seq_dim=1, offset=0)
query_heads = rotate(query_heads, seq_dim=1, offset=0)
b, t, h, d = query_heads.shape
_, _, kv_h, _ = key_heads.shape
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
# Compute attention weights.
# Attention softmax is always carried out in fp32.
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
jnp.float32
)
attn_logits *= self.attn_output_multiplier
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
mask = mask[:, :, None, :, :]
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
# Weight the values by the attention and flatten the head vectors.
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
leading_dims = attn.shape[:2]
attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
# Apply another projection to get the final embeddings.
final_projection = Linear(self.model_size, with_bias=False)
return MHAOutput(final_projection(attn))
@hk.transparent
def _linear_projection(
self,
x: jax.Array,
head_size: int,
num_heads: int,
name: Optional[str] = None,
) -> jax.Array:
y = Linear(num_heads * head_size, with_bias=False, name=name)(x)
*leading_dims, _ = x.shape
return y.reshape((*leading_dims, num_heads, head_size))
@dataclass
class MHABlock(hk.Module):
"""A MHA Block"""
num_q_heads: int
num_kv_heads: int
key_size: int
attn_output_multiplier: float = 1.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
) -> MHAOutput:
_, _, model_size = inputs.shape
assert mask.ndim == 4, f"shape: {mask.shape}"
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
side_input = inputs
def attn_block(query, key, value, mask) -> MHAOutput:
return MultiHeadAttention(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
model_size=model_size,
attn_output_multiplier=self.attn_output_multiplier,
)(query, key, value, mask)
attn_output = attn_block(inputs, side_input, side_input, mask)
h_attn = attn_output.embeddings
return MHAOutput(embeddings=h_attn)
@dataclass
class DenseBlock(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float = 4.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
) -> jax.Array: # [B, T, D]
_, _, model_size = inputs.shape
h_v = Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
name="linear_v",
)(inputs)
h_w1 = jax.nn.gelu(
Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
)(inputs)
)
h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)
return h_dense
@dataclass
class DecoderLayer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
num_layers: int
layer_index: Optional[int] = None
widening_factor: float = 4.0
name: Optional[str] = None
attn_output_multiplier: float = 1.0
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T]
padding_mask: Optional[jax.Array],
) -> DecoderOutput:
"""Transforms input embedding sequences to output embedding sequences."""
del padding_mask # Unused.
def layer_norm(x):
return hk_rms_norm(x)
h = inputs
attn_output = MHABlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
)(layer_norm(h), mask)
h_attn = attn_output.embeddings
h_attn = layer_norm(h_attn)
h += h_attn
def base_dense_block(h):
h = DenseBlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=self.widening_factor,
)(h)
return h
h_dense = base_dense_block(layer_norm(h))
h_dense = layer_norm(h_dense)
h += h_dense
return DecoderOutput(
embeddings=h,
)
def layer_norm(x):
return hk_rms_norm(x)
@dataclass
class Transformer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float
attn_output_multiplier: float
num_layers: int
name: Optional[str] = None
def __call__(
self,
embeddings: jax.Array, # [B, T, D]
mask: jax.Array, # [B, T]
candidate_start_offset: Optional[int] = None,
) -> TransformerOutput:
"""Transforms input embedding sequences to output embedding sequences.
Args:
embeddings: Input embeddings of shape [B, T, D]
mask: Padding mask of shape [B, T], True for valid positions
candidate_start_offset: If provided, positions >= this offset are treated as
candidates that can only attend to positions before the offset (user+history)
and themselves (self-attention), but not to other candidates.
Used for recommendation system inference.
Returns:
TransformerOutput containing the output embeddings.
"""
fprop_dtype = embeddings.dtype
_, seq_len, _ = embeddings.shape
padding_mask = mask.copy()
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
if candidate_start_offset is not None:
# Use recommendation system attention mask where candidates attend to
# user+history and themselves, but not to other candidates
attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
mask = mask * attn_mask
else:
# Standard causal mask for autoregressive sequence modelling
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
fprop_dtype
) # [B=1, H=1, T, T]
mask = mask * causal_mask # [B, H=1, T, T]
h = embeddings
def block(
h,
mask,
padding_mask,
layer_index: Optional[int] = None,
widening_factor: Optional[int] = None,
name: Optional[str] = None,
) -> DecoderOutput:
return DecoderLayer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=widening_factor or self.widening_factor,
num_layers=self.num_layers,
attn_output_multiplier=self.attn_output_multiplier,
name=name,
layer_index=layer_index,
)(h, mask, padding_mask)
for i in range(self.num_layers):
decoder_output = block(
h,
mask,
padding_mask,
layer_index=i,
name=f"decoder_layer_{i}",
)
h = decoder_output.embeddings
return TransformerOutput(
embeddings=h,
)