x-algorithm/home-mixer/server.rs
2026-01-20 02:31:49 +00:00

84 lines
3.3 KiB
Rust

use crate::candidate_pipeline::candidate::CandidateHelpers;
use crate::candidate_pipeline::phoenix_candidate_pipeline::PhoenixCandidatePipeline;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use log::info;
use std::sync::Arc;
use std::time::Instant;
use tonic::{Request, Response, Status};
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
use xai_home_mixer_proto as pb;
use xai_home_mixer_proto::{ScoredPost, ScoredPostsResponse};
pub struct HomeMixerServer {
phx_candidate_pipeline: Arc<PhoenixCandidatePipeline>,
}
impl HomeMixerServer {
pub async fn new() -> Self {
HomeMixerServer {
phx_candidate_pipeline: Arc::new(PhoenixCandidatePipeline::prod().await),
}
}
}
#[tonic::async_trait]
impl pb::scored_posts_service_server::ScoredPostsService for HomeMixerServer {
#[xai_stats_macro::receive_stats]
async fn get_scored_posts(
&self,
request: Request<pb::ScoredPostsQuery>,
) -> Result<Response<ScoredPostsResponse>, Status> {
let proto_query = request.into_inner();
if proto_query.viewer_id == 0 {
return Err(Status::invalid_argument("viewer_id must be specified"));
}
let start = Instant::now();
let query = ScoredPostsQuery::new(
proto_query.viewer_id,
proto_query.client_app_id,
proto_query.country_code,
proto_query.language_code,
proto_query.seen_ids,
proto_query.served_ids,
proto_query.in_network_only,
proto_query.is_bottom_request,
proto_query.bloom_filter_entries,
);
info!("Scored Posts request - request_id {}", query.request_id);
let pipeline_result = self.phx_candidate_pipeline.execute(query).await;
let scored_posts: Vec<ScoredPost> = pipeline_result
.selected_candidates
.into_iter()
.map(|candidate| {
let screen_names = candidate.get_screen_names();
ScoredPost {
tweet_id: candidate.tweet_id as u64,
author_id: candidate.author_id,
retweeted_tweet_id: candidate.retweeted_tweet_id.unwrap_or(0),
retweeted_user_id: candidate.retweeted_user_id.unwrap_or(0),
in_reply_to_tweet_id: candidate.in_reply_to_tweet_id.unwrap_or(0),
score: candidate.score.unwrap_or(0.0) as f32,
in_network: candidate.in_network.unwrap_or(false),
served_type: candidate.served_type.map(|t| t as i32).unwrap_or_default(),
last_scored_timestamp_ms: candidate.last_scored_at_ms.unwrap_or(0),
prediction_request_id: candidate.prediction_request_id.unwrap_or(0),
ancestors: candidate.ancestors,
screen_names,
visibility_reason: candidate.visibility_reason.map(|r| r.into()),
}
})
.collect();
info!(
"Scored Posts response - request_id {} - {} posts ({} ms)",
pipeline_result.query.request_id,
scored_posts.len(),
start.elapsed().as_millis()
);
Ok(Response::new(ScoredPostsResponse { scored_posts }))
}
}