Open-source X Recommendation Algorithm

This commit is contained in:
CI agent 2026-01-20 02:31:49 +00:00
commit aaa167b3de
79 changed files with 8816 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__/

1
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1 @@
Be excellent to each other.

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

325
README.md Normal file
View File

@ -0,0 +1,325 @@
# X For You Feed Algorithm
This repository contains the core recommendation system powering the "For You" feed on X. It combines in-network content (from accounts you follow) with out-of-network content (discovered through ML-based retrieval) and ranks everything using a Grok-based transformer model.
> **Note:** The transformer implementation is ported from the [Grok-1 open source release](https://github.com/xai-org/grok-1) by xAI, adapted for recommendation system use cases.
## Table of Contents
- [Overview](#overview)
- [System Architecture](#system-architecture)
- [Components](#components)
- [Home Mixer](#home-mixer)
- [Thunder](#thunder)
- [Phoenix](#phoenix)
- [Candidate Pipeline](#candidate-pipeline)
- [How It Works](#how-it-works)
- [Pipeline Stages](#pipeline-stages)
- [Scoring and Ranking](#scoring-and-ranking)
- [Filtering](#filtering)
- [Key Design Decisions](#key-design-decisions)
- [License](#license)
---
## Overview
The For You feed algorithm retrieves, ranks, and filters posts from two sources:
1. **In-Network (Thunder)**: Posts from accounts you follow
2. **Out-of-Network (Phoenix Retrieval)**: Posts discovered from a global corpus
Both sources are combined and ranked together using **Phoenix**, a Grok-based transformer model that predicts engagement probabilities for each post. The final score is a weighted combination of these predicted engagements.
We have eliminated every single hand-engineered feature and most heuristics from the system. The Grok-based transformer does all the heavy lifting by understanding your engagement history (what you liked, replied to, shared, etc.) and using that to determine what content is relevant to you.
---
## System Architecture
```
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
│ FOR YOU FEED REQUEST │
└─────────────────────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
│ HOME MIXER │
│ (Orchestration Layer) │
├─────────────────────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ QUERY HYDRATION │ │
│ │ ┌──────────────────────────┐ ┌──────────────────────────────────────────────┐ │ │
│ │ │ User Action Sequence │ │ User Features │ │ │
│ │ │ (engagement history) │ │ (following list, preferences, etc.) │ │ │
│ │ └──────────────────────────┘ └──────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ CANDIDATE SOURCES │ │
│ │ ┌─────────────────────────────┐ ┌────────────────────────────────┐ │ │
│ │ │ THUNDER │ │ PHOENIX RETRIEVAL │ │ │
│ │ │ (In-Network Posts) │ │ (Out-of-Network Posts) │ │ │
│ │ │ │ │ │ │ │
│ │ │ Posts from accounts │ │ ML-based similarity search │ │ │
│ │ │ you follow │ │ across global corpus │ │ │
│ │ └─────────────────────────────┘ └────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ HYDRATION │ │
│ │ Fetch additional data: core post metadata, author info, media entities, etc. │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ FILTERING │ │
│ │ Remove: duplicates, old posts, self-posts, blocked authors, muted keywords, etc. │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ SCORING │ │
│ │ ┌──────────────────────────┐ │ │
│ │ │ Phoenix Scorer │ Grok-based Transformer predicts: │ │
│ │ │ (ML Predictions) │ P(like), P(reply), P(repost), P(click)... │ │
│ │ └──────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────────────────────┐ │ │
│ │ │ Weighted Scorer │ Weighted Score = Σ (weight × P(action)) │ │
│ │ │ (Combine predictions) │ │ │
│ │ └──────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────────────────────┐ │ │
│ │ │ Author Diversity │ Attenuate repeated author scores │ │
│ │ │ Scorer │ to ensure feed diversity │ │
│ │ └──────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ SELECTION │ │
│ │ Sort by final score, select top K candidates │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ FILTERING (Post-Selection) │ │
│ │ Visibility filtering (deleted/spam/violence/gore etc) │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
│ RANKED FEED RESPONSE │
└─────────────────────────────────────────────────────────────────────────────────────────────┘
```
---
## Components
### Home Mixer
**Location:** [`home-mixer/`](home-mixer/)
The orchestration layer that assembles the For You feed. It leverages the `CandidatePipeline` framework with the following stages:
| Stage | Description |
|-------|-------------|
| Query Hydrators | Fetch user context (engagement history, following list) |
| Sources | Retrieve candidates from Thunder and Phoenix |
| Hydrators | Enrich candidates with additional data |
| Filters | Remove ineligible candidates |
| Scorers | Predict engagement and compute final scores |
| Selector | Sort by score and select top K |
| Post-Selection Filters | Final visibility and dedup checks |
| Side Effects | Cache request info for future use |
The server exposes a gRPC endpoint (`ScoredPostsService`) that returns ranked posts for a given user.
---
### Thunder
**Location:** [`thunder/`](thunder/)
An in-memory post store and realtime ingestion pipeline that tracks recent posts from all users. It:
- Consumes post create/delete events from Kafka
- Maintains per-user stores for original posts, replies/reposts, and video posts
- Serves "in-network" post candidates from accounts the requesting user follows
- Automatically trims posts older than the retention period
Thunder enables sub-millisecond lookups for in-network content without hitting an external database.
---
### Phoenix
**Location:** [`phoenix/`](phoenix/)
The ML component with two main functions:
#### 1. Retrieval (Two-Tower Model)
Finds relevant out-of-network posts:
- **User Tower**: Encodes user features and engagement history into an embedding
- **Candidate Tower**: Encodes all posts into embeddings
- **Similarity Search**: Retrieves top-K posts via dot product similarity
#### 2. Ranking (Transformer with Candidate Isolation)
Predicts engagement probabilities for each candidate:
- Takes user context (engagement history) and candidate posts as input
- Uses special attention masking so candidates cannot attend to each other
- Outputs probabilities for each action type (like, reply, repost, click, etc.)
See [`phoenix/README.md`](phoenix/README.md) for detailed architecture documentation.
---
### Candidate Pipeline
**Location:** [`candidate-pipeline/`](candidate-pipeline/)
A reusable framework for building recommendation pipelines. Defines traits for:
| Trait | Purpose |
|-------|---------|
| `Source` | Fetch candidates from a data source |
| `Hydrator` | Enrich candidates with additional features |
| `Filter` | Remove candidates that shouldn't be shown |
| `Scorer` | Compute scores for ranking |
| `Selector` | Sort and select top candidates |
| `SideEffect` | Run async side effects (caching, logging) |
The framework runs sources and hydrators in parallel where possible, with configurable error handling and logging.
---
## How It Works
### Pipeline Stages
1. **Query Hydration**: Fetch the user's recent engagements history and metadata (eg. following list)
2. **Candidate Sourcing**: Retrieve candidates from:
- **Thunder**: Recent posts from followed accounts (in-network)
- **Phoenix Retrieval**: ML-discovered posts from the global corpus (out-of-network)
3. **Candidate Hydration**: Enrich candidates with:
- Core post data (text, media, etc.)
- Author information (username, verification status)
- Video duration (for video posts)
- Subscription status
4. **Pre-Scoring Filters**: Remove posts that are:
- Duplicates
- Too old
- From the viewer themselves
- From blocked/muted accounts
- Containing muted keywords
- Previously seen or recently served
- Ineligible subscription content
5. **Scoring**: Apply multiple scorers sequentially:
- **Phoenix Scorer**: Get ML predictions from the Phoenix transformer model
- **Weighted Scorer**: Combine predictions into a final relevance score
- **Author Diversity Scorer**: Attenuate repeated author scores for diversity
- **OON Scorer**: Adjust scores for out-of-network content
6. **Selection**: Sort by score and select the top K candidates
7. **Post-Selection Processing**: Final validation of post candidates to be served
---
### Scoring and Ranking
The Phoenix Grok-based transformer model predicts probabilities for multiple engagement types:
```
Predictions:
├── P(favorite)
├── P(reply)
├── P(repost)
├── P(quote)
├── P(click)
├── P(profile_click)
├── P(video_view)
├── P(photo_expand)
├── P(share)
├── P(dwell)
├── P(follow_author)
├── P(not_interested)
├── P(block_author)
├── P(mute_author)
└── P(report)
```
The **Weighted Scorer** combines these into a final score:
```
Final Score = Σ (weight_i × P(action_i))
```
Positive actions (like, repost, share) have positive weights. Negative actions (block, mute, report) have negative weights, pushing down content the user would likely dislike.
---
### Filtering
Filters run at two stages:
**Pre-Scoring Filters:**
| Filter | Purpose |
|--------|---------|
| `DropDuplicatesFilter` | Remove duplicate post IDs |
| `CoreDataHydrationFilter` | Remove posts that failed to hydrate core metadata |
| `AgeFilter` | Remove posts older than threshold |
| `SelfpostFilter` | Remove user's own posts |
| `RepostDeduplicationFilter` | Dedupe reposts of same content |
| `IneligibleSubscriptionFilter` | Remove paywalled content user can't access |
| `PreviouslySeenPostsFilter` | Remove posts user has already seen |
| `PreviouslyServedPostsFilter` | Remove posts already served in session |
| `MutedKeywordFilter` | Remove posts with user's muted keywords |
| `AuthorSocialgraphFilter` | Remove posts from blocked/muted authors |
**Post-Selection Filters:**
| Filter | Purpose |
|--------|---------|
| `VFFilter` | Remove posts that are deleted/spam/violence/gore etc. |
| `DedupConversationFilter` | Deduplicate multiple branches of the same conversation thread |
---
## Key Design Decisions
### 1. No Hand-Engineered Features
The system relies entirely on the Grok-based transformer to learn relevance from user engagement sequences. No manual feature engineering for content relevance. This significantly reduces the complexity in our data pipelines and serving infrastructure.
### 2. Candidate Isolation in Ranking
During transformer inference, candidates cannot attend to each other—only to the user context. This ensures the score for a post doesn't depend on which other posts are in the batch, making scores consistent and cacheable.
### 3. Hash-Based Embeddings
Both retrieval and ranking use multiple hash functions for embedding lookup
### 4. Multi-Action Prediction
Rather than predicting a single "relevance" score, the model predicts probabilities for many actions.
### 5. Composable Pipeline Architecture
The `candidate-pipeline` crate provides a flexible framework for building recommendation pipelines with:
- Separation of pipeline execution and monitoring from business logic
- Parallel execution of independent stages and graceful error handling
- Easy addition of new sources, hydrations, filters, and scorers
---
## License
This project is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.

View File

@ -0,0 +1,329 @@
use crate::filter::Filter;
use crate::hydrator::Hydrator;
use crate::query_hydrator::QueryHydrator;
use crate::scorer::Scorer;
use crate::selector::Selector;
use crate::side_effect::{SideEffect, SideEffectInput};
use crate::source::Source;
use futures::future::join_all;
use log::{error, info, warn};
use std::sync::Arc;
use tonic::async_trait;
#[derive(Copy, Clone, Debug)]
pub enum PipelineStage {
QueryHydrator,
Source,
Hydrator,
PostSelectionHydrator,
Filter,
PostSelectionFilter,
Scorer,
}
pub struct PipelineResult<Q, C> {
pub retrieved_candidates: Vec<C>,
pub filtered_candidates: Vec<C>,
pub selected_candidates: Vec<C>,
pub query: Arc<Q>,
}
/// Provides a stable request identifier for logging/tracing.
pub trait HasRequestId {
fn request_id(&self) -> &str;
}
#[async_trait]
pub trait CandidatePipeline<Q, C>: Send + Sync
where
Q: HasRequestId + Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>];
fn sources(&self) -> &[Box<dyn Source<Q, C>>];
fn hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
fn filters(&self) -> &[Box<dyn Filter<Q, C>>];
fn scorers(&self) -> &[Box<dyn Scorer<Q, C>>];
fn selector(&self) -> &dyn Selector<Q, C>;
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
fn post_selection_filters(&self) -> &[Box<dyn Filter<Q, C>>];
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<Q, C>>>>;
fn result_size(&self) -> usize;
async fn execute(&self, query: Q) -> PipelineResult<Q, C> {
let hydrated_query = self.hydrate_query(query).await;
let candidates = self.fetch_candidates(&hydrated_query).await;
let hydrated_candidates = self.hydrate(&hydrated_query, candidates).await;
let (kept_candidates, mut filtered_candidates) = self
.filter(&hydrated_query, hydrated_candidates.clone())
.await;
let scored_candidates = self.score(&hydrated_query, kept_candidates).await;
let selected_candidates = self.select(&hydrated_query, scored_candidates);
let post_selection_hydrated_candidates = self
.hydrate_post_selection(&hydrated_query, selected_candidates)
.await;
let (mut final_candidates, post_selection_filtered_candidates) = self
.filter_post_selection(&hydrated_query, post_selection_hydrated_candidates)
.await;
filtered_candidates.extend(post_selection_filtered_candidates);
final_candidates.truncate(self.result_size());
let arc_hydrated_query = Arc::new(hydrated_query);
let input = Arc::new(SideEffectInput {
query: arc_hydrated_query.clone(),
selected_candidates: final_candidates.clone(),
});
self.run_side_effects(input);
PipelineResult {
retrieved_candidates: hydrated_candidates,
filtered_candidates,
selected_candidates: final_candidates,
query: arc_hydrated_query,
}
}
/// Run all query hydrators in parallel and merge results into the query.
async fn hydrate_query(&self, query: Q) -> Q {
let request_id = query.request_id().to_string();
let hydrators: Vec<_> = self
.query_hydrators()
.iter()
.filter(|h| h.enable(&query))
.collect();
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(&query));
let results = join_all(hydrate_futures).await;
let mut hydrated_query = query;
for (hydrator, result) in hydrators.iter().zip(results) {
match result {
Ok(hydrated) => {
hydrator.update(&mut hydrated_query, hydrated);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::QueryHydrator,
hydrator.name(),
err
);
}
}
}
hydrated_query
}
/// Run all candidate sources in parallel and collect results.
async fn fetch_candidates(&self, query: &Q) -> Vec<C> {
let request_id = query.request_id().to_string();
let sources: Vec<_> = self.sources().iter().filter(|s| s.enable(query)).collect();
let source_futures = sources.iter().map(|s| s.get_candidates(query));
let results = join_all(source_futures).await;
let mut collected = Vec::new();
for (source, result) in sources.iter().zip(results) {
match result {
Ok(mut candidates) => {
info!(
"request_id={} stage={:?} component={} fetched {} candidates",
request_id,
PipelineStage::Source,
source.name(),
candidates.len()
);
collected.append(&mut candidates);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::Source,
source.name(),
err
);
}
}
}
collected
}
/// Run all candidate hydrators in parallel and merge results into candidates.
async fn hydrate(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
self.run_hydrators(query, candidates, self.hydrators(), PipelineStage::Hydrator)
.await
}
/// Run post-selection candidate hydrators in parallel and merge results into candidates.
async fn hydrate_post_selection(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
self.run_hydrators(
query,
candidates,
self.post_selection_hydrators(),
PipelineStage::PostSelectionHydrator,
)
.await
}
/// Shared helper to hydrate with a provided hydrator list.
async fn run_hydrators(
&self,
query: &Q,
mut candidates: Vec<C>,
hydrators: &[Box<dyn Hydrator<Q, C>>],
stage: PipelineStage,
) -> Vec<C> {
let request_id = query.request_id().to_string();
let hydrators: Vec<_> = hydrators.iter().filter(|h| h.enable(query)).collect();
let expected_len = candidates.len();
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(query, &candidates));
let results = join_all(hydrate_futures).await;
for (hydrator, result) in hydrators.iter().zip(results) {
match result {
Ok(hydrated) => {
if hydrated.len() == expected_len {
hydrator.update_all(&mut candidates, hydrated);
} else {
warn!(
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
request_id,
stage,
hydrator.name(),
expected_len,
hydrated.len()
);
}
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
stage,
hydrator.name(),
err
);
}
}
}
candidates
}
/// Run all filters sequentially. Each filter partitions candidates into kept and removed.
async fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
self.run_filters(query, candidates, self.filters(), PipelineStage::Filter)
.await
}
/// Run post-scoring filters sequentially on already-scored candidates.
async fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
self.run_filters(
query,
candidates,
self.post_selection_filters(),
PipelineStage::PostSelectionFilter,
)
.await
}
// Shared helper to run filters sequentially from a provided filter list.
async fn run_filters(
&self,
query: &Q,
mut candidates: Vec<C>,
filters: &[Box<dyn Filter<Q, C>>],
stage: PipelineStage,
) -> (Vec<C>, Vec<C>) {
let request_id = query.request_id().to_string();
let mut all_removed = Vec::new();
for filter in filters.iter().filter(|f| f.enable(query)) {
let backup = candidates.clone();
match filter.filter(query, candidates).await {
Ok(result) => {
candidates = result.kept;
all_removed.extend(result.removed);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
stage,
filter.name(),
err
);
candidates = backup;
}
}
}
info!(
"request_id={} stage={:?} kept {}, removed {}",
request_id,
stage,
candidates.len(),
all_removed.len()
);
(candidates, all_removed)
}
/// Run all scorers sequentially and apply their results to candidates.
async fn score(&self, query: &Q, mut candidates: Vec<C>) -> Vec<C> {
let request_id = query.request_id().to_string();
let expected_len = candidates.len();
for scorer in self.scorers().iter().filter(|s| s.enable(query)) {
match scorer.score(query, &candidates).await {
Ok(scored) => {
if scored.len() == expected_len {
scorer.update_all(&mut candidates, scored);
} else {
warn!(
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
request_id,
PipelineStage::Scorer,
scorer.name(),
expected_len,
scored.len()
);
}
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::Scorer,
scorer.name(),
err
);
}
}
}
candidates
}
/// Select (sort/truncate) candidates using the configured selector
fn select(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
if self.selector().enable(query) {
self.selector().select(query, candidates)
} else {
candidates
}
}
// Run all side effects in parallel
fn run_side_effects(&self, input: Arc<SideEffectInput<Q, C>>) {
let side_effects = self.side_effects();
tokio::spawn(async move {
let futures = side_effects
.iter()
.filter(|se| se.enable(input.query.clone()))
.map(|se| se.run(input.clone()));
let _ = join_all(futures).await;
});
}
}

View File

@ -0,0 +1,32 @@
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::util;
pub struct FilterResult<C> {
pub kept: Vec<C>,
pub removed: Vec<C>,
}
/// Filters run sequentially and partition candidates into kept and removed sets
#[async_trait]
pub trait Filter<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this filter should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Filter candidates by evaluating each against some criteria.
/// Returns a FilterResult containing kept candidates (which continue to the next stage)
/// and removed candidates (which are excluded from further processing).
async fn filter(&self, query: &Q, candidates: Vec<C>) -> Result<FilterResult<C>, String>;
/// Returns a stable name for logging/metrics.
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}

View File

@ -0,0 +1,39 @@
use crate::util;
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
// Hydrators run in parallel and update candidate fields
#[async_trait]
pub trait Hydrator<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this hydrator should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Hydrate candidates by performing async operations.
/// Returns candidates with this hydrator's fields populated.
///
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
/// Dropping candidates in a hydrator is not allowed - use a filter stage instead.
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
/// Update a single candidate with the hydrated fields.
/// Only the fields this hydrator is responsible for should be copied.
fn update(&self, candidate: &mut C, hydrated: C);
/// Update all candidates with the hydrated fields from `hydrated`.
/// Default implementation iterates and calls `update` for each pair.
fn update_all(&self, candidates: &mut [C], hydrated: Vec<C>) {
for (c, h) in candidates.iter_mut().zip(hydrated) {
self.update(c, h);
}
}
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}

View File

@ -0,0 +1,9 @@
pub mod candidate_pipeline;
pub mod filter;
pub mod hydrator;
pub mod query_hydrator;
pub mod scorer;
pub mod selector;
pub mod side_effect;
pub mod source;
pub mod util;

View File

@ -0,0 +1,27 @@
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::util;
#[async_trait]
pub trait QueryHydrator<Q>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
{
/// Decide if this query hydrator should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Hydrate the query by performing async operations.
/// Returns a new query with this hydrator's fields populated.
async fn hydrate(&self, query: &Q) -> Result<Q, String>;
/// Update the query with the hydrated fields.
/// Only the fields this hydrator is responsible for should be copied.
fn update(&self, query: &mut Q, hydrated: Q);
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}

View File

@ -0,0 +1,39 @@
use crate::util;
use std::any::type_name_of_val;
use tonic::async_trait;
/// Scorers update candidate fields (like a score field) and run sequentially
#[async_trait]
pub trait Scorer<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this scorer should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Score candidates by performing async operations.
/// Returns candidates with this scorer's fields populated.
///
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
/// Dropping candidates in a scorer is not allowed - use a filter stage instead.
async fn score(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
/// Update a single candidate with the scored fields.
/// Only the fields this scorer is responsible for should be copied.
fn update(&self, candidate: &mut C, scored: C);
/// Update all candidates with the scored fields from `scored`.
/// Default implementation iterates and calls `update` for each pair.
fn update_all(&self, candidates: &mut [C], scored: Vec<C>) {
for (c, s) in candidates.iter_mut().zip(scored) {
self.update(c, s);
}
}
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}

View File

@ -0,0 +1,45 @@
use crate::util;
use std::any::type_name_of_val;
pub trait Selector<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Default selection: sort and truncate based on provided configs
fn select(&self, _query: &Q, candidates: Vec<C>) -> Vec<C> {
let mut sorted = self.sort(candidates);
if let Some(limit) = self.size() {
sorted.truncate(limit);
}
sorted
}
/// Decide if this selector should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Extract the score from a candidate to use for sorting.
fn score(&self, candidate: &C) -> f64;
/// Sort candidates by their scores in descending order.
fn sort(&self, candidates: Vec<C>) -> Vec<C> {
let mut sorted = candidates;
sorted.sort_by(|a, b| {
self.score(b)
.partial_cmp(&self.score(a))
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted
}
/// Optionally provide a size to select. Defaults to no truncation if not overridden.
fn size(&self) -> Option<usize> {
None
}
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}

View File

@ -0,0 +1,29 @@
use crate::util;
use std::any::type_name_of_val;
use std::sync::Arc;
use tonic::async_trait;
// A side-effect is an action run that doesn't affect the pipeline result from being returned
#[derive(Clone)]
pub struct SideEffectInput<Q, C> {
pub query: Arc<Q>,
pub selected_candidates: Vec<C>,
}
#[async_trait]
pub trait SideEffect<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this side-effect should be run
fn enable(&self, _query: Arc<Q>) -> bool {
true
}
async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String>;
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}

View File

@ -0,0 +1,22 @@
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::util;
#[async_trait]
pub trait Source<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this source should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
async fn get_candidates(&self, query: &Q) -> Result<Vec<C>, String>;
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}

View File

@ -0,0 +1,58 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct CoreDataCandidateHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
}
impl CoreDataCandidateHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let post_features = client.get_tweet_core_datas(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let core_data = post_features.and_then(|x| x.as_ref());
let text = core_data.map(|x| x.text.clone());
let hydrated = PostCandidate {
author_id: core_data.map(|x| x.author_id).unwrap_or_default(),
retweeted_user_id: core_data.and_then(|x| x.source_user_id),
retweeted_tweet_id: core_data.and_then(|x| x.source_tweet_id),
in_reply_to_tweet_id: core_data.and_then(|x| x.in_reply_to_tweet_id),
tweet_text: text.unwrap_or_default(),
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.retweeted_user_id = hydrated.retweeted_user_id;
candidate.retweeted_tweet_id = hydrated.retweeted_tweet_id;
candidate.in_reply_to_tweet_id = hydrated.in_reply_to_tweet_id;
candidate.tweet_text = hydrated.tweet_text;
}
}

View File

@ -0,0 +1,81 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::gizmoduck_client::GizmoduckClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct GizmoduckCandidateHydrator {
pub gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
}
impl GizmoduckCandidateHydrator {
pub async fn new(gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>) -> Self {
Self { gizmoduck_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.gizmoduck_client;
let author_ids: Vec<_> = candidates.iter().map(|c| c.author_id).collect();
let author_ids: Vec<_> = author_ids.iter().map(|&x| x as i64).collect();
let retweet_user_ids: Vec<_> = candidates.iter().map(|c| c.retweeted_user_id).collect();
let retweet_user_ids: Vec<_> = retweet_user_ids.iter().flatten().collect();
let retweet_user_ids: Vec<_> = retweet_user_ids.iter().map(|&&x| x as i64).collect();
let mut user_ids_to_fetch = Vec::with_capacity(author_ids.len() + retweet_user_ids.len());
user_ids_to_fetch.extend(author_ids);
user_ids_to_fetch.extend(retweet_user_ids);
user_ids_to_fetch.dedup();
let users = client.get_users(user_ids_to_fetch).await;
let users = users.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
let user = users
.get(&(candidate.author_id as i64))
.and_then(|user| user.as_ref());
let user_counts = user.and_then(|user| user.user.as_ref().map(|u| &u.counts));
let user_profile = user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let author_followers_count: Option<i32> =
user_counts.map(|x| x.followers_count).map(|x| x as i32);
let author_screen_name: Option<String> = user_profile.map(|x| x.screen_name.clone());
let retweet_user = candidate
.retweeted_user_id
.and_then(|retweeted_user_id| users.get(&(retweeted_user_id as i64)))
.and_then(|user| user.as_ref());
let retweet_profile =
retweet_user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let retweeted_screen_name: Option<String> =
retweet_profile.map(|x| x.screen_name.clone());
let hydrated = PostCandidate {
author_followers_count,
author_screen_name,
retweeted_screen_name,
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.author_followers_count = hydrated.author_followers_count;
candidate.author_screen_name = hydrated.author_screen_name;
candidate.retweeted_screen_name = hydrated.retweeted_screen_name;
}
}

View File

@ -0,0 +1,44 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct InNetworkCandidateHydrator;
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for InNetworkCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let viewer_id = query.user_id as u64;
let followed_ids: HashSet<u64> = query
.user_features
.followed_user_ids
.iter()
.copied()
.map(|id| id as u64)
.collect();
let hydrated_candidates = candidates
.iter()
.map(|candidate| {
let is_self = candidate.author_id == viewer_id;
let is_in_network = is_self || followed_ids.contains(&candidate.author_id);
PostCandidate {
in_network: Some(is_in_network),
..Default::default()
}
})
.collect();
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.in_network = hydrated.in_network;
}
}

View File

@ -0,0 +1,6 @@
pub mod core_data_candidate_hydrator;
pub mod gizmoduck_hydrator;
pub mod in_network_candidate_hydrator;
pub mod subscription_hydrator;
pub mod vf_candidate_hydrator;
pub mod video_duration_candidate_hydrator;

View File

@ -0,0 +1,50 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct SubscriptionHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
}
impl SubscriptionHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for SubscriptionHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let post_features = client.get_subscription_author_ids(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let subscription_author_id = post_features.and_then(|x| *x);
let hydrated = PostCandidate {
subscription_author_id,
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.subscription_author_id = hydrated.subscription_author_id;
}
}

View File

@ -0,0 +1,101 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use futures::future::join;
use std::collections::HashMap;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_twittercontext_proto::GetTwitterContextViewer;
use xai_twittercontext_proto::TwitterContextViewer;
use xai_visibility_filtering::models::FilteredReason;
use xai_visibility_filtering::vf_client::SafetyLevel;
use xai_visibility_filtering::vf_client::SafetyLevel::{TimelineHome, TimelineHomeRecommendations};
use xai_visibility_filtering::vf_client::VisibilityFilteringClient;
pub struct VFCandidateHydrator {
pub vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
}
impl VFCandidateHydrator {
pub async fn new(vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>) -> Self {
Self { vf_client }
}
async fn fetch_vf_results(
client: &Arc<dyn VisibilityFilteringClient + Send + Sync>,
tweet_ids: Vec<i64>,
safety_level: SafetyLevel,
for_user_id: i64,
context: Option<TwitterContextViewer>,
) -> Result<HashMap<i64, Option<FilteredReason>>, String> {
if tweet_ids.is_empty() {
return Ok(HashMap::new());
}
client
.get_result(tweet_ids, safety_level, for_user_id, context)
.await
.map_err(|e| e.to_string())
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for VFCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let context = query.get_viewer();
let user_id = query.user_id;
let client = &self.vf_client;
let mut in_network_ids = Vec::new();
let mut oon_ids = Vec::new();
for candidate in candidates.iter() {
if candidate.in_network.unwrap_or(false) {
in_network_ids.push(candidate.tweet_id);
} else {
oon_ids.push(candidate.tweet_id);
}
}
let in_network_future = Self::fetch_vf_results(
client,
in_network_ids,
TimelineHome,
user_id,
context.clone(),
);
let oon_future = Self::fetch_vf_results(
client,
oon_ids,
TimelineHomeRecommendations,
user_id,
context,
);
let (in_network_result, oon_result) = join(in_network_future, oon_future).await;
let mut result: HashMap<i64, Option<FilteredReason>> = HashMap::new();
result.extend(in_network_result?);
result.extend(oon_result?);
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
let visibility_reason = result.get(&candidate.tweet_id);
let visibility_reason = visibility_reason.unwrap_or(&None);
let hydrated = PostCandidate {
visibility_reason: visibility_reason.clone(),
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.visibility_reason = hydrated.visibility_reason;
}
}

View File

@ -0,0 +1,62 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::candidate_features::MediaInfo;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct VideoDurationCandidateHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
}
impl VideoDurationCandidateHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for VideoDurationCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let post_features = client.get_tweet_media_entities(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let media_entities = post_features.and_then(|x| x.as_ref());
let video_duration_ms = media_entities.and_then(|entities| {
entities.iter().find_map(|entity| {
if let Some(MediaInfo::VideoInfo(video_info)) = &entity.media_info {
Some(video_info.duration_millis)
} else {
None
}
})
});
let hydrated = PostCandidate {
video_duration_ms,
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.video_duration_ms = hydrated.video_duration_ms;
}
}

View File

@ -0,0 +1,70 @@
use std::collections::HashMap;
use xai_home_mixer_proto as pb;
use xai_visibility_filtering::models as vf;
#[derive(Clone, Debug, Default)]
pub struct PostCandidate {
pub tweet_id: i64,
pub author_id: u64,
pub tweet_text: String,
pub in_reply_to_tweet_id: Option<u64>,
pub retweeted_tweet_id: Option<u64>,
pub retweeted_user_id: Option<u64>,
pub phoenix_scores: PhoenixScores,
pub prediction_request_id: Option<u64>,
pub last_scored_at_ms: Option<u64>,
pub weighted_score: Option<f64>,
pub score: Option<f64>,
pub served_type: Option<pb::ServedType>,
pub in_network: Option<bool>,
pub ancestors: Vec<u64>,
pub video_duration_ms: Option<i32>,
pub author_followers_count: Option<i32>,
pub author_screen_name: Option<String>,
pub retweeted_screen_name: Option<String>,
pub visibility_reason: Option<vf::FilteredReason>,
pub subscription_author_id: Option<u64>,
}
#[derive(Clone, Debug, Default)]
pub struct PhoenixScores {
pub favorite_score: Option<f64>,
pub reply_score: Option<f64>,
pub retweet_score: Option<f64>,
pub photo_expand_score: Option<f64>,
pub click_score: Option<f64>,
pub profile_click_score: Option<f64>,
pub vqv_score: Option<f64>,
pub share_score: Option<f64>,
pub share_via_dm_score: Option<f64>,
pub share_via_copy_link_score: Option<f64>,
pub dwell_score: Option<f64>,
pub quote_score: Option<f64>,
pub quoted_click_score: Option<f64>,
pub follow_author_score: Option<f64>,
pub not_interested_score: Option<f64>,
pub block_author_score: Option<f64>,
pub mute_author_score: Option<f64>,
pub report_score: Option<f64>,
// Continuous actions
pub dwell_time: Option<f64>,
}
pub trait CandidateHelpers {
fn get_screen_names(&self) -> HashMap<u64, String>;
}
impl CandidateHelpers for PostCandidate {
fn get_screen_names(&self) -> HashMap<u64, String> {
let mut screen_names = HashMap::<u64, String>::new();
if let Some(author_screen_name) = self.author_screen_name.clone() {
screen_names.insert(self.author_id, author_screen_name);
}
if let (Some(retweeted_screen_name), Some(retweeted_user_id)) =
(self.retweeted_screen_name.clone(), self.retweeted_user_id)
{
screen_names.insert(retweeted_user_id, retweeted_screen_name);
}
screen_names
}
}

View File

