mirror of
https://github.com/xai-org/x-algorithm.git
synced 2026-02-13 03:05:06 +01:00
Open-source X Recommendation Algorithm
This commit is contained in:
commit
aaa167b3de
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__/
|
||||||
1
CODE_OF_CONDUCT.md
Normal file
1
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
Be excellent to each other.
|
||||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
325
README.md
Normal file
325
README.md
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
# X For You Feed Algorithm
|
||||||
|
|
||||||
|
This repository contains the core recommendation system powering the "For You" feed on X. It combines in-network content (from accounts you follow) with out-of-network content (discovered through ML-based retrieval) and ranks everything using a Grok-based transformer model.
|
||||||
|
|
||||||
|
> **Note:** The transformer implementation is ported from the [Grok-1 open source release](https://github.com/xai-org/grok-1) by xAI, adapted for recommendation system use cases.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [System Architecture](#system-architecture)
|
||||||
|
- [Components](#components)
|
||||||
|
- [Home Mixer](#home-mixer)
|
||||||
|
- [Thunder](#thunder)
|
||||||
|
- [Phoenix](#phoenix)
|
||||||
|
- [Candidate Pipeline](#candidate-pipeline)
|
||||||
|
- [How It Works](#how-it-works)
|
||||||
|
- [Pipeline Stages](#pipeline-stages)
|
||||||
|
- [Scoring and Ranking](#scoring-and-ranking)
|
||||||
|
- [Filtering](#filtering)
|
||||||
|
- [Key Design Decisions](#key-design-decisions)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The For You feed algorithm retrieves, ranks, and filters posts from two sources:
|
||||||
|
|
||||||
|
1. **In-Network (Thunder)**: Posts from accounts you follow
|
||||||
|
2. **Out-of-Network (Phoenix Retrieval)**: Posts discovered from a global corpus
|
||||||
|
|
||||||
|
Both sources are combined and ranked together using **Phoenix**, a Grok-based transformer model that predicts engagement probabilities for each post. The final score is a weighted combination of these predicted engagements.
|
||||||
|
|
||||||
|
We have eliminated every single hand-engineered feature and most heuristics from the system. The Grok-based transformer does all the heavy lifting by understanding your engagement history (what you liked, replied to, shared, etc.) and using that to determine what content is relevant to you.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## System Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ FOR YOU FEED REQUEST │
|
||||||
|
└─────────────────────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ HOME MIXER │
|
||||||
|
│ (Orchestration Layer) │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ QUERY HYDRATION │ │
|
||||||
|
│ │ ┌──────────────────────────┐ ┌──────────────────────────────────────────────┐ │ │
|
||||||
|
│ │ │ User Action Sequence │ │ User Features │ │ │
|
||||||
|
│ │ │ (engagement history) │ │ (following list, preferences, etc.) │ │ │
|
||||||
|
│ │ └──────────────────────────┘ └──────────────────────────────────────────────┘ │ │
|
||||||
|
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ CANDIDATE SOURCES │ │
|
||||||
|
│ │ ┌─────────────────────────────┐ ┌────────────────────────────────┐ │ │
|
||||||
|
│ │ │ THUNDER │ │ PHOENIX RETRIEVAL │ │ │
|
||||||
|
│ │ │ (In-Network Posts) │ │ (Out-of-Network Posts) │ │ │
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
│ │ │ Posts from accounts │ │ ML-based similarity search │ │ │
|
||||||
|
│ │ │ you follow │ │ across global corpus │ │ │
|
||||||
|
│ │ └─────────────────────────────┘ └────────────────────────────────┘ │ │
|
||||||
|
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ HYDRATION │ │
|
||||||
|
│ │ Fetch additional data: core post metadata, author info, media entities, etc. │ │
|
||||||
|
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ FILTERING │ │
|
||||||
|
│ │ Remove: duplicates, old posts, self-posts, blocked authors, muted keywords, etc. │ │
|
||||||
|
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ SCORING │ │
|
||||||
|
│ │ ┌──────────────────────────┐ │ │
|
||||||
|
│ │ │ Phoenix Scorer │ Grok-based Transformer predicts: │ │
|
||||||
|
│ │ │ (ML Predictions) │ P(like), P(reply), P(repost), P(click)... │ │
|
||||||
|
│ │ └──────────────────────────┘ │ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ │ ▼ │ │
|
||||||
|
│ │ ┌──────────────────────────┐ │ │
|
||||||
|
│ │ │ Weighted Scorer │ Weighted Score = Σ (weight × P(action)) │ │
|
||||||
|
│ │ │ (Combine predictions) │ │ │
|
||||||
|
│ │ └──────────────────────────┘ │ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ │ ▼ │ │
|
||||||
|
│ │ ┌──────────────────────────┐ │ │
|
||||||
|
│ │ │ Author Diversity │ Attenuate repeated author scores │ │
|
||||||
|
│ │ │ Scorer │ to ensure feed diversity │ │
|
||||||
|
│ │ └──────────────────────────┘ │ │
|
||||||
|
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ SELECTION │ │
|
||||||
|
│ │ Sort by final score, select top K candidates │ │
|
||||||
|
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ FILTERING (Post-Selection) │ │
|
||||||
|
│ │ Visibility filtering (deleted/spam/violence/gore etc) │ │
|
||||||
|
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │
|
||||||
|
└─────────────────────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ RANKED FEED RESPONSE │
|
||||||
|
└─────────────────────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### Home Mixer
|
||||||
|
|
||||||
|
**Location:** [`home-mixer/`](home-mixer/)
|
||||||
|
|
||||||
|
The orchestration layer that assembles the For You feed. It leverages the `CandidatePipeline` framework with the following stages:
|
||||||
|
|
||||||
|
| Stage | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| Query Hydrators | Fetch user context (engagement history, following list) |
|
||||||
|
| Sources | Retrieve candidates from Thunder and Phoenix |
|
||||||
|
| Hydrators | Enrich candidates with additional data |
|
||||||
|
| Filters | Remove ineligible candidates |
|
||||||
|
| Scorers | Predict engagement and compute final scores |
|
||||||
|
| Selector | Sort by score and select top K |
|
||||||
|
| Post-Selection Filters | Final visibility and dedup checks |
|
||||||
|
| Side Effects | Cache request info for future use |
|
||||||
|
|
||||||
|
The server exposes a gRPC endpoint (`ScoredPostsService`) that returns ranked posts for a given user.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Thunder
|
||||||
|
|
||||||
|
**Location:** [`thunder/`](thunder/)
|
||||||
|
|
||||||
|
An in-memory post store and realtime ingestion pipeline that tracks recent posts from all users. It:
|
||||||
|
|
||||||
|
- Consumes post create/delete events from Kafka
|
||||||
|
- Maintains per-user stores for original posts, replies/reposts, and video posts
|
||||||
|
- Serves "in-network" post candidates from accounts the requesting user follows
|
||||||
|
- Automatically trims posts older than the retention period
|
||||||
|
|
||||||
|
Thunder enables sub-millisecond lookups for in-network content without hitting an external database.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phoenix
|
||||||
|
|
||||||
|
**Location:** [`phoenix/`](phoenix/)
|
||||||
|
|
||||||
|
The ML component with two main functions:
|
||||||
|
|
||||||
|
#### 1. Retrieval (Two-Tower Model)
|
||||||
|
Finds relevant out-of-network posts:
|
||||||
|
- **User Tower**: Encodes user features and engagement history into an embedding
|
||||||
|
- **Candidate Tower**: Encodes all posts into embeddings
|
||||||
|
- **Similarity Search**: Retrieves top-K posts via dot product similarity
|
||||||
|
|
||||||
|
#### 2. Ranking (Transformer with Candidate Isolation)
|
||||||
|
Predicts engagement probabilities for each candidate:
|
||||||
|
- Takes user context (engagement history) and candidate posts as input
|
||||||
|
- Uses special attention masking so candidates cannot attend to each other
|
||||||
|
- Outputs probabilities for each action type (like, reply, repost, click, etc.)
|
||||||
|
|
||||||
|
See [`phoenix/README.md`](phoenix/README.md) for detailed architecture documentation.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Candidate Pipeline
|
||||||
|
|
||||||
|
**Location:** [`candidate-pipeline/`](candidate-pipeline/)
|
||||||
|
|
||||||
|
A reusable framework for building recommendation pipelines. Defines traits for:
|
||||||
|
|
||||||
|
| Trait | Purpose |
|
||||||
|
|-------|---------|
|
||||||
|
| `Source` | Fetch candidates from a data source |
|
||||||
|
| `Hydrator` | Enrich candidates with additional features |
|
||||||
|
| `Filter` | Remove candidates that shouldn't be shown |
|
||||||
|
| `Scorer` | Compute scores for ranking |
|
||||||
|
| `Selector` | Sort and select top candidates |
|
||||||
|
| `SideEffect` | Run async side effects (caching, logging) |
|
||||||
|
|
||||||
|
The framework runs sources and hydrators in parallel where possible, with configurable error handling and logging.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Pipeline Stages
|
||||||
|
|
||||||
|
1. **Query Hydration**: Fetch the user's recent engagements history and metadata (eg. following list)
|
||||||
|
|
||||||
|
2. **Candidate Sourcing**: Retrieve candidates from:
|
||||||
|
- **Thunder**: Recent posts from followed accounts (in-network)
|
||||||
|
- **Phoenix Retrieval**: ML-discovered posts from the global corpus (out-of-network)
|
||||||
|
|
||||||
|
3. **Candidate Hydration**: Enrich candidates with:
|
||||||
|
- Core post data (text, media, etc.)
|
||||||
|
- Author information (username, verification status)
|
||||||
|
- Video duration (for video posts)
|
||||||
|
- Subscription status
|
||||||
|
|
||||||
|
4. **Pre-Scoring Filters**: Remove posts that are:
|
||||||
|
- Duplicates
|
||||||
|
- Too old
|
||||||
|
- From the viewer themselves
|
||||||
|
- From blocked/muted accounts
|
||||||
|
- Containing muted keywords
|
||||||
|
- Previously seen or recently served
|
||||||
|
- Ineligible subscription content
|
||||||
|
|
||||||
|
5. **Scoring**: Apply multiple scorers sequentially:
|
||||||
|
- **Phoenix Scorer**: Get ML predictions from the Phoenix transformer model
|
||||||
|
- **Weighted Scorer**: Combine predictions into a final relevance score
|
||||||
|
- **Author Diversity Scorer**: Attenuate repeated author scores for diversity
|
||||||
|
- **OON Scorer**: Adjust scores for out-of-network content
|
||||||
|
|
||||||
|
6. **Selection**: Sort by score and select the top K candidates
|
||||||
|
|
||||||
|
7. **Post-Selection Processing**: Final validation of post candidates to be served
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Scoring and Ranking
|
||||||
|
|
||||||
|
The Phoenix Grok-based transformer model predicts probabilities for multiple engagement types:
|
||||||
|
|
||||||
|
```
|
||||||
|
Predictions:
|
||||||
|
├── P(favorite)
|
||||||
|
├── P(reply)
|
||||||
|
├── P(repost)
|
||||||
|
├── P(quote)
|
||||||
|
├── P(click)
|
||||||
|
├── P(profile_click)
|
||||||
|
├── P(video_view)
|
||||||
|
├── P(photo_expand)
|
||||||
|
├── P(share)
|
||||||
|
├── P(dwell)
|
||||||
|
├── P(follow_author)
|
||||||
|
├── P(not_interested)
|
||||||
|
├── P(block_author)
|
||||||
|
├── P(mute_author)
|
||||||
|
└── P(report)
|
||||||
|
```
|
||||||
|
|
||||||
|
The **Weighted Scorer** combines these into a final score:
|
||||||
|
|
||||||
|
```
|
||||||
|
Final Score = Σ (weight_i × P(action_i))
|
||||||
|
```
|
||||||
|
|
||||||
|
Positive actions (like, repost, share) have positive weights. Negative actions (block, mute, report) have negative weights, pushing down content the user would likely dislike.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Filtering
|
||||||
|
|
||||||
|
Filters run at two stages:
|
||||||
|
|
||||||
|
**Pre-Scoring Filters:**
|
||||||
|
| Filter | Purpose |
|
||||||
|
|--------|---------|
|
||||||
|
| `DropDuplicatesFilter` | Remove duplicate post IDs |
|
||||||
|
| `CoreDataHydrationFilter` | Remove posts that failed to hydrate core metadata |
|
||||||
|
| `AgeFilter` | Remove posts older than threshold |
|
||||||
|
| `SelfpostFilter` | Remove user's own posts |
|
||||||
|
| `RepostDeduplicationFilter` | Dedupe reposts of same content |
|
||||||
|
| `IneligibleSubscriptionFilter` | Remove paywalled content user can't access |
|
||||||
|
| `PreviouslySeenPostsFilter` | Remove posts user has already seen |
|
||||||
|
| `PreviouslyServedPostsFilter` | Remove posts already served in session |
|
||||||
|
| `MutedKeywordFilter` | Remove posts with user's muted keywords |
|
||||||
|
| `AuthorSocialgraphFilter` | Remove posts from blocked/muted authors |
|
||||||
|
|
||||||
|
**Post-Selection Filters:**
|
||||||
|
| Filter | Purpose |
|
||||||
|
|--------|---------|
|
||||||
|
| `VFFilter` | Remove posts that are deleted/spam/violence/gore etc. |
|
||||||
|
| `DedupConversationFilter` | Deduplicate multiple branches of the same conversation thread |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Design Decisions
|
||||||
|
|
||||||
|
### 1. No Hand-Engineered Features
|
||||||
|
The system relies entirely on the Grok-based transformer to learn relevance from user engagement sequences. No manual feature engineering for content relevance. This significantly reduces the complexity in our data pipelines and serving infrastructure.
|
||||||
|
|
||||||
|
### 2. Candidate Isolation in Ranking
|
||||||
|
During transformer inference, candidates cannot attend to each other—only to the user context. This ensures the score for a post doesn't depend on which other posts are in the batch, making scores consistent and cacheable.
|
||||||
|
|
||||||
|
### 3. Hash-Based Embeddings
|
||||||
|
Both retrieval and ranking use multiple hash functions for embedding lookup
|
||||||
|
|
||||||
|
### 4. Multi-Action Prediction
|
||||||
|
Rather than predicting a single "relevance" score, the model predicts probabilities for many actions.
|
||||||
|
|
||||||
|
### 5. Composable Pipeline Architecture
|
||||||
|
The `candidate-pipeline` crate provides a flexible framework for building recommendation pipelines with:
|
||||||
|
- Separation of pipeline execution and monitoring from business logic
|
||||||
|
- Parallel execution of independent stages and graceful error handling
|
||||||
|
- Easy addition of new sources, hydrations, filters, and scorers
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.
|
||||||
329
candidate-pipeline/candidate_pipeline.rs
Normal file
329
candidate-pipeline/candidate_pipeline.rs
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
use crate::filter::Filter;
|
||||||
|
use crate::hydrator::Hydrator;
|
||||||
|
use crate::query_hydrator::QueryHydrator;
|
||||||
|
use crate::scorer::Scorer;
|
||||||
|
use crate::selector::Selector;
|
||||||
|
use crate::side_effect::{SideEffect, SideEffectInput};
|
||||||
|
use crate::source::Source;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use log::{error, info, warn};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
pub enum PipelineStage {
|
||||||
|
QueryHydrator,
|
||||||
|
Source,
|
||||||
|
Hydrator,
|
||||||
|
PostSelectionHydrator,
|
||||||
|
Filter,
|
||||||
|
PostSelectionFilter,
|
||||||
|
Scorer,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PipelineResult<Q, C> {
|
||||||
|
pub retrieved_candidates: Vec<C>,
|
||||||
|
pub filtered_candidates: Vec<C>,
|
||||||
|
pub selected_candidates: Vec<C>,
|
||||||
|
pub query: Arc<Q>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides a stable request identifier for logging/tracing.
|
||||||
|
pub trait HasRequestId {
|
||||||
|
fn request_id(&self) -> &str;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait CandidatePipeline<Q, C>: Send + Sync
|
||||||
|
where
|
||||||
|
Q: HasRequestId + Clone + Send + Sync + 'static,
|
||||||
|
C: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>];
|
||||||
|
fn sources(&self) -> &[Box<dyn Source<Q, C>>];
|
||||||
|
fn hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
|
||||||
|
fn filters(&self) -> &[Box<dyn Filter<Q, C>>];
|
||||||
|
fn scorers(&self) -> &[Box<dyn Scorer<Q, C>>];
|
||||||
|
fn selector(&self) -> &dyn Selector<Q, C>;
|
||||||
|
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
|
||||||
|
fn post_selection_filters(&self) -> &[Box<dyn Filter<Q, C>>];
|
||||||
|
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<Q, C>>>>;
|
||||||
|
fn result_size(&self) -> usize;
|
||||||
|
|
||||||
|
async fn execute(&self, query: Q) -> PipelineResult<Q, C> {
|
||||||
|
let hydrated_query = self.hydrate_query(query).await;
|
||||||
|
|
||||||
|
let candidates = self.fetch_candidates(&hydrated_query).await;
|
||||||
|
|
||||||
|
let hydrated_candidates = self.hydrate(&hydrated_query, candidates).await;
|
||||||
|
|
||||||
|
let (kept_candidates, mut filtered_candidates) = self
|
||||||
|
.filter(&hydrated_query, hydrated_candidates.clone())
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let scored_candidates = self.score(&hydrated_query, kept_candidates).await;
|
||||||
|
|
||||||
|
let selected_candidates = self.select(&hydrated_query, scored_candidates);
|
||||||
|
|
||||||
|
let post_selection_hydrated_candidates = self
|
||||||
|
.hydrate_post_selection(&hydrated_query, selected_candidates)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (mut final_candidates, post_selection_filtered_candidates) = self
|
||||||
|
.filter_post_selection(&hydrated_query, post_selection_hydrated_candidates)
|
||||||
|
.await;
|
||||||
|
filtered_candidates.extend(post_selection_filtered_candidates);
|
||||||
|
|
||||||
|
final_candidates.truncate(self.result_size());
|
||||||
|
|
||||||
|
let arc_hydrated_query = Arc::new(hydrated_query);
|
||||||
|
let input = Arc::new(SideEffectInput {
|
||||||
|
query: arc_hydrated_query.clone(),
|
||||||
|
selected_candidates: final_candidates.clone(),
|
||||||
|
});
|
||||||
|
self.run_side_effects(input);
|
||||||
|
|
||||||
|
PipelineResult {
|
||||||
|
retrieved_candidates: hydrated_candidates,
|
||||||
|
filtered_candidates,
|
||||||
|
selected_candidates: final_candidates,
|
||||||
|
query: arc_hydrated_query,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all query hydrators in parallel and merge results into the query.
|
||||||
|
async fn hydrate_query(&self, query: Q) -> Q {
|
||||||
|
let request_id = query.request_id().to_string();
|
||||||
|
let hydrators: Vec<_> = self
|
||||||
|
.query_hydrators()
|
||||||
|
.iter()
|
||||||
|
.filter(|h| h.enable(&query))
|
||||||
|
.collect();
|
||||||
|
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(&query));
|
||||||
|
let results = join_all(hydrate_futures).await;
|
||||||
|
|
||||||
|
let mut hydrated_query = query;
|
||||||
|
for (hydrator, result) in hydrators.iter().zip(results) {
|
||||||
|
match result {
|
||||||
|
Ok(hydrated) => {
|
||||||
|
hydrator.update(&mut hydrated_query, hydrated);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!(
|
||||||
|
"request_id={} stage={:?} component={} failed: {}",
|
||||||
|
request_id,
|
||||||
|
PipelineStage::QueryHydrator,
|
||||||
|
hydrator.name(),
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hydrated_query
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all candidate sources in parallel and collect results.
|
||||||
|
async fn fetch_candidates(&self, query: &Q) -> Vec<C> {
|
||||||
|
let request_id = query.request_id().to_string();
|
||||||
|
let sources: Vec<_> = self.sources().iter().filter(|s| s.enable(query)).collect();
|
||||||
|
let source_futures = sources.iter().map(|s| s.get_candidates(query));
|
||||||
|
let results = join_all(source_futures).await;
|
||||||
|
|
||||||
|
let mut collected = Vec::new();
|
||||||
|
for (source, result) in sources.iter().zip(results) {
|
||||||
|
match result {
|
||||||
|
Ok(mut candidates) => {
|
||||||
|
info!(
|
||||||
|
"request_id={} stage={:?} component={} fetched {} candidates",
|
||||||
|
request_id,
|
||||||
|
PipelineStage::Source,
|
||||||
|
source.name(),
|
||||||
|
candidates.len()
|
||||||
|
);
|
||||||
|
collected.append(&mut candidates);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!(
|
||||||
|
"request_id={} stage={:?} component={} failed: {}",
|
||||||
|
request_id,
|
||||||
|
PipelineStage::Source,
|
||||||
|
source.name(),
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
collected
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all candidate hydrators in parallel and merge results into candidates.
|
||||||
|
async fn hydrate(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||||
|
self.run_hydrators(query, candidates, self.hydrators(), PipelineStage::Hydrator)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run post-selection candidate hydrators in parallel and merge results into candidates.
|
||||||
|
async fn hydrate_post_selection(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||||
|
self.run_hydrators(
|
||||||
|
query,
|
||||||
|
candidates,
|
||||||
|
self.post_selection_hydrators(),
|
||||||
|
PipelineStage::PostSelectionHydrator,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared helper to hydrate with a provided hydrator list.
|
||||||
|
async fn run_hydrators(
|
||||||
|
&self,
|
||||||
|
query: &Q,
|
||||||
|
mut candidates: Vec<C>,
|
||||||
|
hydrators: &[Box<dyn Hydrator<Q, C>>],
|
||||||
|
stage: PipelineStage,
|
||||||
|
) -> Vec<C> {
|
||||||
|
let request_id = query.request_id().to_string();
|
||||||
|
let hydrators: Vec<_> = hydrators.iter().filter(|h| h.enable(query)).collect();
|
||||||
|
let expected_len = candidates.len();
|
||||||
|
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(query, &candidates));
|
||||||
|
let results = join_all(hydrate_futures).await;
|
||||||
|
for (hydrator, result) in hydrators.iter().zip(results) {
|
||||||
|
match result {
|
||||||
|
Ok(hydrated) => {
|
||||||
|
if hydrated.len() == expected_len {
|
||||||
|
hydrator.update_all(&mut candidates, hydrated);
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
|
||||||
|
request_id,
|
||||||
|
stage,
|
||||||
|
hydrator.name(),
|
||||||
|
expected_len,
|
||||||
|
hydrated.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!(
|
||||||
|
"request_id={} stage={:?} component={} failed: {}",
|
||||||
|
request_id,
|
||||||
|
stage,
|
||||||
|
hydrator.name(),
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all filters sequentially. Each filter partitions candidates into kept and removed.
|
||||||
|
async fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
|
||||||
|
self.run_filters(query, candidates, self.filters(), PipelineStage::Filter)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run post-scoring filters sequentially on already-scored candidates.
|
||||||
|
async fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
|
||||||
|
self.run_filters(
|
||||||
|
query,
|
||||||
|
candidates,
|
||||||
|
self.post_selection_filters(),
|
||||||
|
PipelineStage::PostSelectionFilter,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shared helper to run filters sequentially from a provided filter list.
|
||||||
|
async fn run_filters(
|
||||||
|
&self,
|
||||||
|
query: &Q,
|
||||||
|
mut candidates: Vec<C>,
|
||||||
|
filters: &[Box<dyn Filter<Q, C>>],
|
||||||
|
stage: PipelineStage,
|
||||||
|
) -> (Vec<C>, Vec<C>) {
|
||||||
|
let request_id = query.request_id().to_string();
|
||||||
|
let mut all_removed = Vec::new();
|
||||||
|
for filter in filters.iter().filter(|f| f.enable(query)) {
|
||||||
|
let backup = candidates.clone();
|
||||||
|
match filter.filter(query, candidates).await {
|
||||||
|
Ok(result) => {
|
||||||
|
candidates = result.kept;
|
||||||
|
all_removed.extend(result.removed);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!(
|
||||||
|
"request_id={} stage={:?} component={} failed: {}",
|
||||||
|
request_id,
|
||||||
|
stage,
|
||||||
|
filter.name(),
|
||||||
|
err
|
||||||
|
);
|
||||||
|
candidates = backup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!(
|
||||||
|
"request_id={} stage={:?} kept {}, removed {}",
|
||||||
|
request_id,
|
||||||
|
stage,
|
||||||
|
candidates.len(),
|
||||||
|
all_removed.len()
|
||||||
|
);
|
||||||
|
(candidates, all_removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all scorers sequentially and apply their results to candidates.
|
||||||
|
async fn score(&self, query: &Q, mut candidates: Vec<C>) -> Vec<C> {
|
||||||
|
let request_id = query.request_id().to_string();
|
||||||
|
let expected_len = candidates.len();
|
||||||
|
for scorer in self.scorers().iter().filter(|s| s.enable(query)) {
|
||||||
|
match scorer.score(query, &candidates).await {
|
||||||
|
Ok(scored) => {
|
||||||
|
if scored.len() == expected_len {
|
||||||
|
scorer.update_all(&mut candidates, scored);
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
|
||||||
|
request_id,
|
||||||
|
PipelineStage::Scorer,
|
||||||
|
scorer.name(),
|
||||||
|
expected_len,
|
||||||
|
scored.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!(
|
||||||
|
"request_id={} stage={:?} component={} failed: {}",
|
||||||
|
request_id,
|
||||||
|
PipelineStage::Scorer,
|
||||||
|
scorer.name(),
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Select (sort/truncate) candidates using the configured selector
|
||||||
|
fn select(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||||
|
if self.selector().enable(query) {
|
||||||
|
self.selector().select(query, candidates)
|
||||||
|
} else {
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run all side effects in parallel
|
||||||
|
fn run_side_effects(&self, input: Arc<SideEffectInput<Q, C>>) {
|
||||||
|
let side_effects = self.side_effects();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let futures = side_effects
|
||||||
|
.iter()
|
||||||
|
.filter(|se| se.enable(input.query.clone()))
|
||||||
|
.map(|se| se.run(input.clone()));
|
||||||
|
let _ = join_all(futures).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
32
candidate-pipeline/filter.rs
Normal file
32
candidate-pipeline/filter.rs
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
use std::any::{Any, type_name_of_val};
|
||||||
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
use crate::util;
|
||||||
|
|
||||||
|
pub struct FilterResult<C> {
|
||||||
|
pub kept: Vec<C>,
|
||||||
|
pub removed: Vec<C>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filters run sequentially and partition candidates into kept and removed sets
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Filter<Q, C>: Any + Send + Sync
|
||||||
|
where
|
||||||
|
Q: Clone + Send + Sync + 'static,
|
||||||
|
C: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Decide if this filter should run for the given query
|
||||||
|
fn enable(&self, _query: &Q) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter candidates by evaluating each against some criteria.
|
||||||
|
/// Returns a FilterResult containing kept candidates (which continue to the next stage)
|
||||||
|
/// and removed candidates (which are excluded from further processing).
|
||||||
|
async fn filter(&self, query: &Q, candidates: Vec<C>) -> Result<FilterResult<C>, String>;
|
||||||
|
|
||||||
|
/// Returns a stable name for logging/metrics.
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
util::short_type_name(type_name_of_val(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
39
candidate-pipeline/hydrator.rs
Normal file
39
candidate-pipeline/hydrator.rs
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
use crate::util;
|
||||||
|
use std::any::{Any, type_name_of_val};
|
||||||
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
// Hydrators run in parallel and update candidate fields
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Hydrator<Q, C>: Any + Send + Sync
|
||||||
|
where
|
||||||
|
Q: Clone + Send + Sync + 'static,
|
||||||
|
C: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Decide if this hydrator should run for the given query
|
||||||
|
fn enable(&self, _query: &Q) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hydrate candidates by performing async operations.
|
||||||
|
/// Returns candidates with this hydrator's fields populated.
|
||||||
|
///
|
||||||
|
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
|
||||||
|
/// Dropping candidates in a hydrator is not allowed - use a filter stage instead.
|
||||||
|
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
|
||||||
|
|
||||||
|
/// Update a single candidate with the hydrated fields.
|
||||||
|
/// Only the fields this hydrator is responsible for should be copied.
|
||||||
|
fn update(&self, candidate: &mut C, hydrated: C);
|
||||||
|
|
||||||
|
/// Update all candidates with the hydrated fields from `hydrated`.
|
||||||
|
/// Default implementation iterates and calls `update` for each pair.
|
||||||
|
fn update_all(&self, candidates: &mut [C], hydrated: Vec<C>) {
|
||||||
|
for (c, h) in candidates.iter_mut().zip(hydrated) {
|
||||||
|
self.update(c, h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
util::short_type_name(type_name_of_val(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
9
candidate-pipeline/lib.rs
Normal file
9
candidate-pipeline/lib.rs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
pub mod candidate_pipeline;
|
||||||
|
pub mod filter;
|
||||||
|
pub mod hydrator;
|
||||||
|
pub mod query_hydrator;
|
||||||
|
pub mod scorer;
|
||||||
|
pub mod selector;
|
||||||
|
pub mod side_effect;
|
||||||
|
pub mod source;
|
||||||
|
pub mod util;
|
||||||
27
candidate-pipeline/query_hydrator.rs
Normal file
27
candidate-pipeline/query_hydrator.rs
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
use std::any::{Any, type_name_of_val};
|
||||||
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
use crate::util;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait QueryHydrator<Q>: Any + Send + Sync
|
||||||
|
where
|
||||||
|
Q: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Decide if this query hydrator should run for the given query
|
||||||
|
fn enable(&self, _query: &Q) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hydrate the query by performing async operations.
|
||||||
|
/// Returns a new query with this hydrator's fields populated.
|
||||||
|
async fn hydrate(&self, query: &Q) -> Result<Q, String>;
|
||||||
|
|
||||||
|
/// Update the query with the hydrated fields.
|
||||||
|
/// Only the fields this hydrator is responsible for should be copied.
|
||||||
|
fn update(&self, query: &mut Q, hydrated: Q);
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
util::short_type_name(type_name_of_val(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
39
candidate-pipeline/scorer.rs
Normal file
39
candidate-pipeline/scorer.rs
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
use crate::util;
|
||||||
|
use std::any::type_name_of_val;
|
||||||
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
/// Scorers update candidate fields (like a score field) and run sequentially
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Scorer<Q, C>: Send + Sync
|
||||||
|
where
|
||||||
|
Q: Clone + Send + Sync + 'static,
|
||||||
|
C: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Decide if this scorer should run for the given query
|
||||||
|
fn enable(&self, _query: &Q) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Score candidates by performing async operations.
|
||||||
|
/// Returns candidates with this scorer's fields populated.
|
||||||
|
///
|
||||||
|
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
|
||||||
|
/// Dropping candidates in a scorer is not allowed - use a filter stage instead.
|
||||||
|
async fn score(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
|
||||||
|
|
||||||
|
/// Update a single candidate with the scored fields.
|
||||||
|
/// Only the fields this scorer is responsible for should be copied.
|
||||||
|
fn update(&self, candidate: &mut C, scored: C);
|
||||||
|
|
||||||
|
/// Update all candidates with the scored fields from `scored`.
|
||||||
|
/// Default implementation iterates and calls `update` for each pair.
|
||||||
|
fn update_all(&self, candidates: &mut [C], scored: Vec<C>) {
|
||||||
|
for (c, s) in candidates.iter_mut().zip(scored) {
|
||||||
|
self.update(c, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
util::short_type_name(type_name_of_val(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
45
candidate-pipeline/selector.rs
Normal file
45
candidate-pipeline/selector.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use crate::util;
|
||||||
|
use std::any::type_name_of_val;
|
||||||
|
|
||||||
|
pub trait Selector<Q, C>: Send + Sync
|
||||||
|
where
|
||||||
|
Q: Clone + Send + Sync + 'static,
|
||||||
|
C: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Default selection: sort and truncate based on provided configs
|
||||||
|
fn select(&self, _query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||||
|
let mut sorted = self.sort(candidates);
|
||||||
|
if let Some(limit) = self.size() {
|
||||||
|
sorted.truncate(limit);
|
||||||
|
}
|
||||||
|
sorted
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decide if this selector should run for the given query
|
||||||
|
fn enable(&self, _query: &Q) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the score from a candidate to use for sorting.
|
||||||
|
fn score(&self, candidate: &C) -> f64;
|
||||||
|
|
||||||
|
/// Sort candidates by their scores in descending order.
|
||||||
|
fn sort(&self, candidates: Vec<C>) -> Vec<C> {
|
||||||
|
let mut sorted = candidates;
|
||||||
|
sorted.sort_by(|a, b| {
|
||||||
|
self.score(b)
|
||||||
|
.partial_cmp(&self.score(a))
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
sorted
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Optionally provide a size to select. Defaults to no truncation if not overridden.
|
||||||
|
fn size(&self) -> Option<usize> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
util::short_type_name(type_name_of_val(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
29
candidate-pipeline/side_effect.rs
Normal file
29
candidate-pipeline/side_effect.rs
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
use crate::util;
|
||||||
|
use std::any::type_name_of_val;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
// A side-effect is an action run that doesn't affect the pipeline result from being returned
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SideEffectInput<Q, C> {
|
||||||
|
pub query: Arc<Q>,
|
||||||
|
pub selected_candidates: Vec<C>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait SideEffect<Q, C>: Send + Sync
|
||||||
|
where
|
||||||
|
Q: Clone + Send + Sync + 'static,
|
||||||
|
C: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Decide if this side-effect should be run
|
||||||
|
fn enable(&self, _query: Arc<Q>) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String>;
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
util::short_type_name(type_name_of_val(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
22
candidate-pipeline/source.rs
Normal file
22
candidate-pipeline/source.rs
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
use std::any::{Any, type_name_of_val};
|
||||||
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
use crate::util;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Source<Q, C>: Any + Send + Sync
|
||||||
|
where
|
||||||
|
Q: Clone + Send + Sync + 'static,
|
||||||
|
C: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Decide if this source should run for the given query
|
||||||
|
fn enable(&self, _query: &Q) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_candidates(&self, query: &Q) -> Result<Vec<C>, String>;
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
util::short_type_name(type_name_of_val(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::tweet_entity_service_client::TESClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||||
|
|
||||||
|
pub struct CoreDataCandidateHydrator {
|
||||||
|
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CoreDataCandidateHydrator {
|
||||||
|
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||||
|
Self { tes_client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Hydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let client = &self.tes_client;
|
||||||
|
|
||||||
|
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let post_features = client.get_tweet_core_datas(tweet_ids.clone()).await;
|
||||||
|
let post_features = post_features.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
|
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||||
|
for tweet_id in tweet_ids {
|
||||||
|
let post_features = post_features.get(&tweet_id);
|
||||||
|
let core_data = post_features.and_then(|x| x.as_ref());
|
||||||
|
let text = core_data.map(|x| x.text.clone());
|
||||||
|
let hydrated = PostCandidate {
|
||||||
|
author_id: core_data.map(|x| x.author_id).unwrap_or_default(),
|
||||||
|
retweeted_user_id: core_data.and_then(|x| x.source_user_id),
|
||||||
|
retweeted_tweet_id: core_data.and_then(|x| x.source_tweet_id),
|
||||||
|
in_reply_to_tweet_id: core_data.and_then(|x| x.in_reply_to_tweet_id),
|
||||||
|
tweet_text: text.unwrap_or_default(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
hydrated_candidates.push(hydrated);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(hydrated_candidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||||
|
candidate.retweeted_user_id = hydrated.retweeted_user_id;
|
||||||
|
candidate.retweeted_tweet_id = hydrated.retweeted_tweet_id;
|
||||||
|
candidate.in_reply_to_tweet_id = hydrated.in_reply_to_tweet_id;
|
||||||
|
candidate.tweet_text = hydrated.tweet_text;
|
||||||
|
}
|
||||||
|
}
|
||||||
81
home-mixer/candidate_hydrators/gizmoduck_hydrator.rs
Normal file
81
home-mixer/candidate_hydrators/gizmoduck_hydrator.rs
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::gizmoduck_client::GizmoduckClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||||
|
|
||||||
|
pub struct GizmoduckCandidateHydrator {
|
||||||
|
pub gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GizmoduckCandidateHydrator {
|
||||||
|
pub async fn new(gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>) -> Self {
|
||||||
|
Self { gizmoduck_client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let client = &self.gizmoduck_client;
|
||||||
|
|
||||||
|
let author_ids: Vec<_> = candidates.iter().map(|c| c.author_id).collect();
|
||||||
|
let author_ids: Vec<_> = author_ids.iter().map(|&x| x as i64).collect();
|
||||||
|
let retweet_user_ids: Vec<_> = candidates.iter().map(|c| c.retweeted_user_id).collect();
|
||||||
|
let retweet_user_ids: Vec<_> = retweet_user_ids.iter().flatten().collect();
|
||||||
|
let retweet_user_ids: Vec<_> = retweet_user_ids.iter().map(|&&x| x as i64).collect();
|
||||||
|
|
||||||
|
let mut user_ids_to_fetch = Vec::with_capacity(author_ids.len() + retweet_user_ids.len());
|
||||||
|
user_ids_to_fetch.extend(author_ids);
|
||||||
|
user_ids_to_fetch.extend(retweet_user_ids);
|
||||||
|
user_ids_to_fetch.dedup();
|
||||||
|
|
||||||
|
let users = client.get_users(user_ids_to_fetch).await;
|
||||||
|
let users = users.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
|
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
let user = users
|
||||||
|
.get(&(candidate.author_id as i64))
|
||||||
|
.and_then(|user| user.as_ref());
|
||||||
|
let user_counts = user.and_then(|user| user.user.as_ref().map(|u| &u.counts));
|
||||||
|
let user_profile = user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
|
||||||
|
|
||||||
|
let author_followers_count: Option<i32> =
|
||||||
|
user_counts.map(|x| x.followers_count).map(|x| x as i32);
|
||||||
|
let author_screen_name: Option<String> = user_profile.map(|x| x.screen_name.clone());
|
||||||
|
|
||||||
|
let retweet_user = candidate
|
||||||
|
.retweeted_user_id
|
||||||
|
.and_then(|retweeted_user_id| users.get(&(retweeted_user_id as i64)))
|
||||||
|
.and_then(|user| user.as_ref());
|
||||||
|
let retweet_profile =
|
||||||
|
retweet_user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
|
||||||
|
let retweeted_screen_name: Option<String> =
|
||||||
|
retweet_profile.map(|x| x.screen_name.clone());
|
||||||
|
|
||||||
|
let hydrated = PostCandidate {
|
||||||
|
author_followers_count,
|
||||||
|
author_screen_name,
|
||||||
|
retweeted_screen_name,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
hydrated_candidates.push(hydrated);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(hydrated_candidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||||
|
candidate.author_followers_count = hydrated.author_followers_count;
|
||||||
|
candidate.author_screen_name = hydrated.author_screen_name;
|
||||||
|
candidate.retweeted_screen_name = hydrated.retweeted_screen_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||||
|
|
||||||
|
pub struct InNetworkCandidateHydrator;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Hydrator<ScoredPostsQuery, PostCandidate> for InNetworkCandidateHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let viewer_id = query.user_id as u64;
|
||||||
|
let followed_ids: HashSet<u64> = query
|
||||||
|
.user_features
|
||||||
|
.followed_user_ids
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.map(|id| id as u64)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let hydrated_candidates = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|candidate| {
|
||||||
|
let is_self = candidate.author_id == viewer_id;
|
||||||
|
let is_in_network = is_self || followed_ids.contains(&candidate.author_id);
|
||||||
|
PostCandidate {
|
||||||
|
in_network: Some(is_in_network),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(hydrated_candidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||||
|
candidate.in_network = hydrated.in_network;
|
||||||
|
}
|
||||||
|
}
|
||||||
6
home-mixer/candidate_hydrators/mod.rs
Normal file
6
home-mixer/candidate_hydrators/mod.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
pub mod core_data_candidate_hydrator;
|
||||||
|
pub mod gizmoduck_hydrator;
|
||||||
|
pub mod in_network_candidate_hydrator;
|
||||||
|
pub mod subscription_hydrator;
|
||||||
|
pub mod vf_candidate_hydrator;
|
||||||
|
pub mod video_duration_candidate_hydrator;
|
||||||
50
home-mixer/candidate_hydrators/subscription_hydrator.rs
Normal file
50
home-mixer/candidate_hydrators/subscription_hydrator.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::tweet_entity_service_client::TESClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||||
|
|
||||||
|
pub struct SubscriptionHydrator {
|
||||||
|
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SubscriptionHydrator {
|
||||||
|
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||||
|
Self { tes_client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Hydrator<ScoredPostsQuery, PostCandidate> for SubscriptionHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let client = &self.tes_client;
|
||||||
|
|
||||||
|
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let post_features = client.get_subscription_author_ids(tweet_ids.clone()).await;
|
||||||
|
let post_features = post_features.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
|
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||||
|
for tweet_id in tweet_ids {
|
||||||
|
let post_features = post_features.get(&tweet_id);
|
||||||
|
let subscription_author_id = post_features.and_then(|x| *x);
|
||||||
|
let hydrated = PostCandidate {
|
||||||
|
subscription_author_id,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
hydrated_candidates.push(hydrated);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(hydrated_candidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||||
|
candidate.subscription_author_id = hydrated.subscription_author_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
101
home-mixer/candidate_hydrators/vf_candidate_hydrator.rs
Normal file
101
home-mixer/candidate_hydrators/vf_candidate_hydrator.rs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use futures::future::join;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||||
|
use xai_twittercontext_proto::GetTwitterContextViewer;
|
||||||
|
use xai_twittercontext_proto::TwitterContextViewer;
|
||||||
|
use xai_visibility_filtering::models::FilteredReason;
|
||||||
|
use xai_visibility_filtering::vf_client::SafetyLevel;
|
||||||
|
use xai_visibility_filtering::vf_client::SafetyLevel::{TimelineHome, TimelineHomeRecommendations};
|
||||||
|
use xai_visibility_filtering::vf_client::VisibilityFilteringClient;
|
||||||
|
|
||||||
|
pub struct VFCandidateHydrator {
|
||||||
|
pub vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VFCandidateHydrator {
|
||||||
|
pub async fn new(vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>) -> Self {
|
||||||
|
Self { vf_client }
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_vf_results(
|
||||||
|
client: &Arc<dyn VisibilityFilteringClient + Send + Sync>,
|
||||||
|
tweet_ids: Vec<i64>,
|
||||||
|
safety_level: SafetyLevel,
|
||||||
|
for_user_id: i64,
|
||||||
|
context: Option<TwitterContextViewer>,
|
||||||
|
) -> Result<HashMap<i64, Option<FilteredReason>>, String> {
|
||||||
|
if tweet_ids.is_empty() {
|
||||||
|
return Ok(HashMap::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
client
|
||||||
|
.get_result(tweet_ids, safety_level, for_user_id, context)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Hydrator<ScoredPostsQuery, PostCandidate> for VFCandidateHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let context = query.get_viewer();
|
||||||
|
let user_id = query.user_id;
|
||||||
|
let client = &self.vf_client;
|
||||||
|
|
||||||
|
let mut in_network_ids = Vec::new();
|
||||||
|
let mut oon_ids = Vec::new();
|
||||||
|
for candidate in candidates.iter() {
|
||||||
|
if candidate.in_network.unwrap_or(false) {
|
||||||
|
in_network_ids.push(candidate.tweet_id);
|
||||||
|
} else {
|
||||||
|
oon_ids.push(candidate.tweet_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let in_network_future = Self::fetch_vf_results(
|
||||||
|
client,
|
||||||
|
in_network_ids,
|
||||||
|
TimelineHome,
|
||||||
|
user_id,
|
||||||
|
context.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let oon_future = Self::fetch_vf_results(
|
||||||
|
client,
|
||||||
|
oon_ids,
|
||||||
|
TimelineHomeRecommendations,
|
||||||
|
user_id,
|
||||||
|
context,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (in_network_result, oon_result) = join(in_network_future, oon_future).await;
|
||||||
|
let mut result: HashMap<i64, Option<FilteredReason>> = HashMap::new();
|
||||||
|
result.extend(in_network_result?);
|
||||||
|
result.extend(oon_result?);
|
||||||
|
|
||||||
|
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||||
|
for candidate in candidates {
|
||||||
|
let visibility_reason = result.get(&candidate.tweet_id);
|
||||||
|
let visibility_reason = visibility_reason.unwrap_or(&None);
|
||||||
|
let hydrated = PostCandidate {
|
||||||
|
visibility_reason: visibility_reason.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
hydrated_candidates.push(hydrated);
|
||||||
|
}
|
||||||
|
Ok(hydrated_candidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||||
|
candidate.visibility_reason = hydrated.visibility_reason;
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,62 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::candidate_features::MediaInfo;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::tweet_entity_service_client::TESClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||||
|
|
||||||
|
pub struct VideoDurationCandidateHydrator {
|
||||||
|
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VideoDurationCandidateHydrator {
|
||||||
|
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||||
|
Self { tes_client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Hydrator<ScoredPostsQuery, PostCandidate> for VideoDurationCandidateHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let client = &self.tes_client;
|
||||||
|
|
||||||
|
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let post_features = client.get_tweet_media_entities(tweet_ids.clone()).await;
|
||||||
|
let post_features = post_features.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
|
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||||
|
for tweet_id in tweet_ids {
|
||||||
|
let post_features = post_features.get(&tweet_id);
|
||||||
|
let media_entities = post_features.and_then(|x| x.as_ref());
|
||||||
|
|
||||||
|
let video_duration_ms = media_entities.and_then(|entities| {
|
||||||
|
entities.iter().find_map(|entity| {
|
||||||
|
if let Some(MediaInfo::VideoInfo(video_info)) = &entity.media_info {
|
||||||
|
Some(video_info.duration_millis)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let hydrated = PostCandidate {
|
||||||
|
video_duration_ms,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
hydrated_candidates.push(hydrated);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(hydrated_candidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||||
|
candidate.video_duration_ms = hydrated.video_duration_ms;
|
||||||
|
}
|
||||||
|
}
|
||||||
70
home-mixer/candidate_pipeline/candidate.rs
Normal file
70
home-mixer/candidate_pipeline/candidate.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use xai_home_mixer_proto as pb;
|
||||||
|
use xai_visibility_filtering::models as vf;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default)]
|
||||||
|
pub struct PostCandidate {
|
||||||
|
pub tweet_id: i64,
|
||||||
|
pub author_id: u64,
|
||||||
|
pub tweet_text: String,
|
||||||
|
pub in_reply_to_tweet_id: Option<u64>,
|
||||||
|
pub retweeted_tweet_id: Option<u64>,
|
||||||
|
pub retweeted_user_id: Option<u64>,
|
||||||
|
pub phoenix_scores: PhoenixScores,
|
||||||
|
pub prediction_request_id: Option<u64>,
|
||||||
|
pub last_scored_at_ms: Option<u64>,
|
||||||
|
pub weighted_score: Option<f64>,
|
||||||
|
pub score: Option<f64>,
|
||||||
|
pub served_type: Option<pb::ServedType>,
|
||||||
|
pub in_network: Option<bool>,
|
||||||
|
pub ancestors: Vec<u64>,
|
||||||
|
pub video_duration_ms: Option<i32>,
|
||||||
|
pub author_followers_count: Option<i32>,
|
||||||
|
pub author_screen_name: Option<String>,
|
||||||
|
pub retweeted_screen_name: Option<String>,
|
||||||
|
pub visibility_reason: Option<vf::FilteredReason>,
|
||||||
|
pub subscription_author_id: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default)]
|
||||||
|
pub struct PhoenixScores {
|
||||||
|
pub favorite_score: Option<f64>,
|
||||||
|
pub reply_score: Option<f64>,
|
||||||
|
pub retweet_score: Option<f64>,
|
||||||
|
pub photo_expand_score: Option<f64>,
|
||||||
|
pub click_score: Option<f64>,
|
||||||
|
pub profile_click_score: Option<f64>,
|
||||||
|
pub vqv_score: Option<f64>,
|
||||||
|
pub share_score: Option<f64>,
|
||||||
|
pub share_via_dm_score: Option<f64>,
|
||||||
|
pub share_via_copy_link_score: Option<f64>,
|
||||||
|
pub dwell_score: Option<f64>,
|
||||||
|
pub quote_score: Option<f64>,
|
||||||
|
pub quoted_click_score: Option<f64>,
|
||||||
|
pub follow_author_score: Option<f64>,
|
||||||
|
pub not_interested_score: Option<f64>,
|
||||||
|
pub block_author_score: Option<f64>,
|
||||||
|
pub mute_author_score: Option<f64>,
|
||||||
|
pub report_score: Option<f64>,
|
||||||
|
// Continuous actions
|
||||||
|
pub dwell_time: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CandidateHelpers {
|
||||||
|
fn get_screen_names(&self) -> HashMap<u64, String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CandidateHelpers for PostCandidate {
|
||||||
|
fn get_screen_names(&self) -> HashMap<u64, String> {
|
||||||
|
let mut screen_names = HashMap::<u64, String>::new();
|
||||||
|
if let Some(author_screen_name) = self.author_screen_name.clone() {
|
||||||
|
screen_names.insert(self.author_id, author_screen_name);
|
||||||
|
}
|
||||||
|
if let (Some(retweeted_screen_name), Some(retweeted_user_id)) =
|
||||||
|
(self.retweeted_screen_name.clone(), self.retweeted_user_id)
|
||||||
|
{
|
||||||
|
screen_names.insert(retweeted_user_id, retweeted_screen_name);
|
||||||
|
}
|
||||||
|
screen_names
|
||||||
|
}
|
||||||
|
}
|
||||||
78
home-mixer/candidate_pipeline/candidate_features.rs
Normal file
78
home-mixer/candidate_pipeline/candidate_features.rs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PureCoreData {
|
||||||
|
pub author_id: u64,
|
||||||
|
pub text: String,
|
||||||
|
pub source_tweet_id: Option<u64>,
|
||||||
|
pub source_user_id: Option<u64>,
|
||||||
|
pub in_reply_to_tweet_id: Option<u64>,
|
||||||
|
pub in_reply_to_user_id: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ExclusiveTweetControl {
|
||||||
|
pub conversation_author_id: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type MediaEntities = Vec<MediaEntity>;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct MediaEntity {
|
||||||
|
pub media_info: Option<MediaInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub enum MediaInfo {
|
||||||
|
VideoInfo(VideoInfo),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct VideoInfo {
|
||||||
|
pub duration_millis: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Share {
|
||||||
|
pub source_tweet_id: u64,
|
||||||
|
pub source_user_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Reply {
|
||||||
|
pub in_reply_to_tweet_id: Option<u64>,
|
||||||
|
pub in_reply_to_user_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GizmoduckUserCounts {
|
||||||
|
pub followers_count: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GizmoduckUserProfile {
|
||||||
|
pub screen_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GizmoduckUser {
|
||||||
|
pub user_id: u64,
|
||||||
|
pub profile: GizmoduckUserProfile,
|
||||||
|
pub counts: GizmoduckUserCounts,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GizmoduckUserResult {
|
||||||
|
pub user: Option<GizmoduckUser>,
|
||||||
|
}
|
||||||
5
home-mixer/candidate_pipeline/mod.rs
Normal file
5
home-mixer/candidate_pipeline/mod.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
pub mod candidate;
|
||||||
|
pub mod candidate_features;
|
||||||
|
pub mod phoenix_candidate_pipeline;
|
||||||
|
pub mod query;
|
||||||
|
pub mod query_features;
|
||||||
255
home-mixer/candidate_pipeline/phoenix_candidate_pipeline.rs
Normal file
255
home-mixer/candidate_pipeline/phoenix_candidate_pipeline.rs
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
use crate::candidate_hydrators::core_data_candidate_hydrator::CoreDataCandidateHydrator;
|
||||||
|
use crate::candidate_hydrators::gizmoduck_hydrator::GizmoduckCandidateHydrator;
|
||||||
|
use crate::candidate_hydrators::in_network_candidate_hydrator::InNetworkCandidateHydrator;
|
||||||
|
use crate::candidate_hydrators::subscription_hydrator::SubscriptionHydrator;
|
||||||
|
use crate::candidate_hydrators::vf_candidate_hydrator::VFCandidateHydrator;
|
||||||
|
use crate::candidate_hydrators::video_duration_candidate_hydrator::VideoDurationCandidateHydrator;
|
||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::gizmoduck_client::{GizmoduckClient, ProdGizmoduckClient};
|
||||||
|
use crate::clients::phoenix_prediction_client::{
|
||||||
|
PhoenixPredictionClient, ProdPhoenixPredictionClient,
|
||||||
|
};
|
||||||
|
use crate::clients::phoenix_retrieval_client::{
|
||||||
|
PhoenixRetrievalClient, ProdPhoenixRetrievalClient,
|
||||||
|
};
|
||||||
|
use crate::clients::s2s::{S2S_CHAIN_PATH, S2S_CRT_PATH, S2S_KEY_PATH};
|
||||||
|
use crate::clients::socialgraph_client::SocialGraphClient;
|
||||||
|
use crate::clients::strato_client::{ProdStratoClient, StratoClient};
|
||||||
|
use crate::clients::thunder_client::ThunderClient;
|
||||||
|
use crate::clients::tweet_entity_service_client::{ProdTESClient, TESClient};
|
||||||
|
use crate::clients::uas_fetcher::UserActionSequenceFetcher;
|
||||||
|
use crate::filters::age_filter::AgeFilter;
|
||||||
|
use crate::filters::author_socialgraph_filter::AuthorSocialgraphFilter;
|
||||||
|
use crate::filters::core_data_hydration_filter::CoreDataHydrationFilter;
|
||||||
|
use crate::filters::dedup_conversation_filter::DedupConversationFilter;
|
||||||
|
use crate::filters::drop_duplicates_filter::DropDuplicatesFilter;
|
||||||
|
use crate::filters::ineligible_subscription_filter::IneligibleSubscriptionFilter;
|
||||||
|
use crate::filters::muted_keyword_filter::MutedKeywordFilter;
|
||||||
|
use crate::filters::previously_seen_posts_filter::PreviouslySeenPostsFilter;
|
||||||
|
use crate::filters::previously_served_posts_filter::PreviouslyServedPostsFilter;
|
||||||
|
use crate::filters::retweet_deduplication_filter::RetweetDeduplicationFilter;
|
||||||
|
use crate::filters::self_tweet_filter::SelfTweetFilter;
|
||||||
|
use crate::filters::vf_filter::VFFilter;
|
||||||
|
use crate::params;
|
||||||
|
use crate::query_hydrators::user_action_seq_query_hydrator::UserActionSeqQueryHydrator;
|
||||||
|
use crate::query_hydrators::user_features_query_hydrator::UserFeaturesQueryHydrator;
|
||||||
|
use crate::scorers::author_diversity_scorer::AuthorDiversityScorer;
|
||||||
|
use crate::scorers::oon_scorer::OONScorer;
|
||||||
|
use crate::scorers::phoenix_scorer::PhoenixScorer;
|
||||||
|
use crate::scorers::weighted_scorer::WeightedScorer;
|
||||||
|
use crate::selectors::TopKScoreSelector;
|
||||||
|
use crate::side_effects::cache_request_info_side_effect::CacheRequestInfoSideEffect;
|
||||||
|
use crate::sources::phoenix_source::PhoenixSource;
|
||||||
|
use crate::sources::thunder_source::ThunderSource;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
|
||||||
|
use xai_candidate_pipeline::filter::Filter;
|
||||||
|
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||||
|
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
|
||||||
|
use xai_candidate_pipeline::scorer::Scorer;
|
||||||
|
use xai_candidate_pipeline::selector::Selector;
|
||||||
|
use xai_candidate_pipeline::side_effect::SideEffect;
|
||||||
|
use xai_candidate_pipeline::source::Source;
|
||||||
|
use xai_visibility_filtering::vf_client::{
|
||||||
|
ProdVisibilityFilteringClient, VisibilityFilteringClient,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct PhoenixCandidatePipeline {
|
||||||
|
query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>>,
|
||||||
|
sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>>,
|
||||||
|
hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>>,
|
||||||
|
filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>>,
|
||||||
|
scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>>,
|
||||||
|
selector: TopKScoreSelector,
|
||||||
|
post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>>,
|
||||||
|
post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>>,
|
||||||
|
side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PhoenixCandidatePipeline {
|
||||||
|
async fn build_with_clients(
|
||||||
|
uas_fetcher: Arc<UserActionSequenceFetcher>,
|
||||||
|
phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
|
||||||
|
phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
|
||||||
|
thunder_client: Arc<ThunderClient>,
|
||||||
|
strato_client: Arc<dyn StratoClient + Send + Sync>,
|
||||||
|
tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||||
|
gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
|
||||||
|
vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
|
||||||
|
) -> PhoenixCandidatePipeline {
|
||||||
|
// Query Hydrators
|
||||||
|
let query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>> = vec![
|
||||||
|
Box::new(UserActionSeqQueryHydrator::new(uas_fetcher)),
|
||||||
|
Box::new(UserFeaturesQueryHydrator {
|
||||||
|
strato_client: strato_client.clone(),
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Sources
|
||||||
|
let phoenix_source = Box::new(PhoenixSource {
|
||||||
|
phoenix_retrieval_client,
|
||||||
|
});
|
||||||
|
let thunder_source = Box::new(ThunderSource { thunder_client });
|
||||||
|
let sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>> =
|
||||||
|
vec![phoenix_source, thunder_source];
|
||||||
|
|
||||||
|
// Hydrators
|
||||||
|
let hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||||
|
Box::new(InNetworkCandidateHydrator),
|
||||||
|
Box::new(CoreDataCandidateHydrator::new(tes_client.clone()).await),
|
||||||
|
Box::new(VideoDurationCandidateHydrator::new(tes_client.clone()).await),
|
||||||
|
Box::new(SubscriptionHydrator::new(tes_client.clone()).await),
|
||||||
|
Box::new(GizmoduckCandidateHydrator::new(gizmoduck_client).await),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Filters
|
||||||
|
let filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||||
|
Box::new(DropDuplicatesFilter),
|
||||||
|
Box::new(CoreDataHydrationFilter),
|
||||||
|
Box::new(AgeFilter::new(Duration::from_secs(params::MAX_POST_AGE))),
|
||||||
|
Box::new(SelfTweetFilter),
|
||||||
|
Box::new(RetweetDeduplicationFilter),
|
||||||
|
Box::new(IneligibleSubscriptionFilter),
|
||||||
|
Box::new(PreviouslySeenPostsFilter),
|
||||||
|
Box::new(PreviouslyServedPostsFilter),
|
||||||
|
Box::new(MutedKeywordFilter::new()),
|
||||||
|
Box::new(AuthorSocialgraphFilter),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Scorers
|
||||||
|
let phoenix_scorer = Box::new(PhoenixScorer { phoenix_client });
|
||||||
|
let weighted_scorer = Box::new(WeightedScorer);
|
||||||
|
let author_diversity_scorer = Box::new(AuthorDiversityScorer::default());
|
||||||
|
let oon_scorer = Box::new(OONScorer);
|
||||||
|
let scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||||
|
phoenix_scorer,
|
||||||
|
weighted_scorer,
|
||||||
|
author_diversity_scorer,
|
||||||
|
oon_scorer,
|
||||||
|
];
|
||||||
|
|
||||||
|
// Selector
|
||||||
|
let selector = TopKScoreSelector;
|
||||||
|
|
||||||
|
// Post-selection hydrators
|
||||||
|
let post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> =
|
||||||
|
vec![Box::new(VFCandidateHydrator::new(vf_client.clone()).await)];
|
||||||
|
|
||||||
|
// Post-selection filters
|
||||||
|
let post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> =
|
||||||
|
vec![Box::new(VFFilter), Box::new(DedupConversationFilter)];
|
||||||
|
|
||||||
|
// Side Effects
|
||||||
|
let side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> =
|
||||||
|
Arc::new(vec![Box::new(CacheRequestInfoSideEffect { strato_client })]);
|
||||||
|
|
||||||
|
PhoenixCandidatePipeline {
|
||||||
|
query_hydrators,
|
||||||
|
hydrators,
|
||||||
|
filters,
|
||||||
|
sources,
|
||||||
|
scorers,
|
||||||
|
selector,
|
||||||
|
post_selection_hydrators,
|
||||||
|
post_selection_filters,
|
||||||
|
side_effects,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn prod() -> PhoenixCandidatePipeline {
|
||||||
|
let uas_fetcher =
|
||||||
|
Arc::new(UserActionSequenceFetcher::new().expect("Failed to create UAS fetcher"));
|
||||||
|
let _sgs_client = Arc::new(SocialGraphClient::new());
|
||||||
|
let phoenix_client = Arc::new(
|
||||||
|
ProdPhoenixPredictionClient::new()
|
||||||
|
.await
|
||||||
|
.expect("Failed to create Phoenix prediction client"),
|
||||||
|
);
|
||||||
|
let phoenix_retrieval_client = Arc::new(
|
||||||
|
ProdPhoenixRetrievalClient::new()
|
||||||
|
.await
|
||||||
|
.expect("Failed to create Phoenix retrieval client"),
|
||||||
|
);
|
||||||
|
let thunder_client = Arc::new(ThunderClient::new().await);
|
||||||
|
let strato_client = Arc::new(
|
||||||
|
ProdStratoClient::new()
|
||||||
|
.await
|
||||||
|
.expect("Failed to create Strato client"),
|
||||||
|
);
|
||||||
|
let tes_client = Arc::new(
|
||||||
|
ProdTESClient::new()
|
||||||
|
.await
|
||||||
|
.expect("Failed to create TES client"),
|
||||||
|
);
|
||||||
|
let gizmoduck_client = Arc::new(
|
||||||
|
ProdGizmoduckClient::new()
|
||||||
|
.await
|
||||||
|
.expect("Failed to create Gizmoduck client"),
|
||||||
|
);
|
||||||
|
let vf_client = Arc::new(
|
||||||
|
ProdVisibilityFilteringClient::new(
|
||||||
|
S2S_CHAIN_PATH.clone(),
|
||||||
|
S2S_CRT_PATH.clone(),
|
||||||
|
S2S_KEY_PATH.clone()
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Failed to create VF client"),
|
||||||
|
);
|
||||||
|
PhoenixCandidatePipeline::build_with_clients(
|
||||||
|
uas_fetcher,
|
||||||
|
phoenix_client,
|
||||||
|
phoenix_retrieval_client,
|
||||||
|
thunder_client,
|
||||||
|
strato_client,
|
||||||
|
tes_client,
|
||||||
|
gizmoduck_client,
|
||||||
|
vf_client,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CandidatePipeline<ScoredPostsQuery, PostCandidate> for PhoenixCandidatePipeline {
|
||||||
|
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<ScoredPostsQuery>>] {
|
||||||
|
&self.query_hydrators
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sources(&self) -> &[Box<dyn Source<ScoredPostsQuery, PostCandidate>>] {
|
||||||
|
&self.sources
|
||||||
|
}
|
||||||
|
fn hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>] {
|
||||||
|
&self.hydrators
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, PostCandidate>>] {
|
||||||
|
&self.filters
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scorers(&self) -> &[Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>] {
|
||||||
|
&self.scorers
|
||||||
|
}
|
||||||
|
|
||||||
|
fn selector(&self) -> &dyn Selector<ScoredPostsQuery, PostCandidate> {
|
||||||
|
&self.selector
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>] {
|
||||||
|
&self.post_selection_hydrators
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_selection_filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, PostCandidate>>] {
|
||||||
|
&self.post_selection_filters
|
||||||
|
}
|
||||||
|
|
||||||
|
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> {
|
||||||
|
Arc::clone(&self.side_effects)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn result_size(&self) -> usize {
|
||||||
|
params::RESULT_SIZE
|
||||||
|
}
|
||||||
|
}
|
||||||
69
home-mixer/candidate_pipeline/query.rs
Normal file
69
home-mixer/candidate_pipeline/query.rs
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
use crate::candidate_pipeline::query_features::UserFeatures;
|
||||||
|
use crate::util::request_util::generate_request_id;
|
||||||
|
use xai_candidate_pipeline::candidate_pipeline::HasRequestId;
|
||||||
|
use xai_home_mixer_proto::ImpressionBloomFilterEntry;
|
||||||
|
use xai_twittercontext_proto::{GetTwitterContextViewer, TwitterContextViewer};
|
||||||
|
|
||||||
|
#[derive(Clone, Default, Debug)]
|
||||||
|
pub struct ScoredPostsQuery {
|
||||||
|
pub user_id: i64,
|
||||||
|
pub client_app_id: i32,
|
||||||
|
pub country_code: String,
|
||||||
|
pub language_code: String,
|
||||||
|
pub seen_ids: Vec<i64>,
|
||||||
|
pub served_ids: Vec<i64>,
|
||||||
|
pub in_network_only: bool,
|
||||||
|
pub is_bottom_request: bool,
|
||||||
|
pub bloom_filter_entries: Vec<ImpressionBloomFilterEntry>,
|
||||||
|
pub user_action_sequence: Option<xai_recsys_proto::UserActionSequence>,
|
||||||
|
pub user_features: UserFeatures,
|
||||||
|
pub request_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScoredPostsQuery {
|
||||||
|
pub fn new(
|
||||||
|
user_id: i64,
|
||||||
|
client_app_id: i32,
|
||||||
|
country_code: String,
|
||||||
|
language_code: String,
|
||||||
|
seen_ids: Vec<i64>,
|
||||||
|
served_ids: Vec<i64>,
|
||||||
|
in_network_only: bool,
|
||||||
|
is_bottom_request: bool,
|
||||||
|
bloom_filter_entries: Vec<ImpressionBloomFilterEntry>,
|
||||||
|
) -> Self {
|
||||||
|
let request_id = format!("{}-{}", generate_request_id(), user_id);
|
||||||
|
Self {
|
||||||
|
user_id,
|
||||||
|
client_app_id,
|
||||||
|
country_code,
|
||||||
|
language_code,
|
||||||
|
seen_ids,
|
||||||
|
served_ids,
|
||||||
|
in_network_only,
|
||||||
|
is_bottom_request,
|
||||||
|
bloom_filter_entries,
|
||||||
|
user_action_sequence: None,
|
||||||
|
user_features: UserFeatures::default(),
|
||||||
|
request_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GetTwitterContextViewer for ScoredPostsQuery {
|
||||||
|
fn get_viewer(&self) -> Option<TwitterContextViewer> {
|
||||||
|
Some(TwitterContextViewer {
|
||||||
|
user_id: self.user_id,
|
||||||
|
client_application_id: self.client_app_id as i64,
|
||||||
|
request_country_code: self.country_code.clone(),
|
||||||
|
request_language_code: self.language_code.clone(),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HasRequestId for ScoredPostsQuery {
|
||||||
|
fn request_id(&self) -> &str {
|
||||||
|
&self.request_id
|
||||||
|
}
|
||||||
|
}
|
||||||
11
home-mixer/candidate_pipeline/query_features.rs
Normal file
11
home-mixer/candidate_pipeline/query_features.rs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct UserFeatures {
|
||||||
|
pub muted_keywords: Vec<String>,
|
||||||
|
pub blocked_user_ids: Vec<i64>,
|
||||||
|
pub muted_user_ids: Vec<i64>,
|
||||||
|
pub followed_user_ids: Vec<i64>,
|
||||||
|
pub subscribed_user_ids: Vec<i64>,
|
||||||
|
}
|
||||||
38
home-mixer/filters/age_filter.rs
Normal file
38
home-mixer/filters/age_filter.rs
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::util::snowflake;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
/// Filter that removes tweets older than a specified duration.
|
||||||
|
pub struct AgeFilter {
|
||||||
|
pub max_age: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgeFilter {
|
||||||
|
pub fn new(max_age: Duration) -> Self {
|
||||||
|
Self { max_age }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_within_age(&self, tweet_id: i64) -> bool {
|
||||||
|
snowflake::duration_since_creation_opt(tweet_id)
|
||||||
|
.map(|age| age <= self.max_age)
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for AgeFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) = candidates
|
||||||
|
.into_iter()
|
||||||
|
.partition(|c| self.is_within_age(c.tweet_id));
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
42
home-mixer/filters/author_socialgraph_filter.rs
Normal file
42
home-mixer/filters/author_socialgraph_filter.rs
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
// Remove candidates that are blocked or muted by the viewer
|
||||||
|
pub struct AuthorSocialgraphFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for AuthorSocialgraphFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let viewer_blocked_user_ids = query.user_features.blocked_user_ids.clone();
|
||||||
|
let viewer_muted_user_ids = query.user_features.muted_user_ids.clone();
|
||||||
|
|
||||||
|
if viewer_blocked_user_ids.is_empty() && viewer_muted_user_ids.is_empty() {
|
||||||
|
return Ok(FilterResult {
|
||||||
|
kept: candidates,
|
||||||
|
removed: Vec::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut kept: Vec<PostCandidate> = Vec::new();
|
||||||
|
let mut removed: Vec<PostCandidate> = Vec::new();
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
let author_id = candidate.author_id as i64;
|
||||||
|
let muted = viewer_muted_user_ids.contains(&author_id);
|
||||||
|
let blocked = viewer_blocked_user_ids.contains(&author_id);
|
||||||
|
if muted || blocked {
|
||||||
|
removed.push(candidate);
|
||||||
|
} else {
|
||||||
|
kept.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
20
home-mixer/filters/core_data_hydration_filter.rs
Normal file
20
home-mixer/filters/core_data_hydration_filter.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
pub struct CoreDataHydrationFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for CoreDataHydrationFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let (kept, removed) = candidates
|
||||||
|
.into_iter()
|
||||||
|
.partition(|c| c.author_id != 0 && !c.tweet_text.trim().is_empty());
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
51
home-mixer/filters/dedup_conversation_filter.rs
Normal file
51
home-mixer/filters/dedup_conversation_filter.rs
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
/// Keeps only the highest-scored candidate per branch of a conversation tree
|
||||||
|
pub struct DedupConversationFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for DedupConversationFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let mut kept: Vec<PostCandidate> = Vec::new();
|
||||||
|
let mut removed: Vec<PostCandidate> = Vec::new();
|
||||||
|
let mut best_per_convo: HashMap<u64, (usize, f64)> = HashMap::new();
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
let conversation_id = get_conversation_id(&candidate);
|
||||||
|
let score = candidate.score.unwrap_or(0.0);
|
||||||
|
|
||||||
|
if let Some((kept_idx, best_score)) = best_per_convo.get_mut(&conversation_id) {
|
||||||
|
if score > *best_score {
|
||||||
|
let previous = std::mem::replace(&mut kept[*kept_idx], candidate);
|
||||||
|
removed.push(previous);
|
||||||
|
*best_score = score;
|
||||||
|
} else {
|
||||||
|
removed.push(candidate);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let idx = kept.len();
|
||||||
|
best_per_convo.insert(conversation_id, (idx, score));
|
||||||
|
kept.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_conversation_id(candidate: &PostCandidate) -> u64 {
|
||||||
|
candidate
|
||||||
|
.ancestors
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.min()
|
||||||
|
.unwrap_or(candidate.tweet_id as u64)
|
||||||
|
}
|
||||||
30
home-mixer/filters/drop_duplicates_filter.rs
Normal file
30
home-mixer/filters/drop_duplicates_filter.rs
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
pub struct DropDuplicatesFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for DropDuplicatesFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let mut seen_ids = HashSet::new();
|
||||||
|
let mut kept = Vec::new();
|
||||||
|
let mut removed = Vec::new();
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
if seen_ids.insert(candidate.tweet_id) {
|
||||||
|
kept.push(candidate);
|
||||||
|
} else {
|
||||||
|
removed.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
34
home-mixer/filters/ineligible_subscription_filter.rs
Normal file
34
home-mixer/filters/ineligible_subscription_filter.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
/// Filters out subscription-only posts from authors the viewer is not subscribed to.
|
||||||
|
pub struct IneligibleSubscriptionFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for IneligibleSubscriptionFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let subscribed_user_ids: HashSet<u64> = query
|
||||||
|
.user_features
|
||||||
|
.subscribed_user_ids
|
||||||
|
.iter()
|
||||||
|
.map(|id| *id as u64)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) =
|
||||||
|
candidates
|
||||||
|
.into_iter()
|
||||||
|
.partition(|candidate| match candidate.subscription_author_id {
|
||||||
|
Some(author_id) => subscribed_user_ids.contains(&author_id),
|
||||||
|
None => true,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
13
home-mixer/filters/mod.rs
Normal file
13
home-mixer/filters/mod.rs
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
pub mod age_filter;
|
||||||
|
pub mod author_socialgraph_filter;
|
||||||
|
pub mod core_data_hydration_filter;
|
||||||
|
pub mod dedup_conversation_filter;
|
||||||
|
pub mod drop_duplicates_filter;
|
||||||
|
|
||||||
|
pub mod ineligible_subscription_filter;
|
||||||
|
pub mod muted_keyword_filter;
|
||||||
|
pub mod previously_seen_posts_filter;
|
||||||
|
pub mod previously_served_posts_filter;
|
||||||
|
pub mod retweet_deduplication_filter;
|
||||||
|
pub mod self_tweet_filter;
|
||||||
|
pub mod vf_filter;
|
||||||
59
home-mixer/filters/muted_keyword_filter.rs
Normal file
59
home-mixer/filters/muted_keyword_filter.rs
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
use xai_post_text::{MatchTweetGroup, TokenSequence, TweetTokenizer, UserMutes};
|
||||||
|
|
||||||
|
pub struct MutedKeywordFilter {
|
||||||
|
pub tokenizer: Arc<TweetTokenizer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MutedKeywordFilter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let tokenizer = TweetTokenizer::new();
|
||||||
|
Self {
|
||||||
|
tokenizer: Arc::new(tokenizer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for MutedKeywordFilter {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let muted_keywords = query.user_features.muted_keywords.clone();
|
||||||
|
|
||||||
|
if muted_keywords.is_empty() {
|
||||||
|
return Ok(FilterResult {
|
||||||
|
kept: candidates,
|
||||||
|
removed: vec![],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenized = muted_keywords.iter().map(|k| self.tokenizer.tokenize(k));
|
||||||
|
let token_sequences: Vec<TokenSequence> = tokenized.collect::<Vec<_>>();
|
||||||
|
let user_mutes = UserMutes::new(token_sequences);
|
||||||
|
let matcher = MatchTweetGroup::new(user_mutes);
|
||||||
|
|
||||||
|
let mut kept = Vec::new();
|
||||||
|
let mut removed = Vec::new();
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
let tweet_text_token_sequence = self.tokenizer.tokenize(&candidate.tweet_text);
|
||||||
|
if matcher.matches(&tweet_text_token_sequence) {
|
||||||
|
// Matches muted keywords - should be removed/filtered out
|
||||||
|
removed.push(candidate);
|
||||||
|
} else {
|
||||||
|
// Does not match muted keywords - keep it
|
||||||
|
kept.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
36
home-mixer/filters/previously_seen_posts_filter.rs
Normal file
36
home-mixer/filters/previously_seen_posts_filter.rs
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::util::bloom_filter::BloomFilter;
|
||||||
|
use crate::util::candidates_util::get_related_post_ids;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
/// Filter out previously seen posts using a Bloom Filter and
|
||||||
|
/// the seen IDs sent in the request directly from the client
|
||||||
|
pub struct PreviouslySeenPostsFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for PreviouslySeenPostsFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let bloom_filters = query
|
||||||
|
.bloom_filter_entries
|
||||||
|
.iter()
|
||||||
|
.map(BloomFilter::from_entry)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let (removed, kept): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||||
|
get_related_post_ids(c).iter().any(|&post_id| {
|
||||||
|
query.seen_ids.contains(&post_id)
|
||||||
|
|| bloom_filters
|
||||||
|
.iter()
|
||||||
|
.any(|filter| filter.may_contain(post_id))
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
28
home-mixer/filters/previously_served_posts_filter.rs
Normal file
28
home-mixer/filters/previously_served_posts_filter.rs
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::util::candidates_util::get_related_post_ids;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
pub struct PreviouslyServedPostsFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for PreviouslyServedPostsFilter {
|
||||||
|
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||||
|
query.is_bottom_request
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let (removed, kept): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||||
|
get_related_post_ids(c)
|
||||||
|
.iter()
|
||||||
|
.any(|id| query.served_ids.contains(id))
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
42
home-mixer/filters/retweet_deduplication_filter.rs
Normal file
42
home-mixer/filters/retweet_deduplication_filter.rs
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
/// Deduplicates retweets, keeping only the first occurrence of a tweet
|
||||||
|
/// (whether as an original or as a retweet).
|
||||||
|
pub struct RetweetDeduplicationFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for RetweetDeduplicationFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let mut seen_tweet_ids: HashSet<u64> = HashSet::new();
|
||||||
|
let mut kept = Vec::new();
|
||||||
|
let mut removed = Vec::new();
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
match candidate.retweeted_tweet_id {
|
||||||
|
Some(retweeted_id) => {
|
||||||
|
// Remove if we've already seen this tweet (as original or retweet)
|
||||||
|
if seen_tweet_ids.insert(retweeted_id) {
|
||||||
|
kept.push(candidate);
|
||||||
|
} else {
|
||||||
|
removed.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// Mark this original tweet ID as seen so retweets of it get filtered
|
||||||
|
seen_tweet_ids.insert(candidate.tweet_id as u64);
|
||||||
|
kept.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
23
home-mixer/filters/self_tweet_filter.rs
Normal file
23
home-mixer/filters/self_tweet_filter.rs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
|
||||||
|
/// Filter that removes tweets where the author is the viewer.
|
||||||
|
pub struct SelfTweetFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for SelfTweetFilter {
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let viewer_id = query.user_id as u64;
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) = candidates
|
||||||
|
.into_iter()
|
||||||
|
.partition(|c| c.author_id != viewer_id);
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
33
home-mixer/filters/vf_filter.rs
Normal file
33
home-mixer/filters/vf_filter.rs
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||||
|
use xai_visibility_filtering::models::{Action, FilteredReason};
|
||||||
|
|
||||||
|
pub struct VFFilter;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter<ScoredPostsQuery, PostCandidate> for VFFilter {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: Vec<PostCandidate>,
|
||||||
|
) -> Result<FilterResult<PostCandidate>, String> {
|
||||||
|
let (removed, kept): (Vec<_>, Vec<_>) = candidates
|
||||||
|
.into_iter()
|
||||||
|
.partition(|c| should_drop(&c.visibility_reason));
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_drop(reason: &Option<FilteredReason>) -> bool {
|
||||||
|
match reason {
|
||||||
|
Some(FilteredReason::SafetyResult(safety_result)) => {
|
||||||
|
matches!(safety_result.action, Action::Drop(_))
|
||||||
|
}
|
||||||
|
Some(_) => true,
|
||||||
|
None => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
14
home-mixer/lib.rs
Normal file
14
home-mixer/lib.rs
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
mod candidate_hydrators;
|
||||||
|
mod candidate_pipeline;
|
||||||
|
pub mod clients; // Excluded from open source release for security reasons
|
||||||
|
mod filters;
|
||||||
|
pub mod params; // Excluded from open source release for security reasons
|
||||||
|
mod query_hydrators;
|
||||||
|
pub mod scorers;
|
||||||
|
mod selectors;
|
||||||
|
mod server;
|
||||||
|
mod side_effects;
|
||||||
|
mod sources;
|
||||||
|
pub mod util; // Excluded from open source release for security reasons
|
||||||
|
|
||||||
|
pub use server::HomeMixerServer;
|
||||||
78
home-mixer/main.rs
Normal file
78
home-mixer/main.rs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
use clap::Parser;
|
||||||
|
use log::info;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use tonic::codec::CompressionEncoding;
|
||||||
|
use tonic::service::RoutesBuilder;
|
||||||
|
use tonic_reflection::server::Builder;
|
||||||
|
|
||||||
|
use xai_home_mixer_proto as pb;
|
||||||
|
use xai_http_server::{CancellationToken, GrpcConfig, HttpServer};
|
||||||
|
|
||||||
|
use xai_home_mixer::HomeMixerServer;
|
||||||
|
use xai_home_mixer::params;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(about = "HomeMixer gRPC Server")]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
grpc_port: u16,
|
||||||
|
#[arg(long)]
|
||||||
|
metrics_port: u16,
|
||||||
|
#[arg(long)]
|
||||||
|
reload_interval_minutes: u64,
|
||||||
|
#[arg(long)]
|
||||||
|
chunk_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[xai_stats_macro::main(name = "home-mixer")]
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
xai_init_utils::init().log();
|
||||||
|
xai_init_utils::init().rustls();
|
||||||
|
info!(
|
||||||
|
"Starting server with gRPC port: {}, metrics port: {}, reload interval: {} minutes, chunk size: {}",
|
||||||
|
args.grpc_port, args.metrics_port, args.reload_interval_minutes, args.chunk_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create the service implementation
|
||||||
|
let service = HomeMixerServer::new().await;
|
||||||
|
// Keep a reference to stats_receiver before service is moved
|
||||||
|
let reflection_service = Builder::configure()
|
||||||
|
.register_encoded_file_descriptor_set(pb::FILE_DESCRIPTOR_SET)
|
||||||
|
.build_v1()?;
|
||||||
|
|
||||||
|
let mut grpc_routes = RoutesBuilder::default();
|
||||||
|
|
||||||
|
grpc_routes.add_service(
|
||||||
|
pb::scored_posts_service_server::ScoredPostsServiceServer::new(service)
|
||||||
|
.max_decoding_message_size(params::MAX_GRPC_MESSAGE_SIZE)
|
||||||
|
.max_encoding_message_size(params::MAX_GRPC_MESSAGE_SIZE)
|
||||||
|
.accept_compressed(CompressionEncoding::Gzip)
|
||||||
|
.accept_compressed(CompressionEncoding::Zstd)
|
||||||
|
.send_compressed(CompressionEncoding::Gzip)
|
||||||
|
.send_compressed(CompressionEncoding::Zstd),
|
||||||
|
);
|
||||||
|
|
||||||
|
grpc_routes.add_service(reflection_service);
|
||||||
|
|
||||||
|
let grpc_config = GrpcConfig::new(args.grpc_port, grpc_routes.routes());
|
||||||
|
|
||||||
|
let http_router = axum::Router::default();
|
||||||
|
|
||||||
|
let mut server = HttpServer::new(
|
||||||
|
args.metrics_port,
|
||||||
|
http_router,
|
||||||
|
Some(grpc_config),
|
||||||
|
CancellationToken::new(),
|
||||||
|
Duration::from_secs(20),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
server.set_readiness(true);
|
||||||
|
info!("Server ready");
|
||||||
|
server.wait_for_termination().await;
|
||||||
|
info!("Server shutdown complete");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
2
home-mixer/query_hydrators/mod.rs
Normal file
2
home-mixer/query_hydrators/mod.rs
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
pub mod user_action_seq_query_hydrator;
|
||||||
|
pub mod user_features_query_hydrator;
|
||||||
188
home-mixer/query_hydrators/user_action_seq_query_hydrator.rs
Normal file
188
home-mixer/query_hydrators/user_action_seq_query_hydrator.rs
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::uas_fetcher::{UserActionSequenceFetcher, UserActionSequenceOps};
|
||||||
|
use crate::params as p;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
|
||||||
|
use xai_recsys_aggregation::aggregation::{DefaultAggregator, UserActionAggregator};
|
||||||
|
use xai_recsys_aggregation::filters::{
|
||||||
|
AggregatedActionFilter, DenseAggregatedActionFilter, KeepOriginalUserActionFilter,
|
||||||
|
UserActionFilter,
|
||||||
|
};
|
||||||
|
use xai_recsys_proto::{
|
||||||
|
AggregatedUserActionList, Mask, MaskType, UserActionSequence, UserActionSequenceDataContainer,
|
||||||
|
UserActionSequenceMeta, user_action_sequence_data_container::Data as ProtoDataContainer,
|
||||||
|
};
|
||||||
|
use xai_uas_thrift::convert::thrift_to_proto_aggregated_user_action;
|
||||||
|
use xai_uas_thrift::user_action_sequence::{
|
||||||
|
AggregatedUserAction as ThriftAggregatedUserAction,
|
||||||
|
UserActionSequence as ThriftUserActionSequence,
|
||||||
|
UserActionSequenceMeta as ThriftUserActionSequenceMeta,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Hydrate a sequence that captures the user's recent actions
|
||||||
|
pub struct UserActionSeqQueryHydrator {
|
||||||
|
pub uas_fetcher: Arc<UserActionSequenceFetcher>,
|
||||||
|
global_filter: Arc<dyn UserActionFilter>,
|
||||||
|
aggregator: Arc<dyn UserActionAggregator>,
|
||||||
|
post_filters: Vec<Arc<dyn AggregatedActionFilter>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserActionSeqQueryHydrator {
|
||||||
|
pub fn new(uas_fetcher: Arc<UserActionSequenceFetcher>) -> Self {
|
||||||
|
Self {
|
||||||
|
uas_fetcher,
|
||||||
|
global_filter: Arc::new(KeepOriginalUserActionFilter::new()),
|
||||||
|
aggregator: Arc::new(DefaultAggregator),
|
||||||
|
post_filters: vec![Arc::new(DenseAggregatedActionFilter::new())],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl QueryHydrator<ScoredPostsQuery> for UserActionSeqQueryHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPostsQuery, String> {
|
||||||
|
let uas_thrift = self
|
||||||
|
.uas_fetcher
|
||||||
|
.get_by_user_id(query.user_id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to fetch user action sequence: {}", e))?;
|
||||||
|
|
||||||
|
let aggregated_uas_proto =
|
||||||
|
self.aggregate_user_action_sequence(query.user_id, uas_thrift)?;
|
||||||
|
|
||||||
|
Ok(ScoredPostsQuery {
|
||||||
|
user_action_sequence: Some(aggregated_uas_proto),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQuery) {
|
||||||
|
query.user_action_sequence = hydrated.user_action_sequence;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
std::any::type_name::<Self>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserActionSeqQueryHydrator {
|
||||||
|
fn aggregate_user_action_sequence(
|
||||||
|
&self,
|
||||||
|
user_id: i64,
|
||||||
|
uas_thrift: ThriftUserActionSequence,
|
||||||
|
) -> Result<UserActionSequence, String> {
|
||||||
|
// Extract user_actions from thrift sequence
|
||||||
|
let thrift_user_actions = uas_thrift.user_actions.clone().unwrap_or_default();
|
||||||
|
if thrift_user_actions.is_empty() {
|
||||||
|
return Err(format!("No user actions found for user {}", user_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-aggregation filter
|
||||||
|
let filtered_actions = self.global_filter.run(thrift_user_actions);
|
||||||
|
if filtered_actions.is_empty() {
|
||||||
|
return Err(format!(
|
||||||
|
"No user actions remaining after filtering for user {}",
|
||||||
|
user_id
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate
|
||||||
|
let mut aggregated_actions =
|
||||||
|
self.aggregator
|
||||||
|
.run(&filtered_actions, p::UAS_WINDOW_TIME_MS, 0);
|
||||||
|
|
||||||
|
// Post-aggregation filters
|
||||||
|
for filter in &self.post_filters {
|
||||||
|
aggregated_actions = filter.run(aggregated_actions);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate to max sequence length (keep last N items)
|
||||||
|
if aggregated_actions.len() > p::UAS_MAX_SEQUENCE_LENGTH {
|
||||||
|
let drain_count = aggregated_actions.len() - p::UAS_MAX_SEQUENCE_LENGTH;
|
||||||
|
aggregated_actions.drain(0..drain_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to proto format
|
||||||
|
let original_metadata = uas_thrift.metadata.clone().unwrap_or_default();
|
||||||
|
convert_to_proto_sequence(
|
||||||
|
user_id,
|
||||||
|
original_metadata,
|
||||||
|
aggregated_actions,
|
||||||
|
self.aggregator.name(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_to_proto_sequence(
|
||||||
|
user_id: i64,
|
||||||
|
original_metadata: ThriftUserActionSequenceMeta,
|
||||||
|
aggregated_actions: Vec<ThriftAggregatedUserAction>,
|
||||||
|
aggregator_name: &str,
|
||||||
|
) -> Result<UserActionSequence, String> {
|
||||||
|
if aggregated_actions.is_empty() {
|
||||||
|
return Err("Cannot create sequence from empty aggregated actions".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let first_sequence_time = aggregated_actions
|
||||||
|
.first()
|
||||||
|
.and_then(|a| a.impressed_time_ms)
|
||||||
|
.unwrap_or(0) as u64;
|
||||||
|
let last_sequence_time = aggregated_actions
|
||||||
|
.last()
|
||||||
|
.and_then(|a| a.impressed_time_ms)
|
||||||
|
.unwrap_or(0) as u64;
|
||||||
|
|
||||||
|
// Preserve lastModifiedEpochMs and lastKafkaPublishEpochMs from original metadata
|
||||||
|
let last_modified_epoch_ms = original_metadata.last_modified_epoch_ms.unwrap_or(0) as u64;
|
||||||
|
let previous_kafka_publish_epoch_ms =
|
||||||
|
original_metadata.last_kafka_publish_epoch_ms.unwrap_or(0) as u64;
|
||||||
|
|
||||||
|
let proto_metadata = UserActionSequenceMeta {
|
||||||
|
length: aggregated_actions.len() as u64,
|
||||||
|
first_sequence_time,
|
||||||
|
last_sequence_time,
|
||||||
|
last_modified_epoch_ms,
|
||||||
|
previous_kafka_publish_epoch_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert thrift aggregated actions to proto
|
||||||
|
let mut proto_agg_actions = Vec::with_capacity(aggregated_actions.len());
|
||||||
|
for action in aggregated_actions {
|
||||||
|
proto_agg_actions.push(
|
||||||
|
thrift_to_proto_aggregated_user_action(action)
|
||||||
|
.map_err(|e| format!("Failed to convert aggregated action: {}", e))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let aggregation_time_ms = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_millis() as u64;
|
||||||
|
|
||||||
|
let agg_list = AggregatedUserActionList {
|
||||||
|
aggregated_user_actions: proto_agg_actions,
|
||||||
|
aggregation_provider: aggregator_name.to_string(),
|
||||||
|
aggregation_time_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mask = Mask {
|
||||||
|
mask_type: MaskType::NewEvent as i32,
|
||||||
|
mask: vec![false; agg_list.aggregated_user_actions.len()],
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build the final UserActionSequence
|
||||||
|
Ok(UserActionSequence {
|
||||||
|
user_id: user_id as u64,
|
||||||
|
metadata: Some(proto_metadata),
|
||||||
|
user_actions_data: Some(UserActionSequenceDataContainer {
|
||||||
|
data: Some(ProtoDataContainer::OrderedAggregatedUserActionsList(
|
||||||
|
agg_list,
|
||||||
|
)),
|
||||||
|
}),
|
||||||
|
masks: vec![mask],
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
41
home-mixer/query_hydrators/user_features_query_hydrator.rs
Normal file
41
home-mixer/query_hydrators/user_features_query_hydrator.rs
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::candidate_pipeline::query_features::UserFeatures;
|
||||||
|
use crate::clients::strato_client::StratoClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
|
||||||
|
use xai_strato::{StratoResult, StratoValue, decode};
|
||||||
|
|
||||||
|
pub struct UserFeaturesQueryHydrator {
|
||||||
|
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl QueryHydrator<ScoredPostsQuery> for UserFeaturesQueryHydrator {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPostsQuery, String> {
|
||||||
|
let user_id = query.user_id;
|
||||||
|
let client = &self.strato_client;
|
||||||
|
let result = client.get_user_features(user_id);
|
||||||
|
let result = result.await.map_err(|e| e.to_string())?;
|
||||||
|
let decoded: StratoResult<StratoValue<UserFeatures>> = decode(&result);
|
||||||
|
match decoded {
|
||||||
|
StratoResult::Ok(v) => {
|
||||||
|
let user_features = v.v.unwrap_or_default();
|
||||||
|
Ok(ScoredPostsQuery {
|
||||||
|
user_features,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
StratoResult::Err(_) => Err("Error received from strato".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQuery) {
|
||||||
|
query.user_features = hydrated.user_features;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
std::any::type_name::<Self>()
|
||||||
|
}
|
||||||
|
}
|
||||||
73
home-mixer/scorers/author_diversity_scorer.rs
Normal file
73
home-mixer/scorers/author_diversity_scorer.rs
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::params as p;
|
||||||
|
use std::cmp::Ordering;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::scorer::Scorer;
|
||||||
|
|
||||||
|
/// Diversify authors served within a single feed response
|
||||||
|
pub struct AuthorDiversityScorer {
|
||||||
|
decay_factor: f64,
|
||||||
|
floor: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AuthorDiversityScorer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new(p::AUTHOR_DIVERSITY_DECAY, p::AUTHOR_DIVERSITY_FLOOR)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthorDiversityScorer {
|
||||||
|
pub fn new(decay_factor: f64, floor: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
decay_factor,
|
||||||
|
floor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn multiplier(&self, position: usize) -> f64 {
|
||||||
|
(1.0 - self.floor) * self.decay_factor.powf(position as f64) + self.floor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer<ScoredPostsQuery, PostCandidate> for AuthorDiversityScorer {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let mut author_counts: HashMap<u64, usize> = HashMap::new();
|
||||||
|
let mut scored = vec![PostCandidate::default(); candidates.len()];
|
||||||
|
|
||||||
|
let mut ordered: Vec<(usize, &PostCandidate)> = candidates.iter().enumerate().collect();
|
||||||
|
ordered.sort_by(|(_, a), (_, b)| {
|
||||||
|
let a_score = a.weighted_score.unwrap_or(f64::NEG_INFINITY);
|
||||||
|
let b_score = b.weighted_score.unwrap_or(f64::NEG_INFINITY);
|
||||||
|
b_score.partial_cmp(&a_score).unwrap_or(Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
for (original_idx, candidate) in ordered {
|
||||||
|
let entry = author_counts.entry(candidate.author_id).or_insert(0);
|
||||||
|
let position = *entry;
|
||||||
|
*entry += 1;
|
||||||
|
|
||||||
|
let multiplier = self.multiplier(position);
|
||||||
|
let adjusted_score = candidate.weighted_score.map(|score| score * multiplier);
|
||||||
|
|
||||||
|
let updated = PostCandidate {
|
||||||
|
score: adjusted_score,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored[original_idx] = updated;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
|
||||||
|
candidate.score = scored.score;
|
||||||
|
}
|
||||||
|
}
|
||||||
4
home-mixer/scorers/mod.rs
Normal file
4
home-mixer/scorers/mod.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
pub mod author_diversity_scorer;
|
||||||
|
pub mod oon_scorer;
|
||||||
|
pub mod phoenix_scorer;
|
||||||
|
pub mod weighted_scorer;
|
||||||
38
home-mixer/scorers/oon_scorer.rs
Normal file
38
home-mixer/scorers/oon_scorer.rs
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::params as p;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::scorer::Scorer;
|
||||||
|
|
||||||
|
// Prioritize in-network candidates over out-of-network candidates
|
||||||
|
pub struct OONScorer;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer<ScoredPostsQuery, PostCandidate> for OONScorer {
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let updated_score = c.score.map(|base_score| match c.in_network {
|
||||||
|
Some(false) => base_score * p::OON_WEIGHT_FACTOR,
|
||||||
|
_ => base_score,
|
||||||
|
});
|
||||||
|
|
||||||
|
PostCandidate {
|
||||||
|
score: updated_score,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
|
||||||
|
candidate.score = scored.score;
|
||||||
|
}
|
||||||
|
}
|
||||||
176
home-mixer/scorers/phoenix_scorer.rs
Normal file
176
home-mixer/scorers/phoenix_scorer.rs
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::phoenix_prediction_client::PhoenixPredictionClient;
|
||||||
|
use crate::util::request_util;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::scorer::Scorer;
|
||||||
|
use xai_recsys_proto::{ActionName, ContinuousActionName};
|
||||||
|
|
||||||
|
pub struct PhoenixScorer {
|
||||||
|
pub phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer<ScoredPostsQuery, PostCandidate> for PhoenixScorer {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let user_id = query.user_id as u64;
|
||||||
|
let prediction_request_id = request_util::generate_request_id();
|
||||||
|
let last_scored_at_ms = Self::current_timestamp_millis();
|
||||||
|
|
||||||
|
if let Some(sequence) = &query.user_action_sequence {
|
||||||
|
let tweet_infos: Vec<xai_recsys_proto::TweetInfo> = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let tweet_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id as u64);
|
||||||
|
let author_id = c.retweeted_user_id.unwrap_or(c.author_id);
|
||||||
|
xai_recsys_proto::TweetInfo {
|
||||||
|
tweet_id,
|
||||||
|
author_id,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.phoenix_client
|
||||||
|
.predict(user_id, sequence.clone(), tweet_infos)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Ok(response) = result {
|
||||||
|
let predictions_map = self.build_predictions_map(&response);
|
||||||
|
|
||||||
|
let scored_candidates = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
// For retweets, look up predictions using the original tweet id
|
||||||
|
let lookup_tweet_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id as u64);
|
||||||
|
|
||||||
|
let phoenix_scores = predictions_map
|
||||||
|
.get(&lookup_tweet_id)
|
||||||
|
.map(|preds| self.extract_phoenix_scores(preds))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
PostCandidate {
|
||||||
|
phoenix_scores,
|
||||||
|
prediction_request_id: Some(prediction_request_id),
|
||||||
|
last_scored_at_ms,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
return Ok(scored_candidates);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return candidates unchanged if no scoring could be done
|
||||||
|
Ok(candidates.to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
|
||||||
|
candidate.phoenix_scores = scored.phoenix_scores;
|
||||||
|
candidate.prediction_request_id = scored.prediction_request_id;
|
||||||
|
candidate.last_scored_at_ms = scored.last_scored_at_ms;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PhoenixScorer {
|
||||||
|
/// Builds Map[tweet_id -> ActionPredictions]
|
||||||
|
fn build_predictions_map(
|
||||||
|
&self,
|
||||||
|
response: &xai_recsys_proto::PredictNextActionsResponse,
|
||||||
|
) -> HashMap<u64, ActionPredictions> {
|
||||||
|
let mut predictions_map = HashMap::new();
|
||||||
|
|
||||||
|
let Some(distribution_set) = response.distribution_sets.first() else {
|
||||||
|
return predictions_map;
|
||||||
|
};
|
||||||
|
|
||||||
|
for distribution in &distribution_set.candidate_distributions {
|
||||||
|
let Some(candidate) = &distribution.candidate else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let tweet_id = candidate.tweet_id;
|
||||||
|
|
||||||
|
let action_probs: HashMap<usize, f64> = distribution
|
||||||
|
.top_log_probs
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, log_prob)| (idx, (*log_prob as f64).exp()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let continuous_values: HashMap<usize, f64> = distribution
|
||||||
|
.continuous_actions_values
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, value)| (idx, *value as f64))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
predictions_map.insert(
|
||||||
|
tweet_id,
|
||||||
|
ActionPredictions {
|
||||||
|
action_probs,
|
||||||
|
continuous_values,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
predictions_map
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_phoenix_scores(&self, p: &ActionPredictions) -> PhoenixScores {
|
||||||
|
PhoenixScores {
|
||||||
|
favorite_score: p.get(ActionName::ServerTweetFav),
|
||||||
|
reply_score: p.get(ActionName::ServerTweetReply),
|
||||||
|
retweet_score: p.get(ActionName::ServerTweetRetweet),
|
||||||
|
photo_expand_score: p.get(ActionName::ClientTweetPhotoExpand),
|
||||||
|
click_score: p.get(ActionName::ClientTweetClick),
|
||||||
|
profile_click_score: p.get(ActionName::ClientTweetClickProfile),
|
||||||
|
vqv_score: p.get(ActionName::ClientTweetVideoQualityView),
|
||||||
|
share_score: p.get(ActionName::ClientTweetShare),
|
||||||
|
share_via_dm_score: p.get(ActionName::ClientTweetClickSendViaDirectMessage),
|
||||||
|
share_via_copy_link_score: p.get(ActionName::ClientTweetShareViaCopyLink),
|
||||||
|
dwell_score: p.get(ActionName::ClientTweetRecapDwelled),
|
||||||
|
quote_score: p.get(ActionName::ServerTweetQuote),
|
||||||
|
quoted_click_score: p.get(ActionName::ClientQuotedTweetClick),
|
||||||
|
follow_author_score: p.get(ActionName::ClientTweetFollowAuthor),
|
||||||
|
not_interested_score: p.get(ActionName::ClientTweetNotInterestedIn),
|
||||||
|
block_author_score: p.get(ActionName::ClientTweetBlockAuthor),
|
||||||
|
mute_author_score: p.get(ActionName::ClientTweetMuteAuthor),
|
||||||
|
report_score: p.get(ActionName::ClientTweetReport),
|
||||||
|
dwell_time: p.get_continuous(ContinuousActionName::DwellTime),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_timestamp_millis() -> Option<u64> {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.ok()
|
||||||
|
.map(|duration| duration.as_millis() as u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ActionPredictions {
|
||||||
|
/// Map of action index -> probability (exp of log prob)
|
||||||
|
action_probs: HashMap<usize, f64>,
|
||||||
|
/// Map of continuous action index -> value
|
||||||
|
continuous_values: HashMap<usize, f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActionPredictions {
|
||||||
|
fn get(&self, action: ActionName) -> Option<f64> {
|
||||||
|
self.action_probs.get(&(action as usize)).copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_continuous(&self, action: ContinuousActionName) -> Option<f64> {
|
||||||
|
self.continuous_values.get(&(action as usize)).copied()
|
||||||
|
}
|
||||||
|
}
|
||||||
92
home-mixer/scorers/weighted_scorer.rs
Normal file
92
home-mixer/scorers/weighted_scorer.rs
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::params as p;
|
||||||
|
use crate::util::score_normalizer::normalize_score;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::scorer::Scorer;
|
||||||
|
|
||||||
|
pub struct WeightedScorer;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer<ScoredPostsQuery, PostCandidate> for WeightedScorer {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_query: &ScoredPostsQuery,
|
||||||
|
candidates: &[PostCandidate],
|
||||||
|
) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let weighted_score = Self::compute_weighted_score(c);
|
||||||
|
let normalized_weighted_score = normalize_score(c, weighted_score);
|
||||||
|
|
||||||
|
PostCandidate {
|
||||||
|
weighted_score: Some(normalized_weighted_score),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
|
||||||
|
candidate.weighted_score = scored.weighted_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WeightedScorer {
|
||||||
|
fn apply(score: Option<f64>, weight: f64) -> f64 {
|
||||||
|
score.unwrap_or(0.0) * weight
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_weighted_score(candidate: &PostCandidate) -> f64 {
|
||||||
|
let s: &PhoenixScores = &candidate.phoenix_scores;
|
||||||
|
|
||||||
|
let vqv_weight = Self::vqv_weight_eligibility(candidate);
|
||||||
|
|
||||||
|
let combined_score = Self::apply(s.favorite_score, p::FAVORITE_WEIGHT)
|
||||||
|
+ Self::apply(s.reply_score, p::REPLY_WEIGHT)
|
||||||
|
+ Self::apply(s.retweet_score, p::RETWEET_WEIGHT)
|
||||||
|
+ Self::apply(s.photo_expand_score, p::PHOTO_EXPAND_WEIGHT)
|
||||||
|
+ Self::apply(s.click_score, p::CLICK_WEIGHT)
|
||||||
|
+ Self::apply(s.profile_click_score, p::PROFILE_CLICK_WEIGHT)
|
||||||
|
+ Self::apply(s.vqv_score, vqv_weight)
|
||||||
|
+ Self::apply(s.share_score, p::SHARE_WEIGHT)
|
||||||
|
+ Self::apply(s.share_via_dm_score, p::SHARE_VIA_DM_WEIGHT)
|
||||||
|
+ Self::apply(s.share_via_copy_link_score, p::SHARE_VIA_COPY_LINK_WEIGHT)
|
||||||
|
+ Self::apply(s.dwell_score, p::DWELL_WEIGHT)
|
||||||
|
+ Self::apply(s.quote_score, p::QUOTE_WEIGHT)
|
||||||
|
+ Self::apply(s.quoted_click_score, p::QUOTED_CLICK_WEIGHT)
|
||||||
|
+ Self::apply(s.dwell_time, p::CONT_DWELL_TIME_WEIGHT)
|
||||||
|
+ Self::apply(s.follow_author_score, p::FOLLOW_AUTHOR_WEIGHT)
|
||||||
|
+ Self::apply(s.not_interested_score, p::NOT_INTERESTED_WEIGHT)
|
||||||
|
+ Self::apply(s.block_author_score, p::BLOCK_AUTHOR_WEIGHT)
|
||||||
|
+ Self::apply(s.mute_author_score, p::MUTE_AUTHOR_WEIGHT)
|
||||||
|
+ Self::apply(s.report_score, p::REPORT_WEIGHT);
|
||||||
|
|
||||||
|
Self::offset_score(combined_score)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vqv_weight_eligibility(candidate: &PostCandidate) -> f64 {
|
||||||
|
if candidate
|
||||||
|
.video_duration_ms
|
||||||
|
.is_some_and(|ms| ms > p::MIN_VIDEO_DURATION_MS)
|
||||||
|
{
|
||||||
|
p::VQV_WEIGHT
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn offset_score(combined_score: f64) -> f64 {
|
||||||
|
if p::WEIGHTS_SUM == 0.0 {
|
||||||
|
combined_score.max(0.0)
|
||||||
|
} else if combined_score < 0.0 {
|
||||||
|
(combined_score + p::NEGATIVE_WEIGHTS_SUM) / p::WEIGHTS_SUM * p::NEGATIVE_SCORES_OFFSET
|
||||||
|
} else {
|
||||||
|
combined_score + p::NEGATIVE_SCORES_OFFSET
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
3
home-mixer/selectors/mod.rs
Normal file
3
home-mixer/selectors/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
mod top_k_score_selector;
|
||||||
|
|
||||||
|
pub use top_k_score_selector::TopKScoreSelector;
|
||||||
15
home-mixer/selectors/top_k_score_selector.rs
Normal file
15
home-mixer/selectors/top_k_score_selector.rs
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::params;
|
||||||
|
use xai_candidate_pipeline::selector::Selector;
|
||||||
|
|
||||||
|
pub struct TopKScoreSelector;
|
||||||
|
|
||||||
|
impl Selector<ScoredPostsQuery, PostCandidate> for TopKScoreSelector {
|
||||||
|
fn score(&self, candidate: &PostCandidate) -> f64 {
|
||||||
|
candidate.score.unwrap_or(f64::NEG_INFINITY)
|
||||||
|
}
|
||||||
|
fn size(&self) -> Option<usize> {
|
||||||
|
Some(params::TOP_K_CANDIDATES_TO_SELECT)
|
||||||
|
}
|
||||||
|
}
|
||||||
83
home-mixer/server.rs
Normal file
83
home-mixer/server.rs
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::CandidateHelpers;
|
||||||
|
use crate::candidate_pipeline::phoenix_candidate_pipeline::PhoenixCandidatePipeline;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use log::info;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tonic::{Request, Response, Status};
|
||||||
|
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
|
||||||
|
use xai_home_mixer_proto as pb;
|
||||||
|
use xai_home_mixer_proto::{ScoredPost, ScoredPostsResponse};
|
||||||
|
|
||||||
|
pub struct HomeMixerServer {
|
||||||
|
phx_candidate_pipeline: Arc<PhoenixCandidatePipeline>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HomeMixerServer {
|
||||||
|
pub async fn new() -> Self {
|
||||||
|
HomeMixerServer {
|
||||||
|
phx_candidate_pipeline: Arc::new(PhoenixCandidatePipeline::prod().await),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tonic::async_trait]
|
||||||
|
impl pb::scored_posts_service_server::ScoredPostsService for HomeMixerServer {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn get_scored_posts(
|
||||||
|
&self,
|
||||||
|
request: Request<pb::ScoredPostsQuery>,
|
||||||
|
) -> Result<Response<ScoredPostsResponse>, Status> {
|
||||||
|
let proto_query = request.into_inner();
|
||||||
|
|
||||||
|
if proto_query.viewer_id == 0 {
|
||||||
|
return Err(Status::invalid_argument("viewer_id must be specified"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
let query = ScoredPostsQuery::new(
|
||||||
|
proto_query.viewer_id,
|
||||||
|
proto_query.client_app_id,
|
||||||
|
proto_query.country_code,
|
||||||
|
proto_query.language_code,
|
||||||
|
proto_query.seen_ids,
|
||||||
|
proto_query.served_ids,
|
||||||
|
proto_query.in_network_only,
|
||||||
|
proto_query.is_bottom_request,
|
||||||
|
proto_query.bloom_filter_entries,
|
||||||
|
);
|
||||||
|
info!("Scored Posts request - request_id {}", query.request_id);
|
||||||
|
let pipeline_result = self.phx_candidate_pipeline.execute(query).await;
|
||||||
|
|
||||||
|
let scored_posts: Vec<ScoredPost> = pipeline_result
|
||||||
|
.selected_candidates
|
||||||
|
.into_iter()
|
||||||
|
.map(|candidate| {
|
||||||
|
let screen_names = candidate.get_screen_names();
|
||||||
|
ScoredPost {
|
||||||
|
tweet_id: candidate.tweet_id as u64,
|
||||||
|
author_id: candidate.author_id,
|
||||||
|
retweeted_tweet_id: candidate.retweeted_tweet_id.unwrap_or(0),
|
||||||
|
retweeted_user_id: candidate.retweeted_user_id.unwrap_or(0),
|
||||||
|
in_reply_to_tweet_id: candidate.in_reply_to_tweet_id.unwrap_or(0),
|
||||||
|
score: candidate.score.unwrap_or(0.0) as f32,
|
||||||
|
in_network: candidate.in_network.unwrap_or(false),
|
||||||
|
served_type: candidate.served_type.map(|t| t as i32).unwrap_or_default(),
|
||||||
|
last_scored_timestamp_ms: candidate.last_scored_at_ms.unwrap_or(0),
|
||||||
|
prediction_request_id: candidate.prediction_request_id.unwrap_or(0),
|
||||||
|
ancestors: candidate.ancestors,
|
||||||
|
screen_names,
|
||||||
|
visibility_reason: candidate.visibility_reason.map(|r| r.into()),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Scored Posts response - request_id {} - {} posts ({} ms)",
|
||||||
|
pipeline_result.query.request_id,
|
||||||
|
scored_posts.len(),
|
||||||
|
start.elapsed().as_millis()
|
||||||
|
);
|
||||||
|
Ok(Response::new(ScoredPostsResponse { scored_posts }))
|
||||||
|
}
|
||||||
|
}
|
||||||
42
home-mixer/side_effects/cache_request_info_side_effect.rs
Normal file
42
home-mixer/side_effects/cache_request_info_side_effect.rs
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::strato_client::StratoClient;
|
||||||
|
use std::env;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::side_effect::{SideEffect, SideEffectInput};
|
||||||
|
use xai_strato::{StratoResult, StratoValue, decode};
|
||||||
|
|
||||||
|
pub struct CacheRequestInfoSideEffect {
|
||||||
|
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl SideEffect<ScoredPostsQuery, PostCandidate> for CacheRequestInfoSideEffect {
|
||||||
|
fn enable(&self, query: Arc<ScoredPostsQuery>) -> bool {
|
||||||
|
env::var("APP_ENV").unwrap_or_default() == "prod" && !query.in_network_only
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
input: Arc<SideEffectInput<ScoredPostsQuery, PostCandidate>>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let user_id: i64 = input.query.user_id;
|
||||||
|
|
||||||
|
let post_ids: Vec<i64> = input
|
||||||
|
.selected_candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.tweet_id)
|
||||||
|
.collect();
|
||||||
|
let client = &self.strato_client;
|
||||||
|
let res = client
|
||||||
|
.store_request_info(user_id, post_ids)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string())?;
|
||||||
|
let decoded: StratoResult<StratoValue<()>> = decode(&res);
|
||||||
|
match decoded {
|
||||||
|
StratoResult::Ok(_) => Ok(()),
|
||||||
|
StratoResult::Err(_) => Err("error received from strato".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1
home-mixer/side_effects/mod.rs
Normal file
1
home-mixer/side_effects/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod cache_request_info_side_effect;
|
||||||
2
home-mixer/sources/mod.rs
Normal file
2
home-mixer/sources/mod.rs
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
pub mod phoenix_source;
|
||||||
|
pub mod thunder_source;
|
||||||
51
home-mixer/sources/phoenix_source.rs
Normal file
51
home-mixer/sources/phoenix_source.rs
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::phoenix_retrieval_client::PhoenixRetrievalClient;
|
||||||
|
use crate::params as p;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::source::Source;
|
||||||
|
use xai_home_mixer_proto as pb;
|
||||||
|
|
||||||
|
pub struct PhoenixSource {
|
||||||
|
pub phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Source<ScoredPostsQuery, PostCandidate> for PhoenixSource {
|
||||||
|
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||||
|
!query.in_network_only
|
||||||
|
}
|
||||||
|
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let user_id = query.user_id as u64;
|
||||||
|
|
||||||
|
let sequence = query
|
||||||
|
.user_action_sequence
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| "PhoenixSource: missing user_action_sequence".to_string())?;
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.phoenix_retrieval_client
|
||||||
|
.retrieve(user_id, sequence.clone(), p::PHOENIX_MAX_RESULTS)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("PhoenixSource: {}", e))?;
|
||||||
|
|
||||||
|
let candidates: Vec<PostCandidate> = response
|
||||||
|
.top_k_candidates
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|scored_candidates| scored_candidates.candidates)
|
||||||
|
.filter_map(|scored_candidate| scored_candidate.candidate)
|
||||||
|
.map(|tweet_info| PostCandidate {
|
||||||
|
tweet_id: tweet_info.tweet_id as i64,
|
||||||
|
author_id: tweet_info.author_id,
|
||||||
|
in_reply_to_tweet_id: Some(tweet_info.in_reply_to_tweet_id),
|
||||||
|
served_type: Some(pb::ServedType::ForYouPhoenixRetrieval),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(candidates)
|
||||||
|
}
|
||||||
|
}
|
||||||
74
home-mixer/sources/thunder_source.rs
Normal file
74
home-mixer/sources/thunder_source.rs
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||||
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||||
|
use crate::clients::thunder_client::{ThunderClient, ThunderCluster};
|
||||||
|
use crate::params as p;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tonic::async_trait;
|
||||||
|
use xai_candidate_pipeline::source::Source;
|
||||||
|
use xai_home_mixer_proto as pb;
|
||||||
|
use xai_thunder_proto::GetInNetworkPostsRequest;
|
||||||
|
use xai_thunder_proto::in_network_posts_service_client::InNetworkPostsServiceClient;
|
||||||
|
|
||||||
|
pub struct ThunderSource {
|
||||||
|
pub thunder_client: Arc<ThunderClient>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Source<ScoredPostsQuery, PostCandidate> for ThunderSource {
|
||||||
|
#[xai_stats_macro::receive_stats]
|
||||||
|
async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>, String> {
|
||||||
|
let cluster = ThunderCluster::Amp;
|
||||||
|
let channel = self
|
||||||
|
.thunder_client
|
||||||
|
.get_random_channel(cluster)
|
||||||
|
.ok_or_else(|| "ThunderSource: no available channel".to_string())?;
|
||||||
|
|
||||||
|
let mut client = InNetworkPostsServiceClient::new(channel.clone());
|
||||||
|
let following_list = &query.user_features.followed_user_ids;
|
||||||
|
let request = GetInNetworkPostsRequest {
|
||||||
|
user_id: query.user_id as u64,
|
||||||
|
following_user_ids: following_list.iter().map(|&id| id as u64).collect(),
|
||||||
|
max_results: p::THUNDER_MAX_RESULTS,
|
||||||
|
exclude_tweet_ids: vec![],
|
||||||
|
algorithm: "default".to_string(),
|
||||||
|
debug: false,
|
||||||
|
is_video_request: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.get_in_network_posts(request)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("ThunderSource: {}", e))?;
|
||||||
|
|
||||||
|
let candidates: Vec<PostCandidate> = response
|
||||||
|
.into_inner()
|
||||||
|
.posts
|
||||||
|
.into_iter()
|
||||||
|
.map(|post| {
|
||||||
|
let in_reply_to_tweet_id = post
|
||||||
|
.in_reply_to_post_id
|
||||||
|
.and_then(|id| u64::try_from(id).ok());
|
||||||
|
let conversation_id = post.conversation_id.and_then(|id| u64::try_from(id).ok());
|
||||||
|
|
||||||
|
let mut ancestors = Vec::new();
|
||||||
|
if let Some(reply_to) = in_reply_to_tweet_id {
|
||||||
|
ancestors.push(reply_to);
|
||||||
|
if let Some(root) = conversation_id.filter(|&root| root != reply_to) {
|
||||||
|
ancestors.push(root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PostCandidate {
|
||||||
|
tweet_id: post.post_id,
|
||||||
|
author_id: post.author_id as u64,
|
||||||
|
in_reply_to_tweet_id,
|
||||||
|
ancestors,
|
||||||
|
served_type: Some(pb::ServedType::ForYouInNetwork),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(candidates)
|
||||||
|
}
|
||||||
|
}
|
||||||
206
phoenix/README.md
Normal file
206
phoenix/README.md
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
# Phoenix: Recommendation System
|
||||||
|
|
||||||
|
This repository contains JAX example code for the Phoenix recommendation system, which powers content ranking and retrieval. Phoenix uses transformer-based architectures for both **retrieval** (finding relevant candidates from millions of items) and **ranking** (ordering a smaller set of candidates by predicted engagement).
|
||||||
|
|
||||||
|
> **Note:** The sample transformer implementation in this repository is ported from the [Grok-1 open source release](https://github.com/xai-org/grok-1) by xAI. The core transformer architecture comes from Grok-1, adapted here for recommendation system use cases with custom input embeddings and attention masking for candidate isolation. This code is representative of the model used internally with the exception of specific scaling optimizations.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Architecture](#architecture)
|
||||||
|
- [Two-Stage Recommendation Pipeline](#two-stage-recommendation-pipeline)
|
||||||
|
- [Retrieval: Two-Tower Model](#retrieval-two-tower-model)
|
||||||
|
- [Ranking: Transformer with Candidate Isolation](#ranking-transformer-with-candidate-isolation)
|
||||||
|
- [Key Design Decisions](#key-design-decisions)
|
||||||
|
- [Running the Code](#running-the-code)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Phoenix is a recommendation system that predicts user engagement (likes, reposts, replies, etc.) for content. It operates in two stages:
|
||||||
|
|
||||||
|
1. **Retrieval**: Efficiently narrow down millions of candidates to hundreds using approximate nearest neighbor (ANN) search
|
||||||
|
2. **Ranking**: Score and order the retrieved candidates using a more expressive transformer model
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Two-Stage Recommendation Pipeline
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ RECOMMENDATION PIPELINE │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌──────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
│ │ User │────▶│ STAGE 1: │────▶│ STAGE 2: │────▶ Feed│
|
||||||
|
│ │ Request │ │ RETRIEVAL │ │ RANKING │ │
|
||||||
|
│ │ │ │ (Two-Tower) │ │ (Transformer) │ │
|
||||||
|
│ └──────────┘ │ │ │ │ │
|
||||||
|
│ │ Millions → 1000s │ │ 1000s → Ranked │ │
|
||||||
|
│ └─────────────────────┘ └─────────────────────┘ │
|
||||||
|
│ │
|
||||||
|
└─────────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Retrieval: Two-Tower Model
|
||||||
|
|
||||||
|
The retrieval stage uses a **two-tower architecture** that enables efficient similarity search at scale.
|
||||||
|
|
||||||
|
#### How Retrieval Works
|
||||||
|
|
||||||
|
1. **User Tower**: Encodes user features and engagement history through a transformer to produce a normalized user embedding `[B, D]`
|
||||||
|
2. **Candidate Tower**: Computes normalized embeddings for all items in the corpus `[N, D]`
|
||||||
|
3. **Similarity Search**: Retrieves top-K candidates using dot product similarity
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Ranking: Transformer with Candidate Isolation
|
||||||
|
|
||||||
|
The ranking model uses a transformer architecture where **candidates cannot attend to each other** during inference. This is a critical design choice that ensures the score for a candidate doesn't depend on which other candidates are in the batch
|
||||||
|
|
||||||
|
|
||||||
|
#### Ranking Model Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
PHOENIX RANKING MODEL
|
||||||
|
┌────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ │
|
||||||
|
│ OUTPUT LOGITS │
|
||||||
|
│ [B, num_candidates, num_actions] │
|
||||||
|
│ │ │
|
||||||
|
│ │ Unembedding │
|
||||||
|
│ │ Projection │
|
||||||
|
│ │ │
|
||||||
|
│ ┌───────────────┴───────────────┐ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ Extract Candidate Outputs │ │
|
||||||
|
│ │ (positions after history) │ │
|
||||||
|
│ │ │ │
|
||||||
|
│ └───────────────┬───────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌───────────────┴───────────────┐ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ Transformer │ │
|
||||||
|
│ │ (with special masking) │ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ Candidates CANNOT attend │ │
|
||||||
|
│ │ to each other │ │
|
||||||
|
│ │ │ │
|
||||||
|
│ └───────────────┬───────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌───────────────────────────────┼───────────────────────────────┐ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ ┌──────────┐ ┌─────────────────┐ ┌────────────┐ │
|
||||||
|
│ │ User │ │ History │ │ Candidates │ │
|
||||||
|
│ │Embedding │ │ Embeddings │ │ Embeddings │ │
|
||||||
|
│ │ [B, 1] │ │ [B, S, D] │ │ [B, C, D] │ │
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
│ │ User │ │ Posts + Authors │ │ Posts + │ │
|
||||||
|
│ │ Hashes │ │ + Actions + │ │ Authors + │ │
|
||||||
|
│ │ │ │ Product Surface │ │ Product │ │
|
||||||
|
│ └──────────┘ └─────────────────┘ │ Surface │ │
|
||||||
|
│ └────────────┘ │
|
||||||
|
│ │
|
||||||
|
└────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Attention Mask: Candidate Isolation
|
||||||
|
|
||||||
|
A key detail is the **attention mask** that prevents candidates from attending to each other while still allowing them to attend to the user and history:
|
||||||
|
|
||||||
|
```
|
||||||
|
ATTENTION MASK VISUALIZATION
|
||||||
|
|
||||||
|
Keys (what we attend TO)
|
||||||
|
─────────────────────────────────────────────▶
|
||||||
|
|
||||||
|
│ User │ History (S positions) │ Candidates (C positions) │
|
||||||
|
┌────┼──────┼─────────────────────────────┼───────────────────────────────┤
|
||||||
|
│ │ │ │ │
|
||||||
|
│ U │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
|
||||||
|
│ │ │ │ │
|
||||||
|
├────┼──────┼─────────────────────────────┼───────────────────────────────┤
|
||||||
|
Q │ │ │ │ │
|
||||||
|
u │ H │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
|
||||||
|
e │ i │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
|
||||||
|
r │ s │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
|
||||||
|
i │ t │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
|
||||||
|
e │ │ │ │ │
|
||||||
|
s ├────┼──────┼─────────────────────────────┼───────────────────────────────┤
|
||||||
|
│ │ │ │ DIAGONAL ONLY (self-attend) │
|
||||||
|
│ │ C │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✓ ✗ ✗ ✗ ✗ ✗ ✗ │
|
||||||
|
│ │ a │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✓ ✗ ✗ ✗ ✗ ✗ │
|
||||||
|
│ │ n │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✓ ✗ ✗ ✗ ✗ │
|
||||||
|
│ │ d │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✓ ✗ ✗ ✗ │
|
||||||
|
│ │ i │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✓ ✗ ✗ │
|
||||||
|
│ │ d │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✓ ✗ │
|
||||||
|
▼ │ s │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✓ │
|
||||||
|
│ │ │ │ │
|
||||||
|
└────┴──────┴─────────────────────────────┴───────────────────────────────┘
|
||||||
|
|
||||||
|
✓ = Can attend (1) ✗ = Cannot attend (0)
|
||||||
|
|
||||||
|
Legend:
|
||||||
|
├─ User + History: Full bidirectional attention among themselves
|
||||||
|
├─ Candidates → User/History: Candidates CAN attend to user and history
|
||||||
|
└─ Candidates → Candidates: Candidates CANNOT attend to each other (only self)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Design Decisions
|
||||||
|
|
||||||
|
### 1. Hash-Based Embeddings
|
||||||
|
|
||||||
|
Both models use multiple hash functions for embedding lookup
|
||||||
|
|
||||||
|
### 2. Shared Architecture
|
||||||
|
|
||||||
|
The retrieval user tower uses the same transformer architecture as the ranking model
|
||||||
|
|
||||||
|
### 3. Multi-Action Prediction
|
||||||
|
|
||||||
|
The ranking model predicts multiple engagement types simultaneously:
|
||||||
|
|
||||||
|
```
|
||||||
|
Output: [B, num_candidates, num_actions]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ Like │ Repost │ Reply │ Click │ ... │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Running the Code
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Install [uv](https://docs.astral.sh/uv/getting-started/installation/)
|
||||||
|
|
||||||
|
### Running the Ranker
|
||||||
|
|
||||||
|
```shell
|
||||||
|
uv run run_ranker.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Retrieval
|
||||||
|
|
||||||
|
```shell
|
||||||
|
uv run run_retrieval.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```shell
|
||||||
|
uv run pytest test_recsys_model.py test_recsys_retrieval_model.py
|
||||||
|
```
|
||||||
586
phoenix/grok.py
Normal file
586
phoenix/grok.py
Normal file
@ -0,0 +1,586 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import NamedTuple, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingState(NamedTuple):
|
||||||
|
"""Container for the training state."""
|
||||||
|
|
||||||
|
params: hk.Params
|
||||||
|
|
||||||
|
|
||||||
|
def ffn_size(emb_size, widening_factor):
|
||||||
|
_ffn_size = int(widening_factor * emb_size) * 2 // 3
|
||||||
|
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
|
||||||
|
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
|
||||||
|
return _ffn_size
|
||||||
|
|
||||||
|
|
||||||
|
def make_recsys_attn_mask(
|
||||||
|
seq_len: int,
|
||||||
|
candidate_start_offset: int,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Create attention mask for recommendation system inference.
|
||||||
|
|
||||||
|
Creates a mask where:
|
||||||
|
- Positions 0 to candidate_start_offset-1 (user+history): causal attention
|
||||||
|
- Positions candidate_start_offset onwards (candidates): can attend to user+history
|
||||||
|
and themselves (self-attention), but NOT to other candidates
|
||||||
|
|
||||||
|
This ensures each candidate is scored independently based on user+history context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_len: Total sequence length (user + history + candidates)
|
||||||
|
candidate_start_offset: Position where candidates start in the sequence
|
||||||
|
dtype: Data type for the mask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
|
||||||
|
"""
|
||||||
|
# Start with causal mask for the full sequence
|
||||||
|
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))
|
||||||
|
|
||||||
|
# Zero out candidate-to-candidate attention (bottom-right block)
|
||||||
|
attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)
|
||||||
|
|
||||||
|
# Add back self-attention for candidates (diagonal of the candidate block)
|
||||||
|
candidate_indices = jnp.arange(candidate_start_offset, seq_len)
|
||||||
|
attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)
|
||||||
|
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
class MHAOutput(NamedTuple):
|
||||||
|
"""Outputs of the multi-head attention operation."""
|
||||||
|
|
||||||
|
embeddings: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderOutput(NamedTuple):
|
||||||
|
embeddings: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerOutput(NamedTuple):
|
||||||
|
embeddings: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformerConfig:
|
||||||
|
emb_size: int
|
||||||
|
key_size: int
|
||||||
|
num_q_heads: int
|
||||||
|
num_kv_heads: int
|
||||||
|
num_layers: int
|
||||||
|
widening_factor: float = 4.0
|
||||||
|
|
||||||
|
attn_output_multiplier: float = 1.0
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
def make(self) -> "Transformer":
|
||||||
|
return Transformer(
|
||||||
|
num_q_heads=self.num_q_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
widening_factor=self.widening_factor,
|
||||||
|
key_size=self.key_size,
|
||||||
|
attn_output_multiplier=self.attn_output_multiplier,
|
||||||
|
num_layers=self.num_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hk_rms_norm(
|
||||||
|
x: jax.Array,
|
||||||
|
fixed_scale=False,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Applies a unique LayerNorm to x with default settings."""
|
||||||
|
ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
|
||||||
|
return ln(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(hk.Linear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_size: int,
|
||||||
|
with_bias: bool = True,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
output_size=output_size,
|
||||||
|
with_bias=with_bias,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__( # type: ignore
|
||||||
|
self,
|
||||||
|
inputs: jax.Array,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Computes a linear transform of the input."""
|
||||||
|
|
||||||
|
fprop_dtype = inputs.dtype
|
||||||
|
if not inputs.shape:
|
||||||
|
raise ValueError("Input must not be scalar.")
|
||||||
|
|
||||||
|
input_size = inputs.shape[-1]
|
||||||
|
output_size = self.output_size
|
||||||
|
|
||||||
|
w = hk.get_parameter(
|
||||||
|
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = jnp.dot(inputs, w.astype(fprop_dtype))
|
||||||
|
if self.with_bias:
|
||||||
|
b = hk.get_parameter(
|
||||||
|
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
|
||||||
|
)
|
||||||
|
b = jnp.broadcast_to(b, out.shape)
|
||||||
|
out = out + b.astype(fprop_dtype)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(hk.RMSNorm):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
axis: Union[int, Sequence[int], slice],
|
||||||
|
eps: float = 1e-5,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
create_scale: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(axis, eps, create_scale=create_scale, name=name)
|
||||||
|
|
||||||
|
def __call__(self, inputs: jax.Array):
|
||||||
|
fprop_dtype = inputs.dtype
|
||||||
|
param_shape = (inputs.shape[-1],)
|
||||||
|
if self.create_scale:
|
||||||
|
scale = hk.get_parameter(
|
||||||
|
"scale",
|
||||||
|
param_shape,
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=hk.initializers.Constant(0),
|
||||||
|
)
|
||||||
|
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
|
||||||
|
else:
|
||||||
|
scale = 1.0
|
||||||
|
inputs = inputs.astype(jnp.float32)
|
||||||
|
scale = jnp.float32(scale)
|
||||||
|
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
|
||||||
|
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
|
||||||
|
|
||||||
|
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
|
||||||
|
|
||||||
|
outputs = scale * normed_inputs
|
||||||
|
|
||||||
|
return outputs.astype(fprop_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(
|
||||||
|
x: jax.Array,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Obtain the rotated counterpart of each feature"""
|
||||||
|
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||||
|
return jnp.concatenate((-x2, x1), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(hk.Module):
|
||||||
|
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
|
||||||
|
as described in https://arxiv.org/abs/2104.09864.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dim (int): Dimensionality of the feature vectors
|
||||||
|
base_exponent (int): Base exponent to compute embeddings from
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
base_exponent: int = 10000,
|
||||||
|
):
|
||||||
|
super().__init__(name)
|
||||||
|
self.dim = dim
|
||||||
|
self.base_exponent = base_exponent
|
||||||
|
assert self.dim % 2 == 0
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: jax.Array,
|
||||||
|
seq_dim: int,
|
||||||
|
offset: jax.Array,
|
||||||
|
const_position: Optional[int] = None,
|
||||||
|
t: Optional[jax.Array] = None,
|
||||||
|
) -> jax.Array:
|
||||||
|
fprop_dtype = x.dtype
|
||||||
|
# Compute the per-dimension frequencies
|
||||||
|
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
|
||||||
|
inv_freq = jnp.asarray(
|
||||||
|
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
if jnp.shape(offset) == ():
|
||||||
|
# Offset can be a scalar or one offset per batch element.
|
||||||
|
offset = jnp.expand_dims(offset, 0)
|
||||||
|
|
||||||
|
# Compute the per element phase (to pass into sin and cos)
|
||||||
|
if const_position:
|
||||||
|
t = const_position * jnp.ones(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
x.shape[seq_dim],
|
||||||
|
),
|
||||||
|
dtype=jnp.float32,
|
||||||
|
)
|
||||||
|
elif t is None:
|
||||||
|
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
|
||||||
|
phase = jnp.einsum("bi,j->bij", t, inv_freq)
|
||||||
|
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
|
||||||
|
|
||||||
|
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
|
||||||
|
x = x.astype(fprop_dtype)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(hk.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
key_size: int,
|
||||||
|
*,
|
||||||
|
with_bias: bool = True,
|
||||||
|
value_size: Optional[int] = None,
|
||||||
|
model_size: Optional[int] = None,
|
||||||
|
attn_output_multiplier: float = 1.0,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self.num_q_heads = num_q_heads
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.key_size = key_size
|
||||||
|
self.value_size = value_size or key_size
|
||||||
|
self.model_size = model_size or key_size * num_q_heads
|
||||||
|
self.attn_output_multiplier = attn_output_multiplier
|
||||||
|
self.with_bias = with_bias
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
query: jax.Array,
|
||||||
|
key: jax.Array,
|
||||||
|
value: jax.Array,
|
||||||
|
mask: jax.Array,
|
||||||
|
) -> MHAOutput:
|
||||||
|
# In shape hints below, we suppress the leading dims [...] for brevity.
|
||||||
|
# Hence e.g. [A, B] should be read in every case as [..., A, B].
|
||||||
|
projection = self._linear_projection
|
||||||
|
|
||||||
|
# Check that the keys and values have consistent batch size and sequence length.
|
||||||
|
assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert mask.ndim == 4
|
||||||
|
assert mask.shape[0] in {
|
||||||
|
1,
|
||||||
|
query.shape[0],
|
||||||
|
}, f"mask/query shape: {mask.shape}/{query.shape}"
|
||||||
|
assert key.shape[0] in {
|
||||||
|
1,
|
||||||
|
query.shape[0],
|
||||||
|
}, f"key/query shape: {key.shape}/{query.shape}"
|
||||||
|
assert mask.shape[1] == 1
|
||||||
|
assert mask.shape[2] in {
|
||||||
|
1,
|
||||||
|
query.shape[1],
|
||||||
|
}, f"mask/query shape: {mask.shape}/{query.shape}"
|
||||||
|
assert mask.shape[3] in {
|
||||||
|
1,
|
||||||
|
key.shape[1],
|
||||||
|
}, f"mask/query shape: {mask.shape}/{key.shape}"
|
||||||
|
|
||||||
|
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
|
||||||
|
assert self.num_q_heads % self.num_kv_heads == 0
|
||||||
|
query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
|
||||||
|
key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
|
||||||
|
value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")
|
||||||
|
|
||||||
|
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
|
||||||
|
key_heads = rotate(key_heads, seq_dim=1, offset=0)
|
||||||
|
query_heads = rotate(query_heads, seq_dim=1, offset=0)
|
||||||
|
|
||||||
|
b, t, h, d = query_heads.shape
|
||||||
|
_, _, kv_h, _ = key_heads.shape
|
||||||
|
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
|
||||||
|
|
||||||
|
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
|
||||||
|
|
||||||
|
# Compute attention weights.
|
||||||
|
# Attention softmax is always carried out in fp32.
|
||||||
|
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
|
||||||
|
jnp.float32
|
||||||
|
)
|
||||||
|
attn_logits *= self.attn_output_multiplier
|
||||||
|
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
|
||||||
|
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
|
||||||
|
|
||||||
|
mask = mask[:, :, None, :, :]
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask.ndim != attn_logits.ndim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
|
||||||
|
f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
|
||||||
|
)
|
||||||
|
attn_logits = jnp.where(mask, attn_logits, -1e30)
|
||||||
|
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
|
||||||
|
|
||||||
|
# Weight the values by the attention and flatten the head vectors.
|
||||||
|
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
|
||||||
|
leading_dims = attn.shape[:2]
|
||||||
|
attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
|
||||||
|
|
||||||
|
# Apply another projection to get the final embeddings.
|
||||||
|
final_projection = Linear(self.model_size, with_bias=False)
|
||||||
|
return MHAOutput(final_projection(attn))
|
||||||
|
|
||||||
|
@hk.transparent
|
||||||
|
def _linear_projection(
|
||||||
|
self,
|
||||||
|
x: jax.Array,
|
||||||
|
head_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> jax.Array:
|
||||||
|
y = Linear(num_heads * head_size, with_bias=False, name=name)(x)
|
||||||
|
*leading_dims, _ = x.shape
|
||||||
|
return y.reshape((*leading_dims, num_heads, head_size))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MHABlock(hk.Module):
|
||||||
|
"""A MHA Block"""
|
||||||
|
|
||||||
|
num_q_heads: int
|
||||||
|
num_kv_heads: int
|
||||||
|
key_size: int
|
||||||
|
attn_output_multiplier: float = 1.0
|
||||||
|
|
||||||
|
@hk.transparent
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: jax.Array, # [B, T, D]
|
||||||
|
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
|
||||||
|
) -> MHAOutput:
|
||||||
|
_, _, model_size = inputs.shape
|
||||||
|
assert mask.ndim == 4, f"shape: {mask.shape}"
|
||||||
|
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
|
||||||
|
assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
|
||||||
|
side_input = inputs
|
||||||
|
|
||||||
|
def attn_block(query, key, value, mask) -> MHAOutput:
|
||||||
|
return MultiHeadAttention(
|
||||||
|
num_q_heads=self.num_q_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
key_size=self.key_size,
|
||||||
|
model_size=model_size,
|
||||||
|
attn_output_multiplier=self.attn_output_multiplier,
|
||||||
|
)(query, key, value, mask)
|
||||||
|
|
||||||
|
attn_output = attn_block(inputs, side_input, side_input, mask)
|
||||||
|
h_attn = attn_output.embeddings
|
||||||
|
|
||||||
|
return MHAOutput(embeddings=h_attn)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DenseBlock(hk.Module):
|
||||||
|
num_q_heads: int
|
||||||
|
num_kv_heads: int
|
||||||
|
key_size: int
|
||||||
|
widening_factor: float = 4.0
|
||||||
|
|
||||||
|
@hk.transparent
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: jax.Array, # [B, T, D]
|
||||||
|
) -> jax.Array: # [B, T, D]
|
||||||
|
_, _, model_size = inputs.shape
|
||||||
|
h_v = Linear(
|
||||||
|
ffn_size(model_size, self.widening_factor),
|
||||||
|
with_bias=False,
|
||||||
|
name="linear_v",
|
||||||
|
)(inputs)
|
||||||
|
h_w1 = jax.nn.gelu(
|
||||||
|
Linear(
|
||||||
|
ffn_size(model_size, self.widening_factor),
|
||||||
|
with_bias=False,
|
||||||
|
)(inputs)
|
||||||
|
)
|
||||||
|
h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)
|
||||||
|
|
||||||
|
return h_dense
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecoderLayer(hk.Module):
|
||||||
|
"""A transformer stack."""
|
||||||
|
|
||||||
|
num_q_heads: int
|
||||||
|
num_kv_heads: int
|
||||||
|
key_size: int
|
||||||
|
num_layers: int
|
||||||
|
layer_index: Optional[int] = None
|
||||||
|
widening_factor: float = 4.0
|
||||||
|
name: Optional[str] = None
|
||||||
|
attn_output_multiplier: float = 1.0
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: jax.Array, # [B, T, D]
|
||||||
|
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T]
|
||||||
|
padding_mask: Optional[jax.Array],
|
||||||
|
) -> DecoderOutput:
|
||||||
|
"""Transforms input embedding sequences to output embedding sequences."""
|
||||||
|
del padding_mask # Unused.
|
||||||
|
|
||||||
|
def layer_norm(x):
|
||||||
|
return hk_rms_norm(x)
|
||||||
|
|
||||||
|
h = inputs
|
||||||
|
|
||||||
|
attn_output = MHABlock(
|
||||||
|
num_q_heads=self.num_q_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
key_size=self.key_size,
|
||||||
|
attn_output_multiplier=self.attn_output_multiplier,
|
||||||
|
)(layer_norm(h), mask)
|
||||||
|
h_attn = attn_output.embeddings
|
||||||
|
|
||||||
|
h_attn = layer_norm(h_attn)
|
||||||
|
h += h_attn
|
||||||
|
|
||||||
|
def base_dense_block(h):
|
||||||
|
h = DenseBlock(
|
||||||
|
num_q_heads=self.num_q_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
key_size=self.key_size,
|
||||||
|
widening_factor=self.widening_factor,
|
||||||
|
)(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
h_dense = base_dense_block(layer_norm(h))
|
||||||
|
|
||||||
|
h_dense = layer_norm(h_dense)
|
||||||
|
h += h_dense
|
||||||
|
|
||||||
|
return DecoderOutput(
|
||||||
|
embeddings=h,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def layer_norm(x):
|
||||||
|
return hk_rms_norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Transformer(hk.Module):
|
||||||
|
"""A transformer stack."""
|
||||||
|
|
||||||
|
num_q_heads: int
|
||||||
|
num_kv_heads: int
|
||||||
|
key_size: int
|
||||||
|
widening_factor: float
|
||||||
|
attn_output_multiplier: float
|
||||||
|
num_layers: int
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
embeddings: jax.Array, # [B, T, D]
|
||||||
|
mask: jax.Array, # [B, T]
|
||||||
|
candidate_start_offset: Optional[int] = None,
|
||||||
|
) -> TransformerOutput:
|
||||||
|
"""Transforms input embedding sequences to output embedding sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: Input embeddings of shape [B, T, D]
|
||||||
|
mask: Padding mask of shape [B, T], True for valid positions
|
||||||
|
candidate_start_offset: If provided, positions >= this offset are treated as
|
||||||
|
candidates that can only attend to positions before the offset (user+history)
|
||||||
|
and themselves (self-attention), but not to other candidates.
|
||||||
|
Used for recommendation system inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformerOutput containing the output embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fprop_dtype = embeddings.dtype
|
||||||
|
_, seq_len, _ = embeddings.shape
|
||||||
|
padding_mask = mask.copy()
|
||||||
|
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
|
||||||
|
|
||||||
|
if candidate_start_offset is not None:
|
||||||
|
# Use recommendation system attention mask where candidates attend to
|
||||||
|
# user+history and themselves, but not to other candidates
|
||||||
|
attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
|
||||||
|
mask = mask * attn_mask
|
||||||
|
else:
|
||||||
|
# Standard causal mask for autoregressive sequence modelling
|
||||||
|
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
|
||||||
|
fprop_dtype
|
||||||
|
) # [B=1, H=1, T, T]
|
||||||
|
mask = mask * causal_mask # [B, H=1, T, T]
|
||||||
|
|
||||||
|
h = embeddings
|
||||||
|
|
||||||
|
def block(
|
||||||
|
h,
|
||||||
|
mask,
|
||||||
|
padding_mask,
|
||||||
|
layer_index: Optional[int] = None,
|
||||||
|
widening_factor: Optional[int] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> DecoderOutput:
|
||||||
|
return DecoderLayer(
|
||||||
|
num_q_heads=self.num_q_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
key_size=self.key_size,
|
||||||
|
widening_factor=widening_factor or self.widening_factor,
|
||||||
|
num_layers=self.num_layers,
|
||||||
|
attn_output_multiplier=self.attn_output_multiplier,
|
||||||
|
name=name,
|
||||||
|
layer_index=layer_index,
|
||||||
|
)(h, mask, padding_mask)
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
decoder_output = block(
|
||||||
|
h,
|
||||||
|
mask,
|
||||||
|
padding_mask,
|
||||||
|
layer_index=i,
|
||||||
|
name=f"decoder_layer_{i}",
|
||||||
|
)
|
||||||
|
h = decoder_output.embeddings
|
||||||
|
|
||||||
|
return TransformerOutput(
|
||||||
|
embeddings=h,
|
||||||
|
)
|
||||||
38
phoenix/pyproject.toml
Normal file
38
phoenix/pyproject.toml
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
[project]
|
||||||
|
name = "grok-1"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Grok-1 model"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"dm-haiku>=0.0.13",
|
||||||
|
"jax==0.8.1",
|
||||||
|
"numpy>=1.26.4",
|
||||||
|
"pyright>=1.1.408",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
environments = [
|
||||||
|
"sys_platform == 'darwin'",
|
||||||
|
"sys_platform == 'linux'",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
indent-width = 4
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
ignore = [
|
||||||
|
"E722",
|
||||||
|
"E731",
|
||||||
|
"E741",
|
||||||
|
"F405",
|
||||||
|
"E402",
|
||||||
|
"F403",
|
||||||
|
]
|
||||||
|
select = ["ISC001"]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
474
phoenix/recsys_model.py
Normal file
474
phoenix/recsys_model.py
Normal file
@ -0,0 +1,474 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from grok import (
|
||||||
|
TransformerConfig,
|
||||||
|
Transformer,
|
||||||
|
layer_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HashConfig:
|
||||||
|
"""Configuration for hash-based embeddings."""
|
||||||
|
|
||||||
|
num_user_hashes: int = 2
|
||||||
|
num_item_hashes: int = 2
|
||||||
|
num_author_hashes: int = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecsysEmbeddings:
|
||||||
|
"""Container for pre-looked-up embeddings from the embedding tables.
|
||||||
|
|
||||||
|
These embeddings are looked up from hash tables before being passed to the model.
|
||||||
|
The block_*_reduce functions will combine multiple hash embeddings into single representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_embeddings: jax.typing.ArrayLike
|
||||||
|
history_post_embeddings: jax.typing.ArrayLike
|
||||||
|
candidate_post_embeddings: jax.typing.ArrayLike
|
||||||
|
history_author_embeddings: jax.typing.ArrayLike
|
||||||
|
candidate_author_embeddings: jax.typing.ArrayLike
|
||||||
|
|
||||||
|
|
||||||
|
class RecsysModelOutput(NamedTuple):
|
||||||
|
"""Output of the recommendation model."""
|
||||||
|
|
||||||
|
logits: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
class RecsysBatch(NamedTuple):
|
||||||
|
"""Input batch for the recommendation model.
|
||||||
|
|
||||||
|
Contains the feature data (hashes, actions, product surfaces) but NOT the embeddings.
|
||||||
|
Embeddings are passed separately via RecsysEmbeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_hashes: jax.typing.ArrayLike
|
||||||
|
history_post_hashes: jax.typing.ArrayLike
|
||||||
|
history_author_hashes: jax.typing.ArrayLike
|
||||||
|
history_actions: jax.typing.ArrayLike
|
||||||
|
history_product_surface: jax.typing.ArrayLike
|
||||||
|
candidate_post_hashes: jax.typing.ArrayLike
|
||||||
|
candidate_author_hashes: jax.typing.ArrayLike
|
||||||
|
candidate_product_surface: jax.typing.ArrayLike
|
||||||
|
|
||||||
|
|
||||||
|
def block_user_reduce(
|
||||||
|
user_hashes: jnp.ndarray,
|
||||||
|
user_embeddings: jnp.ndarray,
|
||||||
|
num_user_hashes: int,
|
||||||
|
emb_size: int,
|
||||||
|
embed_init_scale: float = 1.0,
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
"""Combine multiple user hash embeddings into a single user representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_hashes: [B, num_user_hashes] - hash values (0 = invalid/padding)
|
||||||
|
user_embeddings: [B, num_user_hashes, D] - looked-up embeddings
|
||||||
|
num_user_hashes: number of hash functions used
|
||||||
|
emb_size: embedding dimension D
|
||||||
|
embed_init_scale: initialization scale for projection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
user_embedding: [B, 1, D] - combined user embedding
|
||||||
|
user_padding_mask: [B, 1] - True where user is valid
|
||||||
|
"""
|
||||||
|
B = user_embeddings.shape[0]
|
||||||
|
D = emb_size
|
||||||
|
|
||||||
|
user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * D))
|
||||||
|
|
||||||
|
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
|
||||||
|
proj_mat_1 = hk.get_parameter(
|
||||||
|
"proj_mat_1",
|
||||||
|
[num_user_hashes * D, D],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_embedding = jnp.dot(user_embedding.astype(proj_mat_1.dtype), proj_mat_1).astype(
|
||||||
|
user_embeddings.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# hash 0 is reserved for padding)
|
||||||
|
user_padding_mask = (user_hashes[:, 0] != 0).reshape(B, 1).astype(jnp.bool_)
|
||||||
|
|
||||||
|
return user_embedding, user_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
def block_history_reduce(
|
||||||
|
history_post_hashes: jnp.ndarray,
|
||||||
|
history_post_embeddings: jnp.ndarray,
|
||||||
|
history_author_embeddings: jnp.ndarray,
|
||||||
|
history_product_surface_embeddings: jnp.ndarray,
|
||||||
|
history_actions_embeddings: jnp.ndarray,
|
||||||
|
num_item_hashes: int,
|
||||||
|
num_author_hashes: int,
|
||||||
|
embed_init_scale: float = 1.0,
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
"""Combine history embeddings (post, author, actions, product_surface) into sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history_post_hashes: [B, S, num_item_hashes]
|
||||||
|
history_post_embeddings: [B, S, num_item_hashes, D]
|
||||||
|
history_author_embeddings: [B, S, num_author_hashes, D]
|
||||||
|
history_product_surface_embeddings: [B, S, D]
|
||||||
|
history_actions_embeddings: [B, S, D]
|
||||||
|
num_item_hashes: number of hash functions for items
|
||||||
|
num_author_hashes: number of hash functions for authors
|
||||||
|
emb_size: embedding dimension D
|
||||||
|
embed_init_scale: initialization scale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
history_embeddings: [B, S, D]
|
||||||
|
history_padding_mask: [B, S]
|
||||||
|
"""
|
||||||
|
B, S, _, D = history_post_embeddings.shape
|
||||||
|
|
||||||
|
history_post_embeddings_reshaped = history_post_embeddings.reshape((B, S, num_item_hashes * D))
|
||||||
|
history_author_embeddings_reshaped = history_author_embeddings.reshape(
|
||||||
|
(B, S, num_author_hashes * D)
|
||||||
|
)
|
||||||
|
|
||||||
|
post_author_embedding = jnp.concatenate(
|
||||||
|
[
|
||||||
|
history_post_embeddings_reshaped,
|
||||||
|
history_author_embeddings_reshaped,
|
||||||
|
history_actions_embeddings,
|
||||||
|
history_product_surface_embeddings,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
|
||||||
|
proj_mat_3 = hk.get_parameter(
|
||||||
|
"proj_mat_3",
|
||||||
|
[post_author_embedding.shape[-1], D],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_embedding = jnp.dot(post_author_embedding.astype(proj_mat_3.dtype), proj_mat_3).astype(
|
||||||
|
post_author_embedding.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
history_embedding = history_embedding.reshape(B, S, D)
|
||||||
|
|
||||||
|
history_padding_mask = (history_post_hashes[:, :, 0] != 0).reshape(B, S)
|
||||||
|
|
||||||
|
return history_embedding, history_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
def block_candidate_reduce(
|
||||||
|
candidate_post_hashes: jnp.ndarray,
|
||||||
|
candidate_post_embeddings: jnp.ndarray,
|
||||||
|
candidate_author_embeddings: jnp.ndarray,
|
||||||
|
candidate_product_surface_embeddings: jnp.ndarray,
|
||||||
|
num_item_hashes: int,
|
||||||
|
num_author_hashes: int,
|
||||||
|
embed_init_scale: float = 1.0,
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
"""Combine candidate embeddings (post, author, product_surface) into sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate_post_hashes: [B, C, num_item_hashes]
|
||||||
|
candidate_post_embeddings: [B, C, num_item_hashes, D]
|
||||||
|
candidate_author_embeddings: [B, C, num_author_hashes, D]
|
||||||
|
candidate_product_surface_embeddings: [B, C, D]
|
||||||
|
num_item_hashes: number of hash functions for items
|
||||||
|
num_author_hashes: number of hash functions for authors
|
||||||
|
emb_size: embedding dimension D
|
||||||
|
embed_init_scale: initialization scale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
candidate_embeddings: [B, C, D]
|
||||||
|
candidate_padding_mask: [B, C]
|
||||||
|
"""
|
||||||
|
B, C, _, D = candidate_post_embeddings.shape
|
||||||
|
|
||||||
|
candidate_post_embeddings_reshaped = candidate_post_embeddings.reshape(
|
||||||
|
(B, C, num_item_hashes * D)
|
||||||
|
)
|
||||||
|
candidate_author_embeddings_reshaped = candidate_author_embeddings.reshape(
|
||||||
|
(B, C, num_author_hashes * D)
|
||||||
|
)
|
||||||
|
|
||||||
|
post_author_embedding = jnp.concatenate(
|
||||||
|
[
|
||||||
|
candidate_post_embeddings_reshaped,
|
||||||
|
candidate_author_embeddings_reshaped,
|
||||||
|
candidate_product_surface_embeddings,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
|
||||||
|
proj_mat_2 = hk.get_parameter(
|
||||||
|
"proj_mat_2",
|
||||||
|
[post_author_embedding.shape[-1], D],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate_embedding = jnp.dot(
|
||||||
|
post_author_embedding.astype(proj_mat_2.dtype), proj_mat_2
|
||||||
|
).astype(post_author_embedding.dtype)
|
||||||
|
|
||||||
|
candidate_padding_mask = (candidate_post_hashes[:, :, 0] != 0).reshape(B, C).astype(jnp.bool_)
|
||||||
|
|
||||||
|
return candidate_embedding, candidate_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PhoenixModelConfig:
|
||||||
|
"""Configuration for the recommendation system model."""
|
||||||
|
|
||||||
|
model: TransformerConfig
|
||||||
|
emb_size: int
|
||||||
|
num_actions: int
|
||||||
|
history_seq_len: int = 128
|
||||||
|
candidate_seq_len: int = 32
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
fprop_dtype: Any = jnp.bfloat16
|
||||||
|
|
||||||
|
hash_config: HashConfig = None # type: ignore
|
||||||
|
|
||||||
|
product_surface_vocab_size: int = 16
|
||||||
|
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.hash_config is None:
|
||||||
|
self.hash_config = HashConfig()
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self._initialized = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def make(self):
|
||||||
|
if not self._initialized:
|
||||||
|
logger.warning(f"PhoenixModel {self.name} is not initialized. Initializing.")
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
return PhoenixModel(
|
||||||
|
model=self.model.make(),
|
||||||
|
config=self,
|
||||||
|
fprop_dtype=self.fprop_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PhoenixModel(hk.Module):
|
||||||
|
"""A transformer-based recommendation model for ranking candidates."""
|
||||||
|
|
||||||
|
model: Transformer
|
||||||
|
config: PhoenixModelConfig
|
||||||
|
fprop_dtype: Any = jnp.bfloat16
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
def _get_action_embeddings(
|
||||||
|
self,
|
||||||
|
actions: jax.Array,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Convert multi-hot action vectors to embeddings.
|
||||||
|
|
||||||
|
Uses a learned projection matrix to map the signed action vector
|
||||||
|
to the embedding dimension. This works for any number of actions.
|
||||||
|
"""
|
||||||
|
config = self.config
|
||||||
|
_, _, num_actions = actions.shape
|
||||||
|
D = config.emb_size
|
||||||
|
|
||||||
|
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
|
||||||
|
action_projection = hk.get_parameter(
|
||||||
|
"action_projection",
|
||||||
|
[num_actions, D],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=embed_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions_signed = (2 * actions - 1).astype(jnp.float32)
|
||||||
|
|
||||||
|
action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)
|
||||||
|
|
||||||
|
valid_mask = jnp.any(actions, axis=-1, keepdims=True)
|
||||||
|
action_emb = action_emb * valid_mask
|
||||||
|
|
||||||
|
return action_emb.astype(self.fprop_dtype)
|
||||||
|
|
||||||
|
def _single_hot_to_embeddings(
|
||||||
|
self,
|
||||||
|
input: jax.Array,
|
||||||
|
vocab_size: int,
|
||||||
|
emb_size: int,
|
||||||
|
name: str,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Convert single-hot indices to embeddings via lookup table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: [B, S] tensor of categorical indices
|
||||||
|
vocab_size: size of the vocabulary
|
||||||
|
emb_size: embedding dimension
|
||||||
|
name: name for the embedding table parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embeddings: [B, S, emb_size]
|
||||||
|
"""
|
||||||
|
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
|
||||||
|
embedding_table = hk.get_parameter(
|
||||||
|
name,
|
||||||
|
[vocab_size, emb_size],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=embed_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_one_hot = jax.nn.one_hot(input, vocab_size)
|
||||||
|
output = jnp.dot(input_one_hot, embedding_table)
|
||||||
|
return output.astype(self.fprop_dtype)
|
||||||
|
|
||||||
|
def _get_unembedding(self) -> jax.Array:
|
||||||
|
"""Get the unembedding matrix for decoding to logits."""
|
||||||
|
config = self.config
|
||||||
|
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
|
||||||
|
unembed_mat = hk.get_parameter(
|
||||||
|
"unembeddings",
|
||||||
|
[config.emb_size, config.num_actions],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=embed_init,
|
||||||
|
)
|
||||||
|
return unembed_mat
|
||||||
|
|
||||||
|
def build_inputs(
|
||||||
|
self,
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
) -> Tuple[jax.Array, jax.Array, int]:
|
||||||
|
"""Build input embeddings from batch and pre-looked-up embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing hashes, actions, product surfaces
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embeddings: [B, 1 + history_len + num_candidates, D]
|
||||||
|
padding_mask: [B, 1 + history_len + num_candidates]
|
||||||
|
candidate_start_offset: int - position where candidates start
|
||||||
|
"""
|
||||||
|
config = self.config
|
||||||
|
hash_config = config.hash_config
|
||||||
|
|
||||||
|
history_product_surface_embeddings = self._single_hot_to_embeddings(
|
||||||
|
batch.history_product_surface, # type: ignore
|
||||||
|
config.product_surface_vocab_size,
|
||||||
|
config.emb_size,
|
||||||
|
"product_surface_embedding_table",
|
||||||
|
)
|
||||||
|
candidate_product_surface_embeddings = self._single_hot_to_embeddings(
|
||||||
|
batch.candidate_product_surface, # type: ignore
|
||||||
|
config.product_surface_vocab_size,
|
||||||
|
config.emb_size,
|
||||||
|
"product_surface_embedding_table",
|
||||||
|
)
|
||||||
|
|
||||||
|
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
|
||||||
|
|
||||||
|
user_embeddings, user_padding_mask = block_user_reduce(
|
||||||
|
batch.user_hashes, # type: ignore
|
||||||
|
recsys_embeddings.user_embeddings, # type: ignore
|
||||||
|
hash_config.num_user_hashes,
|
||||||
|
config.emb_size,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_embeddings, history_padding_mask = block_history_reduce(
|
||||||
|
batch.history_post_hashes, # type: ignore
|
||||||
|
recsys_embeddings.history_post_embeddings, # type: ignore
|
||||||
|
recsys_embeddings.history_author_embeddings, # type: ignore
|
||||||
|
history_product_surface_embeddings,
|
||||||
|
history_actions_embeddings,
|
||||||
|
hash_config.num_item_hashes,
|
||||||
|
hash_config.num_author_hashes,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate_embeddings, candidate_padding_mask = block_candidate_reduce(
|
||||||
|
batch.candidate_post_hashes, # type: ignore
|
||||||
|
recsys_embeddings.candidate_post_embeddings, # type: ignore
|
||||||
|
recsys_embeddings.candidate_author_embeddings, # type: ignore
|
||||||
|
candidate_product_surface_embeddings,
|
||||||
|
hash_config.num_item_hashes,
|
||||||
|
hash_config.num_author_hashes,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = jnp.concatenate(
|
||||||
|
[user_embeddings, history_embeddings, candidate_embeddings], axis=1
|
||||||
|
)
|
||||||
|
padding_mask = jnp.concatenate(
|
||||||
|
[user_padding_mask, history_padding_mask, candidate_padding_mask], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate_start_offset = user_padding_mask.shape[1] + history_padding_mask.shape[1]
|
||||||
|
|
||||||
|
return embeddings.astype(self.fprop_dtype), padding_mask, candidate_start_offset
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
) -> RecsysModelOutput:
|
||||||
|
"""Forward pass for ranking candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing hashes, actions, product surfaces
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RecsysModelOutput containing logits for each candidate. Shape = [B, num_candidates, num_actions]
|
||||||
|
"""
|
||||||
|
embeddings, padding_mask, candidate_start_offset = self.build_inputs(
|
||||||
|
batch, recsys_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
model_output = self.model(
|
||||||
|
embeddings,
|
||||||
|
padding_mask,
|
||||||
|
candidate_start_offset=candidate_start_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_embeddings = model_output.embeddings
|
||||||
|
|
||||||
|
out_embeddings = layer_norm(out_embeddings)
|
||||||
|
|
||||||
|
candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]
|
||||||
|
|
||||||
|
unembeddings = self._get_unembedding()
|
||||||
|
logits = jnp.dot(candidate_embeddings.astype(unembeddings.dtype), unembeddings)
|
||||||
|
logits = logits.astype(self.fprop_dtype)
|
||||||
|
|
||||||
|
return RecsysModelOutput(logits=logits)
|
||||||
372
phoenix/recsys_retrieval_model.py
Normal file
372
phoenix/recsys_retrieval_model.py
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from grok import TransformerConfig, Transformer
|
||||||
|
from recsys_model import (
|
||||||
|
HashConfig,
|
||||||
|
RecsysBatch,
|
||||||
|
RecsysEmbeddings,
|
||||||
|
block_history_reduce,
|
||||||
|
block_user_reduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EPS = 1e-12
|
||||||
|
INF = 1e12
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalOutput(NamedTuple):
|
||||||
|
"""Output of the retrieval model."""
|
||||||
|
|
||||||
|
user_representation: jax.Array
|
||||||
|
top_k_indices: jax.Array
|
||||||
|
top_k_scores: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CandidateTower(hk.Module):
|
||||||
|
"""Candidate tower that projects post+author embeddings to a shared embedding space.
|
||||||
|
|
||||||
|
This tower takes the concatenated embeddings of a post and its author,
|
||||||
|
and projects them to a normalized representation suitable for similarity search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
emb_size: int
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
def __call__(self, post_author_embedding: jax.Array) -> jax.Array:
|
||||||
|
"""Project post+author embeddings to normalized representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
post_author_embedding: Concatenated post and author embeddings
|
||||||
|
Shape: [B, C, num_hashes, D] or [B, num_hashes, D]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized candidate representation
|
||||||
|
Shape: [B, C, D] or [B, D]
|
||||||
|
"""
|
||||||
|
if len(post_author_embedding.shape) == 4:
|
||||||
|
B, C, _, _ = post_author_embedding.shape
|
||||||
|
post_author_embedding = jnp.reshape(post_author_embedding, (B, C, -1))
|
||||||
|
else:
|
||||||
|
B, _, _ = post_author_embedding.shape
|
||||||
|
post_author_embedding = jnp.reshape(post_author_embedding, (B, -1))
|
||||||
|
|
||||||
|
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
|
||||||
|
|
||||||
|
proj_1 = hk.get_parameter(
|
||||||
|
"candidate_tower_projection_1",
|
||||||
|
[post_author_embedding.shape[-1], self.emb_size * 2],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=embed_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
proj_2 = hk.get_parameter(
|
||||||
|
"candidate_tower_projection_2",
|
||||||
|
[self.emb_size * 2, self.emb_size],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=embed_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden = jnp.dot(post_author_embedding.astype(proj_1.dtype), proj_1)
|
||||||
|
hidden = jax.nn.silu(hidden)
|
||||||
|
candidate_embeddings = jnp.dot(hidden.astype(proj_2.dtype), proj_2)
|
||||||
|
|
||||||
|
candidate_norm_sq = jnp.sum(candidate_embeddings**2, axis=-1, keepdims=True)
|
||||||
|
candidate_norm = jnp.sqrt(jnp.maximum(candidate_norm_sq, EPS))
|
||||||
|
candidate_representation = candidate_embeddings / candidate_norm
|
||||||
|
|
||||||
|
return candidate_representation.astype(post_author_embedding.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PhoenixRetrievalModelConfig:
|
||||||
|
"""Configuration for the Phoenix Retrieval Model.
|
||||||
|
|
||||||
|
This model uses the same transformer architecture as the Phoenix ranker
|
||||||
|
for encoding user representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: TransformerConfig
|
||||||
|
emb_size: int
|
||||||
|
history_seq_len: int = 128
|
||||||
|
candidate_seq_len: int = 32
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
fprop_dtype: Any = jnp.bfloat16
|
||||||
|
|
||||||
|
hash_config: HashConfig = None # type: ignore
|
||||||
|
|
||||||
|
product_surface_vocab_size: int = 16
|
||||||
|
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.hash_config is None:
|
||||||
|
self.hash_config = HashConfig()
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self._initialized = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def make(self):
|
||||||
|
if not self._initialized:
|
||||||
|
logger.warning(f"PhoenixRetrievalModel {self.name} is not initialized. Initializing.")
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
return PhoenixRetrievalModel(
|
||||||
|
model=self.model.make(),
|
||||||
|
config=self,
|
||||||
|
fprop_dtype=self.fprop_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PhoenixRetrievalModel(hk.Module):
|
||||||
|
"""A two-tower retrieval model using the Phoenix transformer for user encoding.
|
||||||
|
|
||||||
|
This model implements the two-tower architecture for efficient retrieval:
|
||||||
|
- User Tower: Encodes user features + history using the Phoenix transformer
|
||||||
|
- Candidate Tower: Projects candidate embeddings to a shared space
|
||||||
|
|
||||||
|
The user and candidate representations are L2-normalized, enabling efficient
|
||||||
|
approximate nearest neighbor (ANN) search using dot product similarity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: Transformer
|
||||||
|
config: PhoenixRetrievalModelConfig
|
||||||
|
fprop_dtype: Any = jnp.bfloat16
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
def _get_action_embeddings(
|
||||||
|
self,
|
||||||
|
actions: jax.Array,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Convert multi-hot action vectors to embeddings."""
|
||||||
|
config = self.config
|
||||||
|
_, _, num_actions = actions.shape
|
||||||
|
D = config.emb_size
|
||||||
|
|
||||||
|
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
|
||||||
|
action_projection = hk.get_parameter(
|
||||||
|
"action_projection",
|
||||||
|
[num_actions, D],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=embed_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions_signed = (2 * actions - 1).astype(jnp.float32)
|
||||||
|
action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)
|
||||||
|
|
||||||
|
valid_mask = jnp.any(actions, axis=-1, keepdims=True)
|
||||||
|
action_emb = action_emb * valid_mask
|
||||||
|
|
||||||
|
return action_emb.astype(self.fprop_dtype)
|
||||||
|
|
||||||
|
def _single_hot_to_embeddings(
|
||||||
|
self,
|
||||||
|
input: jax.Array,
|
||||||
|
vocab_size: int,
|
||||||
|
emb_size: int,
|
||||||
|
name: str,
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Convert single-hot indices to embeddings via lookup table."""
|
||||||
|
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
|
||||||
|
embedding_table = hk.get_parameter(
|
||||||
|
name,
|
||||||
|
[vocab_size, emb_size],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
init=embed_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_one_hot = jax.nn.one_hot(input, vocab_size)
|
||||||
|
output = jnp.dot(input_one_hot, embedding_table)
|
||||||
|
return output.astype(self.fprop_dtype)
|
||||||
|
|
||||||
|
def build_user_representation(
|
||||||
|
self,
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
"""Build user representation from user features and history.
|
||||||
|
|
||||||
|
Uses the Phoenix transformer to encode user + history embeddings
|
||||||
|
into a single user representation vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing hashes, actions, product surfaces
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
user_representation: L2-normalized user embedding [B, D]
|
||||||
|
user_norm: Pre-normalization L2 norm [B, 1]
|
||||||
|
"""
|
||||||
|
config = self.config
|
||||||
|
hash_config = config.hash_config
|
||||||
|
|
||||||
|
history_product_surface_embeddings = self._single_hot_to_embeddings(
|
||||||
|
batch.history_product_surface, # type: ignore
|
||||||
|
config.product_surface_vocab_size,
|
||||||
|
config.emb_size,
|
||||||
|
"product_surface_embedding_table",
|
||||||
|
)
|
||||||
|
|
||||||
|
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
|
||||||
|
|
||||||
|
user_embeddings, user_padding_mask = block_user_reduce(
|
||||||
|
batch.user_hashes, # type: ignore
|
||||||
|
recsys_embeddings.user_embeddings, # type: ignore
|
||||||
|
hash_config.num_user_hashes,
|
||||||
|
config.emb_size,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_embeddings, history_padding_mask = block_history_reduce(
|
||||||
|
batch.history_post_hashes, # type: ignore
|
||||||
|
recsys_embeddings.history_post_embeddings, # type: ignore
|
||||||
|
recsys_embeddings.history_author_embeddings, # type: ignore
|
||||||
|
history_product_surface_embeddings,
|
||||||
|
history_actions_embeddings,
|
||||||
|
hash_config.num_item_hashes,
|
||||||
|
hash_config.num_author_hashes,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = jnp.concatenate([user_embeddings, history_embeddings], axis=1)
|
||||||
|
padding_mask = jnp.concatenate([user_padding_mask, history_padding_mask], axis=1)
|
||||||
|
|
||||||
|
model_output = self.model(
|
||||||
|
embeddings.astype(self.fprop_dtype),
|
||||||
|
padding_mask,
|
||||||
|
candidate_start_offset=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_outputs = model_output.embeddings
|
||||||
|
|
||||||
|
mask_float = padding_mask.astype(jnp.float32)[:, :, None] # [B, T, 1]
|
||||||
|
user_embeddings_masked = user_outputs * mask_float
|
||||||
|
user_embedding_sum = jnp.sum(user_embeddings_masked, axis=1) # [B, D]
|
||||||
|
mask_sum = jnp.sum(mask_float, axis=1) # [B, 1]
|
||||||
|
user_representation = user_embedding_sum / jnp.maximum(mask_sum, 1.0)
|
||||||
|
|
||||||
|
user_norm_sq = jnp.sum(user_representation**2, axis=-1, keepdims=True)
|
||||||
|
user_norm = jnp.sqrt(jnp.maximum(user_norm_sq, EPS))
|
||||||
|
user_representation = user_representation / user_norm
|
||||||
|
|
||||||
|
return user_representation, user_norm
|
||||||
|
|
||||||
|
def build_candidate_representation(
|
||||||
|
self,
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
"""Build candidate (item) representations.
|
||||||
|
|
||||||
|
Projects post + author embeddings to a shared embedding space
|
||||||
|
using the candidate tower MLP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing candidate hashes
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
candidate_representation: L2-normalized candidate embeddings [B, C, D]
|
||||||
|
candidate_padding_mask: Valid candidate mask [B, C]
|
||||||
|
"""
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
candidate_post_embeddings = recsys_embeddings.candidate_post_embeddings
|
||||||
|
candidate_author_embeddings = recsys_embeddings.candidate_author_embeddings
|
||||||
|
|
||||||
|
post_author_embedding = jnp.concatenate(
|
||||||
|
[candidate_post_embeddings, candidate_author_embeddings], axis=2
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate_tower = CandidateTower(
|
||||||
|
emb_size=config.emb_size,
|
||||||
|
)
|
||||||
|
candidate_representation = candidate_tower(post_author_embedding)
|
||||||
|
|
||||||
|
candidate_padding_mask = (batch.candidate_post_hashes[:, :, 0] != 0).astype(jnp.bool_) # type: ignore
|
||||||
|
|
||||||
|
return candidate_representation, candidate_padding_mask
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
corpus_embeddings: jax.Array,
|
||||||
|
top_k: int,
|
||||||
|
corpus_mask: Optional[jax.Array] = None,
|
||||||
|
) -> RetrievalOutput:
|
||||||
|
"""Retrieve top-k candidates from corpus for each user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing hashes, actions, product surfaces
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
corpus_embeddings: [N, D] normalized corpus candidate embeddings
|
||||||
|
top_k: Number of candidates to retrieve
|
||||||
|
corpus_mask: [N] optional mask for valid corpus entries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RetrievalOutput containing user representation and top-k results
|
||||||
|
"""
|
||||||
|
user_representation, _ = self.build_user_representation(batch, recsys_embeddings)
|
||||||
|
|
||||||
|
top_k_indices, top_k_scores = self._retrieve_top_k(
|
||||||
|
user_representation, corpus_embeddings, top_k, corpus_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return RetrievalOutput(
|
||||||
|
user_representation=user_representation,
|
||||||
|
top_k_indices=top_k_indices,
|
||||||
|
top_k_scores=top_k_scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _retrieve_top_k(
|
||||||
|
self,
|
||||||
|
user_representation: jax.Array,
|
||||||
|
corpus_embeddings: jax.Array,
|
||||||
|
top_k: int,
|
||||||
|
corpus_mask: Optional[jax.Array] = None,
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
"""Retrieve top-k candidates from a corpus for each user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_representation: [B, D] normalized user embeddings
|
||||||
|
corpus_embeddings: [N, D] normalized corpus candidate embeddings
|
||||||
|
top_k: Number of candidates to retrieve
|
||||||
|
corpus_mask: [N] optional mask for valid corpus entries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
top_k_indices: [B, K] indices of top-k candidates
|
||||||
|
top_k_scores: [B, K] similarity scores of top-k candidates
|
||||||
|
"""
|
||||||
|
scores = jnp.matmul(user_representation, corpus_embeddings.T)
|
||||||
|
|
||||||
|
if corpus_mask is not None:
|
||||||
|
scores = jnp.where(corpus_mask[None, :], scores, -INF)
|
||||||
|
|
||||||
|
top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k)
|
||||||
|
|
||||||
|
return top_k_indices, top_k_scores
|
||||||
121
phoenix/run_ranker.py
Normal file
121
phoenix/run_ranker.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from grok import TransformerConfig
|
||||||
|
from recsys_model import PhoenixModelConfig, HashConfig
|
||||||
|
from runners import RecsysInferenceRunner, ModelRunner, create_example_batch, ACTIONS
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Model configuration
|
||||||
|
emb_size = 128 # Embedding dimension
|
||||||
|
num_actions = len(ACTIONS) # Number of explicit engagement actions
|
||||||
|
history_seq_len = 32 # Max history length
|
||||||
|
candidate_seq_len = 8 # Max candidates to rank
|
||||||
|
|
||||||
|
# Hash configuration
|
||||||
|
hash_config = HashConfig(
|
||||||
|
num_user_hashes=2,
|
||||||
|
num_item_hashes=2,
|
||||||
|
num_author_hashes=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
recsys_model = PhoenixModelConfig(
|
||||||
|
emb_size=emb_size,
|
||||||
|
num_actions=num_actions,
|
||||||
|
history_seq_len=history_seq_len,
|
||||||
|
candidate_seq_len=candidate_seq_len,
|
||||||
|
hash_config=hash_config,
|
||||||
|
product_surface_vocab_size=16,
|
||||||
|
model=TransformerConfig(
|
||||||
|
emb_size=emb_size,
|
||||||
|
widening_factor=2,
|
||||||
|
key_size=64,
|
||||||
|
num_q_heads=2,
|
||||||
|
num_kv_heads=2,
|
||||||
|
num_layers=2,
|
||||||
|
attn_output_multiplier=0.125,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create inference runner
|
||||||
|
inference_runner = RecsysInferenceRunner(
|
||||||
|
runner=ModelRunner(
|
||||||
|
model=recsys_model,
|
||||||
|
bs_per_device=0.125,
|
||||||
|
),
|
||||||
|
name="recsys_local",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Initializing model...")
|
||||||
|
inference_runner.initialize()
|
||||||
|
print("Model initialized!")
|
||||||
|
|
||||||
|
# Create example batch with simulated posts
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("RECOMMENDATION SYSTEM DEMO")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
example_batch, example_embeddings = create_example_batch(
|
||||||
|
batch_size=batch_size,
|
||||||
|
emb_size=emb_size,
|
||||||
|
history_len=history_seq_len,
|
||||||
|
num_candidates=candidate_seq_len,
|
||||||
|
num_actions=num_actions,
|
||||||
|
num_user_hashes=hash_config.num_user_hashes,
|
||||||
|
num_item_hashes=hash_config.num_item_hashes,
|
||||||
|
num_author_hashes=hash_config.num_author_hashes,
|
||||||
|
product_surface_vocab_size=recsys_model.product_surface_vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
action_names = [action.replace("_", " ").title() for action in ACTIONS]
|
||||||
|
|
||||||
|
# Count valid history items (where first post hash is non-zero)
|
||||||
|
valid_history_count = int((example_batch.history_post_hashes[:, :, 0] != 0).sum()) # type: ignore
|
||||||
|
print(f"\nUser has viewed {valid_history_count} posts in their history")
|
||||||
|
print(f"Ranking {candidate_seq_len} candidate posts...")
|
||||||
|
|
||||||
|
# Rank candidates
|
||||||
|
ranking_output = inference_runner.rank(example_batch, example_embeddings)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
scores = np.array(ranking_output.scores[0]) # [num_candidates, num_actions]
|
||||||
|
ranked_indices = np.array(ranking_output.ranked_indices[0]) # [num_candidates]
|
||||||
|
|
||||||
|
print("\n" + "-" * 70)
|
||||||
|
print("RANKING RESULTS (ordered by predicted 'Favorite Score' probability)")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
for rank, idx in enumerate(ranked_indices):
|
||||||
|
idx = int(idx)
|
||||||
|
print(f"\nRank {rank + 1}: ")
|
||||||
|
print(" Predicted engagement probabilities:")
|
||||||
|
for action_idx, action_name in enumerate(action_names):
|
||||||
|
prob = float(scores[idx, action_idx])
|
||||||
|
bar = "█" * int(prob * 20) + "░" * (20 - int(prob * 20))
|
||||||
|
print(f" {action_name:24s}: {bar} {prob:.3f}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("Demo complete!")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
||||||
149
phoenix/run_retrieval.py
Normal file
149
phoenix/run_retrieval.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from grok import TransformerConfig
|
||||||
|
from recsys_model import HashConfig
|
||||||
|
from recsys_retrieval_model import PhoenixRetrievalModelConfig
|
||||||
|
from runners import (
|
||||||
|
RecsysRetrievalInferenceRunner,
|
||||||
|
RetrievalModelRunner,
|
||||||
|
create_example_batch,
|
||||||
|
create_example_corpus,
|
||||||
|
ACTIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Model configuration - same architecture as Phoenix ranker
|
||||||
|
emb_size = 128 # Embedding dimension
|
||||||
|
num_actions = len(ACTIONS) # Number of explicit engagement actions
|
||||||
|
history_seq_len = 32 # Max history length
|
||||||
|
candidate_seq_len = 8 # Max candidates per batch (for training)
|
||||||
|
|
||||||
|
# Hash configuration
|
||||||
|
hash_config = HashConfig(
|
||||||
|
num_user_hashes=2,
|
||||||
|
num_item_hashes=2,
|
||||||
|
num_author_hashes=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure the retrieval model - uses same transformer as Phoenix
|
||||||
|
retrieval_model_config = PhoenixRetrievalModelConfig(
|
||||||
|
emb_size=emb_size,
|
||||||
|
history_seq_len=history_seq_len,
|
||||||
|
candidate_seq_len=candidate_seq_len,
|
||||||
|
hash_config=hash_config,
|
||||||
|
product_surface_vocab_size=16,
|
||||||
|
model=TransformerConfig(
|
||||||
|
emb_size=emb_size,
|
||||||
|
widening_factor=2,
|
||||||
|
key_size=64,
|
||||||
|
num_q_heads=2,
|
||||||
|
num_kv_heads=2,
|
||||||
|
num_layers=2,
|
||||||
|
attn_output_multiplier=0.125,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create inference runner
|
||||||
|
inference_runner = RecsysRetrievalInferenceRunner(
|
||||||
|
runner=RetrievalModelRunner(
|
||||||
|
model=retrieval_model_config,
|
||||||
|
bs_per_device=0.125,
|
||||||
|
),
|
||||||
|
name="retrieval_local",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Initializing retrieval model...")
|
||||||
|
inference_runner.initialize()
|
||||||
|
print("Model initialized!")
|
||||||
|
|
||||||
|
# Create example batch with simulated user and history
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("RETRIEVAL SYSTEM DEMO")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
batch_size = 2 # Two users for demo
|
||||||
|
example_batch, example_embeddings = create_example_batch(
|
||||||
|
batch_size=batch_size,
|
||||||
|
emb_size=emb_size,
|
||||||
|
history_len=history_seq_len,
|
||||||
|
num_candidates=candidate_seq_len,
|
||||||
|
num_actions=num_actions,
|
||||||
|
num_user_hashes=hash_config.num_user_hashes,
|
||||||
|
num_item_hashes=hash_config.num_item_hashes,
|
||||||
|
num_author_hashes=hash_config.num_author_hashes,
|
||||||
|
product_surface_vocab_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count valid history items
|
||||||
|
valid_history_count = int((example_batch.history_post_hashes[:, :, 0] != 0).sum()) # type: ignore
|
||||||
|
print(f"\nUsers have viewed {valid_history_count} posts total in their history")
|
||||||
|
|
||||||
|
# Step 1: Create a corpus of candidate posts
|
||||||
|
print("\n" + "-" * 70)
|
||||||
|
print("STEP 1: Creating Candidate Corpus")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
corpus_size = 1000 # Simulated corpus of 1000 posts
|
||||||
|
corpus_embeddings, corpus_post_ids = create_example_corpus(
|
||||||
|
corpus_size=corpus_size,
|
||||||
|
emb_size=emb_size,
|
||||||
|
seed=456,
|
||||||
|
)
|
||||||
|
print(f"Corpus size: {corpus_size} posts")
|
||||||
|
print(f"Corpus embeddings shape: {corpus_embeddings.shape}")
|
||||||
|
|
||||||
|
# Set corpus for retrieval
|
||||||
|
inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)
|
||||||
|
|
||||||
|
# Step 2: Retrieve top-k candidates for each user
|
||||||
|
print("\n" + "-" * 70)
|
||||||
|
print("STEP 2: Retrieving Top-K Candidates")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
top_k = 10
|
||||||
|
retrieval_output = inference_runner.retrieve(
|
||||||
|
example_batch,
|
||||||
|
example_embeddings,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nRetrieved top {top_k} candidates for each of {batch_size} users:")
|
||||||
|
|
||||||
|
top_k_indices = np.array(retrieval_output.top_k_indices)
|
||||||
|
top_k_scores = np.array(retrieval_output.top_k_scores)
|
||||||
|
|
||||||
|
for user_idx in range(batch_size):
|
||||||
|
print(f"\n User {user_idx + 1}:")
|
||||||
|
print(f" {'Rank':<6} {'Post ID':<12} {'Score':<12}")
|
||||||
|
print(f" {'-' * 30}")
|
||||||
|
for rank in range(top_k):
|
||||||
|
post_id = top_k_indices[user_idx, rank]
|
||||||
|
score = top_k_scores[user_idx, rank]
|
||||||
|
bar = "█" * int((score + 1) * 10) + "░" * (20 - int((score + 1) * 10))
|
||||||
|
print(f" {rank + 1:<6} {post_id:<12} {bar} {score:.4f}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("Demo complete!")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
||||||
729
phoenix/runners.py
Normal file
729
phoenix/runners.py
Normal file
@ -0,0 +1,729 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from grok import TrainingState
|
||||||
|
from recsys_retrieval_model import PhoenixRetrievalModelConfig
|
||||||
|
from recsys_retrieval_model import RetrievalOutput as ModelRetrievalOutput
|
||||||
|
|
||||||
|
from recsys_model import (
|
||||||
|
PhoenixModelConfig,
|
||||||
|
RecsysBatch,
|
||||||
|
RecsysEmbeddings,
|
||||||
|
RecsysModelOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
rank_logger = logging.getLogger("rank")
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_batch_from_config(
|
||||||
|
hash_config: Any,
|
||||||
|
history_len: int,
|
||||||
|
num_candidates: int,
|
||||||
|
num_actions: int,
|
||||||
|
batch_size: int = 1,
|
||||||
|
) -> RecsysBatch:
|
||||||
|
"""Create a dummy batch for initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hash_config: HashConfig with num_user_hashes, num_item_hashes, num_author_hashes
|
||||||
|
history_len: History sequence length
|
||||||
|
num_candidates: Number of candidates
|
||||||
|
num_actions: Number of action types
|
||||||
|
batch_size: Batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RecsysBatch with zeros
|
||||||
|
"""
|
||||||
|
return RecsysBatch(
|
||||||
|
user_hashes=np.zeros((batch_size, hash_config.num_user_hashes), dtype=np.int32),
|
||||||
|
history_post_hashes=np.zeros(
|
||||||
|
(batch_size, history_len, hash_config.num_item_hashes), dtype=np.int32
|
||||||
|
),
|
||||||
|
history_author_hashes=np.zeros(
|
||||||
|
(batch_size, history_len, hash_config.num_author_hashes), dtype=np.int32
|
||||||
|
),
|
||||||
|
history_actions=np.zeros((batch_size, history_len, num_actions), dtype=np.float32),
|
||||||
|
history_product_surface=np.zeros((batch_size, history_len), dtype=np.int32),
|
||||||
|
candidate_post_hashes=np.zeros(
|
||||||
|
(batch_size, num_candidates, hash_config.num_item_hashes), dtype=np.int32
|
||||||
|
),
|
||||||
|
candidate_author_hashes=np.zeros(
|
||||||
|
(batch_size, num_candidates, hash_config.num_author_hashes), dtype=np.int32
|
||||||
|
),
|
||||||
|
candidate_product_surface=np.zeros((batch_size, num_candidates), dtype=np.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_embeddings_from_config(
|
||||||
|
hash_config: Any,
|
||||||
|
emb_size: int,
|
||||||
|
history_len: int,
|
||||||
|
num_candidates: int,
|
||||||
|
batch_size: int = 1,
|
||||||
|
) -> RecsysEmbeddings:
|
||||||
|
"""Create dummy embeddings for initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hash_config: HashConfig with num_user_hashes, num_item_hashes, num_author_hashes
|
||||||
|
emb_size: Embedding dimension
|
||||||
|
history_len: History sequence length
|
||||||
|
num_candidates: Number of candidates
|
||||||
|
batch_size: Batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RecsysEmbeddings with zeros
|
||||||
|
"""
|
||||||
|
return RecsysEmbeddings(
|
||||||
|
user_embeddings=np.zeros(
|
||||||
|
(batch_size, hash_config.num_user_hashes, emb_size), dtype=np.float32
|
||||||
|
),
|
||||||
|
history_post_embeddings=np.zeros(
|
||||||
|
(batch_size, history_len, hash_config.num_item_hashes, emb_size), dtype=np.float32
|
||||||
|
),
|
||||||
|
candidate_post_embeddings=np.zeros(
|
||||||
|
(batch_size, num_candidates, hash_config.num_item_hashes, emb_size),
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
history_author_embeddings=np.zeros(
|
||||||
|
(batch_size, history_len, hash_config.num_author_hashes, emb_size), dtype=np.float32
|
||||||
|
),
|
||||||
|
candidate_author_embeddings=np.zeros(
|
||||||
|
(batch_size, num_candidates, hash_config.num_author_hashes, emb_size),
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelRunner(ABC):
|
||||||
|
"""Base class for model runners with shared initialization logic."""
|
||||||
|
|
||||||
|
bs_per_device: float = 2.0
|
||||||
|
rng_seed: int = 42
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def model(self) -> Any:
|
||||||
|
"""Return the model config."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _model_name(self) -> str:
|
||||||
|
"""Return model name for logging."""
|
||||||
|
return "model"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def make_forward_fn(self):
|
||||||
|
"""Create the forward function. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize the model runner."""
|
||||||
|
self.model.initialize()
|
||||||
|
self.model.fprop_dtype = jnp.bfloat16
|
||||||
|
num_local_gpus = len(jax.local_devices())
|
||||||
|
|
||||||
|
self.batch_size = max(1, int(self.bs_per_device * num_local_gpus))
|
||||||
|
|
||||||
|
rank_logger.info(f"Initializing {self._model_name}...")
|
||||||
|
self.forward = self.make_forward_fn()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseInferenceRunner(ABC):
|
||||||
|
"""Base class for inference runners with shared dummy data creation."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def runner(self) -> BaseModelRunner:
|
||||||
|
"""Return the underlying model runner."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_num_actions(self) -> int:
|
||||||
|
"""Get number of actions. Override in subclasses if needed."""
|
||||||
|
model_config = self.runner.model
|
||||||
|
if hasattr(model_config, "num_actions"):
|
||||||
|
return model_config.num_actions
|
||||||
|
return 19
|
||||||
|
|
||||||
|
def create_dummy_batch(self, batch_size: int = 1) -> RecsysBatch:
|
||||||
|
"""Create a dummy batch for initialization."""
|
||||||
|
model_config = self.runner.model
|
||||||
|
return create_dummy_batch_from_config(
|
||||||
|
hash_config=model_config.hash_config,
|
||||||
|
history_len=model_config.history_seq_len,
|
||||||
|
num_candidates=model_config.candidate_seq_len,
|
||||||
|
num_actions=self._get_num_actions(),
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_dummy_embeddings(self, batch_size: int = 1) -> RecsysEmbeddings:
|
||||||
|
"""Create dummy embeddings for initialization."""
|
||||||
|
model_config = self.runner.model
|
||||||
|
return create_dummy_embeddings_from_config(
|
||||||
|
hash_config=model_config.hash_config,
|
||||||
|
emb_size=model_config.emb_size,
|
||||||
|
history_len=model_config.history_seq_len,
|
||||||
|
num_candidates=model_config.candidate_seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize the inference runner. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
ACTIONS: List[str] = [
|
||||||
|
"favorite_score",
|
||||||
|
"reply_score",
|
||||||
|
"repost_score",
|
||||||
|
"photo_expand_score",
|
||||||
|
"click_score",
|
||||||
|
"profile_click_score",
|
||||||
|
"vqv_score",
|
||||||
|
"share_score",
|
||||||
|
"share_via_dm_score",
|
||||||
|
"share_via_copy_link_score",
|
||||||
|
"dwell_score",
|
||||||
|
"quote_score",
|
||||||
|
"quoted_click_score",
|
||||||
|
"follow_author_score",
|
||||||
|
"not_interested_score",
|
||||||
|
"block_author_score",
|
||||||
|
"mute_author_score",
|
||||||
|
"report_score",
|
||||||
|
"dwell_time",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RankingOutput(NamedTuple):
|
||||||
|
"""Output from ranking candidates.
|
||||||
|
|
||||||
|
Contains both the raw scores array and individual probability fields
|
||||||
|
for each engagement type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
scores: jax.Array
|
||||||
|
|
||||||
|
ranked_indices: jax.Array
|
||||||
|
|
||||||
|
p_favorite_score: jax.Array
|
||||||
|
p_reply_score: jax.Array
|
||||||
|
p_repost_score: jax.Array
|
||||||
|
p_photo_expand_score: jax.Array
|
||||||
|
p_click_score: jax.Array
|
||||||
|
p_profile_click_score: jax.Array
|
||||||
|
p_vqv_score: jax.Array
|
||||||
|
p_share_score: jax.Array
|
||||||
|
p_share_via_dm_score: jax.Array
|
||||||
|
p_share_via_copy_link_score: jax.Array
|
||||||
|
p_dwell_score: jax.Array
|
||||||
|
p_quote_score: jax.Array
|
||||||
|
p_quoted_click_score: jax.Array
|
||||||
|
p_follow_author_score: jax.Array
|
||||||
|
p_not_interested_score: jax.Array
|
||||||
|
p_block_author_score: jax.Array
|
||||||
|
p_mute_author_score: jax.Array
|
||||||
|
p_report_score: jax.Array
|
||||||
|
p_dwell_time: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelRunner(BaseModelRunner):
|
||||||
|
"""Runner for the recommendation ranking model."""
|
||||||
|
|
||||||
|
_model: PhoenixModelConfig = None # type: ignore
|
||||||
|
|
||||||
|
def __init__(self, model: PhoenixModelConfig, bs_per_device: float = 2.0, rng_seed: int = 42):
|
||||||
|
self._model = model
|
||||||
|
self.bs_per_device = bs_per_device
|
||||||
|
self.rng_seed = rng_seed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> PhoenixModelConfig:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _model_name(self) -> str:
|
||||||
|
return "ranking model"
|
||||||
|
|
||||||
|
def make_forward_fn(self): # type: ignore
|
||||||
|
def forward(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings):
|
||||||
|
out = self.model.make()(batch, recsys_embeddings)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return hk.transform(forward)
|
||||||
|
|
||||||
|
def init(
|
||||||
|
self, rng: jax.Array, data: RecsysBatch, embeddings: RecsysEmbeddings
|
||||||
|
) -> TrainingState:
|
||||||
|
assert self.forward is not None
|
||||||
|
rng, init_rng = jax.random.split(rng)
|
||||||
|
params = self.forward.init(init_rng, data, embeddings)
|
||||||
|
return TrainingState(params=params)
|
||||||
|
|
||||||
|
def load_or_init(
|
||||||
|
self,
|
||||||
|
init_data: RecsysBatch,
|
||||||
|
init_embeddings: RecsysEmbeddings,
|
||||||
|
):
|
||||||
|
rng = jax.random.PRNGKey(self.rng_seed)
|
||||||
|
state = self.init(rng, init_data, init_embeddings)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecsysInferenceRunner(BaseInferenceRunner):
|
||||||
|
"""Inference runner for the recommendation ranking model."""
|
||||||
|
|
||||||
|
_runner: ModelRunner
|
||||||
|
|
||||||
|
def __init__(self, runner: ModelRunner, name: str):
|
||||||
|
self.name = name
|
||||||
|
self._runner = runner
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runner(self) -> ModelRunner:
|
||||||
|
return self._runner
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize the inference runner."""
|
||||||
|
runner = self.runner
|
||||||
|
|
||||||
|
dummy_batch = self.create_dummy_batch(batch_size=1)
|
||||||
|
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
|
||||||
|
|
||||||
|
runner.initialize()
|
||||||
|
|
||||||
|
state = runner.load_or_init(dummy_batch, dummy_embeddings)
|
||||||
|
self.params = state.params
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def model():
|
||||||
|
return runner.model.make()
|
||||||
|
|
||||||
|
def hk_forward(
|
||||||
|
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
|
||||||
|
) -> RecsysModelOutput:
|
||||||
|
return model()(batch, recsys_embeddings)
|
||||||
|
|
||||||
|
def hk_rank_candidates(
|
||||||
|
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
|
||||||
|
) -> RankingOutput:
|
||||||
|
"""Rank candidates by their predicted engagement scores."""
|
||||||
|
output = hk_forward(batch, recsys_embeddings)
|
||||||
|
logits = output.logits
|
||||||
|
|
||||||
|
probs = jax.nn.sigmoid(logits)
|
||||||
|
|
||||||
|
primary_scores = probs[:, :, 0]
|
||||||
|
|
||||||
|
ranked_indices = jnp.argsort(-primary_scores, axis=-1)
|
||||||
|
|
||||||
|
return RankingOutput(
|
||||||
|
scores=probs,
|
||||||
|
ranked_indices=ranked_indices,
|
||||||
|
p_favorite_score=probs[:, :, 0],
|
||||||
|
p_reply_score=probs[:, :, 1],
|
||||||
|
p_repost_score=probs[:, :, 2],
|
||||||
|
p_photo_expand_score=probs[:, :, 3],
|
||||||
|
p_click_score=probs[:, :, 4],
|
||||||
|
p_profile_click_score=probs[:, :, 5],
|
||||||
|
p_vqv_score=probs[:, :, 6],
|
||||||
|
p_share_score=probs[:, :, 7],
|
||||||
|
p_share_via_dm_score=probs[:, :, 8],
|
||||||
|
p_share_via_copy_link_score=probs[:, :, 9],
|
||||||
|
p_dwell_score=probs[:, :, 10],
|
||||||
|
p_quote_score=probs[:, :, 11],
|
||||||
|
p_quoted_click_score=probs[:, :, 12],
|
||||||
|
p_follow_author_score=probs[:, :, 13],
|
||||||
|
p_not_interested_score=probs[:, :, 14],
|
||||||
|
p_block_author_score=probs[:, :, 15],
|
||||||
|
p_mute_author_score=probs[:, :, 16],
|
||||||
|
p_report_score=probs[:, :, 17],
|
||||||
|
p_dwell_time=probs[:, :, 18],
|
||||||
|
)
|
||||||
|
|
||||||
|
rank_ = hk.without_apply_rng(hk.transform(hk_rank_candidates))
|
||||||
|
self.rank_candidates = rank_.apply
|
||||||
|
|
||||||
|
def rank(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> RankingOutput:
|
||||||
|
"""Rank candidates for the given batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing hashes, actions, product surfaces
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RankingOutput with scores and ranked indices
|
||||||
|
"""
|
||||||
|
return self.rank_candidates(self.params, batch, recsys_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def create_example_batch(
|
||||||
|
batch_size: int,
|
||||||
|
emb_size: int,
|
||||||
|
history_len: int,
|
||||||
|
num_candidates: int,
|
||||||
|
num_actions: int,
|
||||||
|
num_user_hashes: int = 2,
|
||||||
|
num_item_hashes: int = 2,
|
||||||
|
num_author_hashes: int = 2,
|
||||||
|
product_surface_vocab_size: int = 16,
|
||||||
|
num_user_embeddings: int = 100000,
|
||||||
|
num_post_embeddings: int = 100000,
|
||||||
|
num_author_embeddings: int = 100000,
|
||||||
|
) -> Tuple[RecsysBatch, RecsysEmbeddings]:
|
||||||
|
"""Create an example batch with random data for testing.
|
||||||
|
|
||||||
|
This simulates a recommendation scenario where:
|
||||||
|
- We have a user with some embedding
|
||||||
|
- The user has interacted with some posts in their history
|
||||||
|
- We want to rank a set of candidate posts
|
||||||
|
|
||||||
|
Note on embedding table sizes:
|
||||||
|
The num_*_embeddings parameters define the size of the embedding tables for each
|
||||||
|
entity type. Hash values are generated in the range [1, num_*_embeddings) to ensure
|
||||||
|
they can be used as valid indices into the corresponding embedding tables.
|
||||||
|
Hash value 0 is reserved for padding/invalid entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (RecsysBatch, RecsysEmbeddings)
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
|
||||||
|
user_hashes = rng.integers(1, num_user_embeddings, size=(batch_size, num_user_hashes)).astype(
|
||||||
|
np.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
history_post_hashes = rng.integers(
|
||||||
|
1, num_post_embeddings, size=(batch_size, history_len, num_item_hashes)
|
||||||
|
).astype(np.int32)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
valid_len = rng.integers(history_len // 2, history_len + 1)
|
||||||
|
history_post_hashes[b, valid_len:, :] = 0
|
||||||
|
|
||||||
|
history_author_hashes = rng.integers(
|
||||||
|
1, num_author_embeddings, size=(batch_size, history_len, num_author_hashes)
|
||||||
|
).astype(np.int32)
|
||||||
|
for b in range(batch_size):
|
||||||
|
valid_len = rng.integers(history_len // 2, history_len + 1)
|
||||||
|
history_author_hashes[b, valid_len:, :] = 0
|
||||||
|
|
||||||
|
history_actions = (rng.random(size=(batch_size, history_len, num_actions)) > 0.7).astype(
|
||||||
|
np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
history_product_surface = rng.integers(
|
||||||
|
0, product_surface_vocab_size, size=(batch_size, history_len)
|
||||||
|
).astype(np.int32)
|
||||||
|
|
||||||
|
candidate_post_hashes = rng.integers(
|
||||||
|
1, num_post_embeddings, size=(batch_size, num_candidates, num_item_hashes)
|
||||||
|
).astype(np.int32)
|
||||||
|
|
||||||
|
candidate_author_hashes = rng.integers(
|
||||||
|
1, num_author_embeddings, size=(batch_size, num_candidates, num_author_hashes)
|
||||||
|
).astype(np.int32)
|
||||||
|
|
||||||
|
candidate_product_surface = rng.integers(
|
||||||
|
0, product_surface_vocab_size, size=(batch_size, num_candidates)
|
||||||
|
).astype(np.int32)
|
||||||
|
|
||||||
|
batch = RecsysBatch(
|
||||||
|
user_hashes=user_hashes,
|
||||||
|
history_post_hashes=history_post_hashes,
|
||||||
|
history_author_hashes=history_author_hashes,
|
||||||
|
history_actions=history_actions,
|
||||||
|
history_product_surface=history_product_surface,
|
||||||
|
candidate_post_hashes=candidate_post_hashes,
|
||||||
|
candidate_author_hashes=candidate_author_hashes,
|
||||||
|
candidate_product_surface=candidate_product_surface,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = RecsysEmbeddings(
|
||||||
|
user_embeddings=rng.normal(size=(batch_size, num_user_hashes, emb_size)).astype(np.float32),
|
||||||
|
history_post_embeddings=rng.normal(
|
||||||
|
size=(batch_size, history_len, num_item_hashes, emb_size)
|
||||||
|
).astype(np.float32),
|
||||||
|
candidate_post_embeddings=rng.normal(
|
||||||
|
size=(batch_size, num_candidates, num_item_hashes, emb_size)
|
||||||
|
).astype(np.float32),
|
||||||
|
history_author_embeddings=rng.normal(
|
||||||
|
size=(batch_size, history_len, num_author_hashes, emb_size)
|
||||||
|
).astype(np.float32),
|
||||||
|
candidate_author_embeddings=rng.normal(
|
||||||
|
size=(batch_size, num_candidates, num_author_hashes, emb_size)
|
||||||
|
).astype(np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch, embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalOutput(NamedTuple):
|
||||||
|
"""Output from retrieval inference.
|
||||||
|
|
||||||
|
Contains user representations and retrieved candidates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_representation: jax.Array
|
||||||
|
|
||||||
|
top_k_indices: jax.Array
|
||||||
|
|
||||||
|
top_k_scores: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalModelRunner(BaseModelRunner):
|
||||||
|
"""Runner for the Phoenix retrieval model."""
|
||||||
|
|
||||||
|
_model: PhoenixRetrievalModelConfig = None # type: ignore
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: PhoenixRetrievalModelConfig,
|
||||||
|
bs_per_device: float = 2.0,
|
||||||
|
rng_seed: int = 42,
|
||||||
|
):
|
||||||
|
self._model = model
|
||||||
|
self.bs_per_device = bs_per_device
|
||||||
|
self.rng_seed = rng_seed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> PhoenixRetrievalModelConfig:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _model_name(self) -> str:
|
||||||
|
return "retrieval model"
|
||||||
|
|
||||||
|
def make_forward_fn(self): # type: ignore
|
||||||
|
def forward(
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
corpus_embeddings: jax.Array,
|
||||||
|
top_k: int,
|
||||||
|
) -> ModelRetrievalOutput:
|
||||||
|
model = self.model.make()
|
||||||
|
out = model(batch, recsys_embeddings, corpus_embeddings, top_k)
|
||||||
|
|
||||||
|
_ = model.build_candidate_representation(batch, recsys_embeddings)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return hk.transform(forward)
|
||||||
|
|
||||||
|
def init(
|
||||||
|
self,
|
||||||
|
rng: jax.Array,
|
||||||
|
data: RecsysBatch,
|
||||||
|
embeddings: RecsysEmbeddings,
|
||||||
|
corpus_embeddings: jax.Array,
|
||||||
|
top_k: int,
|
||||||
|
) -> TrainingState:
|
||||||
|
assert self.forward is not None
|
||||||
|
rng, init_rng = jax.random.split(rng)
|
||||||
|
params = self.forward.init(init_rng, data, embeddings, corpus_embeddings, top_k)
|
||||||
|
return TrainingState(params=params)
|
||||||
|
|
||||||
|
def load_or_init(
|
||||||
|
self,
|
||||||
|
init_data: RecsysBatch,
|
||||||
|
init_embeddings: RecsysEmbeddings,
|
||||||
|
corpus_embeddings: jax.Array,
|
||||||
|
top_k: int,
|
||||||
|
):
|
||||||
|
rng = jax.random.PRNGKey(self.rng_seed)
|
||||||
|
state = self.init(rng, init_data, init_embeddings, corpus_embeddings, top_k)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecsysRetrievalInferenceRunner(BaseInferenceRunner):
|
||||||
|
"""Inference runner for the Phoenix retrieval model.
|
||||||
|
|
||||||
|
This runner provides methods for:
|
||||||
|
1. Encoding users to get user representations
|
||||||
|
2. Encoding candidates to get candidate embeddings
|
||||||
|
3. Retrieving top-k candidates from a corpus
|
||||||
|
"""
|
||||||
|
|
||||||
|
_runner: RetrievalModelRunner = None # type: ignore
|
||||||
|
|
||||||
|
corpus_embeddings: jax.Array | None = None
|
||||||
|
corpus_post_ids: jax.Array | None = None
|
||||||
|
|
||||||
|
def __init__(self, runner: RetrievalModelRunner, name: str):
|
||||||
|
self.name = name
|
||||||
|
self._runner = runner
|
||||||
|
self.corpus_embeddings = None
|
||||||
|
self.corpus_post_ids = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runner(self) -> RetrievalModelRunner:
|
||||||
|
return self._runner
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize the retrieval inference runner."""
|
||||||
|
runner = self.runner
|
||||||
|
|
||||||
|
dummy_batch = self.create_dummy_batch(batch_size=1)
|
||||||
|
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
|
||||||
|
dummy_corpus = jnp.zeros((10, runner.model.emb_size), dtype=jnp.float32)
|
||||||
|
dummy_top_k = 5
|
||||||
|
|
||||||
|
runner.initialize()
|
||||||
|
|
||||||
|
state = runner.load_or_init(dummy_batch, dummy_embeddings, dummy_corpus, dummy_top_k)
|
||||||
|
self.params = state.params
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def model():
|
||||||
|
return runner.model.make()
|
||||||
|
|
||||||
|
def hk_encode_user(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
|
||||||
|
"""Encode user to get user representation."""
|
||||||
|
m = model()
|
||||||
|
user_rep, _ = m.build_user_representation(batch, recsys_embeddings)
|
||||||
|
return user_rep
|
||||||
|
|
||||||
|
def hk_encode_candidates(
|
||||||
|
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Encode candidates to get candidate representations."""
|
||||||
|
m = model()
|
||||||
|
cand_rep, _ = m.build_candidate_representation(batch, recsys_embeddings)
|
||||||
|
return cand_rep
|
||||||
|
|
||||||
|
def hk_retrieve(
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
corpus_embeddings: jax.Array,
|
||||||
|
top_k: int,
|
||||||
|
) -> "RetrievalOutput":
|
||||||
|
"""Retrieve top-k candidates from corpus."""
|
||||||
|
m = model()
|
||||||
|
return m(batch, recsys_embeddings, corpus_embeddings, top_k)
|
||||||
|
|
||||||
|
encode_user_ = hk.without_apply_rng(hk.transform(hk_encode_user))
|
||||||
|
encode_candidates_ = hk.without_apply_rng(hk.transform(hk_encode_candidates))
|
||||||
|
retrieve_ = hk.without_apply_rng(hk.transform(hk_retrieve))
|
||||||
|
|
||||||
|
self.encode_user_fn = encode_user_.apply
|
||||||
|
self.encode_candidates_fn = encode_candidates_.apply
|
||||||
|
self.retrieve_fn = retrieve_.apply
|
||||||
|
|
||||||
|
def encode_user(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
|
||||||
|
"""Encode users to get user representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing user and history information
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User representations [B, D]
|
||||||
|
"""
|
||||||
|
return self.encode_user_fn(self.params, batch, recsys_embeddings)
|
||||||
|
|
||||||
|
def encode_candidates(
|
||||||
|
self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
|
||||||
|
) -> jax.Array:
|
||||||
|
"""Encode candidates to get candidate representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing candidate information
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Candidate representations [B, C, D]
|
||||||
|
"""
|
||||||
|
return self.encode_candidates_fn(self.params, batch, recsys_embeddings)
|
||||||
|
|
||||||
|
def set_corpus(
|
||||||
|
self,
|
||||||
|
corpus_embeddings: jax.Array,
|
||||||
|
corpus_post_ids: jax.Array,
|
||||||
|
):
|
||||||
|
"""Set the corpus embeddings for retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corpus_embeddings: Pre-computed candidate embeddings [N, D]
|
||||||
|
corpus_post_ids: Optional post IDs corresponding to embeddings [N]
|
||||||
|
"""
|
||||||
|
self.corpus_embeddings = corpus_embeddings
|
||||||
|
self.corpus_post_ids = corpus_post_ids
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self,
|
||||||
|
batch: RecsysBatch,
|
||||||
|
recsys_embeddings: RecsysEmbeddings,
|
||||||
|
top_k: int = 100,
|
||||||
|
corpus_embeddings: Optional[jax.Array] = None,
|
||||||
|
) -> RetrievalOutput:
|
||||||
|
"""Retrieve top-k candidates for users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: RecsysBatch containing user and history information
|
||||||
|
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
|
||||||
|
top_k: Number of candidates to retrieve per user
|
||||||
|
corpus_embeddings: Optional corpus embeddings (uses set_corpus if not provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RetrievalOutput with user representations and top-k candidates
|
||||||
|
"""
|
||||||
|
if corpus_embeddings is None:
|
||||||
|
corpus_embeddings = self.corpus_embeddings
|
||||||
|
|
||||||
|
return self.retrieve_fn(self.params, batch, recsys_embeddings, corpus_embeddings, top_k)
|
||||||
|
|
||||||
|
|
||||||
|
def create_example_corpus(
|
||||||
|
corpus_size: int,
|
||||||
|
emb_size: int,
|
||||||
|
seed: int = 123,
|
||||||
|
) -> Tuple[jax.Array, jax.Array]:
|
||||||
|
"""Create example corpus embeddings for testing retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corpus_size: Number of candidates in corpus
|
||||||
|
emb_size: Embedding dimension
|
||||||
|
seed: Random seed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (corpus_embeddings [N, D], corpus_post_ids [N])
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
corpus_embeddings = rng.normal(size=(corpus_size, emb_size)).astype(np.float32)
|
||||||
|
norms = np.linalg.norm(corpus_embeddings, axis=-1, keepdims=True)
|
||||||
|
corpus_embeddings = corpus_embeddings / np.maximum(norms, 1e-12)
|
||||||
|
|
||||||
|
corpus_post_ids = np.arange(corpus_size, dtype=np.int64)
|
||||||
|
|
||||||
|
return jnp.array(corpus_embeddings), jnp.array(corpus_post_ids)
|
||||||
187
phoenix/test_recsys_model.py
Normal file
187
phoenix/test_recsys_model.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from grok import make_recsys_attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TestMakeRecsysAttnMask:
|
||||||
|
"""Tests for the make_recsys_attn_mask function."""
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
"""Test that the output has the correct shape [1, 1, seq_len, seq_len]."""
|
||||||
|
seq_len = 10
|
||||||
|
candidate_start_offset = 5
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
|
||||||
|
assert mask.shape == (1, 1, seq_len, seq_len)
|
||||||
|
|
||||||
|
def test_user_history_has_causal_attention(self):
|
||||||
|
"""Test that user+history positions (before candidate_start_offset) have causal attention."""
|
||||||
|
seq_len = 8
|
||||||
|
candidate_start_offset = 5
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
mask_2d = mask[0, 0]
|
||||||
|
|
||||||
|
for i in range(candidate_start_offset):
|
||||||
|
for j in range(candidate_start_offset):
|
||||||
|
if j <= i:
|
||||||
|
assert mask_2d[i, j] == 1, f"Position {i} should attend to position {j}"
|
||||||
|
else:
|
||||||
|
assert mask_2d[i, j] == 0, (
|
||||||
|
f"Position {i} should NOT attend to future position {j}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_candidates_attend_to_user_history(self):
|
||||||
|
"""Test that candidates can attend to all user+history positions."""
|
||||||
|
seq_len = 8
|
||||||
|
candidate_start_offset = 5
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
mask_2d = mask[0, 0]
|
||||||
|
|
||||||
|
for candidate_pos in range(candidate_start_offset, seq_len):
|
||||||
|
for history_pos in range(candidate_start_offset):
|
||||||
|
assert mask_2d[candidate_pos, history_pos] == 1, (
|
||||||
|
f"Candidate at {candidate_pos} should attend to user+history at {history_pos}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_candidates_attend_to_themselves(self):
|
||||||
|
"""Test that candidates can attend to themselves (self-attention)."""
|
||||||
|
seq_len = 8
|
||||||
|
candidate_start_offset = 5
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
mask_2d = mask[0, 0]
|
||||||
|
|
||||||
|
for candidate_pos in range(candidate_start_offset, seq_len):
|
||||||
|
assert mask_2d[candidate_pos, candidate_pos] == 1, (
|
||||||
|
f"Candidate at {candidate_pos} should attend to itself"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_candidates_do_not_attend_to_other_candidates(self):
|
||||||
|
"""Test that candidates cannot attend to other candidates."""
|
||||||
|
seq_len = 8
|
||||||
|
candidate_start_offset = 5
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
mask_2d = mask[0, 0]
|
||||||
|
|
||||||
|
for query_pos in range(candidate_start_offset, seq_len):
|
||||||
|
for key_pos in range(candidate_start_offset, seq_len):
|
||||||
|
if query_pos != key_pos:
|
||||||
|
assert mask_2d[query_pos, key_pos] == 0, (
|
||||||
|
f"Candidate at {query_pos} should NOT attend to candidate at {key_pos}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_mask_structure(self):
|
||||||
|
"""Test the complete mask structure with a small example."""
|
||||||
|
# Sequence: [user, h1, h2, c1, c2, c3]
|
||||||
|
# Positions: 0 1 2 3 4 5
|
||||||
|
|
||||||
|
seq_len = 6
|
||||||
|
candidate_start_offset = 3
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
mask_2d = mask[0, 0]
|
||||||
|
|
||||||
|
# Expected mask structure:
|
||||||
|
# Query positions are rows, key positions are columns
|
||||||
|
# 1 = can attend, 0 = cannot attend
|
||||||
|
#
|
||||||
|
# Keys: u h1 h2 c1 c2 c3
|
||||||
|
# Query u : 1 0 0 0 0 0
|
||||||
|
# Query h1 : 1 1 0 0 0 0
|
||||||
|
# Query h2 : 1 1 1 0 0 0
|
||||||
|
# Query c1 : 1 1 1 1 0 0 <- c1 attends to user+history + self
|
||||||
|
# Query c2 : 1 1 1 0 1 0 <- c2 attends to user+history + self
|
||||||
|
# Query c3 : 1 1 1 0 0 1 <- c3 attends to user+history + self
|
||||||
|
|
||||||
|
expected = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0, 0, 0], # user
|
||||||
|
[1, 1, 0, 0, 0, 0], # h1
|
||||||
|
[1, 1, 1, 0, 0, 0], # h2
|
||||||
|
[1, 1, 1, 1, 0, 0], # c1: user+history + self
|
||||||
|
[1, 1, 1, 0, 1, 0], # c2: user+history + self
|
||||||
|
[1, 1, 1, 0, 0, 1], # c3: user+history + self
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
np.array(mask_2d),
|
||||||
|
expected,
|
||||||
|
err_msg="Full mask structure does not match expected pattern",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dtype_preserved(self):
|
||||||
|
"""Test that the specified dtype is used."""
|
||||||
|
seq_len = 5
|
||||||
|
candidate_start_offset = 3
|
||||||
|
|
||||||
|
mask_f32 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float32)
|
||||||
|
mask_f16 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float16)
|
||||||
|
|
||||||
|
assert mask_f32.dtype == jnp.float32
|
||||||
|
assert mask_f16.dtype == jnp.float16
|
||||||
|
|
||||||
|
def test_single_candidate(self):
|
||||||
|
"""Test edge case with a single candidate."""
|
||||||
|
seq_len = 4
|
||||||
|
candidate_start_offset = 3
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
mask_2d = mask[0, 0]
|
||||||
|
|
||||||
|
expected = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(np.array(mask_2d), expected)
|
||||||
|
|
||||||
|
def test_all_candidates(self):
|
||||||
|
"""Test edge case where all positions except first are candidates."""
|
||||||
|
seq_len = 4
|
||||||
|
candidate_start_offset = 1
|
||||||
|
|
||||||
|
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
|
||||||
|
mask_2d = mask[0, 0]
|
||||||
|
|
||||||
|
expected = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0], # user
|
||||||
|
[1, 1, 0, 0], # c1: user + self
|
||||||
|
[1, 0, 1, 0], # c2: user + self
|
||||||
|
[1, 0, 0, 1], # c3: user + self
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(np.array(mask_2d), expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
359
phoenix/test_recsys_retrieval_model.py
Normal file
359
phoenix/test_recsys_retrieval_model.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
# Copyright 2026 X.AI Corp.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for the Phoenix Retrieval Model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from grok import TransformerConfig
|
||||||
|
from recsys_model import HashConfig
|
||||||
|
from recsys_retrieval_model import (
|
||||||
|
CandidateTower,
|
||||||
|
PhoenixRetrievalModelConfig,
|
||||||
|
)
|
||||||
|
from runners import (
|
||||||
|
RecsysRetrievalInferenceRunner,
|
||||||
|
RetrievalModelRunner,
|
||||||
|
create_example_batch,
|
||||||
|
create_example_corpus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCandidateTower(unittest.TestCase):
|
||||||
|
"""Tests for the CandidateTower module."""
|
||||||
|
|
||||||
|
def test_candidate_tower_output_shape(self):
|
||||||
|
"""Test that candidate tower produces correct output shape."""
|
||||||
|
emb_size = 64
|
||||||
|
batch_size = 4
|
||||||
|
num_candidates = 8
|
||||||
|
num_hashes = 4
|
||||||
|
|
||||||
|
def forward(x):
|
||||||
|
tower = CandidateTower(emb_size=emb_size)
|
||||||
|
return tower(x)
|
||||||
|
|
||||||
|
forward_fn = hk.without_apply_rng(hk.transform(forward))
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))
|
||||||
|
|
||||||
|
params = forward_fn.init(rng, x)
|
||||||
|
output = forward_fn.apply(params, x)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))
|
||||||
|
|
||||||
|
def test_candidate_tower_normalized(self):
|
||||||
|
"""Test that candidate tower output is L2 normalized."""
|
||||||
|
emb_size = 64
|
||||||
|
batch_size = 4
|
||||||
|
num_candidates = 8
|
||||||
|
num_hashes = 4
|
||||||
|
|
||||||
|
def forward(x):
|
||||||
|
tower = CandidateTower(emb_size=emb_size)
|
||||||
|
return tower(x)
|
||||||
|
|
||||||
|
forward_fn = hk.without_apply_rng(hk.transform(forward))
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))
|
||||||
|
|
||||||
|
params = forward_fn.init(rng, x)
|
||||||
|
output = forward_fn.apply(params, x)
|
||||||
|
|
||||||
|
norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
|
||||||
|
np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)
|
||||||
|
|
||||||
|
def test_candidate_tower_mean_pooling(self):
|
||||||
|
"""Test candidate tower with mean pooling (no linear projection)."""
|
||||||
|
emb_size = 64
|
||||||
|
batch_size = 4
|
||||||
|
num_candidates = 8
|
||||||
|
num_hashes = 4
|
||||||
|
|
||||||
|
def forward(x):
|
||||||
|
tower = CandidateTower(emb_size=emb_size)
|
||||||
|
return tower(x)
|
||||||
|
|
||||||
|
forward_fn = hk.without_apply_rng(hk.transform(forward))
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))
|
||||||
|
|
||||||
|
params = forward_fn.init(rng, x)
|
||||||
|
output = forward_fn.apply(params, x)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))
|
||||||
|
|
||||||
|
norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
|
||||||
|
np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhoenixRetrievalModel(unittest.TestCase):
|
||||||
|
"""Tests for the full Phoenix Retrieval Model."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.emb_size = 64
|
||||||
|
self.history_seq_len = 16
|
||||||
|
self.candidate_seq_len = 8
|
||||||
|
self.batch_size = 2
|
||||||
|
self.num_actions = 19
|
||||||
|
self.corpus_size = 100
|
||||||
|
self.top_k = 10
|
||||||
|
|
||||||
|
self.hash_config = HashConfig(
|
||||||
|
num_user_hashes=2,
|
||||||
|
num_item_hashes=2,
|
||||||
|
num_author_hashes=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = PhoenixRetrievalModelConfig(
|
||||||
|
emb_size=self.emb_size,
|
||||||
|
history_seq_len=self.history_seq_len,
|
||||||
|
candidate_seq_len=self.candidate_seq_len,
|
||||||
|
hash_config=self.hash_config,
|
||||||
|
product_surface_vocab_size=16,
|
||||||
|
model=TransformerConfig(
|
||||||
|
emb_size=self.emb_size,
|
||||||
|
widening_factor=2,
|
||||||
|
key_size=32,
|
||||||
|
num_q_heads=2,
|
||||||
|
num_kv_heads=2,
|
||||||
|
num_layers=1,
|
||||||
|
attn_output_multiplier=0.125,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_test_batch(self) -> tuple:
|
||||||
|
"""Create test batch and embeddings."""
|
||||||
|
return create_example_batch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
emb_size=self.emb_size,
|
||||||
|
history_len=self.history_seq_len,
|
||||||
|
num_candidates=self.candidate_seq_len,
|
||||||
|
num_actions=self.num_actions,
|
||||||
|
num_user_hashes=self.hash_config.num_user_hashes,
|
||||||
|
num_item_hashes=self.hash_config.num_item_hashes,
|
||||||
|
num_author_hashes=self.hash_config.num_author_hashes,
|
||||||
|
product_surface_vocab_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_test_corpus(self):
|
||||||
|
"""Create test corpus embeddings."""
|
||||||
|
return create_example_corpus(self.corpus_size, self.emb_size)
|
||||||
|
|
||||||
|
def test_model_forward(self):
|
||||||
|
"""Test model forward pass produces correct output shapes."""
|
||||||
|
|
||||||
|
def forward(batch, embeddings, corpus_embeddings, top_k):
|
||||||
|
model = self.config.make()
|
||||||
|
return model(batch, embeddings, corpus_embeddings, top_k)
|
||||||
|
|
||||||
|
forward_fn = hk.without_apply_rng(hk.transform(forward))
|
||||||
|
|
||||||
|
batch, embeddings = self._create_test_batch()
|
||||||
|
corpus_embeddings, _ = self._create_test_corpus()
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
|
||||||
|
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
|
||||||
|
|
||||||
|
self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
|
||||||
|
self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
|
||||||
|
self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))
|
||||||
|
|
||||||
|
def test_user_representation_normalized(self):
|
||||||
|
"""Test that user representations are L2 normalized."""
|
||||||
|
|
||||||
|
def forward(batch, embeddings, corpus_embeddings, top_k):
|
||||||
|
model = self.config.make()
|
||||||
|
return model(batch, embeddings, corpus_embeddings, top_k)
|
||||||
|
|
||||||
|
forward_fn = hk.without_apply_rng(hk.transform(forward))
|
||||||
|
|
||||||
|
batch, embeddings = self._create_test_batch()
|
||||||
|
corpus_embeddings, _ = self._create_test_corpus()
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
|
||||||
|
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
|
||||||
|
|
||||||
|
norms = jnp.sqrt(jnp.sum(output.user_representation**2, axis=-1))
|
||||||
|
np.testing.assert_array_almost_equal(norms, jnp.ones(self.batch_size), decimal=5)
|
||||||
|
|
||||||
|
def test_candidate_representation_normalized(self):
|
||||||
|
"""Test that candidate representations from build_candidate_representation are L2 normalized."""
|
||||||
|
|
||||||
|
def forward(batch, embeddings):
|
||||||
|
model = self.config.make()
|
||||||
|
cand_rep, _ = model.build_candidate_representation(batch, embeddings)
|
||||||
|
return cand_rep
|
||||||
|
|
||||||
|
forward_fn = hk.without_apply_rng(hk.transform(forward))
|
||||||
|
|
||||||
|
batch, embeddings = self._create_test_batch()
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
params = forward_fn.init(rng, batch, embeddings)
|
||||||
|
cand_rep = forward_fn.apply(params, batch, embeddings)
|
||||||
|
|
||||||
|
norms = jnp.sqrt(jnp.sum(cand_rep**2, axis=-1))
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
norms, jnp.ones((self.batch_size, self.candidate_seq_len)), decimal=5
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_retrieve_top_k(self):
|
||||||
|
"""Test top-k retrieval through __call__."""
|
||||||
|
|
||||||
|
def forward(batch, embeddings, corpus_embeddings, top_k):
|
||||||
|
model = self.config.make()
|
||||||
|
return model(batch, embeddings, corpus_embeddings, top_k)
|
||||||
|
|
||||||
|
forward_fn = hk.without_apply_rng(hk.transform(forward))
|
||||||
|
|
||||||
|
batch, embeddings = self._create_test_batch()
|
||||||
|
corpus_embeddings, _ = self._create_test_corpus()
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
|
||||||
|
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
|
||||||
|
|
||||||
|
self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
|
||||||
|
self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))
|
||||||
|
|
||||||
|
self.assertTrue(jnp.all(output.top_k_indices >= 0))
|
||||||
|
self.assertTrue(jnp.all(output.top_k_indices < self.corpus_size))
|
||||||
|
|
||||||
|
for b in range(self.batch_size):
|
||||||
|
scores = np.array(output.top_k_scores[b])
|
||||||
|
self.assertTrue(np.all(scores[:-1] >= scores[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalInferenceRunner(unittest.TestCase):
|
||||||
|
"""Tests for the retrieval inference runner."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.emb_size = 64
|
||||||
|
self.history_seq_len = 16
|
||||||
|
self.candidate_seq_len = 8
|
||||||
|
self.batch_size = 2
|
||||||
|
self.num_actions = 19
|
||||||
|
|
||||||
|
self.hash_config = HashConfig(
|
||||||
|
num_user_hashes=2,
|
||||||
|
num_item_hashes=2,
|
||||||
|
num_author_hashes=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = PhoenixRetrievalModelConfig(
|
||||||
|
emb_size=self.emb_size,
|
||||||
|
history_seq_len=self.history_seq_len,
|
||||||
|
candidate_seq_len=self.candidate_seq_len,
|
||||||
|
hash_config=self.hash_config,
|
||||||
|
product_surface_vocab_size=16,
|
||||||
|
model=TransformerConfig(
|
||||||
|
emb_size=self.emb_size,
|
||||||
|
widening_factor=2,
|
||||||
|
key_size=32,
|
||||||
|
num_q_heads=2,
|
||||||
|
num_kv_heads=2,
|
||||||
|
num_layers=1,
|
||||||
|
attn_output_multiplier=0.125,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_runner_initialization(self):
|
||||||
|
"""Test that runner initializes correctly."""
|
||||||
|
runner = RecsysRetrievalInferenceRunner(
|
||||||
|
runner=RetrievalModelRunner(
|
||||||
|
model=self.config,
|
||||||
|
bs_per_device=0.125,
|
||||||
|
),
|
||||||
|
name="test_retrieval",
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.initialize()
|
||||||
|
|
||||||
|
self.assertIsNotNone(runner.params)
|
||||||
|
|
||||||
|
def test_runner_encode_user(self):
|
||||||
|
"""Test user encoding through runner."""
|
||||||
|
runner = RecsysRetrievalInferenceRunner(
|
||||||
|
runner=RetrievalModelRunner(
|
||||||
|
model=self.config,
|
||||||
|
bs_per_device=0.125,
|
||||||
|
),
|
||||||
|
name="test_retrieval",
|
||||||
|
)
|
||||||
|
runner.initialize()
|
||||||
|
|
||||||
|
batch, embeddings = create_example_batch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
emb_size=self.emb_size,
|
||||||
|
history_len=self.history_seq_len,
|
||||||
|
num_candidates=self.candidate_seq_len,
|
||||||
|
num_actions=self.num_actions,
|
||||||
|
num_user_hashes=self.hash_config.num_user_hashes,
|
||||||
|
num_item_hashes=self.hash_config.num_item_hashes,
|
||||||
|
num_author_hashes=self.hash_config.num_author_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_rep = runner.encode_user(batch, embeddings)
|
||||||
|
|
||||||
|
self.assertEqual(user_rep.shape, (self.batch_size, self.emb_size))
|
||||||
|
|
||||||
|
def test_runner_retrieve(self):
|
||||||
|
"""Test retrieval through runner."""
|
||||||
|
runner = RecsysRetrievalInferenceRunner(
|
||||||
|
runner=RetrievalModelRunner(
|
||||||
|
model=self.config,
|
||||||
|
bs_per_device=0.125,
|
||||||
|
),
|
||||||
|
name="test_retrieval",
|
||||||
|
)
|
||||||
|
runner.initialize()
|
||||||
|
|
||||||
|
batch, embeddings = create_example_batch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
emb_size=self.emb_size,
|
||||||
|
history_len=self.history_seq_len,
|
||||||
|
num_candidates=self.candidate_seq_len,
|
||||||
|
num_actions=self.num_actions,
|
||||||
|
num_user_hashes=self.hash_config.num_user_hashes,
|
||||||
|
num_item_hashes=self.hash_config.num_item_hashes,
|
||||||
|
num_author_hashes=self.hash_config.num_author_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
|
corpus_size = 100
|
||||||
|
corpus_embeddings, corpus_post_ids = create_example_corpus(corpus_size, self.emb_size)
|
||||||
|
runner.set_corpus(corpus_embeddings, corpus_post_ids)
|
||||||
|
|
||||||
|
top_k = 10
|
||||||
|
output = runner.retrieve(batch, embeddings, top_k=top_k)
|
||||||
|
|
||||||
|
self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
|
||||||
|
self.assertEqual(output.top_k_indices.shape, (self.batch_size, top_k))
|
||||||
|
self.assertEqual(output.top_k_scores.shape, (self.batch_size, top_k))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
372
phoenix/uv.lock
Normal file
372
phoenix/uv.lock
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
version = 1
|
||||||
|
revision = 3
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
resolution-markers = [
|
||||||
|
"python_full_version >= '3.13' and sys_platform == 'darwin'",
|
||||||
|
"python_full_version == '3.12.*' and sys_platform == 'darwin'",
|
||||||
|
"python_full_version < '3.12' and sys_platform == 'darwin'",
|
||||||
|
"python_full_version >= '3.13' and sys_platform == 'linux'",
|
||||||
|
"python_full_version == '3.12.*' and sys_platform == 'linux'",
|
||||||
|
"python_full_version < '3.12' and sys_platform == 'linux'",
|
||||||
|
]
|
||||||
|
supported-markers = [
|
||||||
|
"sys_platform == 'darwin'",
|
||||||
|
"sys_platform == 'linux'",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "absl-py"
|
||||||
|
version = "2.3.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dm-haiku"
|
||||||
|
version = "0.0.16"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "absl-py", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "jmp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "tabulate", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/2a/fc/daf4689198f4c0af8b71611f39fcd5d68ce0ae59fa919b9e58192a7d70f5/dm_haiku-0.0.16.tar.gz", hash = "sha256:1830b0ce63c5cef2fb3a63a13033c9d8f612ee7f896f2b0b25a6ba484f5fad28", size = 263092, upload-time = "2025-12-17T15:55:35.145Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/14/91/0f53835d0292a74e6b37e68125b669827e2a75a26e01c34741d6c13cca6c/dm_haiku-0.0.16-py3-none-any.whl", hash = "sha256:cc355d4d5aaa85af20e5a23ccd278bc751232ac8e5971261bed39318c07d744f", size = 374267, upload-time = "2025-12-17T15:55:33.9Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "grok-1"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = { virtual = "." }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "dm-haiku", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "jax", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "pyright", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dev-dependencies]
|
||||||
|
dev = [
|
||||||
|
{ name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.metadata]
|
||||||
|
requires-dist = [
|
||||||
|
{ name = "dm-haiku", specifier = ">=0.0.13" },
|
||||||
|
{ name = "jax", specifier = "==0.8.1" },
|
||||||
|
{ name = "numpy", specifier = ">=1.26.4" },
|
||||||
|
{ name = "pyright", specifier = ">=1.1.408" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.metadata.requires-dev]
|
||||||
|
dev = [{ name = "pytest" }]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iniconfig"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jax"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "jaxlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "ml-dtypes", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "opt-einsum", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/32/82/84fd2c662e4d410a34b0402de9b56bb69d7f72d1b875c3ae0edc07df18cc/jax-0.8.1.tar.gz", hash = "sha256:e53f67b15315f5e154851a7fd77a192b59c6c75b3f7ac56e214296765391cca7", size = 2509320, upload-time = "2025-11-18T19:50:02.609Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f9/e7/19b8cfc8963b2e10a01a4db7bb27ec5fa39ecd024bc62f8e2d1de5625a9d/jax-0.8.1-py3-none-any.whl", hash = "sha256:4cbdc5548f3095cdd69d38e4337950b2fc1f250a740a0234d190e4a319077564", size = 2922137, upload-time = "2025-11-18T19:47:43.693Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jaxlib"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "ml-dtypes", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fe/8b/9babcf487c6f1b533bca9611124c4d9593367c058a96d326c7e70db7d334/jaxlib-0.8.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f", size = 55719927, upload-time = "2025-11-18T19:48:42.679Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/df/0c/b8c67272647ea151b0ac651e43faa846b4987d971058683dcce8abf68bca/jaxlib-0.8.1-cp311-cp311-manylinux_2_27_aarch64.whl", hash = "sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f", size = 74208199, upload-time = "2025-11-18T19:48:45.848Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8f/d0/5b83d614eddb58a2cc97fb948bfeb84509b90da04e808273bf9ae89ad6c1/jaxlib-0.8.1-cp311-cp311-manylinux_2_27_x86_64.whl", hash = "sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6", size = 80247963, upload-time = "2025-11-18T19:48:49.443Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d9/9d/59b36e2f348e599d5812743f263ca54aa03be1a4c9dfc11504d19864b72d/jaxlib-0.8.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f", size = 55728156, upload-time = "2025-11-18T19:48:56.254Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7e/73/2aa891de9f5f4c60ba3c63bda97ec4ace50ffb900ff3bf750ce42c514a3b/jaxlib-0.8.1-cp312-cp312-manylinux_2_27_aarch64.whl", hash = "sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9", size = 74209108, upload-time = "2025-11-18T19:48:59.572Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/eb/4b/3c7e373d81219ee7493c1581c85a926c413ddeb3794cff87a37023a337e4/jaxlib-0.8.1-cp312-cp312-manylinux_2_27_x86_64.whl", hash = "sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6", size = 80256943, upload-time = "2025-11-18T19:49:02.92Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f8/67/97c62849b5d8fc075f902201ff136ad224a2ef113d1fa655ece0ffe8b2a4/jaxlib-0.8.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539", size = 55726611, upload-time = "2025-11-18T19:49:09.162Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fd/2a/9fb7599e43d66958b6a9859e045b605afea31f7fd96cfa35a7a8e978b0f8/jaxlib-0.8.1-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab", size = 74207596, upload-time = "2025-11-18T19:49:12.39Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7d/61/ab5c98641e15f9844dd49efbf6f22c6a9c5d17304319e5be8c51a1dfd088/jaxlib-0.8.1-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b", size = 80254560, upload-time = "2025-11-18T19:49:16.172Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/97/65/e7c625f1fdb54d45ac248d8398a28d6c02528c31feaa6e1c146a08192d77/jaxlib-0.8.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f", size = 55835933, upload-time = "2025-11-18T19:49:27.362Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1f/04/e09ff7b5ba0af93501cb196c65103a30e5050083203c1ff581f18718a356/jaxlib-0.8.1-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0", size = 74323389, upload-time = "2025-11-18T19:49:30.457Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/44/9f/8b7f6ad9eebf8946e73049dae85f86544f5743bc8b2190898415646fa7ec/jaxlib-0.8.1-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3", size = 80358249, upload-time = "2025-11-18T19:49:33.682Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/47/6d/75943de28285afcc8d62e89c3e0efc0abdb7e7a72a9e967c3555fc9a35af/jaxlib-0.8.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8", size = 55729587, upload-time = "2025-11-18T19:49:36.952Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2c/ce/9e68ca9f646039d687a94066a5e3e195fc70cebdfbe44945b3c53ceed321/jaxlib-0.8.1-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375", size = 74222294, upload-time = "2025-11-18T19:49:40.418Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3c/0f/988a413cbf610610cb14783a6e0964a854d0f388ccafe9b4e61c2c188b88/jaxlib-0.8.1-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0", size = 80268801, upload-time = "2025-11-18T19:49:44.943Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/07/9b/f6f01d79f519b0cbd09a6c751844b1e0294fc53ea0b09882466b21169ea5/jaxlib-0.8.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81", size = 55834325, upload-time = "2025-11-18T19:49:52.541Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/61/c7/13d13a6f0b0d2e91431d6a031129d51ea4b23af23bb947882234ed003f09/jaxlib-0.8.1-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389", size = 74320131, upload-time = "2025-11-18T19:49:56.208Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cd/8a/6cad418c0f11ce0cffa2b74b81fb76e6cf30247288fea75a372b6b163f2e/jaxlib-0.8.1-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a", size = 80360481, upload-time = "2025-11-18T19:50:00.065Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jmp"
|
||||||
|
version = "0.0.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ab/b0/e90fbbffef4b345329c878a69f0336d3edc5a1f9fcba193931aca2132d62/jmp-0.0.4.tar.gz", hash = "sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730", size = 18582, upload-time = "2023-01-30T12:47:13.634Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/27/e5/cce82de2831e5aff9332d8d624bb57188f1b2af6ccf6979caf898a8a4348/jmp-0.0.4-py3-none-any.whl", hash = "sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d", size = 18274, upload-time = "2023-01-30T12:47:11.931Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ml-dtypes"
|
||||||
|
version = "0.5.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c6/5e/712092cfe7e5eb667b8ad9ca7c54442f21ed7ca8979745f1000e24cf8737/ml_dtypes-0.5.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90", size = 679734, upload-time = "2025-11-17T22:31:39.223Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4f/cf/912146dfd4b5c0eea956836c01dcd2fce6c9c844b2691f5152aca196ce4f/ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040", size = 5056165, upload-time = "2025-11-17T22:31:41.071Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a9/80/19189ea605017473660e43762dc853d2797984b3c7bf30ce656099add30c/ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483", size = 5034975, upload-time = "2025-11-17T22:31:42.758Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a8/b8/3c70881695e056f8a32f8b941126cf78775d9a4d7feba8abcb52cb7b04f2/ml_dtypes-0.5.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac", size = 676927, upload-time = "2025-11-17T22:31:48.182Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/0f/428ef6881782e5ebb7eca459689448c0394fa0a80bea3aa9262cba5445ea/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900", size = 5028464, upload-time = "2025-11-17T22:31:50.135Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48", size = 676888, upload-time = "2025-11-17T22:31:56.907Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d3/b7/dff378afc2b0d5a7d6cd9d3209b60474d9819d1189d347521e1688a60a53/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b", size = 5036993, upload-time = "2025-11-17T22:31:58.497Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d", size = 5010956, upload-time = "2025-11-17T22:31:59.931Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4f/74/e9ddb35fd1dd43b1106c20ced3f53c2e8e7fc7598c15638e9f80677f81d4/ml_dtypes-0.5.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6", size = 702083, upload-time = "2025-11-17T22:32:04.08Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/74/f5/667060b0aed1aa63166b22897fdf16dca9eb704e6b4bbf86848d5a181aa7/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d", size = 5354111, upload-time = "2025-11-17T22:32:05.546Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/49/0f8c498a28c0efa5f5c95a9e374c83ec1385ca41d0e85e7cf40e5d519a21/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298", size = 5366453, upload-time = "2025-11-17T22:32:07.115Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/72/4e/1339dc6e2557a344f5ba5590872e80346f76f6cb2ac3dd16e4666e88818c/ml_dtypes-0.5.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22", size = 673781, upload-time = "2025-11-17T22:32:11.364Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/04/f9/067b84365c7e83bda15bba2b06c6ca250ce27b20630b1128c435fb7a09aa/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465", size = 5036145, upload-time = "2025-11-17T22:32:12.783Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c6/bb/82c7dcf38070b46172a517e2334e665c5bf374a262f99a283ea454bece7c/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f", size = 5010230, upload-time = "2025-11-17T22:32:14.38Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cd/02/48aa7d84cc30ab4ee37624a2fd98c56c02326785750cd212bc0826c2f15b/ml_dtypes-0.5.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9", size = 702085, upload-time = "2025-11-17T22:32:18.175Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5a/e7/85cb99fe80a7a5513253ec7faa88a65306be071163485e9a626fce1b6e84/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7", size = 5355358, upload-time = "2025-11-17T22:32:19.7Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/79/2b/a826ba18d2179a56e144aef69e57fb2ab7c464ef0b2111940ee8a3a223a2/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf", size = 5366332, upload-time = "2025-11-17T22:32:21.193Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nodeenv"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "numpy"
|
||||||
|
version = "2.4.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/24/62/ae72ff66c0f1fd959925b4c11f8c2dea61f47f6acaea75a08512cdfe3fed/numpy-2.4.1.tar.gz", hash = "sha256:a1ceafc5042451a858231588a104093474c6a5c57dcc724841f5c888d237d690", size = 20721320, upload-time = "2026-01-10T06:44:59.619Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a5/34/2b1bc18424f3ad9af577f6ce23600319968a70575bd7db31ce66731bbef9/numpy-2.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0cce2a669e3c8ba02ee563c7835f92c153cf02edff1ae05e1823f1dde21b16a5", size = 16944563, upload-time = "2026-01-10T06:42:14.615Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2c/57/26e5f97d075aef3794045a6ca9eada6a4ed70eb9a40e7a4a93f9ac80d704/numpy-2.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:899d2c18024984814ac7e83f8f49d8e8180e2fbe1b2e252f2e7f1d06bea92425", size = 12645658, upload-time = "2026-01-10T06:42:17.298Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8e/ba/80fc0b1e3cb2fd5c6143f00f42eb67762aa043eaa05ca924ecc3222a7849/numpy-2.4.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:09aa8a87e45b55a1c2c205d42e2808849ece5c484b2aab11fecabec3841cafba", size = 5474132, upload-time = "2026-01-10T06:42:19.637Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/ae/0a5b9a397f0e865ec171187c78d9b57e5588afc439a04ba9cab1ebb2c945/numpy-2.4.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:edee228f76ee2dab4579fad6f51f6a305de09d444280109e0f75df247ff21501", size = 6804159, upload-time = "2026-01-10T06:42:21.44Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/9c/841c15e691c7085caa6fd162f063eff494099c8327aeccd509d1ab1e36ab/numpy-2.4.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a92f227dbcdc9e4c3e193add1a189a9909947d4f8504c576f4a732fd0b54240a", size = 14708058, upload-time = "2026-01-10T06:42:23.546Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5d/9d/7862db06743f489e6a502a3b93136d73aea27d97b2cf91504f70a27501d6/numpy-2.4.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:538bf4ec353709c765ff75ae616c34d3c3dca1a68312727e8f2676ea644f8509", size = 16651501, upload-time = "2026-01-10T06:42:25.909Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a6/9c/6fc34ebcbd4015c6e5f0c0ce38264010ce8a546cb6beacb457b84a75dfc8/numpy-2.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ac08c63cb7779b85e9d5318e6c3518b424bc1f364ac4cb2c6136f12e5ff2dccc", size = 16492627, upload-time = "2026-01-10T06:42:28.938Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/aa/63/2494a8597502dacda439f61b3c0db4da59928150e62be0e99395c3ad23c5/numpy-2.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4f9c360ecef085e5841c539a9a12b883dff005fbd7ce46722f5e9cef52634d82", size = 18585052, upload-time = "2026-01-10T06:42:31.312Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/78/7f/ec53e32bf10c813604edf07a3682616bd931d026fcde7b6d13195dfb684a/numpy-2.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d3703409aac693fa82c0aee023a1ae06a6e9d065dba10f5e8e80f642f1e9d0a2", size = 16656888, upload-time = "2026-01-10T06:42:40.913Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b8/e0/1f9585d7dae8f14864e948fd7fa86c6cb72dee2676ca2748e63b1c5acfe0/numpy-2.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7211b95ca365519d3596a1d8688a95874cc94219d417504d9ecb2df99fa7bfa8", size = 12373956, upload-time = "2026-01-10T06:42:43.091Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8e/43/9762e88909ff2326f5e7536fa8cb3c49fb03a7d92705f23e6e7f553d9cb3/numpy-2.4.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5adf01965456a664fc727ed69cc71848f28d063217c63e1a0e200a118d5eec9a", size = 5202567, upload-time = "2026-01-10T06:42:45.107Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4b/ee/34b7930eb61e79feb4478800a4b95b46566969d837546aa7c034c742ef98/numpy-2.4.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:26f0bcd9c79a00e339565b303badc74d3ea2bd6d52191eeca5f95936cad107d0", size = 6549459, upload-time = "2026-01-10T06:42:48.152Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/79/e3/5f115fae982565771be994867c89bcd8d7208dbfe9469185497d70de5ddf/numpy-2.4.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0093e85df2960d7e4049664b26afc58b03236e967fb942354deef3208857a04c", size = 14404859, upload-time = "2026-01-10T06:42:49.947Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d9/7d/9c8a781c88933725445a859cac5d01b5871588a15969ee6aeb618ba99eee/numpy-2.4.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ad270f438cbdd402c364980317fb6b117d9ec5e226fff5b4148dd9aa9fc6e02", size = 16371419, upload-time = "2026-01-10T06:42:52.409Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a6/d2/8aa084818554543f17cf4162c42f162acbd3bb42688aefdba6628a859f77/numpy-2.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:297c72b1b98100c2e8f873d5d35fb551fce7040ade83d67dd51d38c8d42a2162", size = 16182131, upload-time = "2026-01-10T06:42:54.694Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/60/db/0425216684297c58a8df35f3284ef56ec4a043e6d283f8a59c53562caf1b/numpy-2.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cf6470d91d34bf669f61d515499859fa7a4c2f7c36434afb70e82df7217933f9", size = 18295342, upload-time = "2026-01-10T06:42:56.991Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/04/68/732d4b7811c00775f3bd522a21e8dd5a23f77eb11acdeb663e4a4ebf0ef4/numpy-2.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d797454e37570cfd61143b73b8debd623c3c0952959adb817dd310a483d58a1b", size = 16652495, upload-time = "2026-01-10T06:43:06.283Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/20/ca/857722353421a27f1465652b2c66813eeeccea9d76d5f7b74b99f298e60e/numpy-2.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:82c55962006156aeef1629b953fd359064aa47e4d82cfc8e67f0918f7da3344f", size = 12368657, upload-time = "2026-01-10T06:43:09.094Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/81/0d/2377c917513449cc6240031a79d30eb9a163d32a91e79e0da47c43f2c0c8/numpy-2.4.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:71abbea030f2cfc3092a0ff9f8c8fdefdc5e0bf7d9d9c99663538bb0ecdac0b9", size = 5197256, upload-time = "2026-01-10T06:43:13.634Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/17/39/569452228de3f5de9064ac75137082c6214be1f5c532016549a7923ab4b5/numpy-2.4.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5b55aa56165b17aaf15520beb9cbd33c9039810e0d9643dd4379e44294c7303e", size = 6545212, upload-time = "2026-01-10T06:43:15.661Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8c/a4/77333f4d1e4dac4395385482557aeecf4826e6ff517e32ca48e1dafbe42a/numpy-2.4.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0faba4a331195bfa96f93dd9dfaa10b2c7aa8cda3a02b7fd635e588fe821bf5", size = 14402871, upload-time = "2026-01-10T06:43:17.324Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ba/87/d341e519956273b39d8d47969dd1eaa1af740615394fe67d06f1efa68773/numpy-2.4.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e3087f53e2b4428766b54932644d148613c5a595150533ae7f00dab2f319a8", size = 16359305, upload-time = "2026-01-10T06:43:19.376Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/32/91/789132c6666288eaa20ae8066bb99eba1939362e8f1a534949a215246e97/numpy-2.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:49e792ec351315e16da54b543db06ca8a86985ab682602d90c60ef4ff4db2a9c", size = 16181909, upload-time = "2026-01-10T06:43:21.808Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cf/b8/090b8bd27b82a844bb22ff8fdf7935cb1980b48d6e439ae116f53cdc2143/numpy-2.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:79e9e06c4c2379db47f3f6fc7a8652e7498251789bf8ff5bd43bf478ef314ca2", size = 18284380, upload-time = "2026-01-10T06:43:23.957Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/a1/354583ac5c4caa566de6ddfbc42744409b515039e085fab6e0ff942e0df5/numpy-2.4.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f93bc6892fe7b0663e5ffa83b61aab510aacffd58c16e012bb9352d489d90cb7", size = 12496156, upload-time = "2026-01-10T06:43:34.237Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/51/b0/42807c6e8cce58c00127b1dc24d365305189991f2a7917aa694a109c8d7d/numpy-2.4.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:178de8f87948163d98a4c9ab5bee4ce6519ca918926ec8df195af582de28544d", size = 5324663, upload-time = "2026-01-10T06:43:36.211Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fe/55/7a621694010d92375ed82f312b2f28017694ed784775269115323e37f5e2/numpy-2.4.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:98b35775e03ab7f868908b524fc0a84d38932d8daf7b7e1c3c3a1b6c7a2c9f15", size = 6645224, upload-time = "2026-01-10T06:43:37.884Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/50/96/9fa8635ed9d7c847d87e30c834f7109fac5e88549d79ef3324ab5c20919f/numpy-2.4.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:941c2a93313d030f219f3a71fd3d91a728b82979a5e8034eb2e60d394a2b83f9", size = 14462352, upload-time = "2026-01-10T06:43:39.479Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/03/d1/8cf62d8bb2062da4fb82dd5d49e47c923f9c0738032f054e0a75342faba7/numpy-2.4.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:529050522e983e00a6c1c6b67411083630de8b57f65e853d7b03d9281b8694d2", size = 16407279, upload-time = "2026-01-10T06:43:41.93Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/1c/95c86e17c6b0b31ce6ef219da00f71113b220bcb14938c8d9a05cee0ff53/numpy-2.4.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2302dc0224c1cbc49bb94f7064f3f923a971bfae45c33870dcbff63a2a550505", size = 16248316, upload-time = "2026-01-10T06:43:44.121Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/30/b4/e7f5ff8697274c9d0fa82398b6a372a27e5cef069b37df6355ccb1f1db1a/numpy-2.4.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:9171a42fcad32dcf3fa86f0a4faa5e9f8facefdb276f54b8b390d90447cff4e2", size = 18329884, upload-time = "2026-01-10T06:43:46.613Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1b/a7/ef08d25698e0e4b4efbad8d55251d20fe2a15f6d9aa7c9b30cd03c165e6f/numpy-2.4.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:3869ea1ee1a1edc16c29bbe3a2f2a4e515cc3a44d43903ad41e0cacdbaf733dc", size = 16652046, upload-time = "2026-01-10T06:43:54.797Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8f/39/e378b3e3ca13477e5ac70293ec027c438d1927f18637e396fe90b1addd72/numpy-2.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e867df947d427cdd7a60e3e271729090b0f0df80f5f10ab7dd436f40811699c3", size = 12378858, upload-time = "2026-01-10T06:43:57.099Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c3/74/7ec6154f0006910ed1fdbb7591cf4432307033102b8a22041599935f8969/numpy-2.4.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:e3bd2cb07841166420d2fa7146c96ce00cb3410664cbc1a6be028e456c4ee220", size = 5207417, upload-time = "2026-01-10T06:43:59.037Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f7/b7/053ac11820d84e42f8feea5cb81cc4fcd1091499b45b1ed8c7415b1bf831/numpy-2.4.1-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:f0a90aba7d521e6954670550e561a4cb925713bd944445dbe9e729b71f6cabee", size = 6542643, upload-time = "2026-01-10T06:44:01.852Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c0/c4/2e7908915c0e32ca636b92e4e4a3bdec4cb1e7eb0f8aedf1ed3c68a0d8cd/numpy-2.4.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d558123217a83b2d1ba316b986e9248a1ed1971ad495963d555ccd75dcb1556", size = 14418963, upload-time = "2026-01-10T06:44:04.047Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/eb/c0/3ed5083d94e7ffd7c404e54619c088e11f2e1939a9544f5397f4adb1b8ba/numpy-2.4.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f44de05659b67d20499cbc96d49f2650769afcb398b79b324bb6e297bfe3844", size = 16363811, upload-time = "2026-01-10T06:44:06.207Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0e/68/42b66f1852bf525050a67315a4fb94586ab7e9eaa541b1bef530fab0c5dd/numpy-2.4.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:69e7419c9012c4aaf695109564e3387f1259f001b4326dfa55907b098af082d3", size = 16197643, upload-time = "2026-01-10T06:44:08.33Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d2/40/e8714fc933d85f82c6bfc7b998a0649ad9769a32f3494ba86598aaf18a48/numpy-2.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2ffd257026eb1b34352e749d7cc1678b5eeec3e329ad8c9965a797e08ccba205", size = 18289601, upload-time = "2026-01-10T06:44:10.841Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/de/bc/ea3f2c96fcb382311827231f911723aeff596364eb6e1b6d1d91128aa29b/numpy-2.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:4e53170557d37ae404bf8d542ca5b7c629d6efa1117dac6a83e394142ea0a43f", size = 12498774, upload-time = "2026-01-10T06:44:19.467Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/aa/ab/ef9d939fe4a812648c7a712610b2ca6140b0853c5efea361301006c02ae5/numpy-2.4.1-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:a73044b752f5d34d4232f25f18160a1cc418ea4507f5f11e299d8ac36875f8a0", size = 5327274, upload-time = "2026-01-10T06:44:23.189Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bd/31/d381368e2a95c3b08b8cf7faac6004849e960f4a042d920337f71cef0cae/numpy-2.4.1-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:fb1461c99de4d040666ca0444057b06541e5642f800b71c56e6ea92d6a853a0c", size = 6648306, upload-time = "2026-01-10T06:44:25.012Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c8/e5/0989b44ade47430be6323d05c23207636d67d7362a1796ccbccac6773dd2/numpy-2.4.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:423797bdab2eeefbe608d7c1ec7b2b4fd3c58d51460f1ee26c7500a1d9c9ee93", size = 14464653, upload-time = "2026-01-10T06:44:26.706Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/10/a7/cfbe475c35371cae1358e61f20c5f075badc18c4797ab4354140e1d283cf/numpy-2.4.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52b5f61bdb323b566b528899cc7db2ba5d1015bda7ea811a8bcf3c89c331fa42", size = 16405144, upload-time = "2026-01-10T06:44:29.378Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f8/a3/0c63fe66b534888fa5177cc7cef061541064dbe2b4b60dcc60ffaf0d2157/numpy-2.4.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:42d7dd5fa36d16d52a84f821eb96031836fd405ee6955dd732f2023724d0aa01", size = 16247425, upload-time = "2026-01-10T06:44:31.721Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6b/2b/55d980cfa2c93bd40ff4c290bf824d792bd41d2fe3487b07707559071760/numpy-2.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e7b6b5e28bbd47b7532698e5db2fe1db693d84b58c254e4389d99a27bb9b8f6b", size = 18330053, upload-time = "2026-01-10T06:44:34.617Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1e/48/d86f97919e79314a1cdee4c832178763e6e98e623e123d0bada19e92c15a/numpy-2.4.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8ad35f20be147a204e28b6a0575fbf3540c5e5f802634d4258d55b1ff5facce1", size = 16822202, upload-time = "2026-01-10T06:44:43.738Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/51/e9/1e62a7f77e0f37dcfb0ad6a9744e65df00242b6ea37dfafb55debcbf5b55/numpy-2.4.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:8097529164c0f3e32bb89412a0905d9100bf434d9692d9fc275e18dcf53c9344", size = 12569985, upload-time = "2026-01-10T06:44:45.945Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c7/7e/914d54f0c801342306fdcdce3e994a56476f1b818c46c47fc21ae968088c/numpy-2.4.1-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:ea66d2b41ca4a1630aae5507ee0a71647d3124d1741980138aa8f28f44dac36e", size = 5398484, upload-time = "2026-01-10T06:44:48.012Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1c/d8/9570b68584e293a33474e7b5a77ca404f1dcc655e40050a600dee81d27fb/numpy-2.4.1-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:d3f8f0df9f4b8be57b3bf74a1d087fec68f927a2fab68231fdb442bf2c12e426", size = 6713216, upload-time = "2026-01-10T06:44:49.725Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/33/9b/9dd6e2db8d49eb24f86acaaa5258e5f4c8ed38209a4ee9de2d1a0ca25045/numpy-2.4.1-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2023ef86243690c2791fd6353e5b4848eedaa88ca8a2d129f462049f6d484696", size = 14538937, upload-time = "2026-01-10T06:44:51.498Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/53/87/d5bd995b0f798a37105b876350d346eea5838bd8f77ea3d7a48392f3812b/numpy-2.4.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8361ea4220d763e54cff2fbe7d8c93526b744f7cd9ddab47afeff7e14e8503be", size = 16479830, upload-time = "2026-01-10T06:44:53.931Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opt-einsum"
|
||||||
|
version = "3.4.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004, upload-time = "2024-09-26T14:33:24.483Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "packaging"
|
||||||
|
version = "25.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pluggy"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pygments"
|
||||||
|
version = "2.19.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyright"
|
||||||
|
version = "1.1.408"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "nodeenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/74/b2/5db700e52554b8f025faa9c3c624c59f1f6c8841ba81ab97641b54322f16/pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684", size = 4400578, upload-time = "2026-01-08T08:07:38.795Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1", size = 6399144, upload-time = "2026-01-08T08:07:37.082Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest"
|
||||||
|
version = "9.0.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "iniconfig", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "pluggy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "pygments", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scipy"
|
||||||
|
version = "1.17.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/56/3e/9cca699f3486ce6bc12ff46dc2031f1ec8eb9ccc9a320fdaf925f1417426/scipy-1.17.0.tar.gz", hash = "sha256:2591060c8e648d8b96439e111ac41fd8342fdeff1876be2e19dea3fe8930454e", size = 30396830, upload-time = "2026-01-10T21:34:23.009Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1e/4b/c89c131aa87cad2b77a54eb0fb94d633a842420fa7e919dc2f922037c3d8/scipy-1.17.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2abd71643797bd8a106dff97894ff7869eeeb0af0f7a5ce02e4227c6a2e9d6fd", size = 31381316, upload-time = "2026-01-10T21:24:33.42Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5e/5f/a6b38f79a07d74989224d5f11b55267714707582908a5f1ae854cf9a9b84/scipy-1.17.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:ef28d815f4d2686503e5f4f00edc387ae58dfd7a2f42e348bb53359538f01558", size = 27966760, upload-time = "2026-01-10T21:24:38.911Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c1/20/095ad24e031ee8ed3c5975954d816b8e7e2abd731e04f8be573de8740885/scipy-1.17.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:272a9f16d6bb4667e8b50d25d71eddcc2158a214df1b566319298de0939d2ab7", size = 20138701, upload-time = "2026-01-10T21:24:43.249Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/89/11/4aad2b3858d0337756f3323f8960755704e530b27eb2a94386c970c32cbe/scipy-1.17.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:7204fddcbec2fe6598f1c5fdf027e9f259106d05202a959a9f1aecf036adc9f6", size = 22480574, upload-time = "2026-01-10T21:24:47.266Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/85/bd/f5af70c28c6da2227e510875cadf64879855193a687fb19951f0f44cfd6b/scipy-1.17.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fc02c37a5639ee67d8fb646ffded6d793c06c5622d36b35cfa8fe5ececb8f042", size = 32862414, upload-time = "2026-01-10T21:24:52.566Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/df/df1457c4df3826e908879fe3d76bc5b6e60aae45f4ee42539512438cfd5d/scipy-1.17.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dac97a27520d66c12a34fd90a4fe65f43766c18c0d6e1c0a80f114d2260080e4", size = 35112380, upload-time = "2026-01-10T21:24:58.433Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5f/bb/88e2c16bd1dd4de19d80d7c5e238387182993c2fb13b4b8111e3927ad422/scipy-1.17.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb7446a39b3ae0fe8f416a9a3fdc6fba3f11c634f680f16a239c5187bc487c0", size = 34922676, upload-time = "2026-01-10T21:25:04.287Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/02/ba/5120242cc735f71fc002cff0303d536af4405eb265f7c60742851e7ccfe9/scipy-1.17.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:474da16199f6af66601a01546144922ce402cb17362e07d82f5a6cf8f963e449", size = 37507599, upload-time = "2026-01-10T21:25:09.851Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/11/7241a63e73ba5a516f1930ac8d5b44cbbfabd35ac73a2d08ca206df007c4/scipy-1.17.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:0d5018a57c24cb1dd828bcf51d7b10e65986d549f52ef5adb6b4d1ded3e32a57", size = 31364580, upload-time = "2026-01-10T21:25:25.717Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/1d/5057f812d4f6adc91a20a2d6f2ebcdb517fdbc87ae3acc5633c9b97c8ba5/scipy-1.17.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:88c22af9e5d5a4f9e027e26772cc7b5922fab8bcc839edb3ae33de404feebd9e", size = 27969012, upload-time = "2026-01-10T21:25:30.921Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e3/21/f6ec556c1e3b6ec4e088da667d9987bb77cc3ab3026511f427dc8451187d/scipy-1.17.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f3cd947f20fe17013d401b64e857c6b2da83cae567adbb75b9dcba865abc66d8", size = 20140691, upload-time = "2026-01-10T21:25:34.802Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7a/fe/5e5ad04784964ba964a96f16c8d4676aa1b51357199014dce58ab7ec5670/scipy-1.17.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:e8c0b331c2c1f531eb51f1b4fc9ba709521a712cce58f1aa627bc007421a5306", size = 22463015, upload-time = "2026-01-10T21:25:39.277Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4a/69/7c347e857224fcaf32a34a05183b9d8a7aca25f8f2d10b8a698b8388561a/scipy-1.17.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5194c445d0a1c7a6c1a4a4681b6b7c71baad98ff66d96b949097e7513c9d6742", size = 32724197, upload-time = "2026-01-10T21:25:44.084Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d1/fe/66d73b76d378ba8cc2fe605920c0c75092e3a65ae746e1e767d9d020a75a/scipy-1.17.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9eeb9b5f5997f75507814ed9d298ab23f62cf79f5a3ef90031b1ee2506abdb5b", size = 35009148, upload-time = "2026-01-10T21:25:50.591Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/af/07/07dec27d9dc41c18d8c43c69e9e413431d20c53a0339c388bcf72f353c4b/scipy-1.17.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:40052543f7bbe921df4408f46003d6f01c6af109b9e2c8a66dd1cf6cf57f7d5d", size = 34798766, upload-time = "2026-01-10T21:25:59.41Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/81/61/0470810c8a093cdacd4ba7504b8a218fd49ca070d79eca23a615f5d9a0b0/scipy-1.17.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0cf46c8013fec9d3694dc572f0b54100c28405d55d3e2cb15e2895b25057996e", size = 37405953, upload-time = "2026-01-10T21:26:07.75Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0c/51/3468fdfd49387ddefee1636f5cf6d03ce603b75205bf439bbf0e62069bfd/scipy-1.17.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:65ec32f3d32dfc48c72df4291345dae4f048749bc8d5203ee0a3f347f96c5ce6", size = 31344101, upload-time = "2026-01-10T21:26:30.25Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/9a/9406aec58268d437636069419e6977af953d1e246df941d42d3720b7277b/scipy-1.17.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:1f9586a58039d7229ce77b52f8472c972448cded5736eaf102d5658bbac4c269", size = 27950385, upload-time = "2026-01-10T21:26:36.801Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4f/98/e7342709e17afdfd1b26b56ae499ef4939b45a23a00e471dfb5375eea205/scipy-1.17.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:9fad7d3578c877d606b1150135c2639e9de9cecd3705caa37b66862977cc3e72", size = 20122115, upload-time = "2026-01-10T21:26:42.107Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fd/0e/9eeeb5357a64fd157cbe0302c213517c541cc16b8486d82de251f3c68ede/scipy-1.17.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:423ca1f6584fc03936972b5f7c06961670dbba9f234e71676a7c7ccf938a0d61", size = 22442402, upload-time = "2026-01-10T21:26:48.029Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c9/10/be13397a0e434f98e0c79552b2b584ae5bb1c8b2be95db421533bbca5369/scipy-1.17.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fe508b5690e9eaaa9467fc047f833af58f1152ae51a0d0aed67aa5801f4dd7d6", size = 32696338, upload-time = "2026-01-10T21:26:55.521Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/63/1e/12fbf2a3bb240161651c94bb5cdd0eae5d4e8cc6eaeceb74ab07b12a753d/scipy-1.17.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6680f2dfd4f6182e7d6db161344537da644d1cf85cf293f015c60a17ecf08752", size = 34977201, upload-time = "2026-01-10T21:27:03.501Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/19/5b/1a63923e23ccd20bd32156d7dd708af5bbde410daa993aa2500c847ab2d2/scipy-1.17.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eec3842ec9ac9de5917899b277428886042a93db0b227ebbe3a333b64ec7643d", size = 34777384, upload-time = "2026-01-10T21:27:11.423Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/39/22/b5da95d74edcf81e540e467202a988c50fef41bd2011f46e05f72ba07df6/scipy-1.17.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d7425fcafbc09a03731e1bc05581f5fad988e48c6a861f441b7ab729a49a55ea", size = 37379586, upload-time = "2026-01-10T21:27:20.171Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/20/b6/7feaa252c21cc7aff335c6c55e1b90ab3e3306da3f048109b8b639b94648/scipy-1.17.0-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:ec0827aa4d36cb79ff1b81de898e948a51ac0b9b1c43e4a372c0508c38c0f9a3", size = 31693194, upload-time = "2026-01-10T21:27:27.454Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/76/bb/bbb392005abce039fb7e672cb78ac7d158700e826b0515cab6b5b60c26fb/scipy-1.17.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:819fc26862b4b3c73a60d486dbb919202f3d6d98c87cf20c223511429f2d1a97", size = 28365415, upload-time = "2026-01-10T21:27:34.26Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/37/da/9d33196ecc99fba16a409c691ed464a3a283ac454a34a13a3a57c0d66f3a/scipy-1.17.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:363ad4ae2853d88ebcde3ae6ec46ccca903ea9835ee8ba543f12f575e7b07e4e", size = 20537232, upload-time = "2026-01-10T21:27:40.306Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/56/9d/f4b184f6ddb28e9a5caea36a6f98e8ecd2a524f9127354087ce780885d83/scipy-1.17.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:979c3a0ff8e5ba254d45d59ebd38cde48fce4f10b5125c680c7a4bfe177aab07", size = 22791051, upload-time = "2026-01-10T21:27:46.539Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9b/9d/025cccdd738a72140efc582b1641d0dd4caf2e86c3fb127568dc80444e6e/scipy-1.17.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:130d12926ae34399d157de777472bf82e9061c60cc081372b3118edacafe1d00", size = 32815098, upload-time = "2026-01-10T21:27:54.389Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/48/5f/09b879619f8bca15ce392bfc1894bd9c54377e01d1b3f2f3b595a1b4d945/scipy-1.17.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e886000eb4919eae3a44f035e63f0fd8b651234117e8f6f29bad1cd26e7bc45", size = 35031342, upload-time = "2026-01-10T21:28:03.012Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f2/9a/f0f0a9f0aa079d2f106555b984ff0fbb11a837df280f04f71f056ea9c6e4/scipy-1.17.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:13c4096ac6bc31d706018f06a49abe0485f96499deb82066b94d19b02f664209", size = 34893199, upload-time = "2026-01-10T21:28:10.832Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/90/b8/4f0f5cf0c5ea4d7548424e6533e6b17d164f34a6e2fb2e43ffebb6697b06/scipy-1.17.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cacbaddd91fcffde703934897c5cd2c7cb0371fac195d383f4e1f1c5d3f3bd04", size = 37438061, upload-time = "2026-01-10T21:28:19.684Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1a/2d/51006cd369b8e7879e1c630999a19d1fbf6f8b5ed3e33374f29dc87e53b3/scipy-1.17.0-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:c17514d11b78be8f7e6331b983a65a7f5ca1fd037b95e27b280921fe5606286a", size = 31346803, upload-time = "2026-01-10T21:28:57.24Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/2e/2349458c3ce445f53a6c93d4386b1c4c5c0c540917304c01222ff95ff317/scipy-1.17.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:4e00562e519c09da34c31685f6acc3aa384d4d50604db0f245c14e1b4488bfa2", size = 27967182, upload-time = "2026-01-10T21:29:04.107Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5e/7c/df525fbfa77b878d1cfe625249529514dc02f4fd5f45f0f6295676a76528/scipy-1.17.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f7df7941d71314e60a481e02d5ebcb3f0185b8d799c70d03d8258f6c80f3d467", size = 20139125, upload-time = "2026-01-10T21:29:10.179Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/33/11/fcf9d43a7ed1234d31765ec643b0515a85a30b58eddccc5d5a4d12b5f194/scipy-1.17.0-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:aabf057c632798832f071a8dde013c2e26284043934f53b00489f1773b33527e", size = 22443554, upload-time = "2026-01-10T21:29:15.888Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/80/5c/ea5d239cda2dd3d31399424967a24d556cf409fbea7b5b21412b0fd0a44f/scipy-1.17.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a38c3337e00be6fd8a95b4ed66b5d988bac4ec888fd922c2ea9fe5fb1603dd67", size = 32757834, upload-time = "2026-01-10T21:29:23.406Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b8/7e/8c917cc573310e5dc91cbeead76f1b600d3fb17cf0969db02c9cf92e3cfa/scipy-1.17.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00fb5f8ec8398ad90215008d8b6009c9db9fa924fd4c7d6be307c6f945f9cd73", size = 34995775, upload-time = "2026-01-10T21:29:31.915Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c5/43/176c0c3c07b3f7df324e7cdd933d3e2c4898ca202b090bd5ba122f9fe270/scipy-1.17.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f2a4942b0f5f7c23c7cd641a0ca1955e2ae83dedcff537e3a0259096635e186b", size = 34841240, upload-time = "2026-01-10T21:29:39.995Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/44/8c/d1f5f4b491160592e7f084d997de53a8e896a3ac01cd07e59f43ca222744/scipy-1.17.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:dbf133ced83889583156566d2bdf7a07ff89228fe0c0cb727f777de92092ec6b", size = 37394463, upload-time = "2026-01-10T21:29:48.723Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e9/01/f58916b9d9ae0112b86d7c3b10b9e685625ce6e8248df139d0fcb17f7397/scipy-1.17.0-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:2b531f57e09c946f56ad0b4a3b2abee778789097871fc541e267d2eca081cff1", size = 31706502, upload-time = "2026-01-10T21:29:56.326Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/59/8e/2912a87f94a7d1f8b38aabc0faf74b82d3b6c9e22be991c49979f0eceed8/scipy-1.17.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:13e861634a2c480bd237deb69333ac79ea1941b94568d4b0efa5db5e263d4fd1", size = 28380854, upload-time = "2026-01-10T21:30:01.554Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bd/1c/874137a52dddab7d5d595c1887089a2125d27d0601fce8c0026a24a92a0b/scipy-1.17.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:eb2651271135154aa24f6481cbae5cc8af1f0dd46e6533fb7b56aa9727b6a232", size = 20552752, upload-time = "2026-01-10T21:30:05.93Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3f/f0/7518d171cb735f6400f4576cf70f756d5b419a07fe1867da34e2c2c9c11b/scipy-1.17.0-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:c5e8647f60679790c2f5c76be17e2e9247dc6b98ad0d3b065861e082c56e078d", size = 22803972, upload-time = "2026-01-10T21:30:10.651Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7c/74/3498563a2c619e8a3ebb4d75457486c249b19b5b04a30600dfd9af06bea5/scipy-1.17.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5fb10d17e649e1446410895639f3385fd2bf4c3c7dfc9bea937bddcbc3d7b9ba", size = 32829770, upload-time = "2026-01-10T21:30:16.359Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/48/d1/7b50cedd8c6c9d6f706b4b36fa8544d829c712a75e370f763b318e9638c1/scipy-1.17.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8547e7c57f932e7354a2319fab613981cde910631979f74c9b542bb167a8b9db", size = 35051093, upload-time = "2026-01-10T21:30:22.987Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e2/82/a2d684dfddb87ba1b3ea325df7c3293496ee9accb3a19abe9429bce94755/scipy-1.17.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:33af70d040e8af9d5e7a38b5ed3b772adddd281e3062ff23fec49e49681c38cf", size = 34909905, upload-time = "2026-01-10T21:30:28.704Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/5e/e565bd73991d42023eb82bb99e51c5b3d9e2c588ca9d4b3e2cc1d3ca62a6/scipy-1.17.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb55bb97d00f8b7ab95cb64f873eb0bf54d9446264d9f3609130381233483f", size = 37457743, upload-time = "2026-01-10T21:30:34.819Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tabulate"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typing-extensions"
|
||||||
|
version = "4.15.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
|
||||||
|
]
|
||||||
26
thunder/deserializer.rs
Normal file
26
thunder/deserializer.rs
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
use crate::schema::{events::Event, tweet_events::TweetEvent};
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use prost::Message;
|
||||||
|
use thrift::protocol::{TBinaryInputProtocol, TSerializable};
|
||||||
|
use xai_thunder_proto::InNetworkEvent;
|
||||||
|
|
||||||
|
/// Deserialize a Thrift binary message into TweetEvent
|
||||||
|
pub fn deserialize_tweet_event(payload: &[u8]) -> Result<TweetEvent> {
|
||||||
|
let mut cursor = std::io::Cursor::new(payload);
|
||||||
|
let mut protocol = TBinaryInputProtocol::new(&mut cursor, true);
|
||||||
|
|
||||||
|
TweetEvent::read_from_in_protocol(&mut protocol).context("Failed to deserialize TweetEvent")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserialize a Thrift binary message into Event
|
||||||
|
pub fn deserialize_event(payload: &[u8]) -> Result<Event> {
|
||||||
|
let mut cursor = std::io::Cursor::new(payload);
|
||||||
|
let mut protocol = TBinaryInputProtocol::new(&mut cursor, true);
|
||||||
|
|
||||||
|
Event::read_from_in_protocol(&mut protocol).context("Failed to deserialize Event")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserialize a proto binary message into InNetworkEvent
|
||||||
|
pub fn deserialize_tweet_event_v2(payload: &[u8]) -> Result<InNetworkEvent> {
|
||||||
|
InNetworkEvent::decode(payload).context("Failed to deserialize InNetworkEvent")
|
||||||
|
}
|
||||||
3
thunder/kafka/mod.rs
Normal file
3
thunder/kafka/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub mod tweet_events_listener;
|
||||||
|
pub mod tweet_events_listener_v2;
|
||||||
|
pub mod utils;
|
||||||
390
thunder/kafka/tweet_events_listener.rs
Normal file
390
thunder/kafka/tweet_events_listener.rs
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use log::{error, info, warn};
|
||||||
|
use prost::Message;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
|
||||||
|
use xai_kafka::{KafkaProducer, KafkaProducerConfig};
|
||||||
|
use xai_thunder_proto::{
|
||||||
|
InNetworkEvent, LightPost, TweetCreateEvent, TweetDeleteEvent, in_network_event,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
args::Args,
|
||||||
|
crate::config::MIN_VIDEO_DURATION_MS,
|
||||||
|
deserializer::deserialize_tweet_event,
|
||||||
|
kafka::utils::{create_kafka_consumer, deserialize_kafka_messages},
|
||||||
|
metrics,
|
||||||
|
schema::{tweet::Tweet, tweet_events::TweetEventData},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Counter for logging batch processing every Nth time
|
||||||
|
static BATCH_LOG_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
|
||||||
|
/// Monitor Kafka partition lag and update metrics
|
||||||
|
async fn monitor_partition_lag(
|
||||||
|
consumer: Arc<RwLock<KafkaConsumer>>,
|
||||||
|
topic: String,
|
||||||
|
interval_secs: u64,
|
||||||
|
) {
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
let consumer = consumer.read().await;
|
||||||
|
match consumer.get_partition_lags().await {
|
||||||
|
Ok(lag_info) => {
|
||||||
|
for partition_lag in lag_info {
|
||||||
|
let partition_str = partition_lag.partition_id.to_string();
|
||||||
|
|
||||||
|
metrics::KAFKA_PARTITION_LAG
|
||||||
|
.with_label_values(&[&topic, &partition_str])
|
||||||
|
.set(partition_lag.lag as f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to get partition lag info: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_eligible_video(tweet: &Tweet) -> bool {
|
||||||
|
let Some(media) = tweet.media.as_ref() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
let [first_media] = media.as_slice() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(crate::schema::tweet_media::MediaInfo::VideoInfo(video_info)) =
|
||||||
|
first_media.media_info.as_ref()
|
||||||
|
else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
video_info
|
||||||
|
.duration_millis
|
||||||
|
.map(|d| d >= MIN_VIDEO_DURATION_MS)
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the partition lag monitoring task in the background
|
||||||
|
pub fn start_partition_lag_monitor(
|
||||||
|
consumer: Arc<RwLock<KafkaConsumer>>,
|
||||||
|
topic: String,
|
||||||
|
interval_secs: u64,
|
||||||
|
) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
info!(
|
||||||
|
"Starting partition lag monitoring task for topic '{}' (interval: {}s)",
|
||||||
|
topic, interval_secs
|
||||||
|
);
|
||||||
|
monitor_partition_lag(consumer, topic, interval_secs).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the tweet event processing loop in the background with configurable number of threads
|
||||||
|
pub async fn start_tweet_event_processing(
|
||||||
|
base_config: KafkaConsumerConfig,
|
||||||
|
producer_config: KafkaProducerConfig,
|
||||||
|
args: &Args,
|
||||||
|
) {
|
||||||
|
let num_partitions = args.tweet_events_num_partitions as usize;
|
||||||
|
let kafka_num_threads = args.kafka_num_threads;
|
||||||
|
|
||||||
|
// Use all available partitions
|
||||||
|
let partitions_to_use: Vec<i32> = (0..num_partitions as i32).collect();
|
||||||
|
let partitions_per_thread = num_partitions.div_ceil(kafka_num_threads);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Starting {} message processing threads for {} partitions ({} partitions per thread)",
|
||||||
|
kafka_num_threads, num_partitions, partitions_per_thread
|
||||||
|
);
|
||||||
|
|
||||||
|
let producer = if !args.is_serving {
|
||||||
|
info!("Kafka producer enabled, starting producer...");
|
||||||
|
let producer = Arc::new(RwLock::new(KafkaProducer::new(producer_config)));
|
||||||
|
if let Err(e) = producer.write().await.start().await {
|
||||||
|
panic!("Failed to start Kafka producer: {:#}", e);
|
||||||
|
}
|
||||||
|
Some(producer)
|
||||||
|
} else {
|
||||||
|
info!("Kafka producer disabled, skipping producer initialization");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
spawn_processing_threads(base_config, partitions_to_use, producer, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn multiple processing threads, each handling a subset of partitions
|
||||||
|
fn spawn_processing_threads(
|
||||||
|
base_config: KafkaConsumerConfig,
|
||||||
|
partitions_to_use: Vec<i32>,
|
||||||
|
producer: Option<Arc<RwLock<KafkaProducer>>>,
|
||||||
|
args: &Args,
|
||||||
|
) {
|
||||||
|
let total_partitions = partitions_to_use.len();
|
||||||
|
let partitions_per_thread = total_partitions.div_ceil(args.kafka_num_threads);
|
||||||
|
|
||||||
|
for thread_id in 0..args.kafka_num_threads {
|
||||||
|
let start_idx = thread_id * partitions_per_thread;
|
||||||
|
let end_idx = ((thread_id + 1) * partitions_per_thread).min(total_partitions);
|
||||||
|
|
||||||
|
if start_idx >= total_partitions {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let thread_partitions = partitions_to_use[start_idx..end_idx].to_vec();
|
||||||
|
let mut thread_config = base_config.clone();
|
||||||
|
thread_config.partitions = Some(thread_partitions.clone());
|
||||||
|
|
||||||
|
let producer_clone = producer.as_ref().map(Arc::clone);
|
||||||
|
let topic = thread_config.base_config.topic.clone();
|
||||||
|
let lag_monitor_interval_secs = args.lag_monitor_interval_secs;
|
||||||
|
let batch_size = args.kafka_batch_size;
|
||||||
|
let post_retention_sec = args.post_retention_seconds;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
info!(
|
||||||
|
"Starting message processing thread {} for partitions {:?}",
|
||||||
|
thread_id, thread_partitions
|
||||||
|
);
|
||||||
|
|
||||||
|
match create_kafka_consumer(thread_config).await {
|
||||||
|
Ok(consumer) => {
|
||||||
|
// Start partition lag monitoring for this thread's partitions
|
||||||
|
start_partition_lag_monitor(
|
||||||
|
Arc::clone(&consumer),
|
||||||
|
topic,
|
||||||
|
lag_monitor_interval_secs,
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Err(e) = process_tweet_events(
|
||||||
|
consumer,
|
||||||
|
batch_size,
|
||||||
|
producer_clone,
|
||||||
|
post_retention_sec as i64,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
panic!(
|
||||||
|
"Tweet events processing thread {} exited unexpectedly: {:#}. This is a critical failure - the feeder cannot function without tweet event processing.",
|
||||||
|
thread_id, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
panic!(
|
||||||
|
"Failed to create consumer for thread {}: {:#}",
|
||||||
|
thread_id, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a batch of messages: deserialize, extract posts, and store them
|
||||||
|
async fn process_message_batch(
|
||||||
|
messages: Vec<KafkaMessage>,
|
||||||
|
batch_num: usize,
|
||||||
|
producer: Option<Arc<RwLock<KafkaProducer>>>,
|
||||||
|
post_retention_sec: i64,
|
||||||
|
) -> Result<()> {
|
||||||
|
let results = deserialize_kafka_messages(messages, deserialize_tweet_event)?;
|
||||||
|
|
||||||
|
let mut create_tweets = Vec::new();
|
||||||
|
let mut delete_tweets = Vec::new();
|
||||||
|
let mut first_post_id = 0;
|
||||||
|
let mut first_user_id = 0;
|
||||||
|
|
||||||
|
let len_posts = results.len();
|
||||||
|
|
||||||
|
let now_secs = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as i64;
|
||||||
|
|
||||||
|
for tweet_event in results {
|
||||||
|
let data = tweet_event.data.unwrap();
|
||||||
|
|
||||||
|
match data {
|
||||||
|
TweetEventData::TweetCreateEvent(create_event) => {
|
||||||
|
first_post_id = create_event.tweet.as_ref().unwrap().id.unwrap();
|
||||||
|
first_user_id = create_event.user.as_ref().unwrap().id.unwrap();
|
||||||
|
|
||||||
|
let tweet = create_event.tweet.as_ref().unwrap();
|
||||||
|
let core_data = tweet.core_data.as_ref().unwrap();
|
||||||
|
|
||||||
|
if let Some(nullcast) = core_data.nullcast
|
||||||
|
&& nullcast
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
create_tweets.push(LightPost {
|
||||||
|
post_id: tweet.id.unwrap(),
|
||||||
|
author_id: create_event.user.as_ref().unwrap().id.unwrap(),
|
||||||
|
created_at: core_data.created_at_secs.unwrap(),
|
||||||
|
in_reply_to_post_id: core_data
|
||||||
|
.reply
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|r| r.in_reply_to_status_id),
|
||||||
|
in_reply_to_user_id: core_data
|
||||||
|
.reply
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|r| r.in_reply_to_user_id),
|
||||||
|
is_retweet: core_data.share.is_some(),
|
||||||
|
is_reply: core_data.reply.is_some(),
|
||||||
|
source_post_id: core_data.share.as_ref().and_then(|s| s.source_status_id),
|
||||||
|
source_user_id: core_data.share.as_ref().and_then(|s| s.source_user_id),
|
||||||
|
has_video: is_eligible_video(tweet),
|
||||||
|
conversation_id: core_data.conversation_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
TweetEventData::TweetDeleteEvent(delete_event) => {
|
||||||
|
let created_at_secs = delete_event
|
||||||
|
.tweet
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.core_data
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.created_at_secs
|
||||||
|
.unwrap();
|
||||||
|
if now_secs - created_at_secs > post_retention_sec {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
delete_tweets.push(delete_event.tweet.as_ref().unwrap().id.unwrap());
|
||||||
|
}
|
||||||
|
TweetEventData::QuotedTweetDeleteEvent(delete_event) => {
|
||||||
|
delete_tweets.push(delete_event.quoting_tweet_id.unwrap());
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
log::info!("Other non post creation/deletion event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send each LightPost as an InNetworkEvent to the producer in separate tasks (only if producer is enabled)
|
||||||
|
if let Some(ref producer) = producer {
|
||||||
|
let mut send_tasks = Vec::with_capacity(create_tweets.len());
|
||||||
|
for light_post in &create_tweets {
|
||||||
|
let event = InNetworkEvent {
|
||||||
|
event_variant: Some(in_network_event::EventVariant::TweetCreateEvent(
|
||||||
|
TweetCreateEvent {
|
||||||
|
post_id: light_post.post_id,
|
||||||
|
author_id: light_post.author_id,
|
||||||
|
created_at: light_post.created_at,
|
||||||
|
in_reply_to_post_id: light_post.in_reply_to_post_id,
|
||||||
|
in_reply_to_user_id: light_post.in_reply_to_user_id,
|
||||||
|
is_retweet: light_post.is_retweet,
|
||||||
|
is_reply: light_post.is_reply,
|
||||||
|
source_post_id: light_post.source_post_id,
|
||||||
|
source_user_id: light_post.source_user_id,
|
||||||
|
has_video: light_post.has_video,
|
||||||
|
conversation_id: light_post.conversation_id,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
let payload = event.encode_to_vec();
|
||||||
|
let producer_clone = Arc::clone(producer);
|
||||||
|
send_tasks.push(tokio::spawn(async move {
|
||||||
|
let producer_lock = producer_clone.read().await;
|
||||||
|
if let Err(e) = producer_lock.send(&payload).await {
|
||||||
|
warn!("Failed to send InNetworkEvent to producer: {:#}", e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for post_id in &delete_tweets {
|
||||||
|
let event = InNetworkEvent {
|
||||||
|
event_variant: Some(in_network_event::EventVariant::TweetDeleteEvent(
|
||||||
|
TweetDeleteEvent {
|
||||||
|
post_id: *post_id,
|
||||||
|
deleted_at: now_secs,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
let payload = event.encode_to_vec();
|
||||||
|
let producer_clone = Arc::clone(producer);
|
||||||
|
send_tasks.push(tokio::spawn(async move {
|
||||||
|
let producer_lock = producer_clone.read().await;
|
||||||
|
if let Err(e) = producer_lock.send(&payload).await {
|
||||||
|
warn!("Failed to send InNetworkEvent to producer: {:#}", e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all send tasks to complete
|
||||||
|
for task in send_tasks {
|
||||||
|
if let Err(e) = task.await {
|
||||||
|
error!("Error writing to kafka {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log every 100th batch
|
||||||
|
let batch_count = BATCH_LOG_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||||
|
if batch_count.is_multiple_of(1000) {
|
||||||
|
info!(
|
||||||
|
"Batch processing milestone: processed {} batches total, latest batch {} had {} posts (first: post_id={}, user_id={})",
|
||||||
|
batch_count + 1,
|
||||||
|
batch_num,
|
||||||
|
len_posts,
|
||||||
|
first_post_id,
|
||||||
|
first_user_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main message processing loop that polls Kafka, batches messages, and stores posts
|
||||||
|
async fn process_tweet_events(
|
||||||
|
consumer: Arc<RwLock<KafkaConsumer>>,
|
||||||
|
batch_size: usize,
|
||||||
|
producer: Option<Arc<RwLock<KafkaProducer>>>,
|
||||||
|
post_retention_sec: i64,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut message_buffer = Vec::new();
|
||||||
|
let mut batch_num = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let poll_result = {
|
||||||
|
let mut consumer_lock = consumer.write().await;
|
||||||
|
consumer_lock.poll(100).await
|
||||||
|
};
|
||||||
|
|
||||||
|
match poll_result {
|
||||||
|
Ok(messages) => {
|
||||||
|
message_buffer.extend(messages);
|
||||||
|
|
||||||
|
// Process batch when we have enough messages
|
||||||
|
if message_buffer.len() >= batch_size {
|
||||||
|
batch_num += 1;
|
||||||
|
|
||||||
|
let messages = std::mem::take(&mut message_buffer);
|
||||||
|
let producer_clone = producer.clone();
|
||||||
|
|
||||||
|
// Spawn batch processing in a blocking task
|
||||||
|
process_message_batch(messages, batch_num, producer_clone, post_retention_sec)
|
||||||
|
.await
|
||||||
|
.context("Error processing tweet event batch")?;
|
||||||
|
|
||||||
|
consumer.write().await.commit_offsets()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Error polling messages: {:#}", e);
|
||||||
|
metrics::KAFKA_POLL_ERRORS.inc();
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
249
thunder/kafka/tweet_events_listener_v2.rs
Normal file
249
thunder/kafka/tweet_events_listener_v2.rs
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use log::{info, warn};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::{RwLock, Semaphore};
|
||||||
|
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
|
||||||
|
|
||||||
|
use xai_thunder_proto::{LightPost, TweetDeleteEvent, in_network_event};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
args::Args,
|
||||||
|
deserializer::deserialize_tweet_event_v2,
|
||||||
|
kafka::utils::{create_kafka_consumer, deserialize_kafka_messages},
|
||||||
|
metrics,
|
||||||
|
posts::post_store::PostStore,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Counter for logging deserialization every Nth time
|
||||||
|
static DESER_LOG_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
|
||||||
|
/// Start the tweet event processing loop in the background with configurable number of threads
|
||||||
|
pub async fn start_tweet_event_processing_v2(
|
||||||
|
base_config: KafkaConsumerConfig,
|
||||||
|
post_store: Arc<PostStore>,
|
||||||
|
args: &Args,
|
||||||
|
tx: tokio::sync::mpsc::Sender<i64>,
|
||||||
|
) {
|
||||||
|
let num_partitions = args.kafka_tweet_events_v2_num_partitions;
|
||||||
|
let kafka_num_threads = args.kafka_num_threads;
|
||||||
|
|
||||||
|
// Use all available partitions
|
||||||
|
let partitions_to_use: Vec<i32> = (0..num_partitions as i32).collect();
|
||||||
|
let partitions_per_thread = num_partitions.div_ceil(kafka_num_threads);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Starting {} message processing threads for {} partitions ({} partitions per thread)",
|
||||||
|
kafka_num_threads, num_partitions, partitions_per_thread
|
||||||
|
);
|
||||||
|
|
||||||
|
spawn_processing_threads_v2(base_config, partitions_to_use, post_store, args, tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn multiple processing threads, each handling a subset of partitions
|
||||||
|
fn spawn_processing_threads_v2(
|
||||||
|
base_config: KafkaConsumerConfig,
|
||||||
|
partitions_to_use: Vec<i32>,
|
||||||
|
post_store: Arc<PostStore>,
|
||||||
|
args: &Args,
|
||||||
|
tx: tokio::sync::mpsc::Sender<i64>,
|
||||||
|
) {
|
||||||
|
let total_partitions = partitions_to_use.len();
|
||||||
|
let partitions_per_thread = total_partitions.div_ceil(args.kafka_num_threads);
|
||||||
|
|
||||||
|
// Create shared semaphore to prevent too many tweet_events partition updates at the same time
|
||||||
|
let semaphore = Arc::new(Semaphore::new(3));
|
||||||
|
|
||||||
|
for thread_id in 0..args.kafka_num_threads {
|
||||||
|
let start_idx = thread_id * partitions_per_thread;
|
||||||
|
let end_idx = ((thread_id + 1) * partitions_per_thread).min(total_partitions);
|
||||||
|
|
||||||
|
if start_idx >= total_partitions {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let thread_partitions = partitions_to_use[start_idx..end_idx].to_vec();
|
||||||
|
let mut thread_config = base_config.clone();
|
||||||
|
thread_config.partitions = Some(thread_partitions.clone());
|
||||||
|
|
||||||
|
let post_store_clone = Arc::clone(&post_store);
|
||||||
|
let topic = thread_config.base_config.topic.clone();
|
||||||
|
let lag_monitor_interval_secs = args.lag_monitor_interval_secs;
|
||||||
|
let batch_size = args.kafka_batch_size;
|
||||||
|
let tx_clone = tx.clone();
|
||||||
|
let semaphore_clone = Arc::clone(&semaphore);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
info!(
|
||||||
|
"Starting message processing thread {} for partitions {:?}",
|
||||||
|
thread_id, thread_partitions
|
||||||
|
);
|
||||||
|
|
||||||
|
match create_kafka_consumer(thread_config).await {
|
||||||
|
Ok(consumer) => {
|
||||||
|
// Start partition lag monitoring for this thread's partitions
|
||||||
|
crate::kafka::tweet_events_listener::start_partition_lag_monitor(
|
||||||
|
Arc::clone(&consumer),
|
||||||
|
topic,
|
||||||
|
lag_monitor_interval_secs,
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Err(e) = process_tweet_events_v2(
|
||||||
|
consumer,
|
||||||
|
post_store_clone,
|
||||||
|
batch_size,
|
||||||
|
tx_clone,
|
||||||
|
semaphore_clone,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
panic!(
|
||||||
|
"Tweet events processing thread {} exited unexpectedly: {:#}. This is a critical failure - the feeder cannot function without tweet event processing.",
|
||||||
|
thread_id, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
panic!(
|
||||||
|
"Failed to create consumer for thread {}: {:#}",
|
||||||
|
thread_id, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a single batch of messages: deserialize, extract posts, and store them
|
||||||
|
fn deserialize_batch(
|
||||||
|
messages: Vec<KafkaMessage>,
|
||||||
|
) -> Result<(Vec<LightPost>, Vec<TweetDeleteEvent>)> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let num_messages = messages.len();
|
||||||
|
let results = deserialize_kafka_messages(messages, deserialize_tweet_event_v2)?;
|
||||||
|
let deser_elapsed = start_time.elapsed();
|
||||||
|
if DESER_LOG_COUNTER
|
||||||
|
.fetch_add(1, Ordering::Relaxed)
|
||||||
|
.is_multiple_of(1000)
|
||||||
|
{
|
||||||
|
info!(
|
||||||
|
"Deserialized {} messages in {:?} ({:.2} msgs/sec)",
|
||||||
|
num_messages,
|
||||||
|
deser_elapsed,
|
||||||
|
num_messages as f64 / deser_elapsed.as_secs_f64()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut create_tweets = Vec::with_capacity(results.len());
|
||||||
|
let mut delete_tweets = Vec::with_capacity(10);
|
||||||
|
|
||||||
|
for tweet_event in results {
|
||||||
|
match tweet_event.event_variant.unwrap() {
|
||||||
|
in_network_event::EventVariant::TweetCreateEvent(create_event) => {
|
||||||
|
create_tweets.push(LightPost {
|
||||||
|
post_id: create_event.post_id,
|
||||||
|
author_id: create_event.author_id,
|
||||||
|
created_at: create_event.created_at,
|
||||||
|
in_reply_to_post_id: create_event.in_reply_to_post_id,
|
||||||
|
in_reply_to_user_id: create_event.in_reply_to_user_id,
|
||||||
|
is_retweet: create_event.is_retweet,
|
||||||
|
is_reply: create_event.is_reply
|
||||||
|
|| create_event.in_reply_to_post_id.is_some()
|
||||||
|
|| create_event.in_reply_to_user_id.is_some(),
|
||||||
|
source_post_id: create_event.source_post_id,
|
||||||
|
source_user_id: create_event.source_user_id,
|
||||||
|
has_video: create_event.has_video,
|
||||||
|
conversation_id: create_event.conversation_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
in_network_event::EventVariant::TweetDeleteEvent(delete_event) => {
|
||||||
|
delete_tweets.push(delete_event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((create_tweets, delete_tweets))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main message processing loop that polls Kafka, batches messages, and stores posts
|
||||||
|
async fn process_tweet_events_v2(
|
||||||
|
consumer: Arc<RwLock<KafkaConsumer>>,
|
||||||
|
post_store: Arc<PostStore>,
|
||||||
|
batch_size: usize,
|
||||||
|
tx: tokio::sync::mpsc::Sender<i64>,
|
||||||
|
semaphore: Arc<Semaphore>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut message_buffer = Vec::new();
|
||||||
|
let mut batch_count = 0_usize;
|
||||||
|
let mut init_data_downloaded = false;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let poll_result = {
|
||||||
|
let mut consumer_lock = consumer.write().await;
|
||||||
|
consumer_lock.poll(batch_size).await
|
||||||
|
};
|
||||||
|
|
||||||
|
match poll_result {
|
||||||
|
Ok(messages) => {
|
||||||
|
let catchup_sender = if !init_data_downloaded {
|
||||||
|
let consumer_lock = consumer.read().await;
|
||||||
|
if let Ok(lags) = consumer_lock.get_partition_lags().await {
|
||||||
|
let total_lag: i64 = lags.iter().map(|l| l.lag).sum();
|
||||||
|
if total_lag < (lags.len() * batch_size) as i64 {
|
||||||
|
init_data_downloaded = true;
|
||||||
|
Some((tx.clone(), total_lag))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
message_buffer.extend(messages);
|
||||||
|
|
||||||
|
// Process batch when we have enough messages
|
||||||
|
if message_buffer.len() >= batch_size {
|
||||||
|
batch_count += 1;
|
||||||
|
let messages = std::mem::take(&mut message_buffer);
|
||||||
|
let post_store_clone = Arc::clone(&post_store);
|
||||||
|
|
||||||
|
// Acquire semaphore permit if init data is downloaded to allow enough CPU for serving requests
|
||||||
|
let permit = if init_data_downloaded {
|
||||||
|
Some(semaphore.clone().acquire_owned().await.unwrap())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send batch to blocking thread pool for processing
|
||||||
|
let _ = tokio::task::spawn_blocking(move || {
|
||||||
|
let _permit = permit; // Hold permit until task completes
|
||||||
|
match deserialize_batch(messages) {
|
||||||
|
Err(e) => warn!("Error processing batch {}: {:#}", batch_count, e),
|
||||||
|
Ok((light_posts, delete_posts)) => {
|
||||||
|
post_store_clone.insert_posts(light_posts);
|
||||||
|
post_store_clone.mark_as_deleted(delete_posts);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Some((sender, lag)) = catchup_sender {
|
||||||
|
info!("Completed kafka init for a single thread");
|
||||||
|
if let Err(e) = sender.send(lag).await {
|
||||||
|
log::error!("error sending {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Error polling messages: {:#}", e);
|
||||||
|
metrics::KAFKA_POLL_ERRORS.inc();
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
48
thunder/kafka/utils.rs
Normal file
48
thunder/kafka/utils.rs
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
|
||||||
|
|
||||||
|
use crate::metrics;
|
||||||
|
|
||||||
|
/// Create and start a Kafka consumer with the given configuration
|
||||||
|
pub async fn create_kafka_consumer(
|
||||||
|
config: KafkaConsumerConfig,
|
||||||
|
) -> Result<Arc<RwLock<KafkaConsumer>>> {
|
||||||
|
let mut consumer = KafkaConsumer::new(config);
|
||||||
|
consumer
|
||||||
|
.start()
|
||||||
|
.await
|
||||||
|
.context("Failed to start Kafka consumer")?;
|
||||||
|
|
||||||
|
Ok(Arc::new(RwLock::new(consumer)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a batch of Kafka messages and deserialize them using the provided deserializer function
|
||||||
|
pub fn deserialize_kafka_messages<T, F>(
|
||||||
|
messages: Vec<KafkaMessage>,
|
||||||
|
deserializer: F,
|
||||||
|
) -> Result<Vec<T>>
|
||||||
|
where
|
||||||
|
F: Fn(&[u8]) -> Result<T>,
|
||||||
|
{
|
||||||
|
let _timer = metrics::Timer::new(metrics::BATCH_PROCESSING_TIME.clone());
|
||||||
|
|
||||||
|
let mut kafka_data = Vec::with_capacity(messages.len());
|
||||||
|
|
||||||
|
for msg in messages.iter() {
|
||||||
|
if let Some(payload) = &msg.payload {
|
||||||
|
match deserializer(payload) {
|
||||||
|
Ok(deserialized_msg) => {
|
||||||
|
kafka_data.push(deserialized_msg);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to parse Kafka message: {}", e);
|
||||||
|
metrics::KAFKA_MESSAGES_FAILED_PARSE.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(kafka_data)
|
||||||
|
}
|
||||||
115
thunder/kafka_utils.rs
Normal file
115
thunder/kafka_utils.rs
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use xai_kafka::KafkaProducerConfig;
|
||||||
|
use xai_kafka::config::{KafkaConfig, KafkaConsumerConfig, SslConfig};
|
||||||
|
use xai_wily::WilyConfig;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
args,
|
||||||
|
kafka::{
|
||||||
|
tweet_events_listener::start_tweet_event_processing,
|
||||||
|
tweet_events_listener_v2::start_tweet_event_processing_v2,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const TWEET_EVENT_TOPIC: &str = "";
|
||||||
|
const TWEET_EVENT_DEST: &str = "";
|
||||||
|
|
||||||
|
const IN_NETWORK_EVENTS_DEST: &str = "";
|
||||||
|
const IN_NETWORK_EVENTS_TOPIC: &str = "";
|
||||||
|
|
||||||
|
pub async fn start_kafka(
|
||||||
|
args: &args::Args,
|
||||||
|
post_store: Arc<crate::posts::post_store::PostStore>,
|
||||||
|
user: &str,
|
||||||
|
tx: tokio::sync::mpsc::Sender<i64>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let sasl_password = std::env::var("")
|
||||||
|
.ok()
|
||||||
|
.or(args.sasl_password.clone())?;
|
||||||
|
|
||||||
|
let producer_sasl_password = std::env::var("")
|
||||||
|
.ok()
|
||||||
|
.or(args.producer_sasl_password.clone());
|
||||||
|
|
||||||
|
if args.is_serving {
|
||||||
|
let unique_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
let v2_tweet_events_consumer_config = KafkaConsumerConfig {
|
||||||
|
base_config: KafkaConfig {
|
||||||
|
dest: args.in_network_events_consumer_dest.clone(),
|
||||||
|
topic: IN_NETWORK_EVENTS_TOPIC.to_string(),
|
||||||
|
wily_config: Some(WilyConfig::default()),
|
||||||
|
ssl: Some(SslConfig {
|
||||||
|
security_protocol: args.security_protocol.clone(),
|
||||||
|
sasl_mechanism: Some(args.producer_sasl_mechanism.clone()),
|
||||||
|
sasl_username: Some(args.producer_sasl_username.clone()),
|
||||||
|
sasl_password: producer_sasl_password.clone(),
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
group_id: format!("{}-{}", args.kafka_group_id, unique_id),
|
||||||
|
auto_offset_reset: args.auto_offset_reset.clone(),
|
||||||
|
fetch_timeout_ms: args.fetch_timeout_ms,
|
||||||
|
max_partition_fetch_bytes: Some(1024 * 1024 * 100),
|
||||||
|
skip_to_latest: args.skip_to_latest,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start Kafka background tasks
|
||||||
|
start_tweet_event_processing_v2(
|
||||||
|
v2_tweet_events_consumer_config,
|
||||||
|
Arc::clone(&post_store),
|
||||||
|
args,
|
||||||
|
tx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only start Kafka processing and background tasks if not in serving mode
|
||||||
|
if !args.is_serving {
|
||||||
|
// Create Kafka consumer config
|
||||||
|
let tweet_events_consumer_config = KafkaConsumerConfig {
|
||||||
|
base_config: KafkaConfig {
|
||||||
|
dest: TWEET_EVENT_DEST.to_string(),
|
||||||
|
topic: TWEET_EVENT_TOPIC.to_string(),
|
||||||
|
wily_config: Some(WilyConfig::default()),
|
||||||
|
ssl: Some(SslConfig {
|
||||||
|
security_protocol: args.security_protocol.clone(),
|
||||||
|
sasl_mechanism: Some(args.sasl_mechanism.clone()),
|
||||||
|
sasl_username: Some(args.sasl_username.clone()),
|
||||||
|
sasl_password: Some(sasl_password.clone()),
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
group_id: format!("{}-{}", args.kafka_group_id, user),
|
||||||
|
auto_offset_reset: args.auto_offset_reset.clone(),
|
||||||
|
enable_auto_commit: false,
|
||||||
|
fetch_timeout_ms: args.fetch_timeout_ms,
|
||||||
|
max_partition_fetch_bytes: Some(1024 * 1024 * 10),
|
||||||
|
partitions: None,
|
||||||
|
skip_to_latest: args.skip_to_latest,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let producer_config = KafkaProducerConfig {
|
||||||
|
base_config: KafkaConfig {
|
||||||
|
dest: IN_NETWORK_EVENTS_DEST.to_string(),
|
||||||
|
topic: IN_NETWORK_EVENTS_TOPIC.to_string(),
|
||||||
|
wily_config: Some(WilyConfig::default()),
|
||||||
|
ssl: Some(SslConfig {
|
||||||
|
security_protocol: args.security_protocol.clone(),
|
||||||
|
sasl_mechanism: Some(args.producer_sasl_mechanism.clone()),
|
||||||
|
sasl_username: Some(args.producer_sasl_username.clone()),
|
||||||
|
sasl_password: producer_sasl_password.clone(),
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
start_tweet_event_processing(tweet_events_consumer_config, producer_config, args).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
11
thunder/lib.rs
Normal file
11
thunder/lib.rs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
pub mod args;
|
||||||
|
pub mod config;
|
||||||
|
pub mod deserializer;
|
||||||
|
pub mod kafka;
|
||||||
|
pub mod kafka_utils;
|
||||||
|
pub mod metrics;
|
||||||
|
pub mod o2;
|
||||||
|
pub mod posts;
|
||||||
|
pub mod schema;
|
||||||
|
pub mod strato_client;
|
||||||
|
pub mod thunder_service;
|
||||||
100
thunder/main.rs
Normal file
100
thunder/main.rs
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use axum::Router;
|
||||||
|
use clap::Parser;
|
||||||
|
use log::info;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tonic::service::Routes;
|
||||||
|
use xai_http_server::{CancellationToken, GrpcConfig, HttpServer};
|
||||||
|
|
||||||
|
use thunder::{
|
||||||
|
args, kafka_utils, posts::post_store::PostStore, strato_client::StratoClient,
|
||||||
|
thunder_service::ThunderServiceImpl,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
env_logger::init();
|
||||||
|
let args = args::Args::parse();
|
||||||
|
|
||||||
|
// Initialize PostStore
|
||||||
|
let post_store = Arc::new(PostStore::new(
|
||||||
|
args.post_retention_seconds,
|
||||||
|
args.request_timeout_ms,
|
||||||
|
));
|
||||||
|
info!(
|
||||||
|
"Initialized PostStore for in-memory post storage (retention: {} seconds / {:.1} days, request_timeout: {}ms)",
|
||||||
|
args.post_retention_seconds,
|
||||||
|
args.post_retention_seconds as f64 / 86400.0,
|
||||||
|
args.request_timeout_ms
|
||||||
|
);
|
||||||
|
|
||||||
|
// Initialize StratoClient for fetching following lists
|
||||||
|
let strato_client = Arc::new(StratoClient::new());
|
||||||
|
info!("Initialized StratoClient");
|
||||||
|
|
||||||
|
// Create ThunderService with the PostStore, StratoClient, and concurrency limit
|
||||||
|
let thunder_service = ThunderServiceImpl::new(
|
||||||
|
Arc::clone(&post_store),
|
||||||
|
Arc::clone(&strato_client),
|
||||||
|
args.max_concurrent_requests,
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
"Initialized with max_concurrent_requests={}",
|
||||||
|
args.max_concurrent_requests
|
||||||
|
);
|
||||||
|
let routes = Routes::new(thunder_service.server());
|
||||||
|
|
||||||
|
// Set up gRPC config
|
||||||
|
let grpc_config = GrpcConfig::new(args.grpc_port, routes);
|
||||||
|
|
||||||
|
// Create HTTP server with gRPC support
|
||||||
|
let mut http_server = HttpServer::new(
|
||||||
|
args.http_port,
|
||||||
|
Router::new(),
|
||||||
|
Some(grpc_config),
|
||||||
|
CancellationToken::new(),
|
||||||
|
Duration::from_secs(10),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.context("Failed to create HTTP server")?;
|
||||||
|
|
||||||
|
if args.enable_profiling {
|
||||||
|
xai_profiling::spawn_server(3000, CancellationToken::new()).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create channel for post events
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<i64>(args.kafka_num_threads);
|
||||||
|
kafka_utils::start_kafka(&args, post_store.clone(), "", tx).await?;
|
||||||
|
|
||||||
|
if args.is_serving {
|
||||||
|
// Wait for Kafka catchup signal
|
||||||
|
let start = Instant::now();
|
||||||
|
for _ in 0..args.kafka_num_threads {
|
||||||
|
rx.recv().await;
|
||||||
|
}
|
||||||
|
info!("Kafka init took {:?}", start.elapsed());
|
||||||
|
|
||||||
|
post_store.finalize_init().await?;
|
||||||
|
|
||||||
|
// Start stats logger
|
||||||
|
Arc::clone(&post_store).start_stats_logger();
|
||||||
|
info!("Started PostStore stats logger",);
|
||||||
|
|
||||||
|
// Start auto-trim task to remove posts older than retention period
|
||||||
|
Arc::clone(&post_store).start_auto_trim(2); // Run every 2 minutes
|
||||||
|
info!(
|
||||||
|
"Started PostStore auto-trim task (interval: 2 minutes, retention: {:.1} days)",
|
||||||
|
args.post_retention_seconds as f64 / 86400.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
http_server.set_readiness(true);
|
||||||
|
info!("HTTP/gRPC server is ready");
|
||||||
|
|
||||||
|
// Wait for termination signal
|
||||||
|
http_server.wait_for_termination().await;
|
||||||
|
info!("Server terminated");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
1
thunder/posts/mod.rs
Normal file
1
thunder/posts/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod post_store;
|
||||||
526
thunder/posts/post_store.rs
Normal file
526
thunder/posts/post_store.rs
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use log::info;
|
||||||
|
use std::collections::{HashSet, VecDeque};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use xai_thunder_proto::{LightPost, TweetDeleteEvent};
|
||||||
|
|
||||||
|
use crate::config::{
|
||||||
|
DELETE_EVENT_KEY, MAX_ORIGINAL_POSTS_PER_AUTHOR, MAX_REPLY_POSTS_PER_AUTHOR,
|
||||||
|
MAX_TINY_POSTS_PER_USER_SCAN, MAX_VIDEO_POSTS_PER_AUTHOR,
|
||||||
|
};
|
||||||
|
use crate::metrics::{
|
||||||
|
POST_STORE_DELETED_POSTS, POST_STORE_DELETED_POSTS_FILTERED, POST_STORE_ENTITY_COUNT,
|
||||||
|
POST_STORE_POSTS_RETURNED, POST_STORE_POSTS_RETURNED_RATIO, POST_STORE_REQUEST_TIMEOUTS,
|
||||||
|
POST_STORE_REQUESTS, POST_STORE_TOTAL_POSTS, POST_STORE_USER_COUNT,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Minimal post reference stored in user timelines (only ID and timestamp)
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub struct TinyPost {
|
||||||
|
pub post_id: i64,
|
||||||
|
pub created_at: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TinyPost {
|
||||||
|
/// Create a new TinyPost from a post ID and creation timestamp
|
||||||
|
pub fn new(post_id: i64, created_at: i64) -> Self {
|
||||||
|
TinyPost {
|
||||||
|
post_id,
|
||||||
|
created_at,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A thread-safe store for posts grouped by user ID
|
||||||
|
/// Note: LightPost is now defined in the protobuf schema (in-network.proto)
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PostStore {
|
||||||
|
/// Full post data indexed by post_id
|
||||||
|
posts: Arc<DashMap<i64, LightPost>>,
|
||||||
|
/// Maps user_id to a deque of TinyPost references for original posts (non-reply, non-retweet)
|
||||||
|
original_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
|
||||||
|
/// Maps user_id to a deque of TinyPost references for replies and retweets
|
||||||
|
secondary_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
|
||||||
|
/// Maps user_id to a deque of TinyPost references for video posts
|
||||||
|
video_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
|
||||||
|
deleted_posts: Arc<DashMap<i64, bool>>,
|
||||||
|
/// Retention period for posts in seconds
|
||||||
|
retention_seconds: u64,
|
||||||
|
/// Request timeout for get_posts_by_users iteration (0 = no timeout)
|
||||||
|
request_timeout: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PostStore {
|
||||||
|
/// Creates a new empty PostStore with the specified retention period and request timeout
|
||||||
|
pub fn new(retention_seconds: u64, request_timeout_ms: u64) -> Self {
|
||||||
|
PostStore {
|
||||||
|
posts: Arc::new(DashMap::new()),
|
||||||
|
original_posts_by_user: Arc::new(DashMap::new()),
|
||||||
|
secondary_posts_by_user: Arc::new(DashMap::new()),
|
||||||
|
video_posts_by_user: Arc::new(DashMap::new()),
|
||||||
|
deleted_posts: Arc::new(DashMap::new()),
|
||||||
|
retention_seconds,
|
||||||
|
request_timeout: Duration::from_millis(request_timeout_ms),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mark_as_deleted(&self, posts: Vec<TweetDeleteEvent>) {
|
||||||
|
for post in posts.into_iter() {
|
||||||
|
self.posts.remove(&post.post_id);
|
||||||
|
self.deleted_posts.insert(post.post_id, true);
|
||||||
|
|
||||||
|
let mut user_posts_entry = self
|
||||||
|
.original_posts_by_user
|
||||||
|
.entry(DELETE_EVENT_KEY)
|
||||||
|
.or_default();
|
||||||
|
user_posts_entry.push_back(TinyPost {
|
||||||
|
post_id: post.post_id,
|
||||||
|
created_at: post.deleted_at,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inserts posts into the post store
|
||||||
|
pub fn insert_posts(&self, mut posts: Vec<LightPost>) {
|
||||||
|
// Filter to keep only posts created in the last retention_seconds and not from the future
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as i64;
|
||||||
|
posts.retain(|p| {
|
||||||
|
p.created_at < current_time
|
||||||
|
&& current_time - p.created_at <= (self.retention_seconds as i64)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Sort remaining posts by created_at timestamp
|
||||||
|
posts.sort_unstable_by_key(|p| p.created_at);
|
||||||
|
|
||||||
|
Self::insert_posts_internal(self, posts);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn finalize_init(&self) -> Result<()> {
|
||||||
|
self.sort_all_user_posts().await;
|
||||||
|
self.trim_old_posts().await;
|
||||||
|
|
||||||
|
// This is needed because order of create_event/delete_event can be be lost in the feeder
|
||||||
|
for entry in self.deleted_posts.iter() {
|
||||||
|
self.posts.remove(entry.key());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_posts_internal(&self, posts: Vec<LightPost>) {
|
||||||
|
for post in posts {
|
||||||
|
let post_id = post.post_id;
|
||||||
|
let author_id = post.author_id;
|
||||||
|
let created_at = post.created_at;
|
||||||
|
let is_original = !post.is_reply && !post.is_retweet;
|
||||||
|
|
||||||
|
if self.deleted_posts.contains_key(&post_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the full post data
|
||||||
|
let old = self.posts.insert(post_id, post);
|
||||||
|
if old.is_some() {
|
||||||
|
// if already stored - don't add it again
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a TinyPost reference for the timeline
|
||||||
|
let tiny_post = TinyPost::new(post_id, created_at);
|
||||||
|
|
||||||
|
// Use entry API to get mutable access to the appropriate user's posts timeline
|
||||||
|
if is_original {
|
||||||
|
let mut user_posts_entry =
|
||||||
|
self.original_posts_by_user.entry(author_id).or_default();
|
||||||
|
user_posts_entry.push_back(tiny_post.clone());
|
||||||
|
} else {
|
||||||
|
let mut user_posts_entry =
|
||||||
|
self.secondary_posts_by_user.entry(author_id).or_default();
|
||||||
|
user_posts_entry.push_back(tiny_post.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut video_eligible = post.has_video;
|
||||||
|
|
||||||
|
// If this is a retweet and the retweeted post has video, mark has_video as true
|
||||||
|
if !video_eligible
|
||||||
|
&& post.is_retweet
|
||||||
|
&& let Some(source_post_id) = post.source_post_id
|
||||||
|
&& let Some(source_post) = self.posts.get(&source_post_id)
|
||||||
|
{
|
||||||
|
video_eligible = !source_post.is_reply && source_post.has_video;
|
||||||
|
}
|
||||||
|
|
||||||
|
if post.is_reply {
|
||||||
|
video_eligible = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also add to video posts timeline if post has video
|
||||||
|
if video_eligible {
|
||||||
|
let mut user_posts_entry = self.video_posts_by_user.entry(author_id).or_default();
|
||||||
|
user_posts_entry.push_back(tiny_post);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieves video posts from multiple users
|
||||||
|
pub fn get_videos_by_users(
|
||||||
|
&self,
|
||||||
|
user_ids: &[i64],
|
||||||
|
exclude_tweet_ids: &HashSet<i64>,
|
||||||
|
start_time: Instant,
|
||||||
|
request_user_id: i64,
|
||||||
|
) -> Vec<LightPost> {
|
||||||
|
let video_posts = self.get_posts_from_map(
|
||||||
|
&self.video_posts_by_user,
|
||||||
|
user_ids,
|
||||||
|
MAX_VIDEO_POSTS_PER_AUTHOR,
|
||||||
|
exclude_tweet_ids,
|
||||||
|
&HashSet::new(),
|
||||||
|
start_time,
|
||||||
|
request_user_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
POST_STORE_POSTS_RETURNED.observe(video_posts.len() as f64);
|
||||||
|
video_posts
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieves all posts from multiple users
|
||||||
|
pub fn get_all_posts_by_users(
|
||||||
|
&self,
|
||||||
|
user_ids: &[i64],
|
||||||
|
exclude_tweet_ids: &HashSet<i64>,
|
||||||
|
start_time: Instant,
|
||||||
|
request_user_id: i64,
|
||||||
|
) -> Vec<LightPost> {
|
||||||
|
let following_users_set: HashSet<i64> = user_ids.iter().copied().collect();
|
||||||
|
|
||||||
|
let mut all_posts = self.get_posts_from_map(
|
||||||
|
&self.original_posts_by_user,
|
||||||
|
user_ids,
|
||||||
|
MAX_ORIGINAL_POSTS_PER_AUTHOR,
|
||||||
|
exclude_tweet_ids,
|
||||||
|
&HashSet::new(),
|
||||||
|
start_time,
|
||||||
|
request_user_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let secondary_posts = self.get_posts_from_map(
|
||||||
|
&self.secondary_posts_by_user,
|
||||||
|
user_ids,
|
||||||
|
MAX_REPLY_POSTS_PER_AUTHOR,
|
||||||
|
exclude_tweet_ids,
|
||||||
|
&following_users_set,
|
||||||
|
start_time,
|
||||||
|
request_user_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
all_posts.extend(secondary_posts);
|
||||||
|
POST_STORE_POSTS_RETURNED.observe(all_posts.len() as f64);
|
||||||
|
all_posts
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn get_posts_from_map(
|
||||||
|
&self,
|
||||||
|
posts_map: &Arc<DashMap<i64, VecDeque<TinyPost>>>,
|
||||||
|
user_ids: &[i64],
|
||||||
|
max_per_user: usize,
|
||||||
|
exclude_tweet_ids: &HashSet<i64>,
|
||||||
|
following_users: &HashSet<i64>,
|
||||||
|
start_time: Instant,
|
||||||
|
request_user_id: i64,
|
||||||
|
) -> Vec<LightPost> {
|
||||||
|
POST_STORE_REQUESTS.inc();
|
||||||
|
let mut light_posts = Vec::new();
|
||||||
|
|
||||||
|
let mut total_eligible: usize = 0;
|
||||||
|
|
||||||
|
for (i, user_id) in user_ids.iter().enumerate() {
|
||||||
|
if !self.request_timeout.is_zero() && start_time.elapsed() >= self.request_timeout {
|
||||||
|
log::error!(
|
||||||
|
"Timed out fetching posts for user={}; Processed: {}/{}. Stage: {}",
|
||||||
|
request_user_id,
|
||||||
|
i,
|
||||||
|
user_ids.len(),
|
||||||
|
if following_users.is_empty() {
|
||||||
|
"original"
|
||||||
|
} else {
|
||||||
|
"secondary"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
POST_STORE_REQUEST_TIMEOUTS.inc();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(user_posts_ref) = posts_map.get(user_id) {
|
||||||
|
let user_posts = user_posts_ref.value();
|
||||||
|
total_eligible += user_posts.len();
|
||||||
|
|
||||||
|
// Start from newest posts (reverse iterator)
|
||||||
|
// Take a capped number to prevent from going all the way back to when user is inactive
|
||||||
|
let tiny_posts_iter = user_posts
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.filter(|post| !exclude_tweet_ids.contains(&post.post_id))
|
||||||
|
.take(MAX_TINY_POSTS_PER_USER_SCAN);
|
||||||
|
|
||||||
|
// Perform light doc lookup to get full LightPost data. This will also filter deleted posts
|
||||||
|
// Note: We copy the value immediately to release the read lock and avoid potential
|
||||||
|
// deadlock when acquiring nested read locks while a writer is waiting.
|
||||||
|
let light_post_iter_1 = tiny_posts_iter
|
||||||
|
.filter_map(|tiny_post| self.posts.get(&tiny_post.post_id).map(|r| *r.value()));
|
||||||
|
|
||||||
|
let light_post_iter = light_post_iter_1.filter(|post| {
|
||||||
|
if self.deleted_posts.get(&post.post_id).is_some() {
|
||||||
|
POST_STORE_DELETED_POSTS_FILTERED.inc();
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let light_post_iter = light_post_iter.filter(|post| {
|
||||||
|
!(post.is_retweet && post.source_user_id == Some(request_user_id))
|
||||||
|
});
|
||||||
|
|
||||||
|
let filtered_post_iter = light_post_iter.filter(|post| {
|
||||||
|
if following_users.is_empty() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
post.in_reply_to_post_id.is_none_or(|reply_to_post_id| {
|
||||||
|
if let Some(replied_to_post) = self.posts.get(&reply_to_post_id) {
|
||||||
|
if !replied_to_post.is_retweet && !replied_to_post.is_reply {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return post.conversation_id.is_some_and(|convo_id| {
|
||||||
|
let reply_to_reply_to_original =
|
||||||
|
replied_to_post.in_reply_to_post_id == Some(convo_id);
|
||||||
|
let reply_to_followed_user = post
|
||||||
|
.in_reply_to_user_id
|
||||||
|
.map(|uid| following_users.contains(&uid))
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
reply_to_reply_to_original && reply_to_followed_user
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
light_posts.extend(filtered_post_iter.take(max_per_user));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track ratio of returned posts to eligible posts
|
||||||
|
if total_eligible > 0 {
|
||||||
|
let ratio = light_posts.len() as f64 / total_eligible as f64;
|
||||||
|
POST_STORE_POSTS_RETURNED_RATIO.observe(ratio);
|
||||||
|
}
|
||||||
|
|
||||||
|
light_posts
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start a background task that periodically logs PostStore statistics
|
||||||
|
pub fn start_stats_logger(self: Arc<Self>) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_secs(5));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
let user_count = self.original_posts_by_user.len();
|
||||||
|
let total_posts = self.posts.len();
|
||||||
|
let deleted_posts = self.deleted_posts.len();
|
||||||
|
|
||||||
|
// Sum up all VecDeque sizes for each map
|
||||||
|
let original_posts_count: usize = self
|
||||||
|
.original_posts_by_user
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.value().len())
|
||||||
|
.sum();
|
||||||
|
let secondary_posts_count: usize = self
|
||||||
|
.secondary_posts_by_user
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.value().len())
|
||||||
|
.sum();
|
||||||
|
let video_posts_count: usize = self
|
||||||
|
.video_posts_by_user
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.value().len())
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
// Update Prometheus gauges
|
||||||
|
POST_STORE_USER_COUNT.set(user_count as f64);
|
||||||
|
POST_STORE_TOTAL_POSTS.set(total_posts as f64);
|
||||||
|
POST_STORE_DELETED_POSTS.set(deleted_posts as f64);
|
||||||
|
|
||||||
|
// Update entity count gauge with labels
|
||||||
|
POST_STORE_ENTITY_COUNT
|
||||||
|
.with_label_values(&["users"])
|
||||||
|
.set(user_count as f64);
|
||||||
|
POST_STORE_ENTITY_COUNT
|
||||||
|
.with_label_values(&["posts"])
|
||||||
|
.set(total_posts as f64);
|
||||||
|
POST_STORE_ENTITY_COUNT
|
||||||
|
.with_label_values(&["original_posts"])
|
||||||
|
.set(original_posts_count as f64);
|
||||||
|
POST_STORE_ENTITY_COUNT
|
||||||
|
.with_label_values(&["secondary_posts"])
|
||||||
|
.set(secondary_posts_count as f64);
|
||||||
|
POST_STORE_ENTITY_COUNT
|
||||||
|
.with_label_values(&["video_posts"])
|
||||||
|
.set(video_posts_count as f64);
|
||||||
|
POST_STORE_ENTITY_COUNT
|
||||||
|
.with_label_values(&["deleted_posts"])
|
||||||
|
.set(deleted_posts as f64);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"PostStore Stats: {} users, {} total posts, {} deleted posts",
|
||||||
|
user_count, total_posts, deleted_posts
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start a background task that periodically trims old posts
|
||||||
|
pub fn start_auto_trim(self: Arc<Self>, interval_minutes: u64) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_secs(interval_minutes * 60));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
let trimmed = self.trim_old_posts().await;
|
||||||
|
if trimmed > 0 {
|
||||||
|
info!("Auto-trim: removed {} old posts", trimmed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manually trim posts older than retention period from all users
|
||||||
|
/// Returns the number of posts trimmed
|
||||||
|
pub async fn trim_old_posts(&self) -> usize {
|
||||||
|
let posts_map = Arc::clone(&self.posts);
|
||||||
|
let original_posts_by_user = Arc::clone(&self.original_posts_by_user);
|
||||||
|
let secondary_posts_by_user = Arc::clone(&self.secondary_posts_by_user);
|
||||||
|
let video_posts_by_user = Arc::clone(&self.video_posts_by_user);
|
||||||
|
let deleted_posts = Arc::clone(&self.deleted_posts);
|
||||||
|
let retention_seconds = self.retention_seconds;
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
let mut total_trimmed = 0;
|
||||||
|
|
||||||
|
// Helper closure to trim posts from a given map
|
||||||
|
let trim_map = |posts_by_user: &DashMap<i64, VecDeque<TinyPost>>,
|
||||||
|
posts_map: &DashMap<i64, LightPost>,
|
||||||
|
deleted_posts: &DashMap<i64, bool>|
|
||||||
|
-> usize {
|
||||||
|
let mut trimmed = 0;
|
||||||
|
let mut users_to_remove = Vec::new();
|
||||||
|
|
||||||
|
for mut entry in posts_by_user.iter_mut() {
|
||||||
|
let user_id = *entry.key();
|
||||||
|
let user_posts = entry.value_mut();
|
||||||
|
|
||||||
|
while let Some(oldest_post) = user_posts.front() {
|
||||||
|
if current_time - (oldest_post.created_at as u64) > retention_seconds {
|
||||||
|
let trimmed_post = user_posts.pop_front().unwrap();
|
||||||
|
posts_map.remove(&trimmed_post.post_id);
|
||||||
|
|
||||||
|
if user_id == DELETE_EVENT_KEY {
|
||||||
|
deleted_posts.remove(&trimmed_post.post_id);
|
||||||
|
}
|
||||||
|
trimmed += 1;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_posts.capacity() > user_posts.len() * 2 {
|
||||||
|
let new_cap = user_posts.len() as f32 * 1.5_f32;
|
||||||
|
user_posts.shrink_to(new_cap as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_posts.is_empty() {
|
||||||
|
users_to_remove.push(user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for user_id in users_to_remove {
|
||||||
|
posts_by_user.remove_if(&user_id, |_, posts| posts.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed
|
||||||
|
};
|
||||||
|
|
||||||
|
total_trimmed += trim_map(&original_posts_by_user, &posts_map, &deleted_posts);
|
||||||
|
total_trimmed += trim_map(&secondary_posts_by_user, &posts_map, &deleted_posts);
|
||||||
|
trim_map(&video_posts_by_user, &posts_map, &deleted_posts);
|
||||||
|
|
||||||
|
total_trimmed
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("spawn_blocking failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sorts all user post lists by creation time (newest first)
|
||||||
|
pub async fn sort_all_user_posts(&self) {
|
||||||
|
let original_posts_by_user = Arc::clone(&self.original_posts_by_user);
|
||||||
|
let secondary_posts_by_user = Arc::clone(&self.secondary_posts_by_user);
|
||||||
|
let video_posts_by_user = Arc::clone(&self.video_posts_by_user);
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
// Sort original posts
|
||||||
|
for mut entry in original_posts_by_user.iter_mut() {
|
||||||
|
let user_posts = entry.value_mut();
|
||||||
|
user_posts
|
||||||
|
.make_contiguous()
|
||||||
|
.sort_unstable_by_key(|a| a.created_at);
|
||||||
|
}
|
||||||
|
// Sort secondary posts
|
||||||
|
for mut entry in secondary_posts_by_user.iter_mut() {
|
||||||
|
let user_posts = entry.value_mut();
|
||||||
|
user_posts
|
||||||
|
.make_contiguous()
|
||||||
|
.sort_unstable_by_key(|a| a.created_at);
|
||||||
|
}
|
||||||
|
// Sort video posts
|
||||||
|
for mut entry in video_posts_by_user.iter_mut() {
|
||||||
|
let user_posts = entry.value_mut();
|
||||||
|
user_posts
|
||||||
|
.make_contiguous()
|
||||||
|
.sort_unstable_by_key(|a| a.created_at);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("spawn_blocking failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clears all posts from the store
|
||||||
|
pub fn clear(&self) {
|
||||||
|
self.posts.clear();
|
||||||
|
self.original_posts_by_user.clear();
|
||||||
|
self.secondary_posts_by_user.clear();
|
||||||
|
self.video_posts_by_user.clear();
|
||||||
|
info!("PostStore cleared");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PostStore {
|
||||||
|
fn default() -> Self {
|
||||||
|
// Default to 2 days retention, no timeout
|
||||||
|
Self::new(2 * 24 * 60 * 60, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
339
thunder/thunder_service.rs
Normal file
339
thunder/thunder_service.rs
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
use lazy_static::lazy_static;
|
||||||
|
use log::{debug, info, warn};
|
||||||
|
use std::cmp::Reverse;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
use tonic::{Request, Response, Status};
|
||||||
|
|
||||||
|
use xai_thunder_proto::{
|
||||||
|
GetInNetworkPostsRequest, GetInNetworkPostsResponse, LightPost,
|
||||||
|
in_network_posts_service_server::{InNetworkPostsService, InNetworkPostsServiceServer},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::config::{
|
||||||
|
MAX_INPUT_LIST_SIZE, MAX_POSTS_TO_RETURN, MAX_VIDEOS_TO_RETURN,
|
||||||
|
};
|
||||||
|
use crate::metrics::{
|
||||||
|
GET_IN_NETWORK_POSTS_COUNT, GET_IN_NETWORK_POSTS_DURATION,
|
||||||
|
GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO, GET_IN_NETWORK_POSTS_EXCLUDED_SIZE,
|
||||||
|
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE, GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS,
|
||||||
|
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR, GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO,
|
||||||
|
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS, GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS,
|
||||||
|
GET_IN_NETWORK_POSTS_MAX_RESULTS, IN_FLIGHT_REQUESTS, REJECTED_REQUESTS, Timer,
|
||||||
|
};
|
||||||
|
use crate::posts::post_store::PostStore;
|
||||||
|
use crate::strato_client::StratoClient;
|
||||||
|
|
||||||
|
pub struct ThunderServiceImpl {
|
||||||
|
/// PostStore for retrieving posts by user ID
|
||||||
|
post_store: Arc<PostStore>,
|
||||||
|
/// StratoClient for fetching following lists when not provided
|
||||||
|
strato_client: Arc<StratoClient>,
|
||||||
|
/// Semaphore to limit concurrent requests and prevent overload
|
||||||
|
request_semaphore: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThunderServiceImpl {
|
||||||
|
pub fn new(
|
||||||
|
post_store: Arc<PostStore>,
|
||||||
|
strato_client: Arc<StratoClient>,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
) -> Self {
|
||||||
|
info!(
|
||||||
|
"Initializing ThunderService with max_concurrent_requests={}",
|
||||||
|
max_concurrent_requests
|
||||||
|
);
|
||||||
|
Self {
|
||||||
|
post_store,
|
||||||
|
strato_client,
|
||||||
|
request_semaphore: Arc::new(Semaphore::new(max_concurrent_requests)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a gRPC server for this service
|
||||||
|
pub fn server(self) -> InNetworkPostsServiceServer<Self> {
|
||||||
|
InNetworkPostsServiceServer::new(self)
|
||||||
|
.accept_compressed(tonic::codec::CompressionEncoding::Zstd)
|
||||||
|
.send_compressed(tonic::codec::CompressionEncoding::Zstd)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Analyze found posts, calculate statistics, and report metrics
|
||||||
|
/// The `stage` parameter is used as a label to differentiate between stages (e.g., "post_store", "scored")
|
||||||
|
fn analyze_and_report_post_statistics(posts: &[LightPost], stage: &str) {
|
||||||
|
if posts.is_empty() {
|
||||||
|
debug!("[{}] No posts found for analysis", stage);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as i64;
|
||||||
|
|
||||||
|
// Time since most recent post
|
||||||
|
let time_since_most_recent = posts
|
||||||
|
.iter()
|
||||||
|
.map(|post| post.created_at)
|
||||||
|
.max()
|
||||||
|
.map(|most_recent| now - most_recent);
|
||||||
|
|
||||||
|
// Time since oldest post
|
||||||
|
let time_since_oldest = posts
|
||||||
|
.iter()
|
||||||
|
.map(|post| post.created_at)
|
||||||
|
.min()
|
||||||
|
.map(|oldest| now - oldest);
|
||||||
|
|
||||||
|
// Count replies vs original posts
|
||||||
|
let reply_count = posts.iter().filter(|post| post.is_reply).count();
|
||||||
|
let original_count = posts.len() - reply_count;
|
||||||
|
|
||||||
|
// Unique authors
|
||||||
|
let unique_authors: HashSet<_> = posts.iter().map(|post| post.author_id).collect();
|
||||||
|
let unique_author_count = unique_authors.len();
|
||||||
|
|
||||||
|
// Report metrics with stage label
|
||||||
|
if let Some(freshness) = time_since_most_recent {
|
||||||
|
GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS
|
||||||
|
.with_label_values(&[stage])
|
||||||
|
.observe(freshness as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (Some(oldest), Some(newest)) = (time_since_oldest, time_since_most_recent) {
|
||||||
|
let time_range = oldest - newest;
|
||||||
|
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS
|
||||||
|
.with_label_values(&[stage])
|
||||||
|
.observe(time_range as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
let reply_ratio = reply_count as f64 / posts.len() as f64;
|
||||||
|
GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO
|
||||||
|
.with_label_values(&[stage])
|
||||||
|
.observe(reply_ratio);
|
||||||
|
|
||||||
|
GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS
|
||||||
|
.with_label_values(&[stage])
|
||||||
|
.observe(unique_author_count as f64);
|
||||||
|
|
||||||
|
if unique_author_count > 0 {
|
||||||
|
let posts_per_author = posts.len() as f64 / unique_author_count as f64;
|
||||||
|
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR
|
||||||
|
.with_label_values(&[stage])
|
||||||
|
.observe(posts_per_author);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log statistics with stage label
|
||||||
|
debug!(
|
||||||
|
"[{}] Post statistics: total={}, original={}, replies={}, unique_authors={}, posts_per_author={:.2}, reply_ratio={:.2}, time_since_most_recent={:?}s, time_range={:?}s",
|
||||||
|
stage,
|
||||||
|
posts.len(),
|
||||||
|
original_count,
|
||||||
|
reply_count,
|
||||||
|
unique_author_count,
|
||||||
|
if unique_author_count > 0 {
|
||||||
|
posts.len() as f64 / unique_author_count as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
},
|
||||||
|
reply_ratio,
|
||||||
|
time_since_most_recent,
|
||||||
|
if let (Some(o), Some(n)) = (time_since_oldest, time_since_most_recent) {
|
||||||
|
Some(o - n)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tonic::async_trait]
|
||||||
|
impl InNetworkPostsService for ThunderServiceImpl {
|
||||||
|
/// Get posts from users in the network
|
||||||
|
async fn get_in_network_posts(
|
||||||
|
&self,
|
||||||
|
request: Request<GetInNetworkPostsRequest>,
|
||||||
|
) -> Result<Response<GetInNetworkPostsResponse>, Status> {
|
||||||
|
// Try to acquire semaphore permit without blocking
|
||||||
|
// If we're at capacity, reject immediately with RESOURCE_EXHAUSTED
|
||||||
|
let _permit = match self.request_semaphore.try_acquire() {
|
||||||
|
Ok(permit) => {
|
||||||
|
IN_FLIGHT_REQUESTS.inc();
|
||||||
|
permit
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
REJECTED_REQUESTS.inc();
|
||||||
|
return Err(Status::resource_exhausted(
|
||||||
|
"Server at capacity, please retry",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use a guard to decrement in_flight_requests when the request completes
|
||||||
|
struct InFlightGuard;
|
||||||
|
impl Drop for InFlightGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
IN_FLIGHT_REQUESTS.dec();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _in_flight_guard = InFlightGuard;
|
||||||
|
|
||||||
|
// Start timer for total latency
|
||||||
|
let _total_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION.clone());
|
||||||
|
|
||||||
|
let req = request.into_inner();
|
||||||
|
|
||||||
|
if req.debug {
|
||||||
|
info!(
|
||||||
|
"Received GetInNetworkPosts request: user_id={}, following_count={}, exclude_tweet_ids={}",
|
||||||
|
req.user_id,
|
||||||
|
req.following_user_ids.len(),
|
||||||
|
req.exclude_tweet_ids.len(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If following_user_id list is empty, fetch it from Strato
|
||||||
|
let following_user_ids = if req.following_user_ids.is_empty() && req.debug {
|
||||||
|
info!(
|
||||||
|
"Following list is empty, fetching from Strato for user {}",
|
||||||
|
req.user_id
|
||||||
|
);
|
||||||
|
|
||||||
|
match self
|
||||||
|
.strato_client
|
||||||
|
.fetch_following_list(req.user_id as i64, MAX_INPUT_LIST_SIZE as i32)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(following_list) => {
|
||||||
|
info!(
|
||||||
|
"Fetched {} following users from Strato for user {}",
|
||||||
|
following_list.len(),
|
||||||
|
req.user_id
|
||||||
|
);
|
||||||
|
following_list.into_iter().map(|id| id as u64).collect()
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"Failed to fetch following list from Strato for user {}: {}",
|
||||||
|
req.user_id, e
|
||||||
|
);
|
||||||
|
return Err(Status::internal(format!(
|
||||||
|
"Failed to fetch following list: {}",
|
||||||
|
e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
req.following_user_ids
|
||||||
|
};
|
||||||
|
|
||||||
|
// Record metrics for request parameters
|
||||||
|
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE.observe(following_user_ids.len() as f64);
|
||||||
|
GET_IN_NETWORK_POSTS_EXCLUDED_SIZE.observe(req.exclude_tweet_ids.len() as f64);
|
||||||
|
|
||||||
|
// Start timer for latency without strato call
|
||||||
|
let _processing_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO.clone());
|
||||||
|
|
||||||
|
// Default max_results if not specified
|
||||||
|
let max_results = if req.max_results > 0 {
|
||||||
|
req.max_results as usize
|
||||||
|
} else if req.is_video_request {
|
||||||
|
MAX_VIDEOS_TO_RETURN
|
||||||
|
} else {
|
||||||
|
MAX_POSTS_TO_RETURN
|
||||||
|
};
|
||||||
|
GET_IN_NETWORK_POSTS_MAX_RESULTS.observe(max_results as f64);
|
||||||
|
|
||||||
|
// Limit following_user_ids and exclude_tweet_ids to first K entries
|
||||||
|
let following_count = following_user_ids.len();
|
||||||
|
if following_count > MAX_INPUT_LIST_SIZE {
|
||||||
|
warn!(
|
||||||
|
"Limiting following_user_ids from {} to {} entries for user {}",
|
||||||
|
following_count, MAX_INPUT_LIST_SIZE, req.user_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let following_user_ids: Vec<u64> = following_user_ids
|
||||||
|
.into_iter()
|
||||||
|
.take(MAX_INPUT_LIST_SIZE)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let exclude_count = req.exclude_tweet_ids.len();
|
||||||
|
if exclude_count > MAX_INPUT_LIST_SIZE {
|
||||||
|
warn!(
|
||||||
|
"Limiting exclude_tweet_ids from {} to {} entries for user {}",
|
||||||
|
exclude_count, MAX_INPUT_LIST_SIZE, req.user_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let exclude_tweet_ids: Vec<u64> = req
|
||||||
|
.exclude_tweet_ids
|
||||||
|
.into_iter()
|
||||||
|
.take(MAX_INPUT_LIST_SIZE)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Clone Arc references needed inside spawn_blocking
|
||||||
|
let post_store = Arc::clone(&self.post_store);
|
||||||
|
let request_user_id = req.user_id as i64;
|
||||||
|
|
||||||
|
// Use spawn_blocking to avoid blocking tokio's async runtime
|
||||||
|
let proto_posts = tokio::task::spawn_blocking(move || {
|
||||||
|
// Create exclude tweet IDs set for efficient filtering of previously seen posts
|
||||||
|
let exclude_tweet_ids: HashSet<i64> =
|
||||||
|
exclude_tweet_ids.iter().map(|&id| id as i64).collect();
|
||||||
|
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
// Fetch all posts (original + secondary) for the followed users
|
||||||
|
let all_posts: Vec<LightPost> = if req.is_video_request {
|
||||||
|
post_store.get_videos_by_users(
|
||||||
|
&following_user_ids,
|
||||||
|
&exclude_tweet_ids,
|
||||||
|
start_time,
|
||||||
|
request_user_id,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
post_store.get_all_posts_by_users(
|
||||||
|
&following_user_ids,
|
||||||
|
&exclude_tweet_ids,
|
||||||
|
start_time,
|
||||||
|
request_user_id,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Analyze posts and report statistics after querying post_store
|
||||||
|
ThunderServiceImpl::analyze_and_report_post_statistics(&all_posts, "retrieved");
|
||||||
|
|
||||||
|
let scored_posts = score_recent(all_posts, max_results);
|
||||||
|
|
||||||
|
// Analyze posts and report statistics after scoring
|
||||||
|
ThunderServiceImpl::analyze_and_report_post_statistics(&scored_posts, "scored");
|
||||||
|
|
||||||
|
scored_posts
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| Status::internal(format!("Failed to process posts: {}", e)))?;
|
||||||
|
|
||||||
|
if req.debug {
|
||||||
|
info!(
|
||||||
|
"Returning {} posts for user {}",
|
||||||
|
proto_posts.len(),
|
||||||
|
req.user_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record the number of posts returned
|
||||||
|
GET_IN_NETWORK_POSTS_COUNT.observe(proto_posts.len() as f64);
|
||||||
|
|
||||||
|
let response = GetInNetworkPostsResponse { posts: proto_posts };
|
||||||
|
|
||||||
|
Ok(Response::new(response))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Score posts by recency (created_at timestamp, newer posts first)
|
||||||
|
fn score_recent(mut light_posts: Vec<LightPost>, max_results: usize) -> Vec<LightPost> {
|
||||||
|
light_posts.sort_unstable_by_key(|post| Reverse(post.created_at));
|
||||||
|
|
||||||
|
// Limit to max results
|
||||||
|
light_posts.into_iter().take(max_results).collect()
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user