mirror of
https://github.com/xai-org/x-algorithm.git
synced 2026-02-13 03:05:06 +01:00
587 lines
18 KiB
Python
587 lines
18 KiB
Python
# Copyright 2026 X.AI Corp.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import NamedTuple, Optional, Sequence, Union
|
|
|
|
import haiku as hk
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrainingState(NamedTuple):
|
|
"""Container for the training state."""
|
|
|
|
params: hk.Params
|
|
|
|
|
|
def ffn_size(emb_size, widening_factor):
|
|
_ffn_size = int(widening_factor * emb_size) * 2 // 3
|
|
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
|
|
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
|
|
return _ffn_size
|
|
|
|
|
|
def make_recsys_attn_mask(
|
|
seq_len: int,
|
|
candidate_start_offset: int,
|
|
dtype: jnp.dtype = jnp.float32,
|
|
) -> jax.Array:
|
|
"""Create attention mask for recommendation system inference.
|
|
|
|
Creates a mask where:
|
|
- Positions 0 to candidate_start_offset-1 (user+history): causal attention
|
|
- Positions candidate_start_offset onwards (candidates): can attend to user+history
|
|
and themselves (self-attention), but NOT to other candidates
|
|
|
|
This ensures each candidate is scored independently based on user+history context.
|
|
|
|
Args:
|
|
seq_len: Total sequence length (user + history + candidates)
|
|
candidate_start_offset: Position where candidates start in the sequence
|
|
dtype: Data type for the mask
|
|
|
|
Returns:
|
|
Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
|
|
"""
|
|
# Start with causal mask for the full sequence
|
|
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))
|
|
|
|
# Zero out candidate-to-candidate attention (bottom-right block)
|
|
attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)
|
|
|
|
# Add back self-attention for candidates (diagonal of the candidate block)
|
|
candidate_indices = jnp.arange(candidate_start_offset, seq_len)
|
|
attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)
|
|
|
|
return attn_mask
|
|
|
|
|
|
class MHAOutput(NamedTuple):
|
|
"""Outputs of the multi-head attention operation."""
|
|
|
|
embeddings: jax.Array
|
|
|
|
|
|
class DecoderOutput(NamedTuple):
|
|
embeddings: jax.Array
|
|
|
|
|
|
class TransformerOutput(NamedTuple):
|
|
embeddings: jax.Array
|
|
|
|
|
|
@dataclass
|
|
class TransformerConfig:
|
|
emb_size: int
|
|
key_size: int
|
|
num_q_heads: int
|
|
num_kv_heads: int
|
|
num_layers: int
|
|
widening_factor: float = 4.0
|
|
|
|
attn_output_multiplier: float = 1.0
|
|
|
|
name: Optional[str] = None
|
|
|
|
def make(self) -> "Transformer":
|
|
return Transformer(
|
|
num_q_heads=self.num_q_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
widening_factor=self.widening_factor,
|
|
key_size=self.key_size,
|
|
attn_output_multiplier=self.attn_output_multiplier,
|
|
num_layers=self.num_layers,
|
|
)
|
|
|
|
|
|
def hk_rms_norm(
|
|
x: jax.Array,
|
|
fixed_scale=False,
|
|
) -> jax.Array:
|
|
"""Applies a unique LayerNorm to x with default settings."""
|
|
ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
|
|
return ln(x)
|
|
|
|
|
|
class Linear(hk.Linear):
|
|
def __init__(
|
|
self,
|
|
output_size: int,
|
|
with_bias: bool = True,
|
|
name: Optional[str] = None,
|
|
):
|
|
super().__init__(
|
|
output_size=output_size,
|
|
with_bias=with_bias,
|
|
name=name,
|
|
)
|
|
|
|
def __call__( # type: ignore
|
|
self,
|
|
inputs: jax.Array,
|
|
) -> jax.Array:
|
|
"""Computes a linear transform of the input."""
|
|
|
|
fprop_dtype = inputs.dtype
|
|
if not inputs.shape:
|
|
raise ValueError("Input must not be scalar.")
|
|
|
|
input_size = inputs.shape[-1]
|
|
output_size = self.output_size
|
|
|
|
w = hk.get_parameter(
|
|
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
|
|
)
|
|
|
|
out = jnp.dot(inputs, w.astype(fprop_dtype))
|
|
if self.with_bias:
|
|
b = hk.get_parameter(
|
|
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
|
|
)
|
|
b = jnp.broadcast_to(b, out.shape)
|
|
out = out + b.astype(fprop_dtype)
|
|
|
|
return out
|
|
|
|
|
|
class RMSNorm(hk.RMSNorm):
|
|
def __init__(
|
|
self,
|
|
axis: Union[int, Sequence[int], slice],
|
|
eps: float = 1e-5,
|
|
name: Optional[str] = None,
|
|
create_scale: bool = True,
|
|
):
|
|
super().__init__(axis, eps, create_scale=create_scale, name=name)
|
|
|
|
def __call__(self, inputs: jax.Array):
|
|
fprop_dtype = inputs.dtype
|
|
param_shape = (inputs.shape[-1],)
|
|
if self.create_scale:
|
|
scale = hk.get_parameter(
|
|
"scale",
|
|
param_shape,
|
|
dtype=jnp.float32,
|
|
init=hk.initializers.Constant(0),
|
|
)
|
|
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
|
|
else:
|
|
scale = 1.0
|
|
inputs = inputs.astype(jnp.float32)
|
|
scale = jnp.float32(scale)
|
|
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
|
|
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
|
|
|
|
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
|
|
|
|
outputs = scale * normed_inputs
|
|
|
|
return outputs.astype(fprop_dtype)
|
|
|
|
|
|
def rotate_half(
|
|
x: jax.Array,
|
|
) -> jax.Array:
|
|
"""Obtain the rotated counterpart of each feature"""
|
|
x1, x2 = jnp.split(x, 2, axis=-1)
|
|
return jnp.concatenate((-x2, x1), axis=-1)
|
|
|
|
|
|
class RotaryEmbedding(hk.Module):
|
|
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
|
|
as described in https://arxiv.org/abs/2104.09864.
|
|
|
|
Attributes:
|
|
dim (int): Dimensionality of the feature vectors
|
|
base_exponent (int): Base exponent to compute embeddings from
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
name: Optional[str] = None,
|
|
base_exponent: int = 10000,
|
|
):
|
|
super().__init__(name)
|
|
self.dim = dim
|
|
self.base_exponent = base_exponent
|
|
assert self.dim % 2 == 0
|
|
|
|
def __call__(
|
|
self,
|
|
x: jax.Array,
|
|
seq_dim: int,
|
|
offset: jax.Array,
|
|
const_position: Optional[int] = None,
|
|
t: Optional[jax.Array] = None,
|
|
) -> jax.Array:
|
|
fprop_dtype = x.dtype
|
|
# Compute the per-dimension frequencies
|
|
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
|
|
inv_freq = jnp.asarray(
|
|
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
|
|
)
|
|
|
|
if jnp.shape(offset) == ():
|
|
# Offset can be a scalar or one offset per batch element.
|
|
offset = jnp.expand_dims(offset, 0)
|
|
|
|
# Compute the per element phase (to pass into sin and cos)
|
|
if const_position:
|
|
t = const_position * jnp.ones(
|
|
(
|
|
1,
|
|
x.shape[seq_dim],
|
|
),
|
|
dtype=jnp.float32,
|
|
)
|
|
elif t is None:
|
|
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
|
|
phase = jnp.einsum("bi,j->bij", t, inv_freq)
|
|
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
|
|
|
|
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
|
|
x = x.astype(fprop_dtype)
|
|
|
|
return x
|
|
|
|
|
|
class MultiHeadAttention(hk.Module):
|
|
def __init__(
|
|
self,
|
|
num_q_heads: int,
|
|
num_kv_heads: int,
|
|
key_size: int,
|
|
*,
|
|
with_bias: bool = True,
|
|
value_size: Optional[int] = None,
|
|
model_size: Optional[int] = None,
|
|
attn_output_multiplier: float = 1.0,
|
|
name: Optional[str] = None,
|
|
):
|
|
super().__init__(name=name)
|
|
self.num_q_heads = num_q_heads
|
|
self.num_kv_heads = num_kv_heads
|
|
self.key_size = key_size
|
|
self.value_size = value_size or key_size
|
|
self.model_size = model_size or key_size * num_q_heads
|
|
self.attn_output_multiplier = attn_output_multiplier
|
|
self.with_bias = with_bias
|
|
|
|
def __call__(
|
|
self,
|
|
query: jax.Array,
|
|
key: jax.Array,
|
|
value: jax.Array,
|
|
mask: jax.Array,
|
|
) -> MHAOutput:
|
|
# In shape hints below, we suppress the leading dims [...] for brevity.
|
|
# Hence e.g. [A, B] should be read in every case as [..., A, B].
|
|
projection = self._linear_projection
|
|
|
|
# Check that the keys and values have consistent batch size and sequence length.
|
|
assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
|
|
|
|
if mask is not None:
|
|
assert mask.ndim == 4
|
|
assert mask.shape[0] in {
|
|
1,
|
|
query.shape[0],
|
|
}, f"mask/query shape: {mask.shape}/{query.shape}"
|
|
assert key.shape[0] in {
|
|
1,
|
|
query.shape[0],
|
|
}, f"key/query shape: {key.shape}/{query.shape}"
|
|
assert mask.shape[1] == 1
|
|
assert mask.shape[2] in {
|
|
1,
|
|
query.shape[1],
|
|
}, f"mask/query shape: {mask.shape}/{query.shape}"
|
|
assert mask.shape[3] in {
|
|
1,
|
|
key.shape[1],
|
|
}, f"mask/query shape: {mask.shape}/{key.shape}"
|
|
|
|
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
|
|
assert self.num_q_heads % self.num_kv_heads == 0
|
|
query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
|
|
key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
|
|
value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")
|
|
|
|
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
|
|
key_heads = rotate(key_heads, seq_dim=1, offset=0)
|
|
query_heads = rotate(query_heads, seq_dim=1, offset=0)
|
|
|
|
b, t, h, d = query_heads.shape
|
|
_, _, kv_h, _ = key_heads.shape
|
|
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
|
|
|
|
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
|
|
|
|
# Compute attention weights.
|
|
# Attention softmax is always carried out in fp32.
|
|
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
|
|
jnp.float32
|
|
)
|
|
attn_logits *= self.attn_output_multiplier
|
|
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
|
|
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
|
|
|
|
mask = mask[:, :, None, :, :]
|
|
|
|
if mask is not None:
|
|
if mask.ndim != attn_logits.ndim:
|
|
raise ValueError(
|
|
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
|
|
f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
|
|
)
|
|
attn_logits = jnp.where(mask, attn_logits, -1e30)
|
|
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
|
|
|
|
# Weight the values by the attention and flatten the head vectors.
|
|
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
|
|
leading_dims = attn.shape[:2]
|
|
attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
|
|
|
|
# Apply another projection to get the final embeddings.
|
|
final_projection = Linear(self.model_size, with_bias=False)
|
|
return MHAOutput(final_projection(attn))
|
|
|
|
@hk.transparent
|
|
def _linear_projection(
|
|
self,
|
|
x: jax.Array,
|
|
head_size: int,
|
|
num_heads: int,
|
|
name: Optional[str] = None,
|
|
) -> jax.Array:
|
|
y = Linear(num_heads * head_size, with_bias=False, name=name)(x)
|
|
*leading_dims, _ = x.shape
|
|
return y.reshape((*leading_dims, num_heads, head_size))
|
|
|
|
|
|
@dataclass
|
|
class MHABlock(hk.Module):
|
|
"""A MHA Block"""
|
|
|
|
num_q_heads: int
|
|
num_kv_heads: int
|
|
key_size: int
|
|
attn_output_multiplier: float = 1.0
|
|
|
|
@hk.transparent
|
|
def __call__(
|
|
self,
|
|
inputs: jax.Array, # [B, T, D]
|
|
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
|
|
) -> MHAOutput:
|
|
_, _, model_size = inputs.shape
|
|
assert mask.ndim == 4, f"shape: {mask.shape}"
|
|
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
|
|
assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
|
|
side_input = inputs
|
|
|
|
def attn_block(query, key, value, mask) -> MHAOutput:
|
|
return MultiHeadAttention(
|
|
num_q_heads=self.num_q_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
key_size=self.key_size,
|
|
model_size=model_size,
|
|
attn_output_multiplier=self.attn_output_multiplier,
|
|
)(query, key, value, mask)
|
|
|
|
attn_output = attn_block(inputs, side_input, side_input, mask)
|
|
h_attn = attn_output.embeddings
|
|
|
|
return MHAOutput(embeddings=h_attn)
|
|
|
|
|
|
@dataclass
|
|
class DenseBlock(hk.Module):
|
|
num_q_heads: int
|
|
num_kv_heads: int
|
|
key_size: int
|
|
widening_factor: float = 4.0
|
|
|
|
@hk.transparent
|
|
def __call__(
|
|
self,
|
|
inputs: jax.Array, # [B, T, D]
|
|
) -> jax.Array: # [B, T, D]
|
|
_, _, model_size = inputs.shape
|
|
h_v = Linear(
|
|
ffn_size(model_size, self.widening_factor),
|
|
with_bias=False,
|
|
name="linear_v",
|
|
)(inputs)
|
|
h_w1 = jax.nn.gelu(
|
|
Linear(
|
|
ffn_size(model_size, self.widening_factor),
|
|
with_bias=False,
|
|
)(inputs)
|
|
)
|
|
h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)
|
|
|
|
return h_dense
|
|
|
|
|
|
@dataclass
|
|
class DecoderLayer(hk.Module):
|
|
"""A transformer stack."""
|
|
|
|
num_q_heads: int
|
|
num_kv_heads: int
|
|
key_size: int
|
|
num_layers: int
|
|
layer_index: Optional[int] = None
|
|
widening_factor: float = 4.0
|
|
name: Optional[str] = None
|
|
attn_output_multiplier: float = 1.0
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: jax.Array, # [B, T, D]
|
|
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T]
|
|
padding_mask: Optional[jax.Array],
|
|
) -> DecoderOutput:
|
|
"""Transforms input embedding sequences to output embedding sequences."""
|
|
del padding_mask # Unused.
|
|
|
|
def layer_norm(x):
|
|
return hk_rms_norm(x)
|
|
|
|
h = inputs
|
|
|
|
attn_output = MHABlock(
|
|
num_q_heads=self.num_q_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
key_size=self.key_size,
|
|
attn_output_multiplier=self.attn_output_multiplier,
|
|
)(layer_norm(h), mask)
|
|
h_attn = attn_output.embeddings
|
|
|
|
h_attn = layer_norm(h_attn)
|
|
h += h_attn
|
|
|
|
def base_dense_block(h):
|
|
h = DenseBlock(
|
|
num_q_heads=self.num_q_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
key_size=self.key_size,
|
|
widening_factor=self.widening_factor,
|
|
)(h)
|
|
return h
|
|
|
|
h_dense = base_dense_block(layer_norm(h))
|
|
|
|
h_dense = layer_norm(h_dense)
|
|
h += h_dense
|
|
|
|
return DecoderOutput(
|
|
embeddings=h,
|
|
)
|
|
|
|
|
|
def layer_norm(x):
|
|
return hk_rms_norm(x)
|
|
|
|
|
|
@dataclass
|
|
class Transformer(hk.Module):
|
|
"""A transformer stack."""
|
|
|
|
num_q_heads: int
|
|
num_kv_heads: int
|
|
key_size: int
|
|
widening_factor: float
|
|
attn_output_multiplier: float
|
|
num_layers: int
|
|
name: Optional[str] = None
|
|
|
|
def __call__(
|
|
self,
|
|
embeddings: jax.Array, # [B, T, D]
|
|
mask: jax.Array, # [B, T]
|
|
candidate_start_offset: Optional[int] = None,
|
|
) -> TransformerOutput:
|
|
"""Transforms input embedding sequences to output embedding sequences.
|
|
|
|
Args:
|
|
embeddings: Input embeddings of shape [B, T, D]
|
|
mask: Padding mask of shape [B, T], True for valid positions
|
|
candidate_start_offset: If provided, positions >= this offset are treated as
|
|
candidates that can only attend to positions before the offset (user+history)
|
|
and themselves (self-attention), but not to other candidates.
|
|
Used for recommendation system inference.
|
|
|
|
Returns:
|
|
TransformerOutput containing the output embeddings.
|
|
"""
|
|
|
|
fprop_dtype = embeddings.dtype
|
|
_, seq_len, _ = embeddings.shape
|
|
padding_mask = mask.copy()
|
|
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
|
|
|
|
if candidate_start_offset is not None:
|
|
# Use recommendation system attention mask where candidates attend to
|
|
# user+history and themselves, but not to other candidates
|
|
attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
|
|
mask = mask * attn_mask
|
|
else:
|
|
# Standard causal mask for autoregressive sequence modelling
|
|
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
|
|
fprop_dtype
|
|
) # [B=1, H=1, T, T]
|
|
mask = mask * causal_mask # [B, H=1, T, T]
|
|
|
|
h = embeddings
|
|
|
|
def block(
|
|
h,
|
|
mask,
|
|
padding_mask,
|
|
layer_index: Optional[int] = None,
|
|
widening_factor: Optional[int] = None,
|
|
name: Optional[str] = None,
|
|
) -> DecoderOutput:
|
|
return DecoderLayer(
|
|
num_q_heads=self.num_q_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
key_size=self.key_size,
|
|
widening_factor=widening_factor or self.widening_factor,
|
|
num_layers=self.num_layers,
|
|
attn_output_multiplier=self.attn_output_multiplier,
|
|
name=name,
|
|
layer_index=layer_index,
|
|
)(h, mask, padding_mask)
|
|
|
|
for i in range(self.num_layers):
|
|
decoder_output = block(
|
|
h,
|
|
mask,
|
|
padding_mask,
|
|
layer_index=i,
|
|
name=f"decoder_layer_{i}",
|
|
)
|
|
h = decoder_output.embeddings
|
|
|
|
return TransformerOutput(
|
|
embeddings=h,
|
|
)
|