@ -0,0 +1,78 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct PureCoreData {
pub author_id: u64,
pub text: String,
pub source_tweet_id: Option<u64>,
pub source_user_id: Option<u64>,
pub in_reply_to_tweet_id: Option<u64>,
pub in_reply_to_user_id: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct ExclusiveTweetControl {
pub conversation_author_id: i64,
}
pub type MediaEntities = Vec<MediaEntity>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct MediaEntity {
pub media_info: Option<MediaInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub enum MediaInfo {
VideoInfo(VideoInfo),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct VideoInfo {
pub duration_millis: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct Share {
pub source_tweet_id: u64,
pub source_user_id: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct Reply {
pub in_reply_to_tweet_id: Option<u64>,
pub in_reply_to_user_id: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUserCounts {
pub followers_count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUserProfile {
pub screen_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUser {
pub user_id: u64,
pub profile: GizmoduckUserProfile,
pub counts: GizmoduckUserCounts,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUserResult {
pub user: Option<GizmoduckUser>,
}

View File

@ -0,0 +1,5 @@
pub mod candidate;
pub mod candidate_features;
pub mod phoenix_candidate_pipeline;
pub mod query;
pub mod query_features;

View File

@ -0,0 +1,255 @@
use crate::candidate_hydrators::core_data_candidate_hydrator::CoreDataCandidateHydrator;
use crate::candidate_hydrators::gizmoduck_hydrator::GizmoduckCandidateHydrator;
use crate::candidate_hydrators::in_network_candidate_hydrator::InNetworkCandidateHydrator;
use crate::candidate_hydrators::subscription_hydrator::SubscriptionHydrator;
use crate::candidate_hydrators::vf_candidate_hydrator::VFCandidateHydrator;
use crate::candidate_hydrators::video_duration_candidate_hydrator::VideoDurationCandidateHydrator;
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::gizmoduck_client::{GizmoduckClient, ProdGizmoduckClient};
use crate::clients::phoenix_prediction_client::{
PhoenixPredictionClient, ProdPhoenixPredictionClient,
};
use crate::clients::phoenix_retrieval_client::{
PhoenixRetrievalClient, ProdPhoenixRetrievalClient,
};
use crate::clients::s2s::{S2S_CHAIN_PATH, S2S_CRT_PATH, S2S_KEY_PATH};
use crate::clients::socialgraph_client::SocialGraphClient;
use crate::clients::strato_client::{ProdStratoClient, StratoClient};
use crate::clients::thunder_client::ThunderClient;
use crate::clients::tweet_entity_service_client::{ProdTESClient, TESClient};
use crate::clients::uas_fetcher::UserActionSequenceFetcher;
use crate::filters::age_filter::AgeFilter;
use crate::filters::author_socialgraph_filter::AuthorSocialgraphFilter;
use crate::filters::core_data_hydration_filter::CoreDataHydrationFilter;
use crate::filters::dedup_conversation_filter::DedupConversationFilter;
use crate::filters::drop_duplicates_filter::DropDuplicatesFilter;
use crate::filters::ineligible_subscription_filter::IneligibleSubscriptionFilter;
use crate::filters::muted_keyword_filter::MutedKeywordFilter;
use crate::filters::previously_seen_posts_filter::PreviouslySeenPostsFilter;
use crate::filters::previously_served_posts_filter::PreviouslyServedPostsFilter;
use crate::filters::retweet_deduplication_filter::RetweetDeduplicationFilter;
use crate::filters::self_tweet_filter::SelfTweetFilter;
use crate::filters::vf_filter::VFFilter;
use crate::params;
use crate::query_hydrators::user_action_seq_query_hydrator::UserActionSeqQueryHydrator;
use crate::query_hydrators::user_features_query_hydrator::UserFeaturesQueryHydrator;
use crate::scorers::author_diversity_scorer::AuthorDiversityScorer;
use crate::scorers::oon_scorer::OONScorer;
use crate::scorers::phoenix_scorer::PhoenixScorer;
use crate::scorers::weighted_scorer::WeightedScorer;
use crate::selectors::TopKScoreSelector;
use crate::side_effects::cache_request_info_side_effect::CacheRequestInfoSideEffect;
use crate::sources::phoenix_source::PhoenixSource;
use crate::sources::thunder_source::ThunderSource;
use std::sync::Arc;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
use xai_candidate_pipeline::filter::Filter;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
use xai_candidate_pipeline::scorer::Scorer;
use xai_candidate_pipeline::selector::Selector;
use xai_candidate_pipeline::side_effect::SideEffect;
use xai_candidate_pipeline::source::Source;
use xai_visibility_filtering::vf_client::{
ProdVisibilityFilteringClient, VisibilityFilteringClient,
};
pub struct PhoenixCandidatePipeline {
query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>>,
sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>>,
hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>>,
filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>>,
scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>>,
selector: TopKScoreSelector,
post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>>,
post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>>,
side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>>,
}
impl PhoenixCandidatePipeline {
async fn build_with_clients(
uas_fetcher: Arc<UserActionSequenceFetcher>,
phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
thunder_client: Arc<ThunderClient>,
strato_client: Arc<dyn StratoClient + Send + Sync>,
tes_client: Arc<dyn TESClient + Send + Sync>,
gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
) -> PhoenixCandidatePipeline {
// Query Hydrators
let query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>> = vec![
Box::new(UserActionSeqQueryHydrator::new(uas_fetcher)),
Box::new(UserFeaturesQueryHydrator {
strato_client: strato_client.clone(),
}),
];
// Sources
let phoenix_source = Box::new(PhoenixSource {
phoenix_retrieval_client,
});
let thunder_source = Box::new(ThunderSource { thunder_client });
let sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>> =
vec![phoenix_source, thunder_source];
// Hydrators
let hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(InNetworkCandidateHydrator),
Box::new(CoreDataCandidateHydrator::new(tes_client.clone()).await),
Box::new(VideoDurationCandidateHydrator::new(tes_client.clone()).await),
Box::new(SubscriptionHydrator::new(tes_client.clone()).await),
Box::new(GizmoduckCandidateHydrator::new(gizmoduck_client).await),
];
// Filters
let filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(DropDuplicatesFilter),
Box::new(CoreDataHydrationFilter),
Box::new(AgeFilter::new(Duration::from_secs(params::MAX_POST_AGE))),
Box::new(SelfTweetFilter),
Box::new(RetweetDeduplicationFilter),
Box::new(IneligibleSubscriptionFilter),
Box::new(PreviouslySeenPostsFilter),
Box::new(PreviouslyServedPostsFilter),
Box::new(MutedKeywordFilter::new()),
Box::new(AuthorSocialgraphFilter),
];
// Scorers
let phoenix_scorer = Box::new(PhoenixScorer { phoenix_client });
let weighted_scorer = Box::new(WeightedScorer);
let author_diversity_scorer = Box::new(AuthorDiversityScorer::default());
let oon_scorer = Box::new(OONScorer);
let scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>> = vec![
phoenix_scorer,
weighted_scorer,
author_diversity_scorer,
oon_scorer,
];
// Selector
let selector = TopKScoreSelector;
// Post-selection hydrators
let post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> =
vec![Box::new(VFCandidateHydrator::new(vf_client.clone()).await)];
// Post-selection filters
let post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> =
vec![Box::new(VFFilter), Box::new(DedupConversationFilter)];
// Side Effects
let side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> =
Arc::new(vec![Box::new(CacheRequestInfoSideEffect { strato_client })]);
PhoenixCandidatePipeline {
query_hydrators,
hydrators,
filters,
sources,
scorers,
selector,
post_selection_hydrators,
post_selection_filters,
side_effects,
}
}
pub async fn prod() -> PhoenixCandidatePipeline {
let uas_fetcher =
Arc::new(UserActionSequenceFetcher::new().expect("Failed to create UAS fetcher"));
let _sgs_client = Arc::new(SocialGraphClient::new());
let phoenix_client = Arc::new(
ProdPhoenixPredictionClient::new()
.await
.expect("Failed to create Phoenix prediction client"),
);
let phoenix_retrieval_client = Arc::new(
ProdPhoenixRetrievalClient::new()
.await
.expect("Failed to create Phoenix retrieval client"),
);
let thunder_client = Arc::new(ThunderClient::new().await);
let strato_client = Arc::new(
ProdStratoClient::new()
.await
.expect("Failed to create Strato client"),
);
let tes_client = Arc::new(
ProdTESClient::new()
.await
.expect("Failed to create TES client"),
);
let gizmoduck_client = Arc::new(
ProdGizmoduckClient::new()
.await
.expect("Failed to create Gizmoduck client"),
);
let vf_client = Arc::new(
ProdVisibilityFilteringClient::new(
S2S_CHAIN_PATH.clone(),
S2S_CRT_PATH.clone(),
S2S_KEY_PATH.clone()
)
.await
.expect("Failed to create VF client"),
);
PhoenixCandidatePipeline::build_with_clients(
uas_fetcher,
phoenix_client,
phoenix_retrieval_client,
thunder_client,
strato_client,
tes_client,
gizmoduck_client,
vf_client,
)
.await
}
}
#[async_trait]
impl CandidatePipeline<ScoredPostsQuery, PostCandidate> for PhoenixCandidatePipeline {
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<ScoredPostsQuery>>] {
&self.query_hydrators
}
fn sources(&self) -> &[Box<dyn Source<ScoredPostsQuery, PostCandidate>>] {
&self.sources
}
fn hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>] {
&self.hydrators
}
fn filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, PostCandidate>>] {
&self.filters
}
fn scorers(&self) -> &[Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>] {
&self.scorers
}
fn selector(&self) -> &dyn Selector<ScoredPostsQuery, PostCandidate> {
&self.selector
}
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>] {
&self.post_selection_hydrators
}
fn post_selection_filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, PostCandidate>>] {
&self.post_selection_filters
}
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> {
Arc::clone(&self.side_effects)
}
fn result_size(&self) -> usize {
params::RESULT_SIZE
}
}

View File

@ -0,0 +1,69 @@
use crate::candidate_pipeline::query_features::UserFeatures;
use crate::util::request_util::generate_request_id;
use xai_candidate_pipeline::candidate_pipeline::HasRequestId;
use xai_home_mixer_proto::ImpressionBloomFilterEntry;
use xai_twittercontext_proto::{GetTwitterContextViewer, TwitterContextViewer};
#[derive(Clone, Default, Debug)]
pub struct ScoredPostsQuery {
pub user_id: i64,
pub client_app_id: i32,
pub country_code: String,
pub language_code: String,
pub seen_ids: Vec<i64>,
pub served_ids: Vec<i64>,
pub in_network_only: bool,
pub is_bottom_request: bool,
pub bloom_filter_entries: Vec<ImpressionBloomFilterEntry>,
pub user_action_sequence: Option<xai_recsys_proto::UserActionSequence>,
pub user_features: UserFeatures,
pub request_id: String,
}
impl ScoredPostsQuery {
pub fn new(
user_id: i64,
client_app_id: i32,
country_code: String,
language_code: String,
seen_ids: Vec<i64>,
served_ids: Vec<i64>,
in_network_only: bool,
is_bottom_request: bool,
bloom_filter_entries: Vec<ImpressionBloomFilterEntry>,
) -> Self {
let request_id = format!("{}-{}", generate_request_id(), user_id);
Self {
user_id,
client_app_id,
country_code,
language_code,
seen_ids,
served_ids,
in_network_only,
is_bottom_request,
bloom_filter_entries,
user_action_sequence: None,
user_features: UserFeatures::default(),
request_id,
}
}
}
impl GetTwitterContextViewer for ScoredPostsQuery {
fn get_viewer(&self) -> Option<TwitterContextViewer> {
Some(TwitterContextViewer {
user_id: self.user_id,
client_application_id: self.client_app_id as i64,
request_country_code: self.country_code.clone(),
request_language_code: self.language_code.clone(),
..Default::default()
})
}
}
impl HasRequestId for ScoredPostsQuery {
fn request_id(&self) -> &str {
&self.request_id
}
}

View File

@ -0,0 +1,11 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct UserFeatures {
pub muted_keywords: Vec<String>,
pub blocked_user_ids: Vec<i64>,
pub muted_user_ids: Vec<i64>,
pub followed_user_ids: Vec<i64>,
pub subscribed_user_ids: Vec<i64>,
}

View File

@ -0,0 +1,38 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::util::snowflake;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filter that removes tweets older than a specified duration.
pub struct AgeFilter {
pub max_age: Duration,
}
impl AgeFilter {
pub fn new(max_age: Duration) -> Self {
Self { max_age }
}
fn is_within_age(&self, tweet_id: i64) -> bool {
snowflake::duration_since_creation_opt(tweet_id)
.map(|age| age <= self.max_age)
.unwrap_or(false)
}
}
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for AgeFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (kept, removed): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| self.is_within_age(c.tweet_id));
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,42 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
// Remove candidates that are blocked or muted by the viewer
pub struct AuthorSocialgraphFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for AuthorSocialgraphFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let viewer_blocked_user_ids = query.user_features.blocked_user_ids.clone();
let viewer_muted_user_ids = query.user_features.muted_user_ids.clone();
if viewer_blocked_user_ids.is_empty() && viewer_muted_user_ids.is_empty() {
return Ok(FilterResult {
kept: candidates,
removed: Vec::new(),
});
}
let mut kept: Vec<PostCandidate> = Vec::new();
let mut removed: Vec<PostCandidate> = Vec::new();
for candidate in candidates {
let author_id = candidate.author_id as i64;
let muted = viewer_muted_user_ids.contains(&author_id);
let blocked = viewer_blocked_user_ids.contains(&author_id);
if muted || blocked {
removed.push(candidate);
} else {
kept.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,20 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct CoreDataHydrationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for CoreDataHydrationFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (kept, removed) = candidates
.into_iter()
.partition(|c| c.author_id != 0 && !c.tweet_text.trim().is_empty());
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,51 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashMap;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Keeps only the highest-scored candidate per branch of a conversation tree
pub struct DedupConversationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for DedupConversationFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let mut kept: Vec<PostCandidate> = Vec::new();
let mut removed: Vec<PostCandidate> = Vec::new();
let mut best_per_convo: HashMap<u64, (usize, f64)> = HashMap::new();
for candidate in candidates {
let conversation_id = get_conversation_id(&candidate);
let score = candidate.score.unwrap_or(0.0);
if let Some((kept_idx, best_score)) = best_per_convo.get_mut(&conversation_id) {
if score > *best_score {
let previous = std::mem::replace(&mut kept[*kept_idx], candidate);
removed.push(previous);
*best_score = score;
} else {
removed.push(candidate);
}
} else {
let idx = kept.len();
best_per_convo.insert(conversation_id, (idx, score));
kept.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}
fn get_conversation_id(candidate: &PostCandidate) -> u64 {
candidate
.ancestors
.iter()
.copied()
.min()
.unwrap_or(candidate.tweet_id as u64)
}

View File

@ -0,0 +1,30 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct DropDuplicatesFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for DropDuplicatesFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let mut seen_ids = HashSet::new();
let mut kept = Vec::new();
let mut removed = Vec::new();
for candidate in candidates {
if seen_ids.insert(candidate.tweet_id) {
kept.push(candidate);
} else {
removed.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,34 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filters out subscription-only posts from authors the viewer is not subscribed to.
pub struct IneligibleSubscriptionFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for IneligibleSubscriptionFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let subscribed_user_ids: HashSet<u64> = query
.user_features
.subscribed_user_ids
.iter()
.map(|id| *id as u64)
.collect();
let (kept, removed): (Vec<_>, Vec<_>) =
candidates
.into_iter()
.partition(|candidate| match candidate.subscription_author_id {
Some(author_id) => subscribed_user_ids.contains(&author_id),
None => true,
});
Ok(FilterResult { kept, removed })
}
}

13
home-mixer/filters/mod.rs Normal file
View File

@ -0,0 +1,13 @@
pub mod age_filter;
pub mod author_socialgraph_filter;
pub mod core_data_hydration_filter;
pub mod dedup_conversation_filter;
pub mod drop_duplicates_filter;
pub mod ineligible_subscription_filter;
pub mod muted_keyword_filter;
pub mod previously_seen_posts_filter;
pub mod previously_served_posts_filter;
pub mod retweet_deduplication_filter;
pub mod self_tweet_filter;
pub mod vf_filter;

View File

@ -0,0 +1,59 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
use xai_post_text::{MatchTweetGroup, TokenSequence, TweetTokenizer, UserMutes};
pub struct MutedKeywordFilter {
pub tokenizer: Arc<TweetTokenizer>,
}
impl MutedKeywordFilter {
pub fn new() -> Self {
let tokenizer = TweetTokenizer::new();
Self {
tokenizer: Arc::new(tokenizer),
}
}
}
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for MutedKeywordFilter {
#[xai_stats_macro::receive_stats]
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let muted_keywords = query.user_features.muted_keywords.clone();
if muted_keywords.is_empty() {
return Ok(FilterResult {
kept: candidates,
removed: vec![],
});
}
let tokenized = muted_keywords.iter().map(|k| self.tokenizer.tokenize(k));
let token_sequences: Vec<TokenSequence> = tokenized.collect::<Vec<_>>();
let user_mutes = UserMutes::new(token_sequences);
let matcher = MatchTweetGroup::new(user_mutes);
let mut kept = Vec::new();
let mut removed = Vec::new();
for candidate in candidates {
let tweet_text_token_sequence = self.tokenizer.tokenize(&candidate.tweet_text);
if matcher.matches(&tweet_text_token_sequence) {
// Matches muted keywords - should be removed/filtered out
removed.push(candidate);
} else {
// Does not match muted keywords - keep it
kept.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,36 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::util::bloom_filter::BloomFilter;
use crate::util::candidates_util::get_related_post_ids;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filter out previously seen posts using a Bloom Filter and
/// the seen IDs sent in the request directly from the client
pub struct PreviouslySeenPostsFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for PreviouslySeenPostsFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let bloom_filters = query
.bloom_filter_entries
.iter()
.map(BloomFilter::from_entry)
.collect::<Vec<_>>();
let (removed, kept): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
get_related_post_ids(c).iter().any(|&post_id| {
query.seen_ids.contains(&post_id)
|| bloom_filters
.iter()
.any(|filter| filter.may_contain(post_id))
})
});
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,28 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::util::candidates_util::get_related_post_ids;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct PreviouslyServedPostsFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for PreviouslyServedPostsFilter {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
query.is_bottom_request
}
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (removed, kept): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
get_related_post_ids(c)
.iter()
.any(|id| query.served_ids.contains(id))
});
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,42 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Deduplicates retweets, keeping only the first occurrence of a tweet
/// (whether as an original or as a retweet).
pub struct RetweetDeduplicationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for RetweetDeduplicationFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let mut seen_tweet_ids: HashSet<u64> = HashSet::new();
let mut kept = Vec::new();
let mut removed = Vec::new();
for candidate in candidates {
match candidate.retweeted_tweet_id {
Some(retweeted_id) => {
// Remove if we've already seen this tweet (as original or retweet)
if seen_tweet_ids.insert(retweeted_id) {
kept.push(candidate);
} else {
removed.push(candidate);
}
}
None => {
// Mark this original tweet ID as seen so retweets of it get filtered
seen_tweet_ids.insert(candidate.tweet_id as u64);
kept.push(candidate);
}
}
}
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,23 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filter that removes tweets where the author is the viewer.
pub struct SelfTweetFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for SelfTweetFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let viewer_id = query.user_id as u64;
let (kept, removed): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| c.author_id != viewer_id);
Ok(FilterResult { kept, removed })
}
}

View File

@ -0,0 +1,33 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
use xai_visibility_filtering::models::{Action, FilteredReason};
pub struct VFFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for VFFilter {
#[xai_stats_macro::receive_stats]
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (removed, kept): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| should_drop(&c.visibility_reason));
Ok(FilterResult { kept, removed })
}
}
fn should_drop(reason: &Option<FilteredReason>) -> bool {
match reason {
Some(FilteredReason::SafetyResult(safety_result)) => {
matches!(safety_result.action, Action::Drop(_))
}
Some(_) => true,
None => false,
}
}

14
home-mixer/lib.rs Normal file
View File

@ -0,0 +1,14 @@
mod candidate_hydrators;
mod candidate_pipeline;
pub mod clients; // Excluded from open source release for security reasons
mod filters;
pub mod params; // Excluded from open source release for security reasons
mod query_hydrators;
pub mod scorers;
mod selectors;
mod server;
mod side_effects;
mod sources;
pub mod util; // Excluded from open source release for security reasons
pub use server::HomeMixerServer;

78
home-mixer/main.rs Normal file
View File

@ -0,0 +1,78 @@
use clap::Parser;
use log::info;
use std::time::Duration;
use tonic::codec::CompressionEncoding;
use tonic::service::RoutesBuilder;
use tonic_reflection::server::Builder;
use xai_home_mixer_proto as pb;
use xai_http_server::{CancellationToken, GrpcConfig, HttpServer};
use xai_home_mixer::HomeMixerServer;
use xai_home_mixer::params;
#[derive(Parser, Debug)]
#[command(about = "HomeMixer gRPC Server")]
struct Args {
#[arg(long)]
grpc_port: u16,
#[arg(long)]
metrics_port: u16,
#[arg(long)]
reload_interval_minutes: u64,
#[arg(long)]
chunk_size: usize,
}
#[xai_stats_macro::main(name = "home-mixer")]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
xai_init_utils::init().log();
xai_init_utils::init().rustls();
info!(
"Starting server with gRPC port: {}, metrics port: {}, reload interval: {} minutes, chunk size: {}",
args.grpc_port, args.metrics_port, args.reload_interval_minutes, args.chunk_size,
);
// Create the service implementation
let service = HomeMixerServer::new().await;
// Keep a reference to stats_receiver before service is moved
let reflection_service = Builder::configure()
.register_encoded_file_descriptor_set(pb::FILE_DESCRIPTOR_SET)
.build_v1()?;
let mut grpc_routes = RoutesBuilder::default();
grpc_routes.add_service(
pb::scored_posts_service_server::ScoredPostsServiceServer::new(service)
.max_decoding_message_size(params::MAX_GRPC_MESSAGE_SIZE)
.max_encoding_message_size(params::MAX_GRPC_MESSAGE_SIZE)
.accept_compressed(CompressionEncoding::Gzip)
.accept_compressed(CompressionEncoding::Zstd)
.send_compressed(CompressionEncoding::Gzip)
.send_compressed(CompressionEncoding::Zstd),
);
grpc_routes.add_service(reflection_service);
let grpc_config = GrpcConfig::new(args.grpc_port, grpc_routes.routes());
let http_router = axum::Router::default();
let mut server = HttpServer::new(
args.metrics_port,
http_router,
Some(grpc_config),
CancellationToken::new(),
Duration::from_secs(20),
)
.await?;
server.set_readiness(true);
info!("Server ready");
server.wait_for_termination().await;
info!("Server shutdown complete");
Ok(())
}

View File

@ -0,0 +1,2 @@
pub mod user_action_seq_query_hydrator;
pub mod user_features_query_hydrator;

View File

@ -0,0 +1,188 @@
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::uas_fetcher::{UserActionSequenceFetcher, UserActionSequenceOps};
use crate::params as p;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tonic::async_trait;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
use xai_recsys_aggregation::aggregation::{DefaultAggregator, UserActionAggregator};
use xai_recsys_aggregation::filters::{
AggregatedActionFilter, DenseAggregatedActionFilter, KeepOriginalUserActionFilter,
UserActionFilter,
};
use xai_recsys_proto::{
AggregatedUserActionList, Mask, MaskType, UserActionSequence, UserActionSequenceDataContainer,
UserActionSequenceMeta, user_action_sequence_data_container::Data as ProtoDataContainer,
};
use xai_uas_thrift::convert::thrift_to_proto_aggregated_user_action;
use xai_uas_thrift::user_action_sequence::{
AggregatedUserAction as ThriftAggregatedUserAction,
UserActionSequence as ThriftUserActionSequence,
UserActionSequenceMeta as ThriftUserActionSequenceMeta,
};
/// Hydrate a sequence that captures the user's recent actions
pub struct UserActionSeqQueryHydrator {
pub uas_fetcher: Arc<UserActionSequenceFetcher>,
global_filter: Arc<dyn UserActionFilter>,
aggregator: Arc<dyn UserActionAggregator>,
post_filters: Vec<Arc<dyn AggregatedActionFilter>>,
}
impl UserActionSeqQueryHydrator {
pub fn new(uas_fetcher: Arc<UserActionSequenceFetcher>) -> Self {
Self {
uas_fetcher,
global_filter: Arc::new(KeepOriginalUserActionFilter::new()),
aggregator: Arc::new(DefaultAggregator),
post_filters: vec![Arc::new(DenseAggregatedActionFilter::new())],
}
}
}
#[async_trait]
impl QueryHydrator<ScoredPostsQuery> for UserActionSeqQueryHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPostsQuery, String> {
let uas_thrift = self
.uas_fetcher
.get_by_user_id(query.user_id)
.await
.map_err(|e| format!("Failed to fetch user action sequence: {}", e))?;
let aggregated_uas_proto =
self.aggregate_user_action_sequence(query.user_id, uas_thrift)?;
Ok(ScoredPostsQuery {
user_action_sequence: Some(aggregated_uas_proto),
..Default::default()
})
}
fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQuery) {
query.user_action_sequence = hydrated.user_action_sequence;
}
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
impl UserActionSeqQueryHydrator {
fn aggregate_user_action_sequence(
&self,
user_id: i64,
uas_thrift: ThriftUserActionSequence,
) -> Result<UserActionSequence, String> {
// Extract user_actions from thrift sequence
let thrift_user_actions = uas_thrift.user_actions.clone().unwrap_or_default();
if thrift_user_actions.is_empty() {
return Err(format!("No user actions found for user {}", user_id));
}
// Pre-aggregation filter
let filtered_actions = self.global_filter.run(thrift_user_actions);
if filtered_actions.is_empty() {
return Err(format!(
"No user actions remaining after filtering for user {}",
user_id
));
}
// Aggregate
let mut aggregated_actions =
self.aggregator
.run(&filtered_actions, p::UAS_WINDOW_TIME_MS, 0);
// Post-aggregation filters
for filter in &self.post_filters {
aggregated_actions = filter.run(aggregated_actions);
}
// Truncate to max sequence length (keep last N items)
if aggregated_actions.len() > p::UAS_MAX_SEQUENCE_LENGTH {
let drain_count = aggregated_actions.len() - p::UAS_MAX_SEQUENCE_LENGTH;
aggregated_actions.drain(0..drain_count);
}
// Convert to proto format
let original_metadata = uas_thrift.metadata.clone().unwrap_or_default();
convert_to_proto_sequence(
user_id,
original_metadata,
aggregated_actions,
self.aggregator.name(),
)
}
}
fn convert_to_proto_sequence(
user_id: i64,
original_metadata: ThriftUserActionSequenceMeta,
aggregated_actions: Vec<ThriftAggregatedUserAction>,
aggregator_name: &str,
) -> Result<UserActionSequence, String> {
if aggregated_actions.is_empty() {
return Err("Cannot create sequence from empty aggregated actions".to_string());
}
let first_sequence_time = aggregated_actions
.first()
.and_then(|a| a.impressed_time_ms)
.unwrap_or(0) as u64;
let last_sequence_time = aggregated_actions
.last()
.and_then(|a| a.impressed_time_ms)
.unwrap_or(0) as u64;
// Preserve lastModifiedEpochMs and lastKafkaPublishEpochMs from original metadata
let last_modified_epoch_ms = original_metadata.last_modified_epoch_ms.unwrap_or(0) as u64;
let previous_kafka_publish_epoch_ms =
original_metadata.last_kafka_publish_epoch_ms.unwrap_or(0) as u64;
let proto_metadata = UserActionSequenceMeta {
length: aggregated_actions.len() as u64,
first_sequence_time,
last_sequence_time,
last_modified_epoch_ms,
previous_kafka_publish_epoch_ms,
};
// Convert thrift aggregated actions to proto
let mut proto_agg_actions = Vec::with_capacity(aggregated_actions.len());
for action in aggregated_actions {
proto_agg_actions.push(
thrift_to_proto_aggregated_user_action(action)
.map_err(|e| format!("Failed to convert aggregated action: {}", e))?,
);
}
let aggregation_time_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let agg_list = AggregatedUserActionList {
aggregated_user_actions: proto_agg_actions,
aggregation_provider: aggregator_name.to_string(),
aggregation_time_ms,
};
let mask = Mask {
mask_type: MaskType::NewEvent as i32,
mask: vec![false; agg_list.aggregated_user_actions.len()],
};
// Build the final UserActionSequence
Ok(UserActionSequence {
user_id: user_id as u64,
metadata: Some(proto_metadata),
user_actions_data: Some(UserActionSequenceDataContainer {
data: Some(ProtoDataContainer::OrderedAggregatedUserActionsList(
agg_list,
)),
}),
masks: vec![mask],
..Default::default()
})
}

View File

@ -0,0 +1,41 @@
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::candidate_pipeline::query_features::UserFeatures;
use crate::clients::strato_client::StratoClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
use xai_strato::{StratoResult, StratoValue, decode};
pub struct UserFeaturesQueryHydrator {
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
}
#[async_trait]
impl QueryHydrator<ScoredPostsQuery> for UserFeaturesQueryHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPostsQuery, String> {
let user_id = query.user_id;
let client = &self.strato_client;
let result = client.get_user_features(user_id);
let result = result.await.map_err(|e| e.to_string())?;
let decoded: StratoResult<StratoValue<UserFeatures>> = decode(&result);
match decoded {
StratoResult::Ok(v) => {
let user_features = v.v.unwrap_or_default();
Ok(ScoredPostsQuery {
user_features,
..Default::default()
})
}
StratoResult::Err(_) => Err("Error received from strato".to_string()),
}
}
fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQuery) {
query.user_features = hydrated.user_features;
}
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}

View File

@ -0,0 +1,73 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params as p;
use std::cmp::Ordering;
use std::collections::HashMap;
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
/// Diversify authors served within a single feed response
pub struct AuthorDiversityScorer {
decay_factor: f64,
floor: f64,
}
impl Default for AuthorDiversityScorer {
fn default() -> Self {
Self::new(p::AUTHOR_DIVERSITY_DECAY, p::AUTHOR_DIVERSITY_FLOOR)
}
}
impl AuthorDiversityScorer {
pub fn new(decay_factor: f64, floor: f64) -> Self {
Self {
decay_factor,
floor,
}
}
fn multiplier(&self, position: usize) -> f64 {
(1.0 - self.floor) * self.decay_factor.powf(position as f64) + self.floor
}
}
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for AuthorDiversityScorer {
#[xai_stats_macro::receive_stats]
async fn score(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let mut author_counts: HashMap<u64, usize> = HashMap::new();
let mut scored = vec![PostCandidate::default(); candidates.len()];
let mut ordered: Vec<(usize, &PostCandidate)> = candidates.iter().enumerate().collect();
ordered.sort_by(|(_, a), (_, b)| {
let a_score = a.weighted_score.unwrap_or(f64::NEG_INFINITY);
let b_score = b.weighted_score.unwrap_or(f64::NEG_INFINITY);
b_score.partial_cmp(&a_score).unwrap_or(Ordering::Equal)
});
for (original_idx, candidate) in ordered {
let entry = author_counts.entry(candidate.author_id).or_insert(0);
let position = *entry;
*entry += 1;
let multiplier = self.multiplier(position);
let adjusted_score = candidate.weighted_score.map(|score| score * multiplier);
let updated = PostCandidate {
score: adjusted_score,
..Default::default()
};
scored[original_idx] = updated;
}
Ok(scored)
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.score = scored.score;
}
}

View File

@ -0,0 +1,4 @@
pub mod author_diversity_scorer;
pub mod oon_scorer;
pub mod phoenix_scorer;
pub mod weighted_scorer;

View File

@ -0,0 +1,38 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params as p;
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
// Prioritize in-network candidates over out-of-network candidates
pub struct OONScorer;
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for OONScorer {
async fn score(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let updated_score = c.score.map(|base_score| match c.in_network {
Some(false) => base_score * p::OON_WEIGHT_FACTOR,
_ => base_score,
});
PostCandidate {
score: updated_score,
..Default::default()
}
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.score = scored.score;
}
}

View File

