x-algorithm/home-mixer/sources/phoenix_source.rs
2026-01-20 02:31:49 +00:00

52 lines
1.8 KiB
Rust

use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::phoenix_retrieval_client::PhoenixRetrievalClient;
use crate::params as p;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::source::Source;
use xai_home_mixer_proto as pb;
pub struct PhoenixSource {
pub phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
}
#[async_trait]
impl Source<ScoredPostsQuery, PostCandidate> for PhoenixSource {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.in_network_only
}
#[xai_stats_macro::receive_stats]
async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>, String> {
let user_id = query.user_id as u64;
let sequence = query
.user_action_sequence
.as_ref()
.ok_or_else(|| "PhoenixSource: missing user_action_sequence".to_string())?;
let response = self
.phoenix_retrieval_client
.retrieve(user_id, sequence.clone(), p::PHOENIX_MAX_RESULTS)
.await
.map_err(|e| format!("PhoenixSource: {}", e))?;
let candidates: Vec<PostCandidate> = response
.top_k_candidates
.into_iter()
.flat_map(|scored_candidates| scored_candidates.candidates)
.filter_map(|scored_candidate| scored_candidate.candidate)
.map(|tweet_info| PostCandidate {
tweet_id: tweet_info.tweet_id as i64,
author_id: tweet_info.author_id,
in_reply_to_tweet_id: Some(tweet_info.in_reply_to_tweet_id),
served_type: Some(pb::ServedType::ForYouPhoenixRetrieval),
..Default::default()
})
.collect();
Ok(candidates)
}
}