@ -0,0 +1,176 @@
use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::phoenix_prediction_client::PhoenixPredictionClient;
use crate::util::request_util;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
use xai_recsys_proto::{ActionName, ContinuousActionName};
pub struct PhoenixScorer {
pub phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
}
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for PhoenixScorer {
#[xai_stats_macro::receive_stats]
async fn score(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let user_id = query.user_id as u64;
let prediction_request_id = request_util::generate_request_id();
let last_scored_at_ms = Self::current_timestamp_millis();
if let Some(sequence) = &query.user_action_sequence {
let tweet_infos: Vec<xai_recsys_proto::TweetInfo> = candidates
.iter()
.map(|c| {
let tweet_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id as u64);
let author_id = c.retweeted_user_id.unwrap_or(c.author_id);
xai_recsys_proto::TweetInfo {
tweet_id,
author_id,
..Default::default()
}
})
.collect();
let result = self
.phoenix_client
.predict(user_id, sequence.clone(), tweet_infos)
.await;
if let Ok(response) = result {
let predictions_map = self.build_predictions_map(&response);
let scored_candidates = candidates
.iter()
.map(|c| {
// For retweets, look up predictions using the original tweet id
let lookup_tweet_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id as u64);
let phoenix_scores = predictions_map
.get(&lookup_tweet_id)
.map(|preds| self.extract_phoenix_scores(preds))
.unwrap_or_default();
PostCandidate {
phoenix_scores,
prediction_request_id: Some(prediction_request_id),
last_scored_at_ms,
..Default::default()
}
})
.collect();
return Ok(scored_candidates);
}
}
// Return candidates unchanged if no scoring could be done
Ok(candidates.to_vec())
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.phoenix_scores = scored.phoenix_scores;
candidate.prediction_request_id = scored.prediction_request_id;
candidate.last_scored_at_ms = scored.last_scored_at_ms;
}
}
impl PhoenixScorer {
/// Builds Map[tweet_id -> ActionPredictions]
fn build_predictions_map(
&self,
response: &xai_recsys_proto::PredictNextActionsResponse,
) -> HashMap<u64, ActionPredictions> {
let mut predictions_map = HashMap::new();
let Some(distribution_set) = response.distribution_sets.first() else {
return predictions_map;
};
for distribution in &distribution_set.candidate_distributions {
let Some(candidate) = &distribution.candidate else {
continue;
};
let tweet_id = candidate.tweet_id;
let action_probs: HashMap<usize, f64> = distribution
.top_log_probs
.iter()
.enumerate()
.map(|(idx, log_prob)| (idx, (*log_prob as f64).exp()))
.collect();
let continuous_values: HashMap<usize, f64> = distribution
.continuous_actions_values
.iter()
.enumerate()
.map(|(idx, value)| (idx, *value as f64))
.collect();
predictions_map.insert(
tweet_id,
ActionPredictions {
action_probs,
continuous_values,
},
);
}
predictions_map
}
fn extract_phoenix_scores(&self, p: &ActionPredictions) -> PhoenixScores {
PhoenixScores {
favorite_score: p.get(ActionName::ServerTweetFav),
reply_score: p.get(ActionName::ServerTweetReply),
retweet_score: p.get(ActionName::ServerTweetRetweet),
photo_expand_score: p.get(ActionName::ClientTweetPhotoExpand),
click_score: p.get(ActionName::ClientTweetClick),
profile_click_score: p.get(ActionName::ClientTweetClickProfile),
vqv_score: p.get(ActionName::ClientTweetVideoQualityView),
share_score: p.get(ActionName::ClientTweetShare),
share_via_dm_score: p.get(ActionName::ClientTweetClickSendViaDirectMessage),
share_via_copy_link_score: p.get(ActionName::ClientTweetShareViaCopyLink),
dwell_score: p.get(ActionName::ClientTweetRecapDwelled),
quote_score: p.get(ActionName::ServerTweetQuote),
quoted_click_score: p.get(ActionName::ClientQuotedTweetClick),
follow_author_score: p.get(ActionName::ClientTweetFollowAuthor),
not_interested_score: p.get(ActionName::ClientTweetNotInterestedIn),
block_author_score: p.get(ActionName::ClientTweetBlockAuthor),
mute_author_score: p.get(ActionName::ClientTweetMuteAuthor),
report_score: p.get(ActionName::ClientTweetReport),
dwell_time: p.get_continuous(ContinuousActionName::DwellTime),
}
}
fn current_timestamp_millis() -> Option<u64> {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.ok()
.map(|duration| duration.as_millis() as u64)
}
}
struct ActionPredictions {
/// Map of action index -> probability (exp of log prob)
action_probs: HashMap<usize, f64>,
/// Map of continuous action index -> value
continuous_values: HashMap<usize, f64>,
}
impl ActionPredictions {
fn get(&self, action: ActionName) -> Option<f64> {
self.action_probs.get(&(action as usize)).copied()
}
fn get_continuous(&self, action: ContinuousActionName) -> Option<f64> {
self.continuous_values.get(&(action as usize)).copied()
}
}

View File

@ -0,0 +1,92 @@
use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params as p;
use crate::util::score_normalizer::normalize_score;
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
pub struct WeightedScorer;
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for WeightedScorer {
#[xai_stats_macro::receive_stats]
async fn score(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let weighted_score = Self::compute_weighted_score(c);
let normalized_weighted_score = normalize_score(c, weighted_score);
PostCandidate {
weighted_score: Some(normalized_weighted_score),
..Default::default()
}
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.weighted_score = scored.weighted_score;
}
}
impl WeightedScorer {
fn apply(score: Option<f64>, weight: f64) -> f64 {
score.unwrap_or(0.0) * weight
}
fn compute_weighted_score(candidate: &PostCandidate) -> f64 {
let s: &PhoenixScores = &candidate.phoenix_scores;
let vqv_weight = Self::vqv_weight_eligibility(candidate);
let combined_score = Self::apply(s.favorite_score, p::FAVORITE_WEIGHT)
+ Self::apply(s.reply_score, p::REPLY_WEIGHT)
+ Self::apply(s.retweet_score, p::RETWEET_WEIGHT)
+ Self::apply(s.photo_expand_score, p::PHOTO_EXPAND_WEIGHT)
+ Self::apply(s.click_score, p::CLICK_WEIGHT)
+ Self::apply(s.profile_click_score, p::PROFILE_CLICK_WEIGHT)
+ Self::apply(s.vqv_score, vqv_weight)
+ Self::apply(s.share_score, p::SHARE_WEIGHT)
+ Self::apply(s.share_via_dm_score, p::SHARE_VIA_DM_WEIGHT)
+ Self::apply(s.share_via_copy_link_score, p::SHARE_VIA_COPY_LINK_WEIGHT)
+ Self::apply(s.dwell_score, p::DWELL_WEIGHT)
+ Self::apply(s.quote_score, p::QUOTE_WEIGHT)
+ Self::apply(s.quoted_click_score, p::QUOTED_CLICK_WEIGHT)
+ Self::apply(s.dwell_time, p::CONT_DWELL_TIME_WEIGHT)
+ Self::apply(s.follow_author_score, p::FOLLOW_AUTHOR_WEIGHT)
+ Self::apply(s.not_interested_score, p::NOT_INTERESTED_WEIGHT)
+ Self::apply(s.block_author_score, p::BLOCK_AUTHOR_WEIGHT)
+ Self::apply(s.mute_author_score, p::MUTE_AUTHOR_WEIGHT)
+ Self::apply(s.report_score, p::REPORT_WEIGHT);
Self::offset_score(combined_score)
}
fn vqv_weight_eligibility(candidate: &PostCandidate) -> f64 {
if candidate
.video_duration_ms
.is_some_and(|ms| ms > p::MIN_VIDEO_DURATION_MS)
{
p::VQV_WEIGHT
} else {
0.0
}
}
fn offset_score(combined_score: f64) -> f64 {
if p::WEIGHTS_SUM == 0.0 {
combined_score.max(0.0)
} else if combined_score < 0.0 {
(combined_score + p::NEGATIVE_WEIGHTS_SUM) / p::WEIGHTS_SUM * p::NEGATIVE_SCORES_OFFSET
} else {
combined_score + p::NEGATIVE_SCORES_OFFSET
}
}
}

View File

@ -0,0 +1,3 @@
mod top_k_score_selector;
pub use top_k_score_selector::TopKScoreSelector;

View File

@ -0,0 +1,15 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params;
use xai_candidate_pipeline::selector::Selector;
pub struct TopKScoreSelector;
impl Selector<ScoredPostsQuery, PostCandidate> for TopKScoreSelector {
fn score(&self, candidate: &PostCandidate) -> f64 {
candidate.score.unwrap_or(f64::NEG_INFINITY)
}
fn size(&self) -> Option<usize> {
Some(params::TOP_K_CANDIDATES_TO_SELECT)
}
}

83
home-mixer/server.rs Normal file
View File

@ -0,0 +1,83 @@
use crate::candidate_pipeline::candidate::CandidateHelpers;
use crate::candidate_pipeline::phoenix_candidate_pipeline::PhoenixCandidatePipeline;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use log::info;
use std::sync::Arc;
use std::time::Instant;
use tonic::{Request, Response, Status};
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
use xai_home_mixer_proto as pb;
use xai_home_mixer_proto::{ScoredPost, ScoredPostsResponse};
pub struct HomeMixerServer {
phx_candidate_pipeline: Arc<PhoenixCandidatePipeline>,
}
impl HomeMixerServer {
pub async fn new() -> Self {
HomeMixerServer {
phx_candidate_pipeline: Arc::new(PhoenixCandidatePipeline::prod().await),
}
}
}
#[tonic::async_trait]
impl pb::scored_posts_service_server::ScoredPostsService for HomeMixerServer {
#[xai_stats_macro::receive_stats]
async fn get_scored_posts(
&self,
request: Request<pb::ScoredPostsQuery>,
) -> Result<Response<ScoredPostsResponse>, Status> {
let proto_query = request.into_inner();
if proto_query.viewer_id == 0 {
return Err(Status::invalid_argument("viewer_id must be specified"));
}
let start = Instant::now();
let query = ScoredPostsQuery::new(
proto_query.viewer_id,
proto_query.client_app_id,
proto_query.country_code,
proto_query.language_code,
proto_query.seen_ids,
proto_query.served_ids,
proto_query.in_network_only,
proto_query.is_bottom_request,
proto_query.bloom_filter_entries,
);
info!("Scored Posts request - request_id {}", query.request_id);
let pipeline_result = self.phx_candidate_pipeline.execute(query).await;
let scored_posts: Vec<ScoredPost> = pipeline_result
.selected_candidates
.into_iter()
.map(|candidate| {
let screen_names = candidate.get_screen_names();
ScoredPost {
tweet_id: candidate.tweet_id as u64,
author_id: candidate.author_id,
retweeted_tweet_id: candidate.retweeted_tweet_id.unwrap_or(0),
retweeted_user_id: candidate.retweeted_user_id.unwrap_or(0),
in_reply_to_tweet_id: candidate.in_reply_to_tweet_id.unwrap_or(0),
score: candidate.score.unwrap_or(0.0) as f32,
in_network: candidate.in_network.unwrap_or(false),
served_type: candidate.served_type.map(|t| t as i32).unwrap_or_default(),
last_scored_timestamp_ms: candidate.last_scored_at_ms.unwrap_or(0),
prediction_request_id: candidate.prediction_request_id.unwrap_or(0),
ancestors: candidate.ancestors,
screen_names,
visibility_reason: candidate.visibility_reason.map(|r| r.into()),
}
})
.collect();
info!(
"Scored Posts response - request_id {} - {} posts ({} ms)",
pipeline_result.query.request_id,
scored_posts.len(),
start.elapsed().as_millis()
);
Ok(Response::new(ScoredPostsResponse { scored_posts }))
}
}

View File

@ -0,0 +1,42 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::strato_client::StratoClient;
use std::env;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::side_effect::{SideEffect, SideEffectInput};
use xai_strato::{StratoResult, StratoValue, decode};
pub struct CacheRequestInfoSideEffect {
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
}
#[async_trait]
impl SideEffect<ScoredPostsQuery, PostCandidate> for CacheRequestInfoSideEffect {
fn enable(&self, query: Arc<ScoredPostsQuery>) -> bool {
env::var("APP_ENV").unwrap_or_default() == "prod" && !query.in_network_only
}
async fn run(
&self,
input: Arc<SideEffectInput<ScoredPostsQuery, PostCandidate>>,
) -> Result<(), String> {
let user_id: i64 = input.query.user_id;
let post_ids: Vec<i64> = input
.selected_candidates
.iter()
.map(|c| c.tweet_id)
.collect();
let client = &self.strato_client;
let res = client
.store_request_info(user_id, post_ids)
.await
.map_err(|e| e.to_string())?;
let decoded: StratoResult<StratoValue<()>> = decode(&res);
match decoded {
StratoResult::Ok(_) => Ok(()),
StratoResult::Err(_) => Err("error received from strato".to_string()),
}
}
}

View File

@ -0,0 +1 @@
pub mod cache_request_info_side_effect;

View File

@ -0,0 +1,2 @@
pub mod phoenix_source;
pub mod thunder_source;

View File

@ -0,0 +1,51 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::phoenix_retrieval_client::PhoenixRetrievalClient;
use crate::params as p;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::source::Source;
use xai_home_mixer_proto as pb;
pub struct PhoenixSource {
pub phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
}
#[async_trait]
impl Source<ScoredPostsQuery, PostCandidate> for PhoenixSource {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.in_network_only
}
#[xai_stats_macro::receive_stats]
async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>, String> {
let user_id = query.user_id as u64;
let sequence = query
.user_action_sequence
.as_ref()
.ok_or_else(|| "PhoenixSource: missing user_action_sequence".to_string())?;
let response = self
.phoenix_retrieval_client
.retrieve(user_id, sequence.clone(), p::PHOENIX_MAX_RESULTS)
.await
.map_err(|e| format!("PhoenixSource: {}", e))?;
let candidates: Vec<PostCandidate> = response
.top_k_candidates
.into_iter()
.flat_map(|scored_candidates| scored_candidates.candidates)
.filter_map(|scored_candidate| scored_candidate.candidate)
.map(|tweet_info| PostCandidate {
tweet_id: tweet_info.tweet_id as i64,
author_id: tweet_info.author_id,
in_reply_to_tweet_id: Some(tweet_info.in_reply_to_tweet_id),
served_type: Some(pb::ServedType::ForYouPhoenixRetrieval),
..Default::default()
})
.collect();
Ok(candidates)
}
}

View File

@ -0,0 +1,74 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::thunder_client::{ThunderClient, ThunderCluster};
use crate::params as p;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::source::Source;
use xai_home_mixer_proto as pb;
use xai_thunder_proto::GetInNetworkPostsRequest;
use xai_thunder_proto::in_network_posts_service_client::InNetworkPostsServiceClient;
pub struct ThunderSource {
pub thunder_client: Arc<ThunderClient>,
}
#[async_trait]
impl Source<ScoredPostsQuery, PostCandidate> for ThunderSource {
#[xai_stats_macro::receive_stats]
async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>, String> {
let cluster = ThunderCluster::Amp;
let channel = self
.thunder_client
.get_random_channel(cluster)
.ok_or_else(|| "ThunderSource: no available channel".to_string())?;
let mut client = InNetworkPostsServiceClient::new(channel.clone());
let following_list = &query.user_features.followed_user_ids;
let request = GetInNetworkPostsRequest {
user_id: query.user_id as u64,
following_user_ids: following_list.iter().map(|&id| id as u64).collect(),
max_results: p::THUNDER_MAX_RESULTS,
exclude_tweet_ids: vec![],
algorithm: "default".to_string(),
debug: false,
is_video_request: false,
};
let response = client
.get_in_network_posts(request)
.await
.map_err(|e| format!("ThunderSource: {}", e))?;
let candidates: Vec<PostCandidate> = response
.into_inner()
.posts
.into_iter()
.map(|post| {
let in_reply_to_tweet_id = post
.in_reply_to_post_id
.and_then(|id| u64::try_from(id).ok());
let conversation_id = post.conversation_id.and_then(|id| u64::try_from(id).ok());
let mut ancestors = Vec::new();
if let Some(reply_to) = in_reply_to_tweet_id {
ancestors.push(reply_to);
if let Some(root) = conversation_id.filter(|&root| root != reply_to) {
ancestors.push(root);
}
}
PostCandidate {
tweet_id: post.post_id,
author_id: post.author_id as u64,
in_reply_to_tweet_id,
ancestors,
served_type: Some(pb::ServedType::ForYouInNetwork),
..Default::default()
}
})
.collect();
Ok(candidates)
}
}

206
phoenix/README.md Normal file
View File

@ -0,0 +1,206 @@
# Phoenix: Recommendation System
This repository contains JAX example code for the Phoenix recommendation system, which powers content ranking and retrieval. Phoenix uses transformer-based architectures for both **retrieval** (finding relevant candidates from millions of items) and **ranking** (ordering a smaller set of candidates by predicted engagement).
> **Note:** The sample transformer implementation in this repository is ported from the [Grok-1 open source release](https://github.com/xai-org/grok-1) by xAI. The core transformer architecture comes from Grok-1, adapted here for recommendation system use cases with custom input embeddings and attention masking for candidate isolation. This code is representative of the model used internally with the exception of specific scaling optimizations.
## Table of Contents
- [Overview](#overview)
- [Architecture](#architecture)
- [Two-Stage Recommendation Pipeline](#two-stage-recommendation-pipeline)
- [Retrieval: Two-Tower Model](#retrieval-two-tower-model)
- [Ranking: Transformer with Candidate Isolation](#ranking-transformer-with-candidate-isolation)
- [Key Design Decisions](#key-design-decisions)
- [Running the Code](#running-the-code)
- [License](#license)
---
## Overview
Phoenix is a recommendation system that predicts user engagement (likes, reposts, replies, etc.) for content. It operates in two stages:
1. **Retrieval**: Efficiently narrow down millions of candidates to hundreds using approximate nearest neighbor (ANN) search
2. **Ranking**: Score and order the retrieved candidates using a more expressive transformer model
---
## Architecture
### Two-Stage Recommendation Pipeline
```
┌─────────────────────────────────────────────────────────────────────────────────┐
│ RECOMMENDATION PIPELINE │
├─────────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ │ │ │ │ │ │
│ │ User │────▶│ STAGE 1: │────▶│ STAGE 2: │────▶ Feed│
│ │ Request │ │ RETRIEVAL │ │ RANKING │ │
│ │ │ │ (Two-Tower) │ │ (Transformer) │ │
│ └──────────┘ │ │ │ │ │
│ │ Millions → 1000s │ │ 1000s → Ranked │ │
│ └─────────────────────┘ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────────┘
```
---
### Retrieval: Two-Tower Model
The retrieval stage uses a **two-tower architecture** that enables efficient similarity search at scale.
#### How Retrieval Works
1. **User Tower**: Encodes user features and engagement history through a transformer to produce a normalized user embedding `[B, D]`
2. **Candidate Tower**: Computes normalized embeddings for all items in the corpus `[N, D]`
3. **Similarity Search**: Retrieves top-K candidates using dot product similarity
---
### Ranking: Transformer with Candidate Isolation
The ranking model uses a transformer architecture where **candidates cannot attend to each other** during inference. This is a critical design choice that ensures the score for a candidate doesn't depend on which other candidates are in the batch
#### Ranking Model Architecture
```
PHOENIX RANKING MODEL
┌────────────────────────────────────────────────────────────────────────────┐
│ │
│ OUTPUT LOGITS │
│ [B, num_candidates, num_actions] │
│ │ │
│ │ Unembedding │
│ │ Projection │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ │ │ │
│ │ Extract Candidate Outputs │ │
│ │ (positions after history) │ │
│ │ │ │
│ └───────────────┬───────────────┘ │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ │ │ │
│ │ Transformer │ │
│ │ (with special masking) │ │
│ │ │ │
│ │ Candidates CANNOT attend │ │
│ │ to each other │ │
│ │ │ │
│ └───────────────┬───────────────┘ │
│ │ │
│ ┌───────────────────────────────┼───────────────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────┐ ┌─────────────────┐ ┌────────────┐ │
│ │ User │ │ History │ │ Candidates │ │
│ │Embedding │ │ Embeddings │ │ Embeddings │ │
│ │ [B, 1] │ │ [B, S, D] │ │ [B, C, D] │ │
│ │ │ │ │ │ │ │
│ │ User │ │ Posts + Authors │ │ Posts + │ │
│ │ Hashes │ │ + Actions + │ │ Authors + │ │
│ │ │ │ Product Surface │ │ Product │ │
│ └──────────┘ └─────────────────┘ │ Surface │ │
│ └────────────┘ │
│ │
└────────────────────────────────────────────────────────────────────────────┘
```
#### Attention Mask: Candidate Isolation
A key detail is the **attention mask** that prevents candidates from attending to each other while still allowing them to attend to the user and history:
```
ATTENTION MASK VISUALIZATION
Keys (what we attend TO)
─────────────────────────────────────────────▶
│ User │ History (S positions) │ Candidates (C positions) │
┌────┼──────┼─────────────────────────────┼───────────────────────────────┤
│ │ │ │ │
│ U │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
│ │ │ │ │
├────┼──────┼─────────────────────────────┼───────────────────────────────┤
Q │ │ │ │ │
u │ H │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
e │ i │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
r │ s │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
i │ t │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
e │ │ │ │ │
s ├────┼──────┼─────────────────────────────┼───────────────────────────────┤
│ │ │ │ DIAGONAL ONLY (self-attend) │
│ │ C │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✓ ✗ ✗ ✗ ✗ ✗ ✗ │
│ │ a │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✓ ✗ ✗ ✗ ✗ ✗ │
│ │ n │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✓ ✗ ✗ ✗ ✗ │
│ │ d │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✓ ✗ ✗ ✗ │
│ │ i │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✓ ✗ ✗ │
│ │ d │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✓ ✗ │
▼ │ s │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✓ │
│ │ │ │ │
└────┴──────┴─────────────────────────────┴───────────────────────────────┘
✓ = Can attend (1) ✗ = Cannot attend (0)
Legend:
├─ User + History: Full bidirectional attention among themselves
├─ Candidates → User/History: Candidates CAN attend to user and history
└─ Candidates → Candidates: Candidates CANNOT attend to each other (only self)
```
---
## Key Design Decisions
### 1. Hash-Based Embeddings
Both models use multiple hash functions for embedding lookup
### 2. Shared Architecture
The retrieval user tower uses the same transformer architecture as the ranking model
### 3. Multi-Action Prediction
The ranking model predicts multiple engagement types simultaneously:
```
Output: [B, num_candidates, num_actions]
┌─────────────────────────────────────┐
│ Like │ Repost │ Reply │ Click │ ... │
└─────────────────────────────────────┘
```
---
## Running the Code
### Installation
Install [uv](https://docs.astral.sh/uv/getting-started/installation/)
### Running the Ranker
```shell
uv run run_ranker.py
```
### Running Retrieval
```shell
uv run run_retrieval.py
```
### Running Tests
```shell
uv run pytest test_recsys_model.py test_recsys_retrieval_model.py
```

586
phoenix/grok.py Normal file
View File

@ -0,0 +1,586 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
logger = logging.getLogger(__name__)
class TrainingState(NamedTuple):
"""Container for the training state."""
params: hk.Params
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
return _ffn_size
def make_recsys_attn_mask(
seq_len: int,
candidate_start_offset: int,
dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""Create attention mask for recommendation system inference.
Creates a mask where:
- Positions 0 to candidate_start_offset-1 (user+history): causal attention
- Positions candidate_start_offset onwards (candidates): can attend to user+history
and themselves (self-attention), but NOT to other candidates
This ensures each candidate is scored independently based on user+history context.
Args:
seq_len: Total sequence length (user + history + candidates)
candidate_start_offset: Position where candidates start in the sequence
dtype: Data type for the mask
Returns:
Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
"""
# Start with causal mask for the full sequence
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))
# Zero out candidate-to-candidate attention (bottom-right block)
attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)
# Add back self-attention for candidates (diagonal of the candidate block)
candidate_indices = jnp.arange(candidate_start_offset, seq_len)
attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)
return attn_mask
class MHAOutput(NamedTuple):
"""Outputs of the multi-head attention operation."""
embeddings: jax.Array
class DecoderOutput(NamedTuple):
embeddings: jax.Array
class TransformerOutput(NamedTuple):
embeddings: jax.Array
@dataclass
class TransformerConfig:
emb_size: int
key_size: int
num_q_heads: int
num_kv_heads: int
num_layers: int
widening_factor: float = 4.0
attn_output_multiplier: float = 1.0
name: Optional[str] = None
def make(self) -> "Transformer":
return Transformer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
widening_factor=self.widening_factor,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
num_layers=self.num_layers,
)
def hk_rms_norm(
x: jax.Array,
fixed_scale=False,
) -> jax.Array:
"""Applies a unique LayerNorm to x with default settings."""
ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
return ln(x)
class Linear(hk.Linear):
def __init__(
self,
output_size: int,
with_bias: bool = True,
name: Optional[str] = None,
):
super().__init__(
output_size=output_size,
with_bias=with_bias,
name=name,
)
def __call__( # type: ignore
self,
inputs: jax.Array,
) -> jax.Array:
"""Computes a linear transform of the input."""
fprop_dtype = inputs.dtype
if not inputs.shape:
raise ValueError("Input must not be scalar.")
input_size = inputs.shape[-1]
output_size = self.output_size
w = hk.get_parameter(
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
)
out = jnp.dot(inputs, w.astype(fprop_dtype))
if self.with_bias:
b = hk.get_parameter(
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
)
b = jnp.broadcast_to(b, out.shape)
out = out + b.astype(fprop_dtype)
return out
class RMSNorm(hk.RMSNorm):
def __init__(
self,
axis: Union[int, Sequence[int], slice],
eps: float = 1e-5,
name: Optional[str] = None,
create_scale: bool = True,
):
super().__init__(axis, eps, create_scale=create_scale, name=name)
def __call__(self, inputs: jax.Array):
fprop_dtype = inputs.dtype
param_shape = (inputs.shape[-1],)
if self.create_scale:
scale = hk.get_parameter(
"scale",
param_shape,
dtype=jnp.float32,
init=hk.initializers.Constant(0),
)
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
else:
scale = 1.0
inputs = inputs.astype(jnp.float32)
scale = jnp.float32(scale)
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
outputs = scale * normed_inputs
return outputs.astype(fprop_dtype)
def rotate_half(
x: jax.Array,
) -> jax.Array:
"""Obtain the rotated counterpart of each feature"""
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-x2, x1), axis=-1)
class RotaryEmbedding(hk.Module):
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
as described in https://arxiv.org/abs/2104.09864.
Attributes:
dim (int): Dimensionality of the feature vectors
base_exponent (int): Base exponent to compute embeddings from
"""
def __init__(
self,
dim: int,
name: Optional[str] = None,
base_exponent: int = 10000,
):
super().__init__(name)
self.dim = dim
self.base_exponent = base_exponent
assert self.dim % 2 == 0
def __call__(
self,
x: jax.Array,
seq_dim: int,
offset: jax.Array,
const_position: Optional[int] = None,
t: Optional[jax.Array] = None,
) -> jax.Array:
fprop_dtype = x.dtype
# Compute the per-dimension frequencies
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
inv_freq = jnp.asarray(
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
)
if jnp.shape(offset) == ():
# Offset can be a scalar or one offset per batch element.
offset = jnp.expand_dims(offset, 0)
# Compute the per element phase (to pass into sin and cos)
if const_position:
t = const_position * jnp.ones(
(
1,
x.shape[seq_dim],
),
dtype=jnp.float32,
)
elif t is None:
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
phase = jnp.einsum("bi,j->bij", t, inv_freq)
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
x = x.astype(fprop_dtype)
return x
class MultiHeadAttention(hk.Module):
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
key_size: int,
*,
with_bias: bool = True,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
attn_output_multiplier: float = 1.0,
name: Optional[str] = None,
):
super().__init__(name=name)
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.key_size = key_size
self.value_size = value_size or key_size
self.model_size = model_size or key_size * num_q_heads
self.attn_output_multiplier = attn_output_multiplier
self.with_bias = with_bias
def __call__(
self,
query: jax.Array,
key: jax.Array,
value: jax.Array,
mask: jax.Array,
) -> MHAOutput:
# In shape hints below, we suppress the leading dims [...] for brevity.
# Hence e.g. [A, B] should be read in every case as [..., A, B].
projection = self._linear_projection
# Check that the keys and values have consistent batch size and sequence length.
assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
if mask is not None:
assert mask.ndim == 4
assert mask.shape[0] in {
1,
query.shape[0],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert key.shape[0] in {
1,
query.shape[0],
}, f"key/query shape: {key.shape}/{query.shape}"
assert mask.shape[1] == 1
assert mask.shape[2] in {
1,
query.shape[1],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert mask.shape[3] in {
1,
key.shape[1],
}, f"mask/query shape: {mask.shape}/{key.shape}"
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
assert self.num_q_heads % self.num_kv_heads == 0
query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
key_heads = rotate(key_heads, seq_dim=1, offset=0)
query_heads = rotate(query_heads, seq_dim=1, offset=0)
b, t, h, d = query_heads.shape
_, _, kv_h, _ = key_heads.shape
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
# Compute attention weights.
# Attention softmax is always carried out in fp32.
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
jnp.float32
)
attn_logits *= self.attn_output_multiplier
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
mask = mask[:, :, None, :, :]
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
# Weight the values by the attention and flatten the head vectors.
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
leading_dims = attn.shape[:2]
attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
# Apply another projection to get the final embeddings.
final_projection = Linear(self.model_size, with_bias=False)
return MHAOutput(final_projection(attn))
@hk.transparent
def _linear_projection(
self,
x: jax.Array,
head_size: int,
num_heads: int,
name: Optional[str] = None,
) -> jax.Array:
y = Linear(num_heads * head_size, with_bias=False, name=name)(x)
*leading_dims, _ = x.shape
return y.reshape((*leading_dims, num_heads, head_size))
@dataclass
class MHABlock(hk.Module):
"""A MHA Block"""
num_q_heads: int
num_kv_heads: int
key_size: int
attn_output_multiplier: float = 1.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
) -> MHAOutput:
_, _, model_size = inputs.shape
assert mask.ndim == 4, f"shape: {mask.shape}"
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
side_input = inputs
def attn_block(query, key, value, mask) -> MHAOutput:
return MultiHeadAttention(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
model_size=model_size,
attn_output_multiplier=self.attn_output_multiplier,
)(query, key, value, mask)
attn_output = attn_block(inputs, side_input, side_input, mask)
h_attn = attn_output.embeddings
return MHAOutput(embeddings=h_attn)
@dataclass
class DenseBlock(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float = 4.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
) -> jax.Array: # [B, T, D]
_, _, model_size = inputs.shape
h_v = Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
name="linear_v",
)(inputs)
h_w1 = jax.nn.gelu(
Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
)(inputs)
)
h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)
return h_dense
@dataclass
class DecoderLayer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
num_layers: int
layer_index: Optional[int] = None
widening_factor: float = 4.0
name: Optional[str] = None
attn_output_multiplier: float = 1.0
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T]
padding_mask: Optional[jax.Array],
) -> DecoderOutput:
"""Transforms input embedding sequences to output embedding sequences."""
del padding_mask # Unused.
def layer_norm(x):
return hk_rms_norm(x)
h = inputs
attn_output = MHABlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
)(layer_norm(h), mask)
h_attn = attn_output.embeddings
h_attn = layer_norm(h_attn)
h += h_attn
def base_dense_block(h):
h = DenseBlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=self.widening_factor,
)(h)
return h
h_dense = base_dense_block(layer_norm(h))
h_dense = layer_norm(h_dense)
h += h_dense
return DecoderOutput(
embeddings=h,
)
def layer_norm(x):
return hk_rms_norm(x)
@dataclass
class Transformer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float
attn_output_multiplier: float
num_layers: int
name: Optional[str] = None
def __call__(
self,
embeddings: jax.Array, # [B, T, D]
mask: jax.Array, # [B, T]
candidate_start_offset: Optional[int] = None,
) -> TransformerOutput:
"""Transforms input embedding sequences to output embedding sequences.
Args:
embeddings: Input embeddings of shape [B, T, D]
mask: Padding mask of shape [B, T], True for valid positions
candidate_start_offset: If provided, positions >= this offset are treated as
candidates that can only attend to positions before the offset (user+history)
and themselves (self-attention), but not to other candidates.
Used for recommendation system inference.
Returns:
TransformerOutput containing the output embeddings.
"""
fprop_dtype = embeddings.dtype
_, seq_len, _ = embeddings.shape
padding_mask = mask.copy()
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
if candidate_start_offset is not None:
# Use recommendation system attention mask where candidates attend to
# user+history and themselves, but not to other candidates
attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
mask = mask * attn_mask
else:
# Standard causal mask for autoregressive sequence modelling
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
fprop_dtype
) # [B=1, H=1, T, T]
mask = mask * causal_mask # [B, H=1, T, T]
h = embeddings
def block(
h,
mask,
padding_mask,
layer_index: Optional[int] = None,
widening_factor: Optional[int] = None,
name: Optional[str] = None,
) -> DecoderOutput:
return DecoderLayer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=widening_factor or self.widening_factor,
num_layers=self.num_layers,
attn_output_multiplier=self.attn_output_multiplier,
name=name,
layer_index=layer_index,
)(h, mask, padding_mask)
for i in range(self.num_layers):
decoder_output = block(
h,
mask,
padding_mask,
layer_index=i,
name=f"decoder_layer_{i}",
)
h = decoder_output.embeddings
return TransformerOutput(
embeddings=h,
)

38
phoenix/pyproject.toml Normal file
View File

@ -0,0 +1,38 @@
[project]
name = "grok-1"
version = "0.1.0"
description = "Grok-1 model"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"dm-haiku>=0.0.13",
"jax==0.8.1",
"numpy>=1.26.4",
"pyright>=1.1.408",
]
[tool.uv]
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
]
[tool.ruff]
indent-width = 4
line-length = 100
[tool.ruff.lint]
ignore = [
"E722",
"E731",
"E741",
"F405",
"E402",
"F403",
]
select = ["ISC001"]
[dependency-groups]
dev = [
"pytest",
]

474
phoenix/recsys_model.py Normal file
View File

@ -0,0 +1,474 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
from grok import (
TransformerConfig,
Transformer,
layer_norm,
)
logger = logging.getLogger(__name__)
@dataclass
class HashConfig:
"""Configuration for hash-based embeddings."""
num_user_hashes: int = 2
num_item_hashes: int = 2
num_author_hashes: int = 2
@dataclass
class RecsysEmbeddings:
"""Container for pre-looked-up embeddings from the embedding tables.
These embeddings are looked up from hash tables before being passed to the model.
The block_*_reduce functions will combine multiple hash embeddings into single representations.
"""
user_embeddings: jax.typing.ArrayLike
history_post_embeddings: jax.typing.ArrayLike
candidate_post_embeddings: jax.typing.ArrayLike
history_author_embeddings: jax.typing.ArrayLike
candidate_author_embeddings: jax.typing.ArrayLike
class RecsysModelOutput(NamedTuple):
"""Output of the recommendation model."""
logits: jax.Array
class RecsysBatch(NamedTuple):
"""Input batch for the recommendation model.
Contains the feature data (hashes, actions, product surfaces) but NOT the embeddings.
Embeddings are passed separately via RecsysEmbeddings.
"""
user_hashes: jax.typing.ArrayLike
history_post_hashes: jax.typing.ArrayLike
history_author_hashes: jax.typing.ArrayLike
history_actions: jax.typing.ArrayLike
history_product_surface: jax.typing.ArrayLike
candidate_post_hashes: jax.typing.ArrayLike
candidate_author_hashes: jax.typing.ArrayLike
candidate_product_surface: jax.typing.ArrayLike
def block_user_reduce(
user_hashes: jnp.ndarray,
user_embeddings: jnp.ndarray,
num_user_hashes: int,
emb_size: int,
embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
"""Combine multiple user hash embeddings into a single user representation.
Args:
user_hashes: [B, num_user_hashes] - hash values (0 = invalid/padding)
user_embeddings: [B, num_user_hashes, D] - looked-up embeddings
num_user_hashes: number of hash functions used
emb_size: embedding dimension D
embed_init_scale: initialization scale for projection
Returns:
user_embedding: [B, 1, D] - combined user embedding
user_padding_mask: [B, 1] - True where user is valid
"""
B = user_embeddings.shape[0]
D = emb_size
user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * D))
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_1 = hk.get_parameter(
"proj_mat_1",
[num_user_hashes * D, D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
user_embedding = jnp.dot(user_embedding.astype(proj_mat_1.dtype), proj_mat_1).astype(
user_embeddings.dtype
)
# hash 0 is reserved for padding)
user_padding_mask = (user_hashes[:, 0] != 0).reshape(B, 1).astype(jnp.bool_)
return user_embedding, user_padding_mask
def block_history_reduce(
history_post_hashes: jnp.ndarray,
history_post_embeddings: jnp.ndarray,
history_author_embeddings: jnp.ndarray,
history_product_surface_embeddings: jnp.ndarray,
history_actions_embeddings: jnp.ndarray,
num_item_hashes: int,
num_author_hashes: int,
embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
"""Combine history embeddings (post, author, actions, product_surface) into sequence.
Args:
history_post_hashes: [B, S, num_item_hashes]
history_post_embeddings: [B, S, num_item_hashes, D]
history_author_embeddings: [B, S, num_author_hashes, D]
history_product_surface_embeddings: [B, S, D]
history_actions_embeddings: [B, S, D]
num_item_hashes: number of hash functions for items
num_author_hashes: number of hash functions for authors
emb_size: embedding dimension D
embed_init_scale: initialization scale
Returns:
history_embeddings: [B, S, D]
history_padding_mask: [B, S]
"""
B, S, _, D = history_post_embeddings.shape
history_post_embeddings_reshaped = history_post_embeddings.reshape((B, S, num_item_hashes * D))
history_author_embeddings_reshaped = history_author_embeddings.reshape(
(B, S, num_author_hashes * D)
)
post_author_embedding = jnp.concatenate(
[
history_post_embeddings_reshaped,
history_author_embeddings_reshaped,
history_actions_embeddings,
history_product_surface_embeddings,
],
axis=-1,
)
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_3 = hk.get_parameter(
"proj_mat_3",
[post_author_embedding.shape[-1], D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
history_embedding = jnp.dot(post_author_embedding.astype(proj_mat_3.dtype), proj_mat_3).astype(
post_author_embedding.dtype
)
history_embedding = history_embedding.reshape(B, S, D)
history_padding_mask = (history_post_hashes[:, :, 0] != 0).reshape(B, S)
return history_embedding, history_padding_mask
def block_candidate_reduce(
candidate_post_hashes: jnp.ndarray,
candidate_post_embeddings: jnp.ndarray,
candidate_author_embeddings: jnp.ndarray,
candidate_product_surface_embeddings: jnp.ndarray,
num_item_hashes: int,
num_author_hashes: int,
embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
"""Combine candidate embeddings (post, author, product_surface) into sequence.
Args:
candidate_post_hashes: [B, C, num_item_hashes]
candidate_post_embeddings: [B, C, num_item_hashes, D]
candidate_author_embeddings: [B, C, num_author_hashes, D]
candidate_product_surface_embeddings: [B, C, D]
num_item_hashes: number of hash functions for items
num_author_hashes: number of hash functions for authors
emb_size: embedding dimension D
embed_init_scale: initialization scale
Returns:
candidate_embeddings: [B, C, D]
candidate_padding_mask: [B, C]
"""
B, C, _, D = candidate_post_embeddings.shape
candidate_post_embeddings_reshaped = candidate_post_embeddings.reshape(
(B, C, num_item_hashes * D)
)
candidate_author_embeddings_reshaped = candidate_author_embeddings.reshape(
(B, C, num_author_hashes * D)
)
post_author_embedding = jnp.concatenate(
[
candidate_post_embeddings_reshaped,
candidate_author_embeddings_reshaped,
candidate_product_surface_embeddings,
],
axis=-1,
)
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_2 = hk.get_parameter(
"proj_mat_2",
[post_author_embedding.shape[-1], D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
candidate_embedding = jnp.dot(
post_author_embedding.astype(proj_mat_2.dtype), proj_mat_2
).astype(post_author_embedding.dtype)
candidate_padding_mask = (candidate_post_hashes[:, :, 0] != 0).reshape(B, C).astype(jnp.bool_)
return candidate_embedding, candidate_padding_mask
@dataclass
class PhoenixModelConfig:
"""Configuration for the recommendation system model."""
model: TransformerConfig
emb_size: int
num_actions: int
history_seq_len: int = 128
candidate_seq_len: int = 32
name: Optional[str] = None
fprop_dtype: Any = jnp.bfloat16
hash_config: HashConfig = None # type: ignore
product_surface_vocab_size: int = 16
_initialized = False
def __post_init__(self):
if self.hash_config is None:
self.hash_config = HashConfig()
def initialize(self):
self._initialized = True
return self
def make(self):
if not self._initialized:
logger.warning(f"PhoenixModel {self.name} is not initialized. Initializing.")
self.initialize()
return PhoenixModel(
model=self.model.make(),
config=self,
fprop_dtype=self.fprop_dtype,
)
@dataclass
class PhoenixModel(hk.Module):
"""A transformer-based recommendation model for ranking candidates."""
model: Transformer
config: PhoenixModelConfig
fprop_dtype: Any = jnp.bfloat16
name: Optional[str] = None
def _get_action_embeddings(
self,
actions: jax.Array,
) -> jax.Array:
"""Convert multi-hot action vectors to embeddings.
Uses a learned projection matrix to map the signed action vector
to the embedding dimension. This works for any number of actions.
"""
config = self.config
_, _, num_actions = actions.shape
D = config.emb_size
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
action_projection = hk.get_parameter(
"action_projection",
[num_actions, D],
dtype=jnp.float32,
init=embed_init,
)
actions_signed = (2 * actions - 1).astype(jnp.float32)
action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)
valid_mask = jnp.any(actions, axis=-1, keepdims=True)
action_emb = action_emb * valid_mask
return action_emb.astype(self.fprop_dtype)
def _single_hot_to_embeddings(
self,
input: jax.Array,
vocab_size: int,
emb_size: int,
name: str,
) -> jax.Array:
"""Convert single-hot indices to embeddings via lookup table.
Args:
input: [B, S] tensor of categorical indices
vocab_size: size of the vocabulary
emb_size: embedding dimension
name: name for the embedding table parameter
Returns:
embeddings: [B, S, emb_size]
"""
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
embedding_table = hk.get_parameter(
name,
[vocab_size, emb_size],
dtype=jnp.float32,
init=embed_init,
)
input_one_hot = jax.nn.one_hot(input, vocab_size)
output = jnp.dot(input_one_hot, embedding_table)
return output.astype(self.fprop_dtype)
def _get_unembedding(self) -> jax.Array:
"""Get the unembedding matrix for decoding to logits."""
config = self.config
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
unembed_mat = hk.get_parameter(
"unembeddings",
[config.emb_size, config.num_actions],
dtype=jnp.float32,
init=embed_init,
)
return unembed_mat
def build_inputs(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array, int]:
"""Build input embeddings from batch and pre-looked-up embeddings.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
embeddings: [B, 1 + history_len + num_candidates, D]
padding_mask: [B, 1 + history_len + num_candidates]
candidate_start_offset: int - position where candidates start
"""
config = self.config
hash_config = config.hash_config
history_product_surface_embeddings = self._single_hot_to_embeddings(
batch.history_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
candidate_product_surface_embeddings = self._single_hot_to_embeddings(
batch.candidate_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
user_embeddings, user_padding_mask = block_user_reduce(
batch.user_hashes, # type: ignore
recsys_embeddings.user_embeddings, # type: ignore
hash_config.num_user_hashes,
config.emb_size,
1.0,
)
history_embeddings, history_padding_mask = block_history_reduce(
batch.history_post_hashes, # type: ignore
recsys_embeddings.history_post_embeddings, # type: ignore
recsys_embeddings.history_author_embeddings, # type: ignore
history_product_surface_embeddings,
history_actions_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
)
candidate_embeddings, candidate_padding_mask = block_candidate_reduce(
batch.candidate_post_hashes, # type: ignore
recsys_embeddings.candidate_post_embeddings, # type: ignore
recsys_embeddings.candidate_author_embeddings, # type: ignore
candidate_product_surface_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
)
embeddings = jnp.concatenate(
[user_embeddings, history_embeddings, candidate_embeddings], axis=1
)
padding_mask = jnp.concatenate(
[user_padding_mask, history_padding_mask, candidate_padding_mask], axis=1
)
candidate_start_offset = user_padding_mask.shape[1] + history_padding_mask.shape[1]
return embeddings.astype(self.fprop_dtype), padding_mask, candidate_start_offset
def __call__(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> RecsysModelOutput:
"""Forward pass for ranking candidates.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
RecsysModelOutput containing logits for each candidate. Shape = [B, num_candidates, num_actions]
"""
embeddings, padding_mask, candidate_start_offset = self.build_inputs(
batch, recsys_embeddings
)
# transformer
model_output = self.model(
embeddings,
padding_mask,
candidate_start_offset=candidate_start_offset,
)
out_embeddings = model_output.embeddings
out_embeddings = layer_norm(out_embeddings)
candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]
unembeddings = self._get_unembedding()
logits = jnp.dot(candidate_embeddings.astype(unembeddings.dtype), unembeddings)
logits = logits.astype(self.fprop_dtype)
return RecsysModelOutput(logits=logits)

View File

@ -0,0 +1,372 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
from grok import TransformerConfig, Transformer
from recsys_model import (
HashConfig,
RecsysBatch,
RecsysEmbeddings,
block_history_reduce,
block_user_reduce,
)
logger = logging.getLogger(__name__)
EPS = 1e-12
INF = 1e12
class RetrievalOutput(NamedTuple):
"""Output of the retrieval model."""
user_representation: jax.Array
top_k_indices: jax.Array
top_k_scores: jax.Array
@dataclass
class CandidateTower(hk.Module):
"""Candidate tower that projects post+author embeddings to a shared embedding space.
This tower takes the concatenated embeddings of a post and its author,
and projects them to a normalized representation suitable for similarity search.
"""
emb_size: int
name: Optional[str] = None
def __call__(self, post_author_embedding: jax.Array) -> jax.Array:
"""Project post+author embeddings to normalized representation.
Args:
post_author_embedding: Concatenated post and author embeddings
Shape: [B, C, num_hashes, D] or [B, num_hashes, D]
Returns:
Normalized candidate representation
Shape: [B, C, D] or [B, D]
"""
if len(post_author_embedding.shape) == 4:
B, C, _, _ = post_author_embedding.shape
post_author_embedding = jnp.reshape(post_author_embedding, (B, C, -1))
else:
B, _, _ = post_author_embedding.shape
post_author_embedding = jnp.reshape(post_author_embedding, (B, -1))
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
proj_1 = hk.get_parameter(
"candidate_tower_projection_1",
[post_author_embedding.shape[-1], self.emb_size * 2],
dtype=jnp.float32,
init=embed_init,
)
proj_2 = hk.get_parameter(
"candidate_tower_projection_2",
[self.emb_size * 2, self.emb_size],
dtype=jnp.float32,
init=embed_init,
)
hidden = jnp.dot(post_author_embedding.astype(proj_1.dtype), proj_1)
hidden = jax.nn.silu(hidden)
candidate_embeddings = jnp.dot(hidden.astype(proj_2.dtype), proj_2)
candidate_norm_sq = jnp.sum(candidate_embeddings**2, axis=-1, keepdims=True)
candidate_norm = jnp.sqrt(jnp.maximum(candidate_norm_sq, EPS))
candidate_representation = candidate_embeddings / candidate_norm
return candidate_representation.astype(post_author_embedding.dtype)
@dataclass
class PhoenixRetrievalModelConfig:
"""Configuration for the Phoenix Retrieval Model.
This model uses the same transformer architecture as the Phoenix ranker
for encoding user representations.
"""
model: TransformerConfig
emb_size: int
history_seq_len: int = 128
candidate_seq_len: int = 32
name: Optional[str] = None
fprop_dtype: Any = jnp.bfloat16
hash_config: HashConfig = None # type: ignore
product_surface_vocab_size: int = 16
_initialized: bool = False
def __post_init__(self):
if self.hash_config is None:
self.hash_config = HashConfig()
def initialize(self):
self._initialized = True
return self
def make(self):
if not self._initialized:
logger.warning(f"PhoenixRetrievalModel {self.name} is not initialized. Initializing.")
self.initialize()
return PhoenixRetrievalModel(
model=self.model.make(),
config=self,
fprop_dtype=self.fprop_dtype,
)
@dataclass
class PhoenixRetrievalModel(hk.Module):
"""A two-tower retrieval model using the Phoenix transformer for user encoding.
This model implements the two-tower architecture for efficient retrieval:
- User Tower: Encodes user features + history using the Phoenix transformer
- Candidate Tower: Projects candidate embeddings to a shared space
The user and candidate representations are L2-normalized, enabling efficient
approximate nearest neighbor (ANN) search using dot product similarity.
"""
model: Transformer
config: PhoenixRetrievalModelConfig
fprop_dtype: Any = jnp.bfloat16
name: Optional[str] = None
def _get_action_embeddings(
self,
actions: jax.Array,
) -> jax.Array:
"""Convert multi-hot action vectors to embeddings."""
config = self.config
_, _, num_actions = actions.shape
D = config.emb_size
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
action_projection = hk.get_parameter(
"action_projection",
[num_actions, D],
dtype=jnp.float32,
init=embed_init,
)
actions_signed = (2 * actions - 1).astype(jnp.float32)
action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)
valid_mask = jnp.any(actions, axis=-1, keepdims=True)
action_emb = action_emb * valid_mask
return action_emb.astype(self.fprop_dtype)
def _single_hot_to_embeddings(
self,
input: jax.Array,
vocab_size: int,
emb_size: int,
name: str,
) -> jax.Array:
"""Convert single-hot indices to embeddings via lookup table."""
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
embedding_table = hk.get_parameter(
name,
[vocab_size, emb_size],
dtype=jnp.float32,
init=embed_init,
)
input_one_hot = jax.nn.one_hot(input, vocab_size)
output = jnp.dot(input_one_hot, embedding_table)
return output.astype(self.fprop_dtype)
def build_user_representation(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]:
"""Build user representation from user features and history.
Uses the Phoenix transformer to encode user + history embeddings
into a single user representation vector.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
user_representation: L2-normalized user embedding [B, D]
user_norm: Pre-normalization L2 norm [B, 1]
"""
config = self.config
hash_config = config.hash_config
history_product_surface_embeddings = self._single_hot_to_embeddings(
batch.history_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
user_embeddings, user_padding_mask = block_user_reduce(
batch.user_hashes, # type: ignore
recsys_embeddings.user_embeddings, # type: ignore
hash_config.num_user_hashes,
config.emb_size,
1.0,
)
history_embeddings, history_padding_mask = block_history_reduce(
batch.history_post_hashes, # type: ignore
recsys_embeddings.history_post_embeddings, # type: ignore
recsys_embeddings.history_author_embeddings, # type: ignore
history_product_surface_embeddings,
history_actions_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
)
embeddings = jnp.concatenate([user_embeddings, history_embeddings], axis=1)
padding_mask = jnp.concatenate([user_padding_mask, history_padding_mask], axis=1)
model_output = self.model(
embeddings.astype(self.fprop_dtype),
padding_mask,
candidate_start_offset=None,
)
user_outputs = model_output.embeddings
mask_float = padding_mask.astype(jnp.float32)[:, :, None] # [B, T, 1]
user_embeddings_masked = user_outputs * mask_float
user_embedding_sum = jnp.sum(user_embeddings_masked, axis=1) # [B, D]
mask_sum = jnp.sum(mask_float, axis=1) # [B, 1]
user_representation = user_embedding_sum / jnp.maximum(mask_sum, 1.0)
user_norm_sq = jnp.sum(user_representation**2, axis=-1, keepdims=True)
user_norm = jnp.sqrt(jnp.maximum(user_norm_sq, EPS))
user_representation = user_representation / user_norm
return user_representation, user_norm
def build_candidate_representation(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]:
"""Build candidate (item) representations.
Projects post + author embeddings to a shared embedding space
using the candidate tower MLP.
Args:
batch: RecsysBatch containing candidate hashes
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
candidate_representation: L2-normalized candidate embeddings [B, C, D]
candidate_padding_mask: Valid candidate mask [B, C]
"""
config = self.config
candidate_post_embeddings = recsys_embeddings.candidate_post_embeddings
candidate_author_embeddings = recsys_embeddings.candidate_author_embeddings
post_author_embedding = jnp.concatenate(
[candidate_post_embeddings, candidate_author_embeddings], axis=2
)
candidate_tower = CandidateTower(
emb_size=config.emb_size,
)
candidate_representation = candidate_tower(post_author_embedding)
candidate_padding_mask = (batch.candidate_post_hashes[:, :, 0] != 0).astype(jnp.bool_) # type: ignore
return candidate_representation, candidate_padding_mask
def __call__(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
corpus_mask: Optional[jax.Array] = None,
) -> RetrievalOutput:
"""Retrieve top-k candidates from corpus for each user.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
corpus_embeddings: [N, D] normalized corpus candidate embeddings
top_k: Number of candidates to retrieve
corpus_mask: [N] optional mask for valid corpus entries
Returns:
RetrievalOutput containing user representation and top-k results
"""
user_representation, _ = self.build_user_representation(batch, recsys_embeddings)
top_k_indices, top_k_scores = self._retrieve_top_k(
user_representation, corpus_embeddings, top_k, corpus_mask
)
return RetrievalOutput(
user_representation=user_representation,
top_k_indices=top_k_indices,
top_k_scores=top_k_scores,
)
def _retrieve_top_k(
self,
user_representation: jax.Array,
corpus_embeddings: jax.Array,
top_k: int,
corpus_mask: Optional[jax.Array] = None,
) -> Tuple[jax.Array, jax.Array]:
"""Retrieve top-k candidates from a corpus for each user.
Args:
user_representation: [B, D] normalized user embeddings
corpus_embeddings: [N, D] normalized corpus candidate embeddings
top_k: Number of candidates to retrieve
corpus_mask: [N] optional mask for valid corpus entries
Returns:
top_k_indices: [B, K] indices of top-k candidates
top_k_scores: [B, K] similarity scores of top-k candidates
"""
scores = jnp.matmul(user_representation, corpus_embeddings.T)
if corpus_mask is not None:
scores = jnp.where(corpus_mask[None, :], scores, -INF)
top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k)
return top_k_indices, top_k_scores

121
phoenix/run_ranker.py Normal file
View File

@ -0,0 +1,121 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from grok import TransformerConfig
from recsys_model import PhoenixModelConfig, HashConfig
from runners import RecsysInferenceRunner, ModelRunner, create_example_batch, ACTIONS
def main():
# Model configuration
emb_size = 128 # Embedding dimension
num_actions = len(ACTIONS) # Number of explicit engagement actions
history_seq_len = 32 # Max history length
candidate_seq_len = 8 # Max candidates to rank
# Hash configuration
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
recsys_model = PhoenixModelConfig(
emb_size=emb_size,
num_actions=num_actions,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)
# Create inference runner
inference_runner = RecsysInferenceRunner(
runner=ModelRunner(
model=recsys_model,
bs_per_device=0.125,
),
name="recsys_local",
)
print("Initializing model...")
inference_runner.initialize()
print("Model initialized!")
# Create example batch with simulated posts
print("\n" + "=" * 70)
print("RECOMMENDATION SYSTEM DEMO")
print("=" * 70)
batch_size = 1
example_batch, example_embeddings = create_example_batch(
batch_size=batch_size,
emb_size=emb_size,
history_len=history_seq_len,
num_candidates=candidate_seq_len,
num_actions=num_actions,
num_user_hashes=hash_config.num_user_hashes,
num_item_hashes=hash_config.num_item_hashes,
num_author_hashes=hash_config.num_author_hashes,
product_surface_vocab_size=recsys_model.product_surface_vocab_size,
)
action_names = [action.replace("_", " ").title() for action in ACTIONS]
# Count valid history items (where first post hash is non-zero)
valid_history_count = int((example_batch.history_post_hashes[:, :, 0] != 0).sum()) # type: ignore
print(f"\nUser has viewed {valid_history_count} posts in their history")
print(f"Ranking {candidate_seq_len} candidate posts...")
# Rank candidates
ranking_output = inference_runner.rank(example_batch, example_embeddings)
# Display results
scores = np.array(ranking_output.scores[0]) # [num_candidates, num_actions]
ranked_indices = np.array(ranking_output.ranked_indices[0]) # [num_candidates]
print("\n" + "-" * 70)
print("RANKING RESULTS (ordered by predicted 'Favorite Score' probability)")
print("-" * 70)
for rank, idx in enumerate(ranked_indices):
idx = int(idx)
print(f"\nRank {rank + 1}: ")
print(" Predicted engagement probabilities:")
for action_idx, action_name in enumerate(action_names):
prob = float(scores[idx, action_idx])
bar = "" * int(prob * 20) + "" * (20 - int(prob * 20))
print(f" {action_name:24s}: {bar} {prob:.3f}")
print("\n" + "=" * 70)
print("Demo complete!")
print("=" * 70)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

149
phoenix/run_retrieval.py Normal file
View File

@ -0,0 +1,149 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from grok import TransformerConfig
from recsys_model import HashConfig
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from runners import (
RecsysRetrievalInferenceRunner,
RetrievalModelRunner,
create_example_batch,
create_example_corpus,
ACTIONS,
)
def main():
# Model configuration - same architecture as Phoenix ranker
emb_size = 128 # Embedding dimension
num_actions = len(ACTIONS) # Number of explicit engagement actions
history_seq_len = 32 # Max history length
candidate_seq_len = 8 # Max candidates per batch (for training)
# Hash configuration
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
# Configure the retrieval model - uses same transformer as Phoenix
retrieval_model_config = PhoenixRetrievalModelConfig(
emb_size=emb_size,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)
# Create inference runner
inference_runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=retrieval_model_config,
bs_per_device=0.125,
),
name="retrieval_local",
)
print("Initializing retrieval model...")
inference_runner.initialize()
print("Model initialized!")
# Create example batch with simulated user and history
print("\n" + "=" * 70)
print("RETRIEVAL SYSTEM DEMO")
print("=" * 70)
batch_size = 2 # Two users for demo
example_batch, example_embeddings = create_example_batch(
batch_size=batch_size,
emb_size=emb_size,
history_len=history_seq_len,
num_candidates=candidate_seq_len,
num_actions=num_actions,
num_user_hashes=hash_config.num_user_hashes,
num_item_hashes=hash_config.num_item_hashes,
num_author_hashes=hash_config.num_author_hashes,
product_surface_vocab_size=16,
)
# Count valid history items
valid_history_count = int((example_batch.history_post_hashes[:, :, 0] != 0).sum()) # type: ignore
print(f"\nUsers have viewed {valid_history_count} posts total in their history")
# Step 1: Create a corpus of candidate posts
print("\n" + "-" * 70)
print("STEP 1: Creating Candidate Corpus")
print("-" * 70)
corpus_size = 1000 # Simulated corpus of 1000 posts
corpus_embeddings, corpus_post_ids = create_example_corpus(
corpus_size=corpus_size,
emb_size=emb_size,
seed=456,
)
print(f"Corpus size: {corpus_size} posts")
print(f"Corpus embeddings shape: {corpus_embeddings.shape}")
# Set corpus for retrieval
inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)
# Step 2: Retrieve top-k candidates for each user
print("\n" + "-" * 70)
print("STEP 2: Retrieving Top-K Candidates")
print("-" * 70)
top_k = 10
retrieval_output = inference_runner.retrieve(
example_batch,
example_embeddings,
top_k=top_k,
)
print(f"\nRetrieved top {top_k} candidates for each of {batch_size} users:")
top_k_indices = np.array(retrieval_output.top_k_indices)
top_k_scores = np.array(retrieval_output.top_k_scores)
for user_idx in range(batch_size):
print(f"\n User {user_idx + 1}:")
print(f" {'Rank':<6} {'Post ID':<12} {'Score':<12}")
print(f" {'-' * 30}")
for rank in range(top_k):
post_id = top_k_indices[user_idx, rank]
score = top_k_scores[user_idx, rank]
bar = "" * int((score + 1) * 10) + "" * (20 - int((score + 1) * 10))
print(f" {rank + 1:<6} {post_id:<12} {bar} {score:.4f}")
print("\n" + "=" * 70)
print("Demo complete!")
print("=" * 70)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

729
phoenix/runners.py Normal file
View File

@ -0,0 +1,729 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from grok import TrainingState
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from recsys_retrieval_model import RetrievalOutput as ModelRetrievalOutput
from recsys_model import (
PhoenixModelConfig,
RecsysBatch,
RecsysEmbeddings,
RecsysModelOutput,
)
rank_logger = logging.getLogger("rank")
def create_dummy_batch_from_config(
hash_config: Any,
history_len: int,
num_candidates: int,
num_actions: int,
batch_size: int = 1,
) -> RecsysBatch:
"""Create a dummy batch for initialization.
Args:
hash_config: HashConfig with num_user_hashes, num_item_hashes, num_author_hashes
history_len: History sequence length
num_candidates: Number of candidates
num_actions: Number of action types
batch_size: Batch size
Returns:
RecsysBatch with zeros
"""
return RecsysBatch(
user_hashes=np.zeros((batch_size, hash_config.num_user_hashes), dtype=np.int32),
history_post_hashes=np.zeros(
(batch_size, history_len, hash_config.num_item_hashes), dtype=np.int32
),
history_author_hashes=np.zeros(
(batch_size, history_len, hash_config.num_author_hashes), dtype=np.int32
),
history_actions=np.zeros((batch_size, history_len, num_actions), dtype=np.float32),
history_product_surface=np.zeros((batch_size, history_len), dtype=np.int32),
candidate_post_hashes=np.zeros(
(batch_size, num_candidates, hash_config.num_item_hashes), dtype=np.int32
),
candidate_author_hashes=np.zeros(
(batch_size, num_candidates, hash_config.num_author_hashes), dtype=np.int32
),
candidate_product_surface=np.zeros((batch_size, num_candidates), dtype=np.int32),
)
def create_dummy_embeddings_from_config(
hash_config: Any,
emb_size: int,
history_len: int,
num_candidates: int,
batch_size: int = 1,
) -> RecsysEmbeddings:
"""Create dummy embeddings for initialization.
Args:
hash_config: HashConfig with num_user_hashes, num_item_hashes, num_author_hashes
emb_size: Embedding dimension
history_len: History sequence length
num_candidates: Number of candidates
batch_size: Batch size
Returns:
RecsysEmbeddings with zeros
"""
return RecsysEmbeddings(
user_embeddings=np.zeros(
(batch_size, hash_config.num_user_hashes, emb_size), dtype=np.float32
),
history_post_embeddings=np.zeros(
(batch_size, history_len, hash_config.num_item_hashes, emb_size), dtype=np.float32
),
candidate_post_embeddings=np.zeros(
(batch_size, num_candidates, hash_config.num_item_hashes, emb_size),
dtype=np.float32,
),
history_author_embeddings=np.zeros(
(batch_size, history_len, hash_config.num_author_hashes, emb_size), dtype=np.float32
),
candidate_author_embeddings=np.zeros(
(batch_size, num_candidates, hash_config.num_author_hashes, emb_size),
dtype=np.float32,
),
)
@dataclass
class BaseModelRunner(ABC):
"""Base class for model runners with shared initialization logic."""
bs_per_device: float = 2.0
rng_seed: int = 42
@property
@abstractmethod
def model(self) -> Any:
"""Return the model config."""
pass
@property
def _model_name(self) -> str:
"""Return model name for logging."""
return "model"
@abstractmethod
def make_forward_fn(self):
"""Create the forward function. Must be implemented by subclasses."""
pass
def initialize(self):
"""Initialize the model runner."""
self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16
num_local_gpus = len(jax.local_devices())
self.batch_size = max(1, int(self.bs_per_device * num_local_gpus))
rank_logger.info(f"Initializing {self._model_name}...")
self.forward = self.make_forward_fn()
@dataclass
class BaseInferenceRunner(ABC):
"""Base class for inference runners with shared dummy data creation."""
name: str
@property
@abstractmethod
def runner(self) -> BaseModelRunner:
"""Return the underlying model runner."""
pass
def _get_num_actions(self) -> int:
"""Get number of actions. Override in subclasses if needed."""
model_config = self.runner.model
if hasattr(model_config, "num_actions"):
return model_config.num_actions
return 19
def create_dummy_batch(self, batch_size: int = 1) -> RecsysBatch:
"""Create a dummy batch for initialization."""
model_config = self.runner.model
return create_dummy_batch_from_config(
hash_config=model_config.hash_config,
history_len=model_config.history_seq_len,
num_candidates=model_config.candidate_seq_len,
num_actions=self._get_num_actions(),
batch_size=batch_size,
)
def create_dummy_embeddings(self, batch_size: int = 1) -> RecsysEmbeddings:
"""Create dummy embeddings for initialization."""
model_config = self.runner.model
return create_dummy_embeddings_from_config(
hash_config=model_config.hash_config,
emb_size=model_config.emb_size,
history_len=model_config.history_seq_len,
num_candidates=model_config.candidate_seq_len,
batch_size=batch_size,
)
@abstractmethod
def initialize(self):
"""Initialize the inference runner. Must be implemented by subclasses."""
pass
ACTIONS: List[str] = [
"favorite_score",
"reply_score",
"repost_score",
"photo_expand_score",
"click_score",
"profile_click_score",
"vqv_score",
"share_score",
"share_via_dm_score",
"share_via_copy_link_score",
"dwell_score",
"quote_score",
"quoted_click_score",
"follow_author_score",
"not_interested_score",
"block_author_score",
"mute_author_score",
"report_score",
"dwell_time",
]
class RankingOutput(NamedTuple):
"""Output from ranking candidates.
Contains both the raw scores array and individual probability fields
for each engagement type.
"""
scores: jax.Array
ranked_indices: jax.Array
p_favorite_score: jax.Array
p_reply_score: jax.Array
p_repost_score: jax.Array
p_photo_expand_score: jax.Array
p_click_score: jax.Array
p_profile_click_score: jax.Array
p_vqv_score: jax.Array
p_share_score: jax.Array
p_share_via_dm_score: jax.Array
p_share_via_copy_link_score: jax.Array
p_dwell_score: jax.Array
p_quote_score: jax.Array
p_quoted_click_score: jax.Array
p_follow_author_score: jax.Array
p_not_interested_score: jax.Array
p_block_author_score: jax.Array
p_mute_author_score: jax.Array
p_report_score: jax.Array
p_dwell_time: jax.Array
@dataclass
class ModelRunner(BaseModelRunner):
"""Runner for the recommendation ranking model."""
_model: PhoenixModelConfig = None # type: ignore
def __init__(self, model: PhoenixModelConfig, bs_per_device: float = 2.0, rng_seed: int = 42):
self._model = model
self.bs_per_device = bs_per_device
self.rng_seed = rng_seed
@property
def model(self) -> PhoenixModelConfig:
return self._model
@property
def _model_name(self) -> str:
return "ranking model"
def make_forward_fn(self): # type: ignore
def forward(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings):
out = self.model.make()(batch, recsys_embeddings)
return out
return hk.transform(forward)
def init(
self, rng: jax.Array, data: RecsysBatch, embeddings: RecsysEmbeddings
) -> TrainingState:
assert self.forward is not None
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data, embeddings)
return TrainingState(params=params)
def load_or_init(
self,
init_data: RecsysBatch,
init_embeddings: RecsysEmbeddings,
):
rng = jax.random.PRNGKey(self.rng_seed)
state = self.init(rng, init_data, init_embeddings)
return state
@dataclass
class RecsysInferenceRunner(BaseInferenceRunner):
"""Inference runner for the recommendation ranking model."""
_runner: ModelRunner
def __init__(self, runner: ModelRunner, name: str):
self.name = name
self._runner = runner
@property
def runner(self) -> ModelRunner:
return self._runner
def initialize(self):
"""Initialize the inference runner."""
runner = self.runner
dummy_batch = self.create_dummy_batch(batch_size=1)
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
runner.initialize()
state = runner.load_or_init(dummy_batch, dummy_embeddings)
self.params = state.params
@functools.lru_cache
def model():
return runner.model.make()
def hk_forward(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> RecsysModelOutput:
return model()(batch, recsys_embeddings)
def hk_rank_candidates(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> RankingOutput:
"""Rank candidates by their predicted engagement scores."""
output = hk_forward(batch, recsys_embeddings)
logits = output.logits
probs = jax.nn.sigmoid(logits)
primary_scores = probs[:, :, 0]
ranked_indices = jnp.argsort(-primary_scores, axis=-1)
return RankingOutput(
scores=probs,
ranked_indices=ranked_indices,
p_favorite_score=probs[:, :, 0],
p_reply_score=probs[:, :, 1],
p_repost_score=probs[:, :, 2],
p_photo_expand_score=probs[:, :, 3],
p_click_score=probs[:, :, 4],
p_profile_click_score=probs[:, :, 5],
p_vqv_score=probs[:, :, 6],
p_share_score=probs[:, :, 7],
p_share_via_dm_score=probs[:, :, 8],
p_share_via_copy_link_score=probs[:, :, 9],
p_dwell_score=probs[:, :, 10],
p_quote_score=probs[:, :, 11],
p_quoted_click_score=probs[:, :, 12],
p_follow_author_score=probs[:, :, 13],
p_not_interested_score=probs[:, :, 14],
p_block_author_score=probs[:, :, 15],
p_mute_author_score=probs[:, :, 16],
p_report_score=probs[:, :, 17],
p_dwell_time=probs[:, :, 18],
)
rank_ = hk.without_apply_rng(hk.transform(hk_rank_candidates))
self.rank_candidates = rank_.apply
def rank(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> RankingOutput:
"""Rank candidates for the given batch.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
RankingOutput with scores and ranked indices
"""
return self.rank_candidates(self.params, batch, recsys_embeddings)
def create_example_batch(
batch_size: int,
emb_size: int,
history_len: int,
num_candidates: int,
num_actions: int,
num_user_hashes: int = 2,
num_item_hashes: int = 2,
num_author_hashes: int = 2,
product_surface_vocab_size: int = 16,
num_user_embeddings: int = 100000,
num_post_embeddings: int = 100000,
num_author_embeddings: int = 100000,
) -> Tuple[RecsysBatch, RecsysEmbeddings]:
"""Create an example batch with random data for testing.
This simulates a recommendation scenario where:
- We have a user with some embedding
- The user has interacted with some posts in their history
- We want to rank a set of candidate posts
Note on embedding table sizes:
The num_*_embeddings parameters define the size of the embedding tables for each
entity type. Hash values are generated in the range [1, num_*_embeddings) to ensure
they can be used as valid indices into the corresponding embedding tables.
Hash value 0 is reserved for padding/invalid entries.
Returns:
Tuple of (RecsysBatch, RecsysEmbeddings)
"""
rng = np.random.default_rng(42)
user_hashes = rng.integers(1, num_user_embeddings, size=(batch_size, num_user_hashes)).astype(
np.int32
)
history_post_hashes = rng.integers(
1, num_post_embeddings, size=(batch_size, history_len, num_item_hashes)
).astype(np.int32)
for b in range(batch_size):
valid_len = rng.integers(history_len // 2, history_len + 1)
history_post_hashes[b, valid_len:, :] = 0
history_author_hashes = rng.integers(
1, num_author_embeddings, size=(batch_size, history_len, num_author_hashes)
).astype(np.int32)
for b in range(batch_size):
valid_len = rng.integers(history_len // 2, history_len + 1)
history_author_hashes[b, valid_len:, :] = 0
history_actions = (rng.random(size=(batch_size, history_len, num_actions)) > 0.7).astype(
np.float32
)
history_product_surface = rng.integers(
0, product_surface_vocab_size, size=(batch_size, history_len)
).astype(np.int32)
candidate_post_hashes = rng.integers(
1, num_post_embeddings, size=(batch_size, num_candidates, num_item_hashes)
).astype(np.int32)
candidate_author_hashes = rng.integers(
1, num_author_embeddings, size=(batch_size, num_candidates, num_author_hashes)
).astype(np.int32)
candidate_product_surface = rng.integers(
0, product_surface_vocab_size, size=(batch_size, num_candidates)
).astype(np.int32)
batch = RecsysBatch(
user_hashes=user_hashes,
history_post_hashes=history_post_hashes,
history_author_hashes=history_author_hashes,
history_actions=history_actions,
history_product_surface=history_product_surface,
candidate_post_hashes=candidate_post_hashes,
candidate_author_hashes=candidate_author_hashes,
candidate_product_surface=candidate_product_surface,
)
embeddings = RecsysEmbeddings(
user_embeddings=rng.normal(size=(batch_size, num_user_hashes, emb_size)).astype(np.float32),
history_post_embeddings=rng.normal(
size=(batch_size, history_len, num_item_hashes, emb_size)
).astype(np.float32),
candidate_post_embeddings=rng.normal(
size=(batch_size, num_candidates, num_item_hashes, emb_size)
).astype(np.float32),
history_author_embeddings=rng.normal(
size=(batch_size, history_len, num_author_hashes, emb_size)
).astype(np.float32),
candidate_author_embeddings=rng.normal(
size=(batch_size, num_candidates, num_author_hashes, emb_size)
).astype(np.float32),
)
return batch, embeddings
class RetrievalOutput(NamedTuple):
"""Output from retrieval inference.
Contains user representations and retrieved candidates.
"""
user_representation: jax.Array
top_k_indices: jax.Array
top_k_scores: jax.Array
@dataclass
class RetrievalModelRunner(BaseModelRunner):
"""Runner for the Phoenix retrieval model."""
_model: PhoenixRetrievalModelConfig = None # type: ignore
def __init__(
self,
model: PhoenixRetrievalModelConfig,
bs_per_device: float = 2.0,
rng_seed: int = 42,
):
self._model = model
self.bs_per_device = bs_per_device
self.rng_seed = rng_seed
@property
def model(self) -> PhoenixRetrievalModelConfig:
return self._model
@property
def _model_name(self) -> str:
return "retrieval model"
def make_forward_fn(self): # type: ignore
def forward(
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> ModelRetrievalOutput:
model = self.model.make()
out = model(batch, recsys_embeddings, corpus_embeddings, top_k)
_ = model.build_candidate_representation(batch, recsys_embeddings)
return out
return hk.transform(forward)
def init(
self,
rng: jax.Array,
data: RecsysBatch,
embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> TrainingState:
assert self.forward is not None
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data, embeddings, corpus_embeddings, top_k)
return TrainingState(params=params)
def load_or_init(
self,
init_data: RecsysBatch,
init_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
):
rng = jax.random.PRNGKey(self.rng_seed)
state = self.init(rng, init_data, init_embeddings, corpus_embeddings, top_k)
return state
@dataclass
class RecsysRetrievalInferenceRunner(BaseInferenceRunner):
"""Inference runner for the Phoenix retrieval model.
This runner provides methods for:
1. Encoding users to get user representations
2. Encoding candidates to get candidate embeddings
3. Retrieving top-k candidates from a corpus
"""
_runner: RetrievalModelRunner = None # type: ignore
corpus_embeddings: jax.Array | None = None
corpus_post_ids: jax.Array | None = None
def __init__(self, runner: RetrievalModelRunner, name: str):
self.name = name
self._runner = runner
self.corpus_embeddings = None
self.corpus_post_ids = None
@property
def runner(self) -> RetrievalModelRunner:
return self._runner
def initialize(self):
"""Initialize the retrieval inference runner."""
runner = self.runner
dummy_batch = self.create_dummy_batch(batch_size=1)
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
dummy_corpus = jnp.zeros((10, runner.model.emb_size), dtype=jnp.float32)
dummy_top_k = 5
runner.initialize()
state = runner.load_or_init(dummy_batch, dummy_embeddings, dummy_corpus, dummy_top_k)
self.params = state.params
@functools.lru_cache
def model():
return runner.model.make()
def hk_encode_user(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
"""Encode user to get user representation."""
m = model()
user_rep, _ = m.build_user_representation(batch, recsys_embeddings)
return user_rep
def hk_encode_candidates(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> jax.Array:
"""Encode candidates to get candidate representations."""
m = model()
cand_rep, _ = m.build_candidate_representation(batch, recsys_embeddings)
return cand_rep
def hk_retrieve(
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> "RetrievalOutput":
"""Retrieve top-k candidates from corpus."""
m = model()
return m(batch, recsys_embeddings, corpus_embeddings, top_k)
encode_user_ = hk.without_apply_rng(hk.transform(hk_encode_user))
encode_candidates_ = hk.without_apply_rng(hk.transform(hk_encode_candidates))
retrieve_ = hk.without_apply_rng(hk.transform(hk_retrieve))
self.encode_user_fn = encode_user_.apply
self.encode_candidates_fn = encode_candidates_.apply
self.retrieve_fn = retrieve_.apply
def encode_user(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
"""Encode users to get user representations.
Args:
batch: RecsysBatch containing user and history information
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
User representations [B, D]
"""
return self.encode_user_fn(self.params, batch, recsys_embeddings)
def encode_candidates(
self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> jax.Array:
"""Encode candidates to get candidate representations.
Args:
batch: RecsysBatch containing candidate information
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
Candidate representations [B, C, D]
"""
return self.encode_candidates_fn(self.params, batch, recsys_embeddings)
def set_corpus(
self,
corpus_embeddings: jax.Array,
corpus_post_ids: jax.Array,
):
"""Set the corpus embeddings for retrieval.
Args:
corpus_embeddings: Pre-computed candidate embeddings [N, D]
corpus_post_ids: Optional post IDs corresponding to embeddings [N]
"""
self.corpus_embeddings = corpus_embeddings
self.corpus_post_ids = corpus_post_ids
def retrieve(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
top_k: int = 100,
corpus_embeddings: Optional[jax.Array] = None,
) -> RetrievalOutput:
"""Retrieve top-k candidates for users.
Args:
batch: RecsysBatch containing user and history information
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
top_k: Number of candidates to retrieve per user
corpus_embeddings: Optional corpus embeddings (uses set_corpus if not provided)
Returns:
RetrievalOutput with user representations and top-k candidates
"""
if corpus_embeddings is None:
corpus_embeddings = self.corpus_embeddings
return self.retrieve_fn(self.params, batch, recsys_embeddings, corpus_embeddings, top_k)
def create_example_corpus(
corpus_size: int,
emb_size: int,
seed: int = 123,
) -> Tuple[jax.Array, jax.Array]:
"""Create example corpus embeddings for testing retrieval.
Args:
corpus_size: Number of candidates in corpus
emb_size: Embedding dimension
seed: Random seed
Returns:
Tuple of (corpus_embeddings [N, D], corpus_post_ids [N])
"""
rng = np.random.default_rng(seed)
corpus_embeddings = rng.normal(size=(corpus_size, emb_size)).astype(np.float32)
norms = np.linalg.norm(corpus_embeddings, axis=-1, keepdims=True)
corpus_embeddings = corpus_embeddings / np.maximum(norms, 1e-12)
corpus_post_ids = np.arange(corpus_size, dtype=np.int64)
return jnp.array(corpus_embeddings), jnp.array(corpus_post_ids)

View File

@ -0,0 +1,187 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp
import numpy as np
import pytest
from grok import make_recsys_attn_mask
class TestMakeRecsysAttnMask:
"""Tests for the make_recsys_attn_mask function."""
def test_output_shape(self):
"""Test that the output has the correct shape [1, 1, seq_len, seq_len]."""
seq_len = 10
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
assert mask.shape == (1, 1, seq_len, seq_len)
def test_user_history_has_causal_attention(self):
"""Test that user+history positions (before candidate_start_offset) have causal attention."""
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for i in range(candidate_start_offset):
for j in range(candidate_start_offset):
if j <= i:
assert mask_2d[i, j] == 1, f"Position {i} should attend to position {j}"
else:
assert mask_2d[i, j] == 0, (
f"Position {i} should NOT attend to future position {j}"
)
def test_candidates_attend_to_user_history(self):
"""Test that candidates can attend to all user+history positions."""
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for candidate_pos in range(candidate_start_offset, seq_len):
for history_pos in range(candidate_start_offset):
assert mask_2d[candidate_pos, history_pos] == 1, (
f"Candidate at {candidate_pos} should attend to user+history at {history_pos}"
)
def test_candidates_attend_to_themselves(self):
"""Test that candidates can attend to themselves (self-attention)."""
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for candidate_pos in range(candidate_start_offset, seq_len):
assert mask_2d[candidate_pos, candidate_pos] == 1, (
f"Candidate at {candidate_pos} should attend to itself"
)
def test_candidates_do_not_attend_to_other_candidates(self):
"""Test that candidates cannot attend to other candidates."""
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for query_pos in range(candidate_start_offset, seq_len):
for key_pos in range(candidate_start_offset, seq_len):
if query_pos != key_pos:
assert mask_2d[query_pos, key_pos] == 0, (
f"Candidate at {query_pos} should NOT attend to candidate at {key_pos}"
)
def test_full_mask_structure(self):
"""Test the complete mask structure with a small example."""
# Sequence: [user, h1, h2, c1, c2, c3]
# Positions: 0 1 2 3 4 5
seq_len = 6
candidate_start_offset = 3
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
# Expected mask structure:
# Query positions are rows, key positions are columns
# 1 = can attend, 0 = cannot attend
#
# Keys: u h1 h2 c1 c2 c3
# Query u : 1 0 0 0 0 0
# Query h1 : 1 1 0 0 0 0
# Query h2 : 1 1 1 0 0 0
# Query c1 : 1 1 1 1 0 0 <- c1 attends to user+history + self
# Query c2 : 1 1 1 0 1 0 <- c2 attends to user+history + self
# Query c3 : 1 1 1 0 0 1 <- c3 attends to user+history + self
expected = np.array(
[
[1, 0, 0, 0, 0, 0], # user
[1, 1, 0, 0, 0, 0], # h1
[1, 1, 1, 0, 0, 0], # h2
[1, 1, 1, 1, 0, 0], # c1: user+history + self
[1, 1, 1, 0, 1, 0], # c2: user+history + self
[1, 1, 1, 0, 0, 1], # c3: user+history + self
],
dtype=np.float32,
)
np.testing.assert_array_equal(
np.array(mask_2d),
expected,
err_msg="Full mask structure does not match expected pattern",
)
def test_dtype_preserved(self):
"""Test that the specified dtype is used."""
seq_len = 5
candidate_start_offset = 3
mask_f32 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float32)
mask_f16 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float16)
assert mask_f32.dtype == jnp.float32
assert mask_f16.dtype == jnp.float16
def test_single_candidate(self):
"""Test edge case with a single candidate."""
seq_len = 4
candidate_start_offset = 3
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
dtype=np.float32,
)
np.testing.assert_array_equal(np.array(mask_2d), expected)
def test_all_candidates(self):
"""Test edge case where all positions except first are candidates."""
seq_len = 4
candidate_start_offset = 1
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0], # user
[1, 1, 0, 0], # c1: user + self
[1, 0, 1, 0], # c2: user + self
[1, 0, 0, 1], # c3: user + self
],
dtype=np.float32,
)
np.testing.assert_array_equal(np.array(mask_2d), expected)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,359 @@
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Phoenix Retrieval Model."""
import unittest
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from grok import TransformerConfig
from recsys_model import HashConfig
from recsys_retrieval_model import (
CandidateTower,
PhoenixRetrievalModelConfig,
)
from runners import (
RecsysRetrievalInferenceRunner,
RetrievalModelRunner,
create_example_batch,
create_example_corpus,
)
class TestCandidateTower(unittest.TestCase):
"""Tests for the CandidateTower module."""
def test_candidate_tower_output_shape(self):
"""Test that candidate tower produces correct output shape."""
emb_size = 64
batch_size = 4
num_candidates = 8
num_hashes = 4
def forward(x):
tower = CandidateTower(emb_size=emb_size)
return tower(x)
forward_fn = hk.without_apply_rng(hk.transform(forward))
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))
params = forward_fn.init(rng, x)
output = forward_fn.apply(params, x)
self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))
def test_candidate_tower_normalized(self):
"""Test that candidate tower output is L2 normalized."""
emb_size = 64
batch_size = 4
num_candidates = 8
num_hashes = 4
def forward(x):
tower = CandidateTower(emb_size=emb_size)
return tower(x)
forward_fn = hk.without_apply_rng(hk.transform(forward))
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))
params = forward_fn.init(rng, x)
output = forward_fn.apply(params, x)
norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)
def test_candidate_tower_mean_pooling(self):
"""Test candidate tower with mean pooling (no linear projection)."""
emb_size = 64
batch_size = 4
num_candidates = 8
num_hashes = 4
def forward(x):
tower = CandidateTower(emb_size=emb_size)
return tower(x)
forward_fn = hk.without_apply_rng(hk.transform(forward))
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))
params = forward_fn.init(rng, x)
output = forward_fn.apply(params, x)
self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))
norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)
class TestPhoenixRetrievalModel(unittest.TestCase):
"""Tests for the full Phoenix Retrieval Model."""
def setUp(self):
"""Set up test fixtures."""
self.emb_size = 64
self.history_seq_len = 16
self.candidate_seq_len = 8
self.batch_size = 2
self.num_actions = 19
self.corpus_size = 100
self.top_k = 10
self.hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
self.config = PhoenixRetrievalModelConfig(
emb_size=self.emb_size,
history_seq_len=self.history_seq_len,
candidate_seq_len=self.candidate_seq_len,
hash_config=self.hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=self.emb_size,
widening_factor=2,
key_size=32,
num_q_heads=2,
num_kv_heads=2,
num_layers=1,
attn_output_multiplier=0.125,
),
)
def _create_test_batch(self) -> tuple:
"""Create test batch and embeddings."""
return create_example_batch(
batch_size=self.batch_size,
emb_size=self.emb_size,
history_len=self.history_seq_len,
num_candidates=self.candidate_seq_len,
num_actions=self.num_actions,
num_user_hashes=self.hash_config.num_user_hashes,
num_item_hashes=self.hash_config.num_item_hashes,
num_author_hashes=self.hash_config.num_author_hashes,
product_surface_vocab_size=16,
)
def _create_test_corpus(self):
"""Create test corpus embeddings."""
return create_example_corpus(self.corpus_size, self.emb_size)
def test_model_forward(self):
"""Test model forward pass produces correct output shapes."""
def forward(batch, embeddings, corpus_embeddings, top_k):
model = self.config.make()
return model(batch, embeddings, corpus_embeddings, top_k)
forward_fn = hk.without_apply_rng(hk.transform(forward))
batch, embeddings = self._create_test_batch()
corpus_embeddings, _ = self._create_test_corpus()
rng = jax.random.PRNGKey(0)
params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))
def test_user_representation_normalized(self):
"""Test that user representations are L2 normalized."""
def forward(batch, embeddings, corpus_embeddings, top_k):
model = self.config.make()
return model(batch, embeddings, corpus_embeddings, top_k)
forward_fn = hk.without_apply_rng(hk.transform(forward))
batch, embeddings = self._create_test_batch()
corpus_embeddings, _ = self._create_test_corpus()
rng = jax.random.PRNGKey(0)
params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
norms = jnp.sqrt(jnp.sum(output.user_representation**2, axis=-1))
np.testing.assert_array_almost_equal(norms, jnp.ones(self.batch_size), decimal=5)
def test_candidate_representation_normalized(self):
"""Test that candidate representations from build_candidate_representation are L2 normalized."""
def forward(batch, embeddings):
model = self.config.make()
cand_rep, _ = model.build_candidate_representation(batch, embeddings)
return cand_rep
forward_fn = hk.without_apply_rng(hk.transform(forward))
batch, embeddings = self._create_test_batch()
rng = jax.random.PRNGKey(0)
params = forward_fn.init(rng, batch, embeddings)
cand_rep = forward_fn.apply(params, batch, embeddings)
norms = jnp.sqrt(jnp.sum(cand_rep**2, axis=-1))
np.testing.assert_array_almost_equal(
norms, jnp.ones((self.batch_size, self.candidate_seq_len)), decimal=5
)
def test_retrieve_top_k(self):
"""Test top-k retrieval through __call__."""
def forward(batch, embeddings, corpus_embeddings, top_k):
model = self.config.make()
return model(batch, embeddings, corpus_embeddings, top_k)
forward_fn = hk.without_apply_rng(hk.transform(forward))
batch, embeddings = self._create_test_batch()
corpus_embeddings, _ = self._create_test_corpus()
rng = jax.random.PRNGKey(0)
params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))
self.assertTrue(jnp.all(output.top_k_indices >= 0))
self.assertTrue(jnp.all(output.top_k_indices < self.corpus_size))
for b in range(self.batch_size):
scores = np.array(output.top_k_scores[b])
self.assertTrue(np.all(scores[:-1] >= scores[1:]))
class TestRetrievalInferenceRunner(unittest.TestCase):
"""Tests for the retrieval inference runner."""
def setUp(self):
"""Set up test fixtures."""
self.emb_size = 64
self.history_seq_len = 16
self.candidate_seq_len = 8
self.batch_size = 2
self.num_actions = 19
self.hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
self.config = PhoenixRetrievalModelConfig(
emb_size=self.emb_size,
history_seq_len=self.history_seq_len,
candidate_seq_len=self.candidate_seq_len,
hash_config=self.hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=self.emb_size,
widening_factor=2,
key_size=32,
num_q_heads=2,
num_kv_heads=2,
num_layers=1,
attn_output_multiplier=0.125,
),
)
def test_runner_initialization(self):
"""Test that runner initializes correctly."""
runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=self.config,
bs_per_device=0.125,
),
name="test_retrieval",
)
runner.initialize()
self.assertIsNotNone(runner.params)
def test_runner_encode_user(self):
"""Test user encoding through runner."""
runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=self.config,
bs_per_device=0.125,
),
name="test_retrieval",
)
runner.initialize()
batch, embeddings = create_example_batch(
batch_size=self.batch_size,
emb_size=self.emb_size,
history_len=self.history_seq_len,
num_candidates=self.candidate_seq_len,
num_actions=self.num_actions,
num_user_hashes=self.hash_config.num_user_hashes,
num_item_hashes=self.hash_config.num_item_hashes,
num_author_hashes=self.hash_config.num_author_hashes,
)
user_rep = runner.encode_user(batch, embeddings)
self.assertEqual(user_rep.shape, (self.batch_size, self.emb_size))
def test_runner_retrieve(self):
"""Test retrieval through runner."""
runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=self.config,
bs_per_device=0.125,
),
name="test_retrieval",
)
runner.initialize()
batch, embeddings = create_example_batch(
batch_size=self.batch_size,
emb_size=self.emb_size,
history_len=self.history_seq_len,
num_candidates=self.candidate_seq_len,
num_actions=self.num_actions,
num_user_hashes=self.hash_config.num_user_hashes,
num_item_hashes=self.hash_config.num_item_hashes,
num_author_hashes=self.hash_config.num_author_hashes,
)
corpus_size = 100
corpus_embeddings, corpus_post_ids = create_example_corpus(corpus_size, self.emb_size)
runner.set_corpus(corpus_embeddings, corpus_post_ids)
top_k = 10
output = runner.retrieve(batch, embeddings, top_k=top_k)
self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
self.assertEqual(output.top_k_indices.shape, (self.batch_size, top_k))
self.assertEqual(output.top_k_scores.shape, (self.batch_size, top_k))
if __name__ == "__main__":
unittest.main()

372
phoenix/uv.lock Normal file
View File

@ -0,0 +1,372 @@
version = 1
revision = 3
requires-python = ">=3.11"
resolution-markers = [
"python_full_version >= '3.13' and sys_platform == 'darwin'",
"python_full_version == '3.12.*' and sys_platform == 'darwin'",
"python_full_version < '3.12' and sys_platform == 'darwin'",
"python_full_version >= '3.13' and sys_platform == 'linux'",
"python_full_version == '3.12.*' and sys_platform == 'linux'",
"python_full_version < '3.12' and sys_platform == 'linux'",
]
supported-markers = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
]
[[package]]
name = "absl-py"
version = "2.3.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" },
]
[[package]]
name = "dm-haiku"
version = "0.0.16"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "absl-py", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "jmp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tabulate", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/2a/fc/daf4689198f4c0af8b71611f39fcd5d68ce0ae59fa919b9e58192a7d70f5/dm_haiku-0.0.16.tar.gz", hash = "sha256:1830b0ce63c5cef2fb3a63a13033c9d8f612ee7f896f2b0b25a6ba484f5fad28", size = 263092, upload-time = "2025-12-17T15:55:35.145Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/14/91/0f53835d0292a74e6b37e68125b669827e2a75a26e01c34741d6c13cca6c/dm_haiku-0.0.16-py3-none-any.whl", hash = "sha256:cc355d4d5aaa85af20e5a23ccd278bc751232ac8e5971261bed39318c07d744f", size = 374267, upload-time = "2025-12-17T15:55:33.9Z" },
]
[[package]]
name = "grok-1"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "dm-haiku", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "jax", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyright", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[package.dev-dependencies]
dev = [
{ name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[package.metadata]
requires-dist = [
{ name = "dm-haiku", specifier = ">=0.0.13" },
{ name = "jax", specifier = "==0.8.1" },
{ name = "numpy", specifier = ">=1.26.4" },
{ name = "pyright", specifier = ">=1.1.408" },
]
[package.metadata.requires-dev]
dev = [{ name = "pytest" }]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]]
name = "jax"
version = "0.8.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jaxlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "ml-dtypes", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opt-einsum", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/32/82/84fd2c662e4d410a34b0402de9b56bb69d7f72d1b875c3ae0edc07df18cc/jax-0.8.1.tar.gz", hash = "sha256:e53f67b15315f5e154851a7fd77a192b59c6c75b3f7ac56e214296765391cca7", size = 2509320, upload-time = "2025-11-18T19:50:02.609Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/e7/19b8cfc8963b2e10a01a4db7bb27ec5fa39ecd024bc62f8e2d1de5625a9d/jax-0.8.1-py3-none-any.whl", hash = "sha256:4cbdc5548f3095cdd69d38e4337950b2fc1f250a740a0234d190e4a319077564", size = 2922137, upload-time = "2025-11-18T19:47:43.693Z" },
]
[[package]]
name = "jaxlib"
version = "0.8.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ml-dtypes", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/fe/8b/9babcf487c6f1b533bca9611124c4d9593367c058a96d326c7e70db7d334/jaxlib-0.8.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f", size = 55719927, upload-time = "2025-11-18T19:48:42.679Z" },
{ url = "https://files.pythonhosted.org/packages/df/0c/b8c67272647ea151b0ac651e43faa846b4987d971058683dcce8abf68bca/jaxlib-0.8.1-cp311-cp311-manylinux_2_27_aarch64.whl", hash = "sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f", size = 74208199, upload-time = "2025-11-18T19:48:45.848Z" },
{ url = "https://files.pythonhosted.org/packages/8f/d0/5b83d614eddb58a2cc97fb948bfeb84509b90da04e808273bf9ae89ad6c1/jaxlib-0.8.1-cp311-cp311-manylinux_2_27_x86_64.whl", hash = "sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6", size = 80247963, upload-time = "2025-11-18T19:48:49.443Z" },
{ url = "https://files.pythonhosted.org/packages/d9/9d/59b36e2f348e599d5812743f263ca54aa03be1a4c9dfc11504d19864b72d/jaxlib-0.8.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f", size = 55728156, upload-time = "2025-11-18T19:48:56.254Z" },
{ url = "https://files.pythonhosted.org/packages/7e/73/2aa891de9f5f4c60ba3c63bda97ec4ace50ffb900ff3bf750ce42c514a3b/jaxlib-0.8.1-cp312-cp312-manylinux_2_27_aarch64.whl", hash = "sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9", size = 74209108, upload-time = "2025-11-18T19:48:59.572Z" },
{ url = "https://files.pythonhosted.org/packages/eb/4b/3c7e373d81219ee7493c1581c85a926c413ddeb3794cff87a37023a337e4/jaxlib-0.8.1-cp312-cp312-manylinux_2_27_x86_64.whl", hash = "sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6", size = 80256943, upload-time = "2025-11-18T19:49:02.92Z" },
{ url = "https://files.pythonhosted.org/packages/f8/67/97c62849b5d8fc075f902201ff136ad224a2ef113d1fa655ece0ffe8b2a4/jaxlib-0.8.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539", size = 55726611, upload-time = "2025-11-18T19:49:09.162Z" },
{ url = "https://files.pythonhosted.org/packages/fd/2a/9fb7599e43d66958b6a9859e045b605afea31f7fd96cfa35a7a8e978b0f8/jaxlib-0.8.1-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab", size = 74207596, upload-time = "2025-11-18T19:49:12.39Z" },
{ url = "https://files.pythonhosted.org/packages/7d/61/ab5c98641e15f9844dd49efbf6f22c6a9c5d17304319e5be8c51a1dfd088/jaxlib-0.8.1-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b", size = 80254560, upload-time = "2025-11-18T19:49:16.172Z" },
{ url = "https://files.pythonhosted.org/packages/97/65/e7c625f1fdb54d45ac248d8398a28d6c02528c31feaa6e1c146a08192d77/jaxlib-0.8.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f", size = 55835933, upload-time = "2025-11-18T19:49:27.362Z" },
{ url = "https://files.pythonhosted.org/packages/1f/04/e09ff7b5ba0af93501cb196c65103a30e5050083203c1ff581f18718a356/jaxlib-0.8.1-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0", size = 74323389, upload-time = "2025-11-18T19:49:30.457Z" },
{ url = "https://files.pythonhosted.org/packages/44/9f/8b7f6ad9eebf8946e73049dae85f86544f5743bc8b2190898415646fa7ec/jaxlib-0.8.1-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3", size = 80358249, upload-time = "2025-11-18T19:49:33.682Z" },
{ url = "https://files.pythonhosted.org/packages/47/6d/75943de28285afcc8d62e89c3e0efc0abdb7e7a72a9e967c3555fc9a35af/jaxlib-0.8.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8", size = 55729587, upload-time = "2025-11-18T19:49:36.952Z" },
{ url = "https://files.pythonhosted.org/packages/2c/ce/9e68ca9f646039d687a94066a5e3e195fc70cebdfbe44945b3c53ceed321/jaxlib-0.8.1-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375", size = 74222294, upload-time = "2025-11-18T19:49:40.418Z" },
{ url = "https://files.pythonhosted.org/packages/3c/0f/988a413cbf610610cb14783a6e0964a854d0f388ccafe9b4e61c2c188b88/jaxlib-0.8.1-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0", size = 80268801, upload-time = "2025-11-18T19:49:44.943Z" },
{ url = "https://files.pythonhosted.org/packages/07/9b/f6f01d79f519b0cbd09a6c751844b1e0294fc53ea0b09882466b21169ea5/jaxlib-0.8.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81", size = 55834325, upload-time = "2025-11-18T19:49:52.541Z" },
{ url = "https://files.pythonhosted.org/packages/61/c7/13d13a6f0b0d2e91431d6a031129d51ea4b23af23bb947882234ed003f09/jaxlib-0.8.1-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389", size = 74320131, upload-time = "2025-11-18T19:49:56.208Z" },
{ url = "https://files.pythonhosted.org/packages/cd/8a/6cad418c0f11ce0cffa2b74b81fb76e6cf30247288fea75a372b6b163f2e/jaxlib-0.8.1-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a", size = 80360481, upload-time = "2025-11-18T19:50:00.065Z" },
]
[[package]]
name = "jmp"
version = "0.0.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ab/b0/e90fbbffef4b345329c878a69f0336d3edc5a1f9fcba193931aca2132d62/jmp-0.0.4.tar.gz", hash = "sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730", size = 18582, upload-time = "2023-01-30T12:47:13.634Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/27/e5/cce82de2831e5aff9332d8d624bb57188f1b2af6ccf6979caf898a8a4348/jmp-0.0.4-py3-none-any.whl", hash = "sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d", size = 18274, upload-time = "2023-01-30T12:47:11.931Z" },
]
[[package]]
name = "ml-dtypes"
version = "0.5.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/5e/712092cfe7e5eb667b8ad9ca7c54442f21ed7ca8979745f1000e24cf8737/ml_dtypes-0.5.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90", size = 679734, upload-time = "2025-11-17T22:31:39.223Z" },
{ url = "https://files.pythonhosted.org/packages/4f/cf/912146dfd4b5c0eea956836c01dcd2fce6c9c844b2691f5152aca196ce4f/ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040", size = 5056165, upload-time = "2025-11-17T22:31:41.071Z" },
{ url = "https://files.pythonhosted.org/packages/a9/80/19189ea605017473660e43762dc853d2797984b3c7bf30ce656099add30c/ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483", size = 5034975, upload-time = "2025-11-17T22:31:42.758Z" },
{ url = "https://files.pythonhosted.org/packages/a8/b8/3c70881695e056f8a32f8b941126cf78775d9a4d7feba8abcb52cb7b04f2/ml_dtypes-0.5.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac", size = 676927, upload-time = "2025-11-17T22:31:48.182Z" },
{ url = "https://files.pythonhosted.org/packages/54/0f/428ef6881782e5ebb7eca459689448c0394fa0a80bea3aa9262cba5445ea/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900", size = 5028464, upload-time = "2025-11-17T22:31:50.135Z" },
{ url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" },
{ url = "https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48", size = 676888, upload-time = "2025-11-17T22:31:56.907Z" },
{ url = "https://files.pythonhosted.org/packages/d3/b7/dff378afc2b0d5a7d6cd9d3209b60474d9819d1189d347521e1688a60a53/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b", size = 5036993, upload-time = "2025-11-17T22:31:58.497Z" },
{ url = "https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d", size = 5010956, upload-time = "2025-11-17T22:31:59.931Z" },
{ url = "https://files.pythonhosted.org/packages/4f/74/e9ddb35fd1dd43b1106c20ced3f53c2e8e7fc7598c15638e9f80677f81d4/ml_dtypes-0.5.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6", size = 702083, upload-time = "2025-11-17T22:32:04.08Z" },
{ url = "https://files.pythonhosted.org/packages/74/f5/667060b0aed1aa63166b22897fdf16dca9eb704e6b4bbf86848d5a181aa7/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d", size = 5354111, upload-time = "2025-11-17T22:32:05.546Z" },
{ url = "https://files.pythonhosted.org/packages/40/49/0f8c498a28c0efa5f5c95a9e374c83ec1385ca41d0e85e7cf40e5d519a21/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298", size = 5366453, upload-time = "2025-11-17T22:32:07.115Z" },
{ url = "https://files.pythonhosted.org/packages/72/4e/1339dc6e2557a344f5ba5590872e80346f76f6cb2ac3dd16e4666e88818c/ml_dtypes-0.5.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22", size = 673781, upload-time = "2025-11-17T22:32:11.364Z" },
{ url = "https://files.pythonhosted.org/packages/04/f9/067b84365c7e83bda15bba2b06c6ca250ce27b20630b1128c435fb7a09aa/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465", size = 5036145, upload-time = "2025-11-17T22:32:12.783Z" },
{ url = "https://files.pythonhosted.org/packages/c6/bb/82c7dcf38070b46172a517e2334e665c5bf374a262f99a283ea454bece7c/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f", size = 5010230, upload-time = "2025-11-17T22:32:14.38Z" },
{ url = "https://files.pythonhosted.org/packages/cd/02/48aa7d84cc30ab4ee37624a2fd98c56c02326785750cd212bc0826c2f15b/ml_dtypes-0.5.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9", size = 702085, upload-time = "2025-11-17T22:32:18.175Z" },
{ url = "https://files.pythonhosted.org/packages/5a/e7/85cb99fe80a7a5513253ec7faa88a65306be071163485e9a626fce1b6e84/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7", size = 5355358, upload-time = "2025-11-17T22:32:19.7Z" },
{ url = "https://files.pythonhosted.org/packages/79/2b/a826ba18d2179a56e144aef69e57fb2ab7c464ef0b2111940ee8a3a223a2/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf", size = 5366332, upload-time = "2025-11-17T22:32:21.193Z" },
]
[[package]]
name = "nodeenv"
version = "1.10.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" },
]
[[package]]
name = "numpy"
version = "2.4.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/24/62/ae72ff66c0f1fd959925b4c11f8c2dea61f47f6acaea75a08512cdfe3fed/numpy-2.4.1.tar.gz", hash = "sha256:a1ceafc5042451a858231588a104093474c6a5c57dcc724841f5c888d237d690", size = 20721320, upload-time = "2026-01-10T06:44:59.619Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a5/34/2b1bc18424f3ad9af577f6ce23600319968a70575bd7db31ce66731bbef9/numpy-2.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0cce2a669e3c8ba02ee563c7835f92c153cf02edff1ae05e1823f1dde21b16a5", size = 16944563, upload-time = "2026-01-10T06:42:14.615Z" },
{ url = "https://files.pythonhosted.org/packages/2c/57/26e5f97d075aef3794045a6ca9eada6a4ed70eb9a40e7a4a93f9ac80d704/numpy-2.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:899d2c18024984814ac7e83f8f49d8e8180e2fbe1b2e252f2e7f1d06bea92425", size = 12645658, upload-time = "2026-01-10T06:42:17.298Z" },
{ url = "https://files.pythonhosted.org/packages/8e/ba/80fc0b1e3cb2fd5c6143f00f42eb67762aa043eaa05ca924ecc3222a7849/numpy-2.4.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:09aa8a87e45b55a1c2c205d42e2808849ece5c484b2aab11fecabec3841cafba", size = 5474132, upload-time = "2026-01-10T06:42:19.637Z" },
{ url = "https://files.pythonhosted.org/packages/40/ae/0a5b9a397f0e865ec171187c78d9b57e5588afc439a04ba9cab1ebb2c945/numpy-2.4.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:edee228f76ee2dab4579fad6f51f6a305de09d444280109e0f75df247ff21501", size = 6804159, upload-time = "2026-01-10T06:42:21.44Z" },
{ url = "https://files.pythonhosted.org/packages/86/9c/841c15e691c7085caa6fd162f063eff494099c8327aeccd509d1ab1e36ab/numpy-2.4.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a92f227dbcdc9e4c3e193add1a189a9909947d4f8504c576f4a732fd0b54240a", size = 14708058, upload-time = "2026-01-10T06:42:23.546Z" },
{ url = "https://files.pythonhosted.org/packages/5d/9d/7862db06743f489e6a502a3b93136d73aea27d97b2cf91504f70a27501d6/numpy-2.4.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:538bf4ec353709c765ff75ae616c34d3c3dca1a68312727e8f2676ea644f8509", size = 16651501, upload-time = "2026-01-10T06:42:25.909Z" },
{ url = "https://files.pythonhosted.org/packages/a6/9c/6fc34ebcbd4015c6e5f0c0ce38264010ce8a546cb6beacb457b84a75dfc8/numpy-2.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ac08c63cb7779b85e9d5318e6c3518b424bc1f364ac4cb2c6136f12e5ff2dccc", size = 16492627, upload-time = "2026-01-10T06:42:28.938Z" },
{ url = "https://files.pythonhosted.org/packages/aa/63/2494a8597502dacda439f61b3c0db4da59928150e62be0e99395c3ad23c5/numpy-2.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4f9c360ecef085e5841c539a9a12b883dff005fbd7ce46722f5e9cef52634d82", size = 18585052, upload-time = "2026-01-10T06:42:31.312Z" },
{ url = "https://files.pythonhosted.org/packages/78/7f/ec53e32bf10c813604edf07a3682616bd931d026fcde7b6d13195dfb684a/numpy-2.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d3703409aac693fa82c0aee023a1ae06a6e9d065dba10f5e8e80f642f1e9d0a2", size = 16656888, upload-time = "2026-01-10T06:42:40.913Z" },
{ url = "https://files.pythonhosted.org/packages/b8/e0/1f9585d7dae8f14864e948fd7fa86c6cb72dee2676ca2748e63b1c5acfe0/numpy-2.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7211b95ca365519d3596a1d8688a95874cc94219d417504d9ecb2df99fa7bfa8", size = 12373956, upload-time = "2026-01-10T06:42:43.091Z" },
{ url = "https://files.pythonhosted.org/packages/8e/43/9762e88909ff2326f5e7536fa8cb3c49fb03a7d92705f23e6e7f553d9cb3/numpy-2.4.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5adf01965456a664fc727ed69cc71848f28d063217c63e1a0e200a118d5eec9a", size = 5202567, upload-time = "2026-01-10T06:42:45.107Z" },
{ url = "https://files.pythonhosted.org/packages/4b/ee/34b7930eb61e79feb4478800a4b95b46566969d837546aa7c034c742ef98/numpy-2.4.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:26f0bcd9c79a00e339565b303badc74d3ea2bd6d52191eeca5f95936cad107d0", size = 6549459, upload-time = "2026-01-10T06:42:48.152Z" },
{ url = "https://files.pythonhosted.org/packages/79/e3/5f115fae982565771be994867c89bcd8d7208dbfe9469185497d70de5ddf/numpy-2.4.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0093e85df2960d7e4049664b26afc58b03236e967fb942354deef3208857a04c", size = 14404859, upload-time = "2026-01-10T06:42:49.947Z" },
{ url = "https://files.pythonhosted.org/packages/d9/7d/9c8a781c88933725445a859cac5d01b5871588a15969ee6aeb618ba99eee/numpy-2.4.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ad270f438cbdd402c364980317fb6b117d9ec5e226fff5b4148dd9aa9fc6e02", size = 16371419, upload-time = "2026-01-10T06:42:52.409Z" },
{ url = "https://files.pythonhosted.org/packages/a6/d2/8aa084818554543f17cf4162c42f162acbd3bb42688aefdba6628a859f77/numpy-2.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:297c72b1b98100c2e8f873d5d35fb551fce7040ade83d67dd51d38c8d42a2162", size = 16182131, upload-time = "2026-01-10T06:42:54.694Z" },
{ url = "https://files.pythonhosted.org/packages/60/db/0425216684297c58a8df35f3284ef56ec4a043e6d283f8a59c53562caf1b/numpy-2.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cf6470d91d34bf669f61d515499859fa7a4c2f7c36434afb70e82df7217933f9", size = 18295342, upload-time = "2026-01-10T06:42:56.991Z" },
{ url = "https://files.pythonhosted.org/packages/04/68/732d4b7811c00775f3bd522a21e8dd5a23f77eb11acdeb663e4a4ebf0ef4/numpy-2.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d797454e37570cfd61143b73b8debd623c3c0952959adb817dd310a483d58a1b", size = 16652495, upload-time = "2026-01-10T06:43:06.283Z" },
{ url = "https://files.pythonhosted.org/packages/20/ca/857722353421a27f1465652b2c66813eeeccea9d76d5f7b74b99f298e60e/numpy-2.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:82c55962006156aeef1629b953fd359064aa47e4d82cfc8e67f0918f7da3344f", size = 12368657, upload-time = "2026-01-10T06:43:09.094Z" },
{ url = "https://files.pythonhosted.org/packages/81/0d/2377c917513449cc6240031a79d30eb9a163d32a91e79e0da47c43f2c0c8/numpy-2.4.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:71abbea030f2cfc3092a0ff9f8c8fdefdc5e0bf7d9d9c99663538bb0ecdac0b9", size = 5197256, upload-time = "2026-01-10T06:43:13.634Z" },
{ url = "https://files.pythonhosted.org/packages/17/39/569452228de3f5de9064ac75137082c6214be1f5c532016549a7923ab4b5/numpy-2.4.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5b55aa56165b17aaf15520beb9cbd33c9039810e0d9643dd4379e44294c7303e", size = 6545212, upload-time = "2026-01-10T06:43:15.661Z" },
{ url = "https://files.pythonhosted.org/packages/8c/a4/77333f4d1e4dac4395385482557aeecf4826e6ff517e32ca48e1dafbe42a/numpy-2.4.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0faba4a331195bfa96f93dd9dfaa10b2c7aa8cda3a02b7fd635e588fe821bf5", size = 14402871, upload-time = "2026-01-10T06:43:17.324Z" },
{ url = "https://files.pythonhosted.org/packages/ba/87/d341e519956273b39d8d47969dd1eaa1af740615394fe67d06f1efa68773/numpy-2.4.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e3087f53e2b4428766b54932644d148613c5a595150533ae7f00dab2f319a8", size = 16359305, upload-time = "2026-01-10T06:43:19.376Z" },
{ url = "https://files.pythonhosted.org/packages/32/91/789132c6666288eaa20ae8066bb99eba1939362e8f1a534949a215246e97/numpy-2.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:49e792ec351315e16da54b543db06ca8a86985ab682602d90c60ef4ff4db2a9c", size = 16181909, upload-time = "2026-01-10T06:43:21.808Z" },
{ url = "https://files.pythonhosted.org/packages/cf/b8/090b8bd27b82a844bb22ff8fdf7935cb1980b48d6e439ae116f53cdc2143/numpy-2.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:79e9e06c4c2379db47f3f6fc7a8652e7498251789bf8ff5bd43bf478ef314ca2", size = 18284380, upload-time = "2026-01-10T06:43:23.957Z" },
{ url = "https://files.pythonhosted.org/packages/da/a1/354583ac5c4caa566de6ddfbc42744409b515039e085fab6e0ff942e0df5/numpy-2.4.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f93bc6892fe7b0663e5ffa83b61aab510aacffd58c16e012bb9352d489d90cb7", size = 12496156, upload-time = "2026-01-10T06:43:34.237Z" },
{ url = "https://files.pythonhosted.org/packages/51/b0/42807c6e8cce58c00127b1dc24d365305189991f2a7917aa694a109c8d7d/numpy-2.4.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:178de8f87948163d98a4c9ab5bee4ce6519ca918926ec8df195af582de28544d", size = 5324663, upload-time = "2026-01-10T06:43:36.211Z" },
{ url = "https://files.pythonhosted.org/packages/fe/55/7a621694010d92375ed82f312b2f28017694ed784775269115323e37f5e2/numpy-2.4.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:98b35775e03ab7f868908b524fc0a84d38932d8daf7b7e1c3c3a1b6c7a2c9f15", size = 6645224, upload-time = "2026-01-10T06:43:37.884Z" },
{ url = "https://files.pythonhosted.org/packages/50/96/9fa8635ed9d7c847d87e30c834f7109fac5e88549d79ef3324ab5c20919f/numpy-2.4.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:941c2a93313d030f219f3a71fd3d91a728b82979a5e8034eb2e60d394a2b83f9", size = 14462352, upload-time = "2026-01-10T06:43:39.479Z" },
{ url = "https://files.pythonhosted.org/packages/03/d1/8cf62d8bb2062da4fb82dd5d49e47c923f9c0738032f054e0a75342faba7/numpy-2.4.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:529050522e983e00a6c1c6b67411083630de8b57f65e853d7b03d9281b8694d2", size = 16407279, upload-time = "2026-01-10T06:43:41.93Z" },
{ url = "https://files.pythonhosted.org/packages/86/1c/95c86e17c6b0b31ce6ef219da00f71113b220bcb14938c8d9a05cee0ff53/numpy-2.4.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2302dc0224c1cbc49bb94f7064f3f923a971bfae45c33870dcbff63a2a550505", size = 16248316, upload-time = "2026-01-10T06:43:44.121Z" },
{ url = "https://files.pythonhosted.org/packages/30/b4/e7f5ff8697274c9d0fa82398b6a372a27e5cef069b37df6355ccb1f1db1a/numpy-2.4.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:9171a42fcad32dcf3fa86f0a4faa5e9f8facefdb276f54b8b390d90447cff4e2", size = 18329884, upload-time = "2026-01-10T06:43:46.613Z" },
{ url = "https://files.pythonhosted.org/packages/1b/a7/ef08d25698e0e4b4efbad8d55251d20fe2a15f6d9aa7c9b30cd03c165e6f/numpy-2.4.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:3869ea1ee1a1edc16c29bbe3a2f2a4e515cc3a44d43903ad41e0cacdbaf733dc", size = 16652046, upload-time = "2026-01-10T06:43:54.797Z" },
{ url = "https://files.pythonhosted.org/packages/8f/39/e378b3e3ca13477e5ac70293ec027c438d1927f18637e396fe90b1addd72/numpy-2.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e867df947d427cdd7a60e3e271729090b0f0df80f5f10ab7dd436f40811699c3", size = 12378858, upload-time = "2026-01-10T06:43:57.099Z" },
{ url = "https://files.pythonhosted.org/packages/c3/74/7ec6154f0006910ed1fdbb7591cf4432307033102b8a22041599935f8969/numpy-2.4.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:e3bd2cb07841166420d2fa7146c96ce00cb3410664cbc1a6be028e456c4ee220", size = 5207417, upload-time = "2026-01-10T06:43:59.037Z" },
{ url = "https://files.pythonhosted.org/packages/f7/b7/053ac11820d84e42f8feea5cb81cc4fcd1091499b45b1ed8c7415b1bf831/numpy-2.4.1-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:f0a90aba7d521e6954670550e561a4cb925713bd944445dbe9e729b71f6cabee", size = 6542643, upload-time = "2026-01-10T06:44:01.852Z" },
{ url = "https://files.pythonhosted.org/packages/c0/c4/2e7908915c0e32ca636b92e4e4a3bdec4cb1e7eb0f8aedf1ed3c68a0d8cd/numpy-2.4.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d558123217a83b2d1ba316b986e9248a1ed1971ad495963d555ccd75dcb1556", size = 14418963, upload-time = "2026-01-10T06:44:04.047Z" },
{ url = "https://files.pythonhosted.org/packages/eb/c0/3ed5083d94e7ffd7c404e54619c088e11f2e1939a9544f5397f4adb1b8ba/numpy-2.4.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f44de05659b67d20499cbc96d49f2650769afcb398b79b324bb6e297bfe3844", size = 16363811, upload-time = "2026-01-10T06:44:06.207Z" },
{ url = "https://files.pythonhosted.org/packages/0e/68/42b66f1852bf525050a67315a4fb94586ab7e9eaa541b1bef530fab0c5dd/numpy-2.4.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:69e7419c9012c4aaf695109564e3387f1259f001b4326dfa55907b098af082d3", size = 16197643, upload-time = "2026-01-10T06:44:08.33Z" },
{ url = "https://files.pythonhosted.org/packages/d2/40/e8714fc933d85f82c6bfc7b998a0649ad9769a32f3494ba86598aaf18a48/numpy-2.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2ffd257026eb1b34352e749d7cc1678b5eeec3e329ad8c9965a797e08ccba205", size = 18289601, upload-time = "2026-01-10T06:44:10.841Z" },
{ url = "https://files.pythonhosted.org/packages/de/bc/ea3f2c96fcb382311827231f911723aeff596364eb6e1b6d1d91128aa29b/numpy-2.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:4e53170557d37ae404bf8d542ca5b7c629d6efa1117dac6a83e394142ea0a43f", size = 12498774, upload-time = "2026-01-10T06:44:19.467Z" },
{ url = "https://files.pythonhosted.org/packages/aa/ab/ef9d939fe4a812648c7a712610b2ca6140b0853c5efea361301006c02ae5/numpy-2.4.1-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:a73044b752f5d34d4232f25f18160a1cc418ea4507f5f11e299d8ac36875f8a0", size = 5327274, upload-time = "2026-01-10T06:44:23.189Z" },
{ url = "https://files.pythonhosted.org/packages/bd/31/d381368e2a95c3b08b8cf7faac6004849e960f4a042d920337f71cef0cae/numpy-2.4.1-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:fb1461c99de4d040666ca0444057b06541e5642f800b71c56e6ea92d6a853a0c", size = 6648306, upload-time = "2026-01-10T06:44:25.012Z" },
{ url = "https://files.pythonhosted.org/packages/c8/e5/0989b44ade47430be6323d05c23207636d67d7362a1796ccbccac6773dd2/numpy-2.4.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:423797bdab2eeefbe608d7c1ec7b2b4fd3c58d51460f1ee26c7500a1d9c9ee93", size = 14464653, upload-time = "2026-01-10T06:44:26.706Z" },
{ url = "https://files.pythonhosted.org/packages/10/a7/cfbe475c35371cae1358e61f20c5f075badc18c4797ab4354140e1d283cf/numpy-2.4.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52b5f61bdb323b566b528899cc7db2ba5d1015bda7ea811a8bcf3c89c331fa42", size = 16405144, upload-time = "2026-01-10T06:44:29.378Z" },
{ url = "https://files.pythonhosted.org/packages/f8/a3/0c63fe66b534888fa5177cc7cef061541064dbe2b4b60dcc60ffaf0d2157/numpy-2.4.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:42d7dd5fa36d16d52a84f821eb96031836fd405ee6955dd732f2023724d0aa01", size = 16247425, upload-time = "2026-01-10T06:44:31.721Z" },
{ url = "https://files.pythonhosted.org/packages/6b/2b/55d980cfa2c93bd40ff4c290bf824d792bd41d2fe3487b07707559071760/numpy-2.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e7b6b5e28bbd47b7532698e5db2fe1db693d84b58c254e4389d99a27bb9b8f6b", size = 18330053, upload-time = "2026-01-10T06:44:34.617Z" },
{ url = "https://files.pythonhosted.org/packages/1e/48/d86f97919e79314a1cdee4c832178763e6e98e623e123d0bada19e92c15a/numpy-2.4.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8ad35f20be147a204e28b6a0575fbf3540c5e5f802634d4258d55b1ff5facce1", size = 16822202, upload-time = "2026-01-10T06:44:43.738Z" },
{ url = "https://files.pythonhosted.org/packages/51/e9/1e62a7f77e0f37dcfb0ad6a9744e65df00242b6ea37dfafb55debcbf5b55/numpy-2.4.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:8097529164c0f3e32bb89412a0905d9100bf434d9692d9fc275e18dcf53c9344", size = 12569985, upload-time = "2026-01-10T06:44:45.945Z" },
{ url = "https://files.pythonhosted.org/packages/c7/7e/914d54f0c801342306fdcdce3e994a56476f1b818c46c47fc21ae968088c/numpy-2.4.1-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:ea66d2b41ca4a1630aae5507ee0a71647d3124d1741980138aa8f28f44dac36e", size = 5398484, upload-time = "2026-01-10T06:44:48.012Z" },
{ url = "https://files.pythonhosted.org/packages/1c/d8/9570b68584e293a33474e7b5a77ca404f1dcc655e40050a600dee81d27fb/numpy-2.4.1-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:d3f8f0df9f4b8be57b3bf74a1d087fec68f927a2fab68231fdb442bf2c12e426", size = 6713216, upload-time = "2026-01-10T06:44:49.725Z" },
{ url = "https://files.pythonhosted.org/packages/33/9b/9dd6e2db8d49eb24f86acaaa5258e5f4c8ed38209a4ee9de2d1a0ca25045/numpy-2.4.1-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2023ef86243690c2791fd6353e5b4848eedaa88ca8a2d129f462049f6d484696", size = 14538937, upload-time = "2026-01-10T06:44:51.498Z" },
{ url = "https://files.pythonhosted.org/packages/53/87/d5bd995b0f798a37105b876350d346eea5838bd8f77ea3d7a48392f3812b/numpy-2.4.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8361ea4220d763e54cff2fbe7d8c93526b744f7cd9ddab47afeff7e14e8503be", size = 16479830, upload-time = "2026-01-10T06:44:53.931Z" },
]
[[package]]
name = "opt-einsum"
version = "3.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004, upload-time = "2024-09-26T14:33:24.483Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" },
]
[[package]]
name = "packaging"
version = "25.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pygments"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
[[package]]
name = "pyright"
version = "1.1.408"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nodeenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/74/b2/5db700e52554b8f025faa9c3c624c59f1f6c8841ba81ab97641b54322f16/pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684", size = 4400578, upload-time = "2026-01-08T08:07:38.795Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1", size = 6399144, upload-time = "2026-01-08T08:07:37.082Z" },
]
[[package]]
name = "pytest"
version = "9.0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "iniconfig", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pluggy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pygments", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
]
[[package]]
name = "scipy"
version = "1.17.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/56/3e/9cca699f3486ce6bc12ff46dc2031f1ec8eb9ccc9a320fdaf925f1417426/scipy-1.17.0.tar.gz", hash = "sha256:2591060c8e648d8b96439e111ac41fd8342fdeff1876be2e19dea3fe8930454e", size = 30396830, upload-time = "2026-01-10T21:34:23.009Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/4b/c89c131aa87cad2b77a54eb0fb94d633a842420fa7e919dc2f922037c3d8/scipy-1.17.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2abd71643797bd8a106dff97894ff7869eeeb0af0f7a5ce02e4227c6a2e9d6fd", size = 31381316, upload-time = "2026-01-10T21:24:33.42Z" },
{ url = "https://files.pythonhosted.org/packages/5e/5f/a6b38f79a07d74989224d5f11b55267714707582908a5f1ae854cf9a9b84/scipy-1.17.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:ef28d815f4d2686503e5f4f00edc387ae58dfd7a2f42e348bb53359538f01558", size = 27966760, upload-time = "2026-01-10T21:24:38.911Z" },
{ url = "https://files.pythonhosted.org/packages/c1/20/095ad24e031ee8ed3c5975954d816b8e7e2abd731e04f8be573de8740885/scipy-1.17.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:272a9f16d6bb4667e8b50d25d71eddcc2158a214df1b566319298de0939d2ab7", size = 20138701, upload-time = "2026-01-10T21:24:43.249Z" },
{ url = "https://files.pythonhosted.org/packages/89/11/4aad2b3858d0337756f3323f8960755704e530b27eb2a94386c970c32cbe/scipy-1.17.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:7204fddcbec2fe6598f1c5fdf027e9f259106d05202a959a9f1aecf036adc9f6", size = 22480574, upload-time = "2026-01-10T21:24:47.266Z" },
{ url = "https://files.pythonhosted.org/packages/85/bd/f5af70c28c6da2227e510875cadf64879855193a687fb19951f0f44cfd6b/scipy-1.17.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fc02c37a5639ee67d8fb646ffded6d793c06c5622d36b35cfa8fe5ececb8f042", size = 32862414, upload-time = "2026-01-10T21:24:52.566Z" },
{ url = "https://files.pythonhosted.org/packages/ef/df/df1457c4df3826e908879fe3d76bc5b6e60aae45f4ee42539512438cfd5d/scipy-1.17.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dac97a27520d66c12a34fd90a4fe65f43766c18c0d6e1c0a80f114d2260080e4", size = 35112380, upload-time = "2026-01-10T21:24:58.433Z" },
{ url = "https://files.pythonhosted.org/packages/5f/bb/88e2c16bd1dd4de19d80d7c5e238387182993c2fb13b4b8111e3927ad422/scipy-1.17.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb7446a39b3ae0fe8f416a9a3fdc6fba3f11c634f680f16a239c5187bc487c0", size = 34922676, upload-time = "2026-01-10T21:25:04.287Z" },
{ url = "https://files.pythonhosted.org/packages/02/ba/5120242cc735f71fc002cff0303d536af4405eb265f7c60742851e7ccfe9/scipy-1.17.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:474da16199f6af66601a01546144922ce402cb17362e07d82f5a6cf8f963e449", size = 37507599, upload-time = "2026-01-10T21:25:09.851Z" },
{ url = "https://files.pythonhosted.org/packages/0b/11/7241a63e73ba5a516f1930ac8d5b44cbbfabd35ac73a2d08ca206df007c4/scipy-1.17.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:0d5018a57c24cb1dd828bcf51d7b10e65986d549f52ef5adb6b4d1ded3e32a57", size = 31364580, upload-time = "2026-01-10T21:25:25.717Z" },
{ url = "https://files.pythonhosted.org/packages/ed/1d/5057f812d4f6adc91a20a2d6f2ebcdb517fdbc87ae3acc5633c9b97c8ba5/scipy-1.17.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:88c22af9e5d5a4f9e027e26772cc7b5922fab8bcc839edb3ae33de404feebd9e", size = 27969012, upload-time = "2026-01-10T21:25:30.921Z" },
{ url = "https://files.pythonhosted.org/packages/e3/21/f6ec556c1e3b6ec4e088da667d9987bb77cc3ab3026511f427dc8451187d/scipy-1.17.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f3cd947f20fe17013d401b64e857c6b2da83cae567adbb75b9dcba865abc66d8", size = 20140691, upload-time = "2026-01-10T21:25:34.802Z" },
{ url = "https://files.pythonhosted.org/packages/7a/fe/5e5ad04784964ba964a96f16c8d4676aa1b51357199014dce58ab7ec5670/scipy-1.17.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:e8c0b331c2c1f531eb51f1b4fc9ba709521a712cce58f1aa627bc007421a5306", size = 22463015, upload-time = "2026-01-10T21:25:39.277Z" },
{ url = "https://files.pythonhosted.org/packages/4a/69/7c347e857224fcaf32a34a05183b9d8a7aca25f8f2d10b8a698b8388561a/scipy-1.17.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5194c445d0a1c7a6c1a4a4681b6b7c71baad98ff66d96b949097e7513c9d6742", size = 32724197, upload-time = "2026-01-10T21:25:44.084Z" },
{ url = "https://files.pythonhosted.org/packages/d1/fe/66d73b76d378ba8cc2fe605920c0c75092e3a65ae746e1e767d9d020a75a/scipy-1.17.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9eeb9b5f5997f75507814ed9d298ab23f62cf79f5a3ef90031b1ee2506abdb5b", size = 35009148, upload-time = "2026-01-10T21:25:50.591Z" },
{ url = "https://files.pythonhosted.org/packages/af/07/07dec27d9dc41c18d8c43c69e9e413431d20c53a0339c388bcf72f353c4b/scipy-1.17.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:40052543f7bbe921df4408f46003d6f01c6af109b9e2c8a66dd1cf6cf57f7d5d", size = 34798766, upload-time = "2026-01-10T21:25:59.41Z" },
{ url = "https://files.pythonhosted.org/packages/81/61/0470810c8a093cdacd4ba7504b8a218fd49ca070d79eca23a615f5d9a0b0/scipy-1.17.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0cf46c8013fec9d3694dc572f0b54100c28405d55d3e2cb15e2895b25057996e", size = 37405953, upload-time = "2026-01-10T21:26:07.75Z" },
{ url = "https://files.pythonhosted.org/packages/0c/51/3468fdfd49387ddefee1636f5cf6d03ce603b75205bf439bbf0e62069bfd/scipy-1.17.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:65ec32f3d32dfc48c72df4291345dae4f048749bc8d5203ee0a3f347f96c5ce6", size = 31344101, upload-time = "2026-01-10T21:26:30.25Z" },
{ url = "https://files.pythonhosted.org/packages/b2/9a/9406aec58268d437636069419e6977af953d1e246df941d42d3720b7277b/scipy-1.17.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:1f9586a58039d7229ce77b52f8472c972448cded5736eaf102d5658bbac4c269", size = 27950385, upload-time = "2026-01-10T21:26:36.801Z" },
{ url = "https://files.pythonhosted.org/packages/4f/98/e7342709e17afdfd1b26b56ae499ef4939b45a23a00e471dfb5375eea205/scipy-1.17.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:9fad7d3578c877d606b1150135c2639e9de9cecd3705caa37b66862977cc3e72", size = 20122115, upload-time = "2026-01-10T21:26:42.107Z" },
{ url = "https://files.pythonhosted.org/packages/fd/0e/9eeeb5357a64fd157cbe0302c213517c541cc16b8486d82de251f3c68ede/scipy-1.17.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:423ca1f6584fc03936972b5f7c06961670dbba9f234e71676a7c7ccf938a0d61", size = 22442402, upload-time = "2026-01-10T21:26:48.029Z" },
{ url = "https://files.pythonhosted.org/packages/c9/10/be13397a0e434f98e0c79552b2b584ae5bb1c8b2be95db421533bbca5369/scipy-1.17.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fe508b5690e9eaaa9467fc047f833af58f1152ae51a0d0aed67aa5801f4dd7d6", size = 32696338, upload-time = "2026-01-10T21:26:55.521Z" },
{ url = "https://files.pythonhosted.org/packages/63/1e/12fbf2a3bb240161651c94bb5cdd0eae5d4e8cc6eaeceb74ab07b12a753d/scipy-1.17.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6680f2dfd4f6182e7d6db161344537da644d1cf85cf293f015c60a17ecf08752", size = 34977201, upload-time = "2026-01-10T21:27:03.501Z" },
{ url = "https://files.pythonhosted.org/packages/19/5b/1a63923e23ccd20bd32156d7dd708af5bbde410daa993aa2500c847ab2d2/scipy-1.17.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eec3842ec9ac9de5917899b277428886042a93db0b227ebbe3a333b64ec7643d", size = 34777384, upload-time = "2026-01-10T21:27:11.423Z" },
{ url = "https://files.pythonhosted.org/packages/39/22/b5da95d74edcf81e540e467202a988c50fef41bd2011f46e05f72ba07df6/scipy-1.17.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d7425fcafbc09a03731e1bc05581f5fad988e48c6a861f441b7ab729a49a55ea", size = 37379586, upload-time = "2026-01-10T21:27:20.171Z" },
{ url = "https://files.pythonhosted.org/packages/20/b6/7feaa252c21cc7aff335c6c55e1b90ab3e3306da3f048109b8b639b94648/scipy-1.17.0-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:ec0827aa4d36cb79ff1b81de898e948a51ac0b9b1c43e4a372c0508c38c0f9a3", size = 31693194, upload-time = "2026-01-10T21:27:27.454Z" },
{ url = "https://files.pythonhosted.org/packages/76/bb/bbb392005abce039fb7e672cb78ac7d158700e826b0515cab6b5b60c26fb/scipy-1.17.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:819fc26862b4b3c73a60d486dbb919202f3d6d98c87cf20c223511429f2d1a97", size = 28365415, upload-time = "2026-01-10T21:27:34.26Z" },
{ url = "https://files.pythonhosted.org/packages/37/da/9d33196ecc99fba16a409c691ed464a3a283ac454a34a13a3a57c0d66f3a/scipy-1.17.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:363ad4ae2853d88ebcde3ae6ec46ccca903ea9835ee8ba543f12f575e7b07e4e", size = 20537232, upload-time = "2026-01-10T21:27:40.306Z" },
{ url = "https://files.pythonhosted.org/packages/56/9d/f4b184f6ddb28e9a5caea36a6f98e8ecd2a524f9127354087ce780885d83/scipy-1.17.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:979c3a0ff8e5ba254d45d59ebd38cde48fce4f10b5125c680c7a4bfe177aab07", size = 22791051, upload-time = "2026-01-10T21:27:46.539Z" },
{ url = "https://files.pythonhosted.org/packages/9b/9d/025cccdd738a72140efc582b1641d0dd4caf2e86c3fb127568dc80444e6e/scipy-1.17.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:130d12926ae34399d157de777472bf82e9061c60cc081372b3118edacafe1d00", size = 32815098, upload-time = "2026-01-10T21:27:54.389Z" },
{ url = "https://files.pythonhosted.org/packages/48/5f/09b879619f8bca15ce392bfc1894bd9c54377e01d1b3f2f3b595a1b4d945/scipy-1.17.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e886000eb4919eae3a44f035e63f0fd8b651234117e8f6f29bad1cd26e7bc45", size = 35031342, upload-time = "2026-01-10T21:28:03.012Z" },
{ url = "https://files.pythonhosted.org/packages/f2/9a/f0f0a9f0aa079d2f106555b984ff0fbb11a837df280f04f71f056ea9c6e4/scipy-1.17.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:13c4096ac6bc31d706018f06a49abe0485f96499deb82066b94d19b02f664209", size = 34893199, upload-time = "2026-01-10T21:28:10.832Z" },
{ url = "https://files.pythonhosted.org/packages/90/b8/4f0f5cf0c5ea4d7548424e6533e6b17d164f34a6e2fb2e43ffebb6697b06/scipy-1.17.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cacbaddd91fcffde703934897c5cd2c7cb0371fac195d383f4e1f1c5d3f3bd04", size = 37438061, upload-time = "2026-01-10T21:28:19.684Z" },
{ url = "https://files.pythonhosted.org/packages/1a/2d/51006cd369b8e7879e1c630999a19d1fbf6f8b5ed3e33374f29dc87e53b3/scipy-1.17.0-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:c17514d11b78be8f7e6331b983a65a7f5ca1fd037b95e27b280921fe5606286a", size = 31346803, upload-time = "2026-01-10T21:28:57.24Z" },
{ url = "https://files.pythonhosted.org/packages/d6/2e/2349458c3ce445f53a6c93d4386b1c4c5c0c540917304c01222ff95ff317/scipy-1.17.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:4e00562e519c09da34c31685f6acc3aa384d4d50604db0f245c14e1b4488bfa2", size = 27967182, upload-time = "2026-01-10T21:29:04.107Z" },
{ url = "https://files.pythonhosted.org/packages/5e/7c/df525fbfa77b878d1cfe625249529514dc02f4fd5f45f0f6295676a76528/scipy-1.17.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f7df7941d71314e60a481e02d5ebcb3f0185b8d799c70d03d8258f6c80f3d467", size = 20139125, upload-time = "2026-01-10T21:29:10.179Z" },
{ url = "https://files.pythonhosted.org/packages/33/11/fcf9d43a7ed1234d31765ec643b0515a85a30b58eddccc5d5a4d12b5f194/scipy-1.17.0-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:aabf057c632798832f071a8dde013c2e26284043934f53b00489f1773b33527e", size = 22443554, upload-time = "2026-01-10T21:29:15.888Z" },
{ url = "https://files.pythonhosted.org/packages/80/5c/ea5d239cda2dd3d31399424967a24d556cf409fbea7b5b21412b0fd0a44f/scipy-1.17.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a38c3337e00be6fd8a95b4ed66b5d988bac4ec888fd922c2ea9fe5fb1603dd67", size = 32757834, upload-time = "2026-01-10T21:29:23.406Z" },
{ url = "https://files.pythonhosted.org/packages/b8/7e/8c917cc573310e5dc91cbeead76f1b600d3fb17cf0969db02c9cf92e3cfa/scipy-1.17.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00fb5f8ec8398ad90215008d8b6009c9db9fa924fd4c7d6be307c6f945f9cd73", size = 34995775, upload-time = "2026-01-10T21:29:31.915Z" },
{ url = "https://files.pythonhosted.org/packages/c5/43/176c0c3c07b3f7df324e7cdd933d3e2c4898ca202b090bd5ba122f9fe270/scipy-1.17.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f2a4942b0f5f7c23c7cd641a0ca1955e2ae83dedcff537e3a0259096635e186b", size = 34841240, upload-time = "2026-01-10T21:29:39.995Z" },
{ url = "https://files.pythonhosted.org/packages/44/8c/d1f5f4b491160592e7f084d997de53a8e896a3ac01cd07e59f43ca222744/scipy-1.17.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:dbf133ced83889583156566d2bdf7a07ff89228fe0c0cb727f777de92092ec6b", size = 37394463, upload-time = "2026-01-10T21:29:48.723Z" },
{ url = "https://files.pythonhosted.org/packages/e9/01/f58916b9d9ae0112b86d7c3b10b9e685625ce6e8248df139d0fcb17f7397/scipy-1.17.0-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:2b531f57e09c946f56ad0b4a3b2abee778789097871fc541e267d2eca081cff1", size = 31706502, upload-time = "2026-01-10T21:29:56.326Z" },
{ url = "https://files.pythonhosted.org/packages/59/8e/2912a87f94a7d1f8b38aabc0faf74b82d3b6c9e22be991c49979f0eceed8/scipy-1.17.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:13e861634a2c480bd237deb69333ac79ea1941b94568d4b0efa5db5e263d4fd1", size = 28380854, upload-time = "2026-01-10T21:30:01.554Z" },
{ url = "https://files.pythonhosted.org/packages/bd/1c/874137a52dddab7d5d595c1887089a2125d27d0601fce8c0026a24a92a0b/scipy-1.17.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:eb2651271135154aa24f6481cbae5cc8af1f0dd46e6533fb7b56aa9727b6a232", size = 20552752, upload-time = "2026-01-10T21:30:05.93Z" },
{ url = "https://files.pythonhosted.org/packages/3f/f0/7518d171cb735f6400f4576cf70f756d5b419a07fe1867da34e2c2c9c11b/scipy-1.17.0-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:c5e8647f60679790c2f5c76be17e2e9247dc6b98ad0d3b065861e082c56e078d", size = 22803972, upload-time = "2026-01-10T21:30:10.651Z" },
{ url = "https://files.pythonhosted.org/packages/7c/74/3498563a2c619e8a3ebb4d75457486c249b19b5b04a30600dfd9af06bea5/scipy-1.17.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5fb10d17e649e1446410895639f3385fd2bf4c3c7dfc9bea937bddcbc3d7b9ba", size = 32829770, upload-time = "2026-01-10T21:30:16.359Z" },
{ url = "https://files.pythonhosted.org/packages/48/d1/7b50cedd8c6c9d6f706b4b36fa8544d829c712a75e370f763b318e9638c1/scipy-1.17.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8547e7c57f932e7354a2319fab613981cde910631979f74c9b542bb167a8b9db", size = 35051093, upload-time = "2026-01-10T21:30:22.987Z" },
{ url = "https://files.pythonhosted.org/packages/e2/82/a2d684dfddb87ba1b3ea325df7c3293496ee9accb3a19abe9429bce94755/scipy-1.17.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:33af70d040e8af9d5e7a38b5ed3b772adddd281e3062ff23fec49e49681c38cf", size = 34909905, upload-time = "2026-01-10T21:30:28.704Z" },
{ url = "https://files.pythonhosted.org/packages/ef/5e/e565bd73991d42023eb82bb99e51c5b3d9e2c588ca9d4b3e2cc1d3ca62a6/scipy-1.17.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb55bb97d00f8b7ab95cb64f873eb0bf54d9446264d9f3609130381233483f", size = 37457743, upload-time = "2026-01-10T21:30:34.819Z" },
]
[[package]]
name = "tabulate"
version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" },
]
[[package]]
name = "typing-extensions"
version = "4.15.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
]

26
thunder/deserializer.rs Normal file
View File

@ -0,0 +1,26 @@
use crate::schema::{events::Event, tweet_events::TweetEvent};
use anyhow::{Context, Result};
use prost::Message;
use thrift::protocol::{TBinaryInputProtocol, TSerializable};
use xai_thunder_proto::InNetworkEvent;
/// Deserialize a Thrift binary message into TweetEvent
pub fn deserialize_tweet_event(payload: &[u8]) -> Result<TweetEvent> {
let mut cursor = std::io::Cursor::new(payload);
let mut protocol = TBinaryInputProtocol::new(&mut cursor, true);
TweetEvent::read_from_in_protocol(&mut protocol).context("Failed to deserialize TweetEvent")
}
/// Deserialize a Thrift binary message into Event
pub fn deserialize_event(payload: &[u8]) -> Result<Event> {
let mut cursor = std::io::Cursor::new(payload);
let mut protocol = TBinaryInputProtocol::new(&mut cursor, true);
Event::read_from_in_protocol(&mut protocol).context("Failed to deserialize Event")
}
/// Deserialize a proto binary message into InNetworkEvent
pub fn deserialize_tweet_event_v2(payload: &[u8]) -> Result<InNetworkEvent> {
InNetworkEvent::decode(payload).context("Failed to deserialize InNetworkEvent")
}

3
thunder/kafka/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod tweet_events_listener;
pub mod tweet_events_listener_v2;
pub mod utils;

View File

@ -0,0 +1,390 @@
use anyhow::{Context, Result};
use log::{error, info, warn};
use prost::Message;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::RwLock;
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
use xai_kafka::{KafkaProducer, KafkaProducerConfig};
use xai_thunder_proto::{
InNetworkEvent, LightPost, TweetCreateEvent, TweetDeleteEvent, in_network_event,
};
use crate::{
args::Args,
crate::config::MIN_VIDEO_DURATION_MS,
deserializer::deserialize_tweet_event,
kafka::utils::{create_kafka_consumer, deserialize_kafka_messages},
metrics,
schema::{tweet::Tweet, tweet_events::TweetEventData},
};
/// Counter for logging batch processing every Nth time
static BATCH_LOG_COUNTER: AtomicUsize = AtomicUsize::new(0);
/// Monitor Kafka partition lag and update metrics
async fn monitor_partition_lag(
consumer: Arc<RwLock<KafkaConsumer>>,
topic: String,
interval_secs: u64,
) {
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
loop {
interval.tick().await;
let consumer = consumer.read().await;
match consumer.get_partition_lags().await {
Ok(lag_info) => {
for partition_lag in lag_info {
let partition_str = partition_lag.partition_id.to_string();
metrics::KAFKA_PARTITION_LAG
.with_label_values(&[&topic, &partition_str])
.set(partition_lag.lag as f64);
}
}
Err(e) => {
warn!("Failed to get partition lag info: {}", e);
}
}
}
}
fn is_eligible_video(tweet: &Tweet) -> bool {
let Some(media) = tweet.media.as_ref() else {
return false;
};
let [first_media] = media.as_slice() else {
return false;
};
let Some(crate::schema::tweet_media::MediaInfo::VideoInfo(video_info)) =
first_media.media_info.as_ref()
else {
return false;
};
video_info
.duration_millis
.map(|d| d >= MIN_VIDEO_DURATION_MS)
.unwrap_or(false)
}
/// Start the partition lag monitoring task in the background
pub fn start_partition_lag_monitor(
consumer: Arc<RwLock<KafkaConsumer>>,
topic: String,
interval_secs: u64,
) {
tokio::spawn(async move {
info!(
"Starting partition lag monitoring task for topic '{}' (interval: {}s)",
topic, interval_secs
);
monitor_partition_lag(consumer, topic, interval_secs).await;
});
}
/// Start the tweet event processing loop in the background with configurable number of threads
pub async fn start_tweet_event_processing(
base_config: KafkaConsumerConfig,
producer_config: KafkaProducerConfig,
args: &Args,
) {
let num_partitions = args.tweet_events_num_partitions as usize;
let kafka_num_threads = args.kafka_num_threads;
// Use all available partitions
let partitions_to_use: Vec<i32> = (0..num_partitions as i32).collect();
let partitions_per_thread = num_partitions.div_ceil(kafka_num_threads);
info!(
"Starting {} message processing threads for {} partitions ({} partitions per thread)",
kafka_num_threads, num_partitions, partitions_per_thread
);
let producer = if !args.is_serving {
info!("Kafka producer enabled, starting producer...");
let producer = Arc::new(RwLock::new(KafkaProducer::new(producer_config)));
if let Err(e) = producer.write().await.start().await {
panic!("Failed to start Kafka producer: {:#}", e);
}
Some(producer)
} else {
info!("Kafka producer disabled, skipping producer initialization");
None
};
spawn_processing_threads(base_config, partitions_to_use, producer, args);
}
/// Spawn multiple processing threads, each handling a subset of partitions
fn spawn_processing_threads(
base_config: KafkaConsumerConfig,
partitions_to_use: Vec<i32>,
producer: Option<Arc<RwLock<KafkaProducer>>>,
args: &Args,
) {
let total_partitions = partitions_to_use.len();
let partitions_per_thread = total_partitions.div_ceil(args.kafka_num_threads);
for thread_id in 0..args.kafka_num_threads {
let start_idx = thread_id * partitions_per_thread;
let end_idx = ((thread_id + 1) * partitions_per_thread).min(total_partitions);
if start_idx >= total_partitions {
break;
}
let thread_partitions = partitions_to_use[start_idx..end_idx].to_vec();
let mut thread_config = base_config.clone();
thread_config.partitions = Some(thread_partitions.clone());
let producer_clone = producer.as_ref().map(Arc::clone);
let topic = thread_config.base_config.topic.clone();
let lag_monitor_interval_secs = args.lag_monitor_interval_secs;
let batch_size = args.kafka_batch_size;
let post_retention_sec = args.post_retention_seconds;
tokio::spawn(async move {
info!(
"Starting message processing thread {} for partitions {:?}",
thread_id, thread_partitions
);
match create_kafka_consumer(thread_config).await {
Ok(consumer) => {
// Start partition lag monitoring for this thread's partitions
start_partition_lag_monitor(
Arc::clone(&consumer),
topic,
lag_monitor_interval_secs,
);
if let Err(e) = process_tweet_events(
consumer,
batch_size,
producer_clone,
post_retention_sec as i64,
)
.await
{
panic!(
"Tweet events processing thread {} exited unexpectedly: {:#}. This is a critical failure - the feeder cannot function without tweet event processing.",
thread_id, e
);
}
}
Err(e) => {
panic!(
"Failed to create consumer for thread {}: {:#}",
thread_id, e
);
}
}
});
}
}
/// Process a batch of messages: deserialize, extract posts, and store them
async fn process_message_batch(
messages: Vec<KafkaMessage>,
batch_num: usize,
producer: Option<Arc<RwLock<KafkaProducer>>>,
post_retention_sec: i64,
) -> Result<()> {
let results = deserialize_kafka_messages(messages, deserialize_tweet_event)?;
let mut create_tweets = Vec::new();
let mut delete_tweets = Vec::new();
let mut first_post_id = 0;
let mut first_user_id = 0;
let len_posts = results.len();
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
for tweet_event in results {
let data = tweet_event.data.unwrap();
match data {
TweetEventData::TweetCreateEvent(create_event) => {
first_post_id = create_event.tweet.as_ref().unwrap().id.unwrap();
first_user_id = create_event.user.as_ref().unwrap().id.unwrap();
let tweet = create_event.tweet.as_ref().unwrap();
let core_data = tweet.core_data.as_ref().unwrap();
if let Some(nullcast) = core_data.nullcast
&& nullcast
{
continue;
}
create_tweets.push(LightPost {
post_id: tweet.id.unwrap(),
author_id: create_event.user.as_ref().unwrap().id.unwrap(),
created_at: core_data.created_at_secs.unwrap(),
in_reply_to_post_id: core_data
.reply
.as_ref()
.and_then(|r| r.in_reply_to_status_id),
in_reply_to_user_id: core_data
.reply
.as_ref()
.and_then(|r| r.in_reply_to_user_id),
is_retweet: core_data.share.is_some(),
is_reply: core_data.reply.is_some(),
source_post_id: core_data.share.as_ref().and_then(|s| s.source_status_id),
source_user_id: core_data.share.as_ref().and_then(|s| s.source_user_id),
has_video: is_eligible_video(tweet),
conversation_id: core_data.conversation_id,
});
}
TweetEventData::TweetDeleteEvent(delete_event) => {
let created_at_secs = delete_event
.tweet
.as_ref()
.unwrap()
.core_data
.as_ref()
.unwrap()
.created_at_secs
.unwrap();
if now_secs - created_at_secs > post_retention_sec {
continue;
}
delete_tweets.push(delete_event.tweet.as_ref().unwrap().id.unwrap());
}
TweetEventData::QuotedTweetDeleteEvent(delete_event) => {
delete_tweets.push(delete_event.quoting_tweet_id.unwrap());
}
_ => {
log::info!("Other non post creation/deletion event")
}
}
}
// Send each LightPost as an InNetworkEvent to the producer in separate tasks (only if producer is enabled)
if let Some(ref producer) = producer {
let mut send_tasks = Vec::with_capacity(create_tweets.len());
for light_post in &create_tweets {
let event = InNetworkEvent {
event_variant: Some(in_network_event::EventVariant::TweetCreateEvent(
TweetCreateEvent {
post_id: light_post.post_id,
author_id: light_post.author_id,
created_at: light_post.created_at,
in_reply_to_post_id: light_post.in_reply_to_post_id,
in_reply_to_user_id: light_post.in_reply_to_user_id,
is_retweet: light_post.is_retweet,
is_reply: light_post.is_reply,
source_post_id: light_post.source_post_id,
source_user_id: light_post.source_user_id,
has_video: light_post.has_video,
conversation_id: light_post.conversation_id,
},
)),
};
let payload = event.encode_to_vec();
let producer_clone = Arc::clone(producer);
send_tasks.push(tokio::spawn(async move {
let producer_lock = producer_clone.read().await;
if let Err(e) = producer_lock.send(&payload).await {
warn!("Failed to send InNetworkEvent to producer: {:#}", e);
}
}));
}
for post_id in &delete_tweets {
let event = InNetworkEvent {
event_variant: Some(in_network_event::EventVariant::TweetDeleteEvent(
TweetDeleteEvent {
post_id: *post_id,
deleted_at: now_secs,
},
)),
};
let payload = event.encode_to_vec();
let producer_clone = Arc::clone(producer);
send_tasks.push(tokio::spawn(async move {
let producer_lock = producer_clone.read().await;
if let Err(e) = producer_lock.send(&payload).await {
warn!("Failed to send InNetworkEvent to producer: {:#}", e);
}
}));
}
// Wait for all send tasks to complete
for task in send_tasks {
if let Err(e) = task.await {
error!("Error writing to kafka {}", e);
}
}
}
// Log every 100th batch
let batch_count = BATCH_LOG_COUNTER.fetch_add(1, Ordering::Relaxed);
if batch_count.is_multiple_of(1000) {
info!(
"Batch processing milestone: processed {} batches total, latest batch {} had {} posts (first: post_id={}, user_id={})",
batch_count + 1,
batch_num,
len_posts,
first_post_id,
first_user_id
);
}
Ok(())
}
/// Main message processing loop that polls Kafka, batches messages, and stores posts
async fn process_tweet_events(
consumer: Arc<RwLock<KafkaConsumer>>,
batch_size: usize,
producer: Option<Arc<RwLock<KafkaProducer>>>,
post_retention_sec: i64,
) -> Result<()> {
let mut message_buffer = Vec::new();
let mut batch_num = 0;
loop {
let poll_result = {
let mut consumer_lock = consumer.write().await;
consumer_lock.poll(100).await
};
match poll_result {
Ok(messages) => {
message_buffer.extend(messages);
// Process batch when we have enough messages
if message_buffer.len() >= batch_size {
batch_num += 1;
let messages = std::mem::take(&mut message_buffer);
let producer_clone = producer.clone();
// Spawn batch processing in a blocking task
process_message_batch(messages, batch_num, producer_clone, post_retention_sec)
.await
.context("Error processing tweet event batch")?;
consumer.write().await.commit_offsets()?;
}
}
Err(e) => {
warn!("Error polling messages: {:#}", e);
metrics::KAFKA_POLL_ERRORS.inc();
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}

View File

@ -0,0 +1,249 @@
use anyhow::Result;
use log::{info, warn};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, Semaphore};
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
use xai_thunder_proto::{LightPost, TweetDeleteEvent, in_network_event};
use crate::{
args::Args,
deserializer::deserialize_tweet_event_v2,
kafka::utils::{create_kafka_consumer, deserialize_kafka_messages},
metrics,
posts::post_store::PostStore,
};
/// Counter for logging deserialization every Nth time
static DESER_LOG_COUNTER: AtomicUsize = AtomicUsize::new(0);
/// Start the tweet event processing loop in the background with configurable number of threads
pub async fn start_tweet_event_processing_v2(
base_config: KafkaConsumerConfig,
post_store: Arc<PostStore>,
args: &Args,
tx: tokio::sync::mpsc::Sender<i64>,
) {
let num_partitions = args.kafka_tweet_events_v2_num_partitions;
let kafka_num_threads = args.kafka_num_threads;
// Use all available partitions
let partitions_to_use: Vec<i32> = (0..num_partitions as i32).collect();
let partitions_per_thread = num_partitions.div_ceil(kafka_num_threads);
info!(
"Starting {} message processing threads for {} partitions ({} partitions per thread)",
kafka_num_threads, num_partitions, partitions_per_thread
);
spawn_processing_threads_v2(base_config, partitions_to_use, post_store, args, tx);
}
/// Spawn multiple processing threads, each handling a subset of partitions
fn spawn_processing_threads_v2(
base_config: KafkaConsumerConfig,
partitions_to_use: Vec<i32>,
post_store: Arc<PostStore>,
args: &Args,
tx: tokio::sync::mpsc::Sender<i64>,
) {
let total_partitions = partitions_to_use.len();
let partitions_per_thread = total_partitions.div_ceil(args.kafka_num_threads);
// Create shared semaphore to prevent too many tweet_events partition updates at the same time
let semaphore = Arc::new(Semaphore::new(3));
for thread_id in 0..args.kafka_num_threads {
let start_idx = thread_id * partitions_per_thread;
let end_idx = ((thread_id + 1) * partitions_per_thread).min(total_partitions);
if start_idx >= total_partitions {
break;
}
let thread_partitions = partitions_to_use[start_idx..end_idx].to_vec();
let mut thread_config = base_config.clone();
thread_config.partitions = Some(thread_partitions.clone());
let post_store_clone = Arc::clone(&post_store);
let topic = thread_config.base_config.topic.clone();
let lag_monitor_interval_secs = args.lag_monitor_interval_secs;
let batch_size = args.kafka_batch_size;
let tx_clone = tx.clone();
let semaphore_clone = Arc::clone(&semaphore);
tokio::spawn(async move {
info!(
"Starting message processing thread {} for partitions {:?}",
thread_id, thread_partitions
);
match create_kafka_consumer(thread_config).await {
Ok(consumer) => {
// Start partition lag monitoring for this thread's partitions
crate::kafka::tweet_events_listener::start_partition_lag_monitor(
Arc::clone(&consumer),
topic,
lag_monitor_interval_secs,
);
if let Err(e) = process_tweet_events_v2(
consumer,
post_store_clone,
batch_size,
tx_clone,
semaphore_clone,
)
.await
{
panic!(
"Tweet events processing thread {} exited unexpectedly: {:#}. This is a critical failure - the feeder cannot function without tweet event processing.",
thread_id, e
);
}
}
Err(e) => {
panic!(
"Failed to create consumer for thread {}: {:#}",
thread_id, e
);
}
}
});
}
}
/// Process a single batch of messages: deserialize, extract posts, and store them
fn deserialize_batch(
messages: Vec<KafkaMessage>,
) -> Result<(Vec<LightPost>, Vec<TweetDeleteEvent>)> {
let start_time = Instant::now();
let num_messages = messages.len();
let results = deserialize_kafka_messages(messages, deserialize_tweet_event_v2)?;
let deser_elapsed = start_time.elapsed();
if DESER_LOG_COUNTER
.fetch_add(1, Ordering::Relaxed)
.is_multiple_of(1000)
{
info!(
"Deserialized {} messages in {:?} ({:.2} msgs/sec)",
num_messages,
deser_elapsed,
num_messages as f64 / deser_elapsed.as_secs_f64()
);
}
let mut create_tweets = Vec::with_capacity(results.len());
let mut delete_tweets = Vec::with_capacity(10);
for tweet_event in results {
match tweet_event.event_variant.unwrap() {
in_network_event::EventVariant::TweetCreateEvent(create_event) => {
create_tweets.push(LightPost {
post_id: create_event.post_id,
author_id: create_event.author_id,
created_at: create_event.created_at,
in_reply_to_post_id: create_event.in_reply_to_post_id,
in_reply_to_user_id: create_event.in_reply_to_user_id,
is_retweet: create_event.is_retweet,
is_reply: create_event.is_reply
|| create_event.in_reply_to_post_id.is_some()
|| create_event.in_reply_to_user_id.is_some(),
source_post_id: create_event.source_post_id,
source_user_id: create_event.source_user_id,
has_video: create_event.has_video,
conversation_id: create_event.conversation_id,
});
}
in_network_event::EventVariant::TweetDeleteEvent(delete_event) => {
delete_tweets.push(delete_event);
}
}
}
Ok((create_tweets, delete_tweets))
}
/// Main message processing loop that polls Kafka, batches messages, and stores posts
async fn process_tweet_events_v2(
consumer: Arc<RwLock<KafkaConsumer>>,
post_store: Arc<PostStore>,
batch_size: usize,
tx: tokio::sync::mpsc::Sender<i64>,
semaphore: Arc<Semaphore>,
) -> Result<()> {
let mut message_buffer = Vec::new();
let mut batch_count = 0_usize;
let mut init_data_downloaded = false;
loop {
let poll_result = {
let mut consumer_lock = consumer.write().await;
consumer_lock.poll(batch_size).await
};
match poll_result {
Ok(messages) => {
let catchup_sender = if !init_data_downloaded {
let consumer_lock = consumer.read().await;
if let Ok(lags) = consumer_lock.get_partition_lags().await {
let total_lag: i64 = lags.iter().map(|l| l.lag).sum();
if total_lag < (lags.len() * batch_size) as i64 {
init_data_downloaded = true;
Some((tx.clone(), total_lag))
} else {
None
}
} else {
None
}
} else {
None
};
message_buffer.extend(messages);
// Process batch when we have enough messages
if message_buffer.len() >= batch_size {
batch_count += 1;
let messages = std::mem::take(&mut message_buffer);
let post_store_clone = Arc::clone(&post_store);
// Acquire semaphore permit if init data is downloaded to allow enough CPU for serving requests
let permit = if init_data_downloaded {
Some(semaphore.clone().acquire_owned().await.unwrap())
} else {
None
};
// Send batch to blocking thread pool for processing
let _ = tokio::task::spawn_blocking(move || {
let _permit = permit; // Hold permit until task completes
match deserialize_batch(messages) {
Err(e) => warn!("Error processing batch {}: {:#}", batch_count, e),
Ok((light_posts, delete_posts)) => {
post_store_clone.insert_posts(light_posts);
post_store_clone.mark_as_deleted(delete_posts);
}
};
})
.await;
if let Some((sender, lag)) = catchup_sender {
info!("Completed kafka init for a single thread");
if let Err(e) = sender.send(lag).await {
log::error!("error sending {}", e);
}
}
}
}
Err(e) => {
warn!("Error polling messages: {:#}", e);
metrics::KAFKA_POLL_ERRORS.inc();
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}

48
thunder/kafka/utils.rs Normal file
View File

@ -0,0 +1,48 @@
use anyhow::{Context, Result};
use std::sync::Arc;
use tokio::sync::RwLock;
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
use crate::metrics;
/// Create and start a Kafka consumer with the given configuration
pub async fn create_kafka_consumer(
config: KafkaConsumerConfig,
) -> Result<Arc<RwLock<KafkaConsumer>>> {
let mut consumer = KafkaConsumer::new(config);
consumer
.start()
.await
.context("Failed to start Kafka consumer")?;
Ok(Arc::new(RwLock::new(consumer)))
}
/// Process a batch of Kafka messages and deserialize them using the provided deserializer function
pub fn deserialize_kafka_messages<T, F>(
messages: Vec<KafkaMessage>,
deserializer: F,
) -> Result<Vec<T>>
where
F: Fn(&[u8]) -> Result<T>,
{
let _timer = metrics::Timer::new(metrics::BATCH_PROCESSING_TIME.clone());
let mut kafka_data = Vec::with_capacity(messages.len());
for msg in messages.iter() {
if let Some(payload) = &msg.payload {
match deserializer(payload) {
Ok(deserialized_msg) => {
kafka_data.push(deserialized_msg);
}
Err(e) => {
log::error!("Failed to parse Kafka message: {}", e);
metrics::KAFKA_MESSAGES_FAILED_PARSE.inc();
}
}
}
}
Ok(kafka_data)
}

115
thunder/kafka_utils.rs Normal file
View File

@ -0,0 +1,115 @@
use anyhow::{Context, Result};
use std::sync::Arc;
use xai_kafka::KafkaProducerConfig;
use xai_kafka::config::{KafkaConfig, KafkaConsumerConfig, SslConfig};
use xai_wily::WilyConfig;
use crate::{
args,
kafka::{
tweet_events_listener::start_tweet_event_processing,
tweet_events_listener_v2::start_tweet_event_processing_v2,
},
};
const TWEET_EVENT_TOPIC: &str = "";
const TWEET_EVENT_DEST: &str = "";
const IN_NETWORK_EVENTS_DEST: &str = "";
const IN_NETWORK_EVENTS_TOPIC: &str = "";
pub async fn start_kafka(
args: &args::Args,
post_store: Arc<crate::posts::post_store::PostStore>,
user: &str,
tx: tokio::sync::mpsc::Sender<i64>,
) -> Result<()> {
let sasl_password = std::env::var("")
.ok()
.or(args.sasl_password.clone())?;
let producer_sasl_password = std::env::var("")
.ok()
.or(args.producer_sasl_password.clone());
if args.is_serving {
let unique_id = uuid::Uuid::new_v4().to_string();
let v2_tweet_events_consumer_config = KafkaConsumerConfig {
base_config: KafkaConfig {
dest: args.in_network_events_consumer_dest.clone(),
topic: IN_NETWORK_EVENTS_TOPIC.to_string(),
wily_config: Some(WilyConfig::default()),
ssl: Some(SslConfig {
security_protocol: args.security_protocol.clone(),
sasl_mechanism: Some(args.producer_sasl_mechanism.clone()),
sasl_username: Some(args.producer_sasl_username.clone()),
sasl_password: producer_sasl_password.clone(),
}),
..Default::default()
},
group_id: format!("{}-{}", args.kafka_group_id, unique_id),
auto_offset_reset: args.auto_offset_reset.clone(),
fetch_timeout_ms: args.fetch_timeout_ms,
max_partition_fetch_bytes: Some(1024 * 1024 * 100),
skip_to_latest: args.skip_to_latest,
..Default::default()
};
// Start Kafka background tasks
start_tweet_event_processing_v2(
v2_tweet_events_consumer_config,
Arc::clone(&post_store),
args,
tx,
)
.await;
}
// Only start Kafka processing and background tasks if not in serving mode
if !args.is_serving {
// Create Kafka consumer config
let tweet_events_consumer_config = KafkaConsumerConfig {
base_config: KafkaConfig {
dest: TWEET_EVENT_DEST.to_string(),
topic: TWEET_EVENT_TOPIC.to_string(),
wily_config: Some(WilyConfig::default()),
ssl: Some(SslConfig {
security_protocol: args.security_protocol.clone(),
sasl_mechanism: Some(args.sasl_mechanism.clone()),
sasl_username: Some(args.sasl_username.clone()),
sasl_password: Some(sasl_password.clone()),
}),
..Default::default()
},
group_id: format!("{}-{}", args.kafka_group_id, user),
auto_offset_reset: args.auto_offset_reset.clone(),
enable_auto_commit: false,
fetch_timeout_ms: args.fetch_timeout_ms,
max_partition_fetch_bytes: Some(1024 * 1024 * 10),
partitions: None,
skip_to_latest: args.skip_to_latest,
..Default::default()
};
let producer_config = KafkaProducerConfig {
base_config: KafkaConfig {
dest: IN_NETWORK_EVENTS_DEST.to_string(),
topic: IN_NETWORK_EVENTS_TOPIC.to_string(),
wily_config: Some(WilyConfig::default()),
ssl: Some(SslConfig {
security_protocol: args.security_protocol.clone(),
sasl_mechanism: Some(args.producer_sasl_mechanism.clone()),
sasl_username: Some(args.producer_sasl_username.clone()),
sasl_password: producer_sasl_password.clone(),
}),
..Default::default()
},
..Default::default()
};
start_tweet_event_processing(tweet_events_consumer_config, producer_config, args).await;
}
Ok(())
}

11
thunder/lib.rs Normal file
View File

@ -0,0 +1,11 @@
pub mod args;
pub mod config;
pub mod deserializer;
pub mod kafka;
pub mod kafka_utils;
pub mod metrics;
pub mod o2;
pub mod posts;
pub mod schema;
pub mod strato_client;
pub mod thunder_service;

100
thunder/main.rs Normal file
View File

@ -0,0 +1,100 @@
use anyhow::{Context, Result};
use axum::Router;
use clap::Parser;
use log::info;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tonic::service::Routes;
use xai_http_server::{CancellationToken, GrpcConfig, HttpServer};
use thunder::{
args, kafka_utils, posts::post_store::PostStore, strato_client::StratoClient,
thunder_service::ThunderServiceImpl,
};
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
let args = args::Args::parse();
// Initialize PostStore
let post_store = Arc::new(PostStore::new(
args.post_retention_seconds,
args.request_timeout_ms,
));
info!(
"Initialized PostStore for in-memory post storage (retention: {} seconds / {:.1} days, request_timeout: {}ms)",
args.post_retention_seconds,
args.post_retention_seconds as f64 / 86400.0,
args.request_timeout_ms
);
// Initialize StratoClient for fetching following lists
let strato_client = Arc::new(StratoClient::new());
info!("Initialized StratoClient");
// Create ThunderService with the PostStore, StratoClient, and concurrency limit
let thunder_service = ThunderServiceImpl::new(
Arc::clone(&post_store),
Arc::clone(&strato_client),
args.max_concurrent_requests,
);
info!(
"Initialized with max_concurrent_requests={}",
args.max_concurrent_requests
);
let routes = Routes::new(thunder_service.server());
// Set up gRPC config
let grpc_config = GrpcConfig::new(args.grpc_port, routes);
// Create HTTP server with gRPC support
let mut http_server = HttpServer::new(
args.http_port,
Router::new(),
Some(grpc_config),
CancellationToken::new(),
Duration::from_secs(10),
)
.await
.context("Failed to create HTTP server")?;
if args.enable_profiling {
xai_profiling::spawn_server(3000, CancellationToken::new()).await;
}
// Create channel for post events
let (tx, mut rx) = tokio::sync::mpsc::channel::<i64>(args.kafka_num_threads);
kafka_utils::start_kafka(&args, post_store.clone(), "", tx).await?;
if args.is_serving {
// Wait for Kafka catchup signal
let start = Instant::now();
for _ in 0..args.kafka_num_threads {
rx.recv().await;
}
info!("Kafka init took {:?}", start.elapsed());
post_store.finalize_init().await?;
// Start stats logger
Arc::clone(&post_store).start_stats_logger();
info!("Started PostStore stats logger",);
// Start auto-trim task to remove posts older than retention period
Arc::clone(&post_store).start_auto_trim(2); // Run every 2 minutes
info!(
"Started PostStore auto-trim task (interval: 2 minutes, retention: {:.1} days)",
args.post_retention_seconds as f64 / 86400.0
);
}
http_server.set_readiness(true);
info!("HTTP/gRPC server is ready");
// Wait for termination signal
http_server.wait_for_termination().await;
info!("Server terminated");
Ok(())
}

1
thunder/posts/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod post_store;

526
thunder/posts/post_store.rs Normal file
View File

@ -0,0 +1,526 @@
use anyhow::Result;
use dashmap::DashMap;
use log::info;
use std::collections::{HashSet, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use xai_thunder_proto::{LightPost, TweetDeleteEvent};
use crate::config::{
DELETE_EVENT_KEY, MAX_ORIGINAL_POSTS_PER_AUTHOR, MAX_REPLY_POSTS_PER_AUTHOR,
MAX_TINY_POSTS_PER_USER_SCAN, MAX_VIDEO_POSTS_PER_AUTHOR,
};
use crate::metrics::{
POST_STORE_DELETED_POSTS, POST_STORE_DELETED_POSTS_FILTERED, POST_STORE_ENTITY_COUNT,
POST_STORE_POSTS_RETURNED, POST_STORE_POSTS_RETURNED_RATIO, POST_STORE_REQUEST_TIMEOUTS,
POST_STORE_REQUESTS, POST_STORE_TOTAL_POSTS, POST_STORE_USER_COUNT,
};
/// Minimal post reference stored in user timelines (only ID and timestamp)
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TinyPost {
pub post_id: i64,
pub created_at: i64,
}
impl TinyPost {
/// Create a new TinyPost from a post ID and creation timestamp
pub fn new(post_id: i64, created_at: i64) -> Self {
TinyPost {
post_id,
created_at,
}
}
}
/// A thread-safe store for posts grouped by user ID
/// Note: LightPost is now defined in the protobuf schema (in-network.proto)
#[derive(Clone)]
pub struct PostStore {
/// Full post data indexed by post_id
posts: Arc<DashMap<i64, LightPost>>,
/// Maps user_id to a deque of TinyPost references for original posts (non-reply, non-retweet)
original_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
/// Maps user_id to a deque of TinyPost references for replies and retweets
secondary_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
/// Maps user_id to a deque of TinyPost references for video posts
video_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
deleted_posts: Arc<DashMap<i64, bool>>,
/// Retention period for posts in seconds
retention_seconds: u64,
/// Request timeout for get_posts_by_users iteration (0 = no timeout)
request_timeout: Duration,
}
impl PostStore {
/// Creates a new empty PostStore with the specified retention period and request timeout
pub fn new(retention_seconds: u64, request_timeout_ms: u64) -> Self {
PostStore {
posts: Arc::new(DashMap::new()),
original_posts_by_user: Arc::new(DashMap::new()),
secondary_posts_by_user: Arc::new(DashMap::new()),
video_posts_by_user: Arc::new(DashMap::new()),
deleted_posts: Arc::new(DashMap::new()),
retention_seconds,
request_timeout: Duration::from_millis(request_timeout_ms),
}
}
pub fn mark_as_deleted(&self, posts: Vec<TweetDeleteEvent>) {
for post in posts.into_iter() {
self.posts.remove(&post.post_id);
self.deleted_posts.insert(post.post_id, true);
let mut user_posts_entry = self
.original_posts_by_user
.entry(DELETE_EVENT_KEY)
.or_default();
user_posts_entry.push_back(TinyPost {
post_id: post.post_id,
created_at: post.deleted_at,
});
}
}
/// Inserts posts into the post store
pub fn insert_posts(&self, mut posts: Vec<LightPost>) {
// Filter to keep only posts created in the last retention_seconds and not from the future
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
posts.retain(|p| {
p.created_at < current_time
&& current_time - p.created_at <= (self.retention_seconds as i64)
});
// Sort remaining posts by created_at timestamp
posts.sort_unstable_by_key(|p| p.created_at);
Self::insert_posts_internal(self, posts);
}
pub async fn finalize_init(&self) -> Result<()> {
self.sort_all_user_posts().await;
self.trim_old_posts().await;
// This is needed because order of create_event/delete_event can be be lost in the feeder
for entry in self.deleted_posts.iter() {
self.posts.remove(entry.key());
}
Ok(())
}
fn insert_posts_internal(&self, posts: Vec<LightPost>) {
for post in posts {
let post_id = post.post_id;
let author_id = post.author_id;
let created_at = post.created_at;
let is_original = !post.is_reply && !post.is_retweet;
if self.deleted_posts.contains_key(&post_id) {
continue;
}
// Store the full post data
let old = self.posts.insert(post_id, post);
if old.is_some() {
// if already stored - don't add it again
continue;
}
// Create a TinyPost reference for the timeline
let tiny_post = TinyPost::new(post_id, created_at);
// Use entry API to get mutable access to the appropriate user's posts timeline
if is_original {
let mut user_posts_entry =
self.original_posts_by_user.entry(author_id).or_default();
user_posts_entry.push_back(tiny_post.clone());
} else {
let mut user_posts_entry =
self.secondary_posts_by_user.entry(author_id).or_default();
user_posts_entry.push_back(tiny_post.clone());
}
let mut video_eligible = post.has_video;
// If this is a retweet and the retweeted post has video, mark has_video as true
if !video_eligible
&& post.is_retweet
&& let Some(source_post_id) = post.source_post_id
&& let Some(source_post) = self.posts.get(&source_post_id)
{
video_eligible = !source_post.is_reply && source_post.has_video;
}
if post.is_reply {
video_eligible = false;
}
// Also add to video posts timeline if post has video
if video_eligible {
let mut user_posts_entry = self.video_posts_by_user.entry(author_id).or_default();
user_posts_entry.push_back(tiny_post);
}
}
}
/// Retrieves video posts from multiple users
pub fn get_videos_by_users(
&self,
user_ids: &[i64],
exclude_tweet_ids: &HashSet<i64>,
start_time: Instant,
request_user_id: i64,
) -> Vec<LightPost> {
let video_posts = self.get_posts_from_map(
&self.video_posts_by_user,
user_ids,
MAX_VIDEO_POSTS_PER_AUTHOR,
exclude_tweet_ids,
&HashSet::new(),
start_time,
request_user_id,
);
POST_STORE_POSTS_RETURNED.observe(video_posts.len() as f64);
video_posts
}
/// Retrieves all posts from multiple users
pub fn get_all_posts_by_users(
&self,
user_ids: &[i64],
exclude_tweet_ids: &HashSet<i64>,
start_time: Instant,
request_user_id: i64,
) -> Vec<LightPost> {
let following_users_set: HashSet<i64> = user_ids.iter().copied().collect();
let mut all_posts = self.get_posts_from_map(
&self.original_posts_by_user,
user_ids,
MAX_ORIGINAL_POSTS_PER_AUTHOR,
exclude_tweet_ids,
&HashSet::new(),
start_time,
request_user_id,
);
let secondary_posts = self.get_posts_from_map(
&self.secondary_posts_by_user,
user_ids,
MAX_REPLY_POSTS_PER_AUTHOR,
exclude_tweet_ids,
&following_users_set,
start_time,
request_user_id,
);
all_posts.extend(secondary_posts);
POST_STORE_POSTS_RETURNED.observe(all_posts.len() as f64);
all_posts
}
#[allow(clippy::too_many_arguments)]
pub fn get_posts_from_map(
&self,
posts_map: &Arc<DashMap<i64, VecDeque<TinyPost>>>,
user_ids: &[i64],
max_per_user: usize,
exclude_tweet_ids: &HashSet<i64>,
following_users: &HashSet<i64>,
start_time: Instant,
request_user_id: i64,
) -> Vec<LightPost> {
POST_STORE_REQUESTS.inc();
let mut light_posts = Vec::new();
let mut total_eligible: usize = 0;
for (i, user_id) in user_ids.iter().enumerate() {
if !self.request_timeout.is_zero() && start_time.elapsed() >= self.request_timeout {
log::error!(
"Timed out fetching posts for user={}; Processed: {}/{}. Stage: {}",
request_user_id,
i,
user_ids.len(),
if following_users.is_empty() {
"original"
} else {
"secondary"
}
);
POST_STORE_REQUEST_TIMEOUTS.inc();
break;
}
if let Some(user_posts_ref) = posts_map.get(user_id) {
let user_posts = user_posts_ref.value();
total_eligible += user_posts.len();
// Start from newest posts (reverse iterator)
// Take a capped number to prevent from going all the way back to when user is inactive
let tiny_posts_iter = user_posts
.iter()
.rev()
.filter(|post| !exclude_tweet_ids.contains(&post.post_id))
.take(MAX_TINY_POSTS_PER_USER_SCAN);
// Perform light doc lookup to get full LightPost data. This will also filter deleted posts
// Note: We copy the value immediately to release the read lock and avoid potential
// deadlock when acquiring nested read locks while a writer is waiting.
let light_post_iter_1 = tiny_posts_iter
.filter_map(|tiny_post| self.posts.get(&tiny_post.post_id).map(|r| *r.value()));
let light_post_iter = light_post_iter_1.filter(|post| {
if self.deleted_posts.get(&post.post_id).is_some() {
POST_STORE_DELETED_POSTS_FILTERED.inc();
false
} else {
true
}
});
let light_post_iter = light_post_iter.filter(|post| {
!(post.is_retweet && post.source_user_id == Some(request_user_id))
});
let filtered_post_iter = light_post_iter.filter(|post| {
if following_users.is_empty() {
return true;
}
post.in_reply_to_post_id.is_none_or(|reply_to_post_id| {
if let Some(replied_to_post) = self.posts.get(&reply_to_post_id) {
if !replied_to_post.is_retweet && !replied_to_post.is_reply {
return true;
}
return post.conversation_id.is_some_and(|convo_id| {
let reply_to_reply_to_original =
replied_to_post.in_reply_to_post_id == Some(convo_id);
let reply_to_followed_user = post
.in_reply_to_user_id
.map(|uid| following_users.contains(&uid))
.unwrap_or(false);
reply_to_reply_to_original && reply_to_followed_user
});
}
false
})
});
light_posts.extend(filtered_post_iter.take(max_per_user));
}
}
// Track ratio of returned posts to eligible posts
if total_eligible > 0 {
let ratio = light_posts.len() as f64 / total_eligible as f64;
POST_STORE_POSTS_RETURNED_RATIO.observe(ratio);
}
light_posts
}
/// Start a background task that periodically logs PostStore statistics
pub fn start_stats_logger(self: Arc<Self>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
let user_count = self.original_posts_by_user.len();
let total_posts = self.posts.len();
let deleted_posts = self.deleted_posts.len();
// Sum up all VecDeque sizes for each map
let original_posts_count: usize = self
.original_posts_by_user
.iter()
.map(|entry| entry.value().len())
.sum();
let secondary_posts_count: usize = self
.secondary_posts_by_user
.iter()
.map(|entry| entry.value().len())
.sum();
let video_posts_count: usize = self
.video_posts_by_user
.iter()
.map(|entry| entry.value().len())
.sum();
// Update Prometheus gauges
POST_STORE_USER_COUNT.set(user_count as f64);
POST_STORE_TOTAL_POSTS.set(total_posts as f64);
POST_STORE_DELETED_POSTS.set(deleted_posts as f64);
// Update entity count gauge with labels
POST_STORE_ENTITY_COUNT
.with_label_values(&["users"])
.set(user_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["posts"])
.set(total_posts as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["original_posts"])
.set(original_posts_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["secondary_posts"])
.set(secondary_posts_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["video_posts"])
.set(video_posts_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["deleted_posts"])
.set(deleted_posts as f64);
info!(
"PostStore Stats: {} users, {} total posts, {} deleted posts",
user_count, total_posts, deleted_posts
);
}
});
}
/// Start a background task that periodically trims old posts
pub fn start_auto_trim(self: Arc<Self>, interval_minutes: u64) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(interval_minutes * 60));
loop {
interval.tick().await;
let trimmed = self.trim_old_posts().await;
if trimmed > 0 {
info!("Auto-trim: removed {} old posts", trimmed);
}
}
});
}
/// Manually trim posts older than retention period from all users
/// Returns the number of posts trimmed
pub async fn trim_old_posts(&self) -> usize {
let posts_map = Arc::clone(&self.posts);
let original_posts_by_user = Arc::clone(&self.original_posts_by_user);
let secondary_posts_by_user = Arc::clone(&self.secondary_posts_by_user);
let video_posts_by_user = Arc::clone(&self.video_posts_by_user);
let deleted_posts = Arc::clone(&self.deleted_posts);
let retention_seconds = self.retention_seconds;
tokio::task::spawn_blocking(move || {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let mut total_trimmed = 0;
// Helper closure to trim posts from a given map
let trim_map = |posts_by_user: &DashMap<i64, VecDeque<TinyPost>>,
posts_map: &DashMap<i64, LightPost>,
deleted_posts: &DashMap<i64, bool>|
-> usize {
let mut trimmed = 0;
let mut users_to_remove = Vec::new();
for mut entry in posts_by_user.iter_mut() {
let user_id = *entry.key();
let user_posts = entry.value_mut();
while let Some(oldest_post) = user_posts.front() {
if current_time - (oldest_post.created_at as u64) > retention_seconds {
let trimmed_post = user_posts.pop_front().unwrap();
posts_map.remove(&trimmed_post.post_id);
if user_id == DELETE_EVENT_KEY {
deleted_posts.remove(&trimmed_post.post_id);
}
trimmed += 1;
} else {
break;
}
}
if user_posts.capacity() > user_posts.len() * 2 {
let new_cap = user_posts.len() as f32 * 1.5_f32;
user_posts.shrink_to(new_cap as usize);
}
if user_posts.is_empty() {
users_to_remove.push(user_id);
}
}
for user_id in users_to_remove {
posts_by_user.remove_if(&user_id, |_, posts| posts.is_empty());
}
trimmed
};
total_trimmed += trim_map(&original_posts_by_user, &posts_map, &deleted_posts);
total_trimmed += trim_map(&secondary_posts_by_user, &posts_map, &deleted_posts);
trim_map(&video_posts_by_user, &posts_map, &deleted_posts);
total_trimmed
})
.await
.expect("spawn_blocking failed")
}
/// Sorts all user post lists by creation time (newest first)
pub async fn sort_all_user_posts(&self) {
let original_posts_by_user = Arc::clone(&self.original_posts_by_user);
let secondary_posts_by_user = Arc::clone(&self.secondary_posts_by_user);
let video_posts_by_user = Arc::clone(&self.video_posts_by_user);
tokio::task::spawn_blocking(move || {
// Sort original posts
for mut entry in original_posts_by_user.iter_mut() {
let user_posts = entry.value_mut();
user_posts
.make_contiguous()
.sort_unstable_by_key(|a| a.created_at);
}
// Sort secondary posts
for mut entry in secondary_posts_by_user.iter_mut() {
let user_posts = entry.value_mut();
user_posts
.make_contiguous()
.sort_unstable_by_key(|a| a.created_at);
}
// Sort video posts
for mut entry in video_posts_by_user.iter_mut() {
let user_posts = entry.value_mut();
user_posts
.make_contiguous()
.sort_unstable_by_key(|a| a.created_at);
}
})
.await
.expect("spawn_blocking failed");
}
/// Clears all posts from the store
pub fn clear(&self) {
self.posts.clear();
self.original_posts_by_user.clear();
self.secondary_posts_by_user.clear();
self.video_posts_by_user.clear();
info!("PostStore cleared");
}
}
impl Default for PostStore {
fn default() -> Self {
// Default to 2 days retention, no timeout
Self::new(2 * 24 * 60 * 60, 0)
}
}

339
thunder/thunder_service.rs Normal file
View File

@ -0,0 +1,339 @@
use lazy_static::lazy_static;
use log::{debug, info, warn};
use std::cmp::Reverse;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::Semaphore;
use tonic::{Request, Response, Status};
use xai_thunder_proto::{
GetInNetworkPostsRequest, GetInNetworkPostsResponse, LightPost,
in_network_posts_service_server::{InNetworkPostsService, InNetworkPostsServiceServer},
};
use crate::config::{
MAX_INPUT_LIST_SIZE, MAX_POSTS_TO_RETURN, MAX_VIDEOS_TO_RETURN,
};
use crate::metrics::{
GET_IN_NETWORK_POSTS_COUNT, GET_IN_NETWORK_POSTS_DURATION,
GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO, GET_IN_NETWORK_POSTS_EXCLUDED_SIZE,
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE, GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS,
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR, GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO,
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS, GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS,
GET_IN_NETWORK_POSTS_MAX_RESULTS, IN_FLIGHT_REQUESTS, REJECTED_REQUESTS, Timer,
};
use crate::posts::post_store::PostStore;
use crate::strato_client::StratoClient;
pub struct ThunderServiceImpl {
/// PostStore for retrieving posts by user ID
post_store: Arc<PostStore>,
/// StratoClient for fetching following lists when not provided
strato_client: Arc<StratoClient>,
/// Semaphore to limit concurrent requests and prevent overload
request_semaphore: Arc<Semaphore>,
}
impl ThunderServiceImpl {
pub fn new(
post_store: Arc<PostStore>,
strato_client: Arc<StratoClient>,
max_concurrent_requests: usize,
) -> Self {
info!(
"Initializing ThunderService with max_concurrent_requests={}",
max_concurrent_requests
);
Self {
post_store,
strato_client,
request_semaphore: Arc::new(Semaphore::new(max_concurrent_requests)),
}
}
/// Create a gRPC server for this service
pub fn server(self) -> InNetworkPostsServiceServer<Self> {
InNetworkPostsServiceServer::new(self)
.accept_compressed(tonic::codec::CompressionEncoding::Zstd)
.send_compressed(tonic::codec::CompressionEncoding::Zstd)
}
/// Analyze found posts, calculate statistics, and report metrics
/// The `stage` parameter is used as a label to differentiate between stages (e.g., "post_store", "scored")
fn analyze_and_report_post_statistics(posts: &[LightPost], stage: &str) {
if posts.is_empty() {
debug!("[{}] No posts found for analysis", stage);
return;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
// Time since most recent post
let time_since_most_recent = posts
.iter()
.map(|post| post.created_at)
.max()
.map(|most_recent| now - most_recent);
// Time since oldest post
let time_since_oldest = posts
.iter()
.map(|post| post.created_at)
.min()
.map(|oldest| now - oldest);
// Count replies vs original posts
let reply_count = posts.iter().filter(|post| post.is_reply).count();
let original_count = posts.len() - reply_count;
// Unique authors
let unique_authors: HashSet<_> = posts.iter().map(|post| post.author_id).collect();
let unique_author_count = unique_authors.len();
// Report metrics with stage label
if let Some(freshness) = time_since_most_recent {
GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS
.with_label_values(&[stage])
.observe(freshness as f64);
}
if let (Some(oldest), Some(newest)) = (time_since_oldest, time_since_most_recent) {
let time_range = oldest - newest;
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS
.with_label_values(&[stage])
.observe(time_range as f64);
}
let reply_ratio = reply_count as f64 / posts.len() as f64;
GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO
.with_label_values(&[stage])
.observe(reply_ratio);
GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS
.with_label_values(&[stage])
.observe(unique_author_count as f64);
if unique_author_count > 0 {
let posts_per_author = posts.len() as f64 / unique_author_count as f64;
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR
.with_label_values(&[stage])
.observe(posts_per_author);
}
// Log statistics with stage label
debug!(
"[{}] Post statistics: total={}, original={}, replies={}, unique_authors={}, posts_per_author={:.2}, reply_ratio={:.2}, time_since_most_recent={:?}s, time_range={:?}s",
stage,
posts.len(),
original_count,
reply_count,
unique_author_count,
if unique_author_count > 0 {
posts.len() as f64 / unique_author_count as f64
} else {
0.0
},
reply_ratio,
time_since_most_recent,
if let (Some(o), Some(n)) = (time_since_oldest, time_since_most_recent) {
Some(o - n)
} else {
None
}
);
}
}
#[tonic::async_trait]
impl InNetworkPostsService for ThunderServiceImpl {
/// Get posts from users in the network
async fn get_in_network_posts(
&self,
request: Request<GetInNetworkPostsRequest>,
) -> Result<Response<GetInNetworkPostsResponse>, Status> {
// Try to acquire semaphore permit without blocking
// If we're at capacity, reject immediately with RESOURCE_EXHAUSTED
let _permit = match self.request_semaphore.try_acquire() {
Ok(permit) => {
IN_FLIGHT_REQUESTS.inc();
permit
}
Err(_) => {
REJECTED_REQUESTS.inc();
return Err(Status::resource_exhausted(
"Server at capacity, please retry",
));
}
};
// Use a guard to decrement in_flight_requests when the request completes
struct InFlightGuard;
impl Drop for InFlightGuard {
fn drop(&mut self) {
IN_FLIGHT_REQUESTS.dec();
}
}
let _in_flight_guard = InFlightGuard;
// Start timer for total latency
let _total_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION.clone());
let req = request.into_inner();
if req.debug {
info!(
"Received GetInNetworkPosts request: user_id={}, following_count={}, exclude_tweet_ids={}",
req.user_id,
req.following_user_ids.len(),
req.exclude_tweet_ids.len(),
);
}
// If following_user_id list is empty, fetch it from Strato
let following_user_ids = if req.following_user_ids.is_empty() && req.debug {
info!(
"Following list is empty, fetching from Strato for user {}",
req.user_id
);
match self
.strato_client
.fetch_following_list(req.user_id as i64, MAX_INPUT_LIST_SIZE as i32)
.await
{
Ok(following_list) => {
info!(
"Fetched {} following users from Strato for user {}",
following_list.len(),
req.user_id
);
following_list.into_iter().map(|id| id as u64).collect()
}
Err(e) => {
warn!(
"Failed to fetch following list from Strato for user {}: {}",
req.user_id, e
);
return Err(Status::internal(format!(
"Failed to fetch following list: {}",
e
)));
}
}
} else {
req.following_user_ids
};
// Record metrics for request parameters
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE.observe(following_user_ids.len() as f64);
GET_IN_NETWORK_POSTS_EXCLUDED_SIZE.observe(req.exclude_tweet_ids.len() as f64);
// Start timer for latency without strato call
let _processing_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO.clone());
// Default max_results if not specified
let max_results = if req.max_results > 0 {
req.max_results as usize
} else if req.is_video_request {
MAX_VIDEOS_TO_RETURN
} else {
MAX_POSTS_TO_RETURN
};
GET_IN_NETWORK_POSTS_MAX_RESULTS.observe(max_results as f64);
// Limit following_user_ids and exclude_tweet_ids to first K entries
let following_count = following_user_ids.len();
if following_count > MAX_INPUT_LIST_SIZE {
warn!(
"Limiting following_user_ids from {} to {} entries for user {}",
following_count, MAX_INPUT_LIST_SIZE, req.user_id
);
}
let following_user_ids: Vec<u64> = following_user_ids
.into_iter()
.take(MAX_INPUT_LIST_SIZE)
.collect();
let exclude_count = req.exclude_tweet_ids.len();
if exclude_count > MAX_INPUT_LIST_SIZE {
warn!(
"Limiting exclude_tweet_ids from {} to {} entries for user {}",
exclude_count, MAX_INPUT_LIST_SIZE, req.user_id
);
}
let exclude_tweet_ids: Vec<u64> = req
.exclude_tweet_ids
.into_iter()
.take(MAX_INPUT_LIST_SIZE)
.collect();
// Clone Arc references needed inside spawn_blocking
let post_store = Arc::clone(&self.post_store);
let request_user_id = req.user_id as i64;
// Use spawn_blocking to avoid blocking tokio's async runtime
let proto_posts = tokio::task::spawn_blocking(move || {
// Create exclude tweet IDs set for efficient filtering of previously seen posts
let exclude_tweet_ids: HashSet<i64> =
exclude_tweet_ids.iter().map(|&id| id as i64).collect();
let start_time = Instant::now();
// Fetch all posts (original + secondary) for the followed users
let all_posts: Vec<LightPost> = if req.is_video_request {
post_store.get_videos_by_users(
&following_user_ids,
&exclude_tweet_ids,
start_time,
request_user_id,
)
} else {
post_store.get_all_posts_by_users(
&following_user_ids,
&exclude_tweet_ids,
start_time,
request_user_id,
)
};
// Analyze posts and report statistics after querying post_store
ThunderServiceImpl::analyze_and_report_post_statistics(&all_posts, "retrieved");
let scored_posts = score_recent(all_posts, max_results);
// Analyze posts and report statistics after scoring
ThunderServiceImpl::analyze_and_report_post_statistics(&scored_posts, "scored");
scored_posts
})
.await
.map_err(|e| Status::internal(format!("Failed to process posts: {}", e)))?;
if req.debug {
info!(
"Returning {} posts for user {}",
proto_posts.len(),
req.user_id
);
}
// Record the number of posts returned
GET_IN_NETWORK_POSTS_COUNT.observe(proto_posts.len() as f64);
let response = GetInNetworkPostsResponse { posts: proto_posts };
Ok(Response::new(response))
}
}
/// Score posts by recency (created_at timestamp, newer posts first)
fn score_recent(mut light_posts: Vec<LightPost>, max_results: usize) -> Vec<LightPost> {
light_posts.sort_unstable_by_key(|post| Reverse(post.created_at));
// Limit to max results
light_posts.into_iter().take(max_results).collect()
}