mirror of
https://github.com/xai-org/x-algorithm.git
synced 2026-02-13 03:05:06 +01:00
84 lines
3.3 KiB
Rust
84 lines
3.3 KiB
Rust
use crate::candidate_pipeline::candidate::CandidateHelpers;
|
|
use crate::candidate_pipeline::phoenix_candidate_pipeline::PhoenixCandidatePipeline;
|
|
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
|
use log::info;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
use tonic::{Request, Response, Status};
|
|
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
|
|
use xai_home_mixer_proto as pb;
|
|
use xai_home_mixer_proto::{ScoredPost, ScoredPostsResponse};
|
|
|
|
pub struct HomeMixerServer {
|
|
phx_candidate_pipeline: Arc<PhoenixCandidatePipeline>,
|
|
}
|
|
|
|
impl HomeMixerServer {
|
|
pub async fn new() -> Self {
|
|
HomeMixerServer {
|
|
phx_candidate_pipeline: Arc::new(PhoenixCandidatePipeline::prod().await),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl pb::scored_posts_service_server::ScoredPostsService for HomeMixerServer {
|
|
#[xai_stats_macro::receive_stats]
|
|
async fn get_scored_posts(
|
|
&self,
|
|
request: Request<pb::ScoredPostsQuery>,
|
|
) -> Result<Response<ScoredPostsResponse>, Status> {
|
|
let proto_query = request.into_inner();
|
|
|
|
if proto_query.viewer_id == 0 {
|
|
return Err(Status::invalid_argument("viewer_id must be specified"));
|
|
}
|
|
|
|
let start = Instant::now();
|
|
let query = ScoredPostsQuery::new(
|
|
proto_query.viewer_id,
|
|
proto_query.client_app_id,
|
|
proto_query.country_code,
|
|
proto_query.language_code,
|
|
proto_query.seen_ids,
|
|
proto_query.served_ids,
|
|
proto_query.in_network_only,
|
|
proto_query.is_bottom_request,
|
|
proto_query.bloom_filter_entries,
|
|
);
|
|
info!("Scored Posts request - request_id {}", query.request_id);
|
|
let pipeline_result = self.phx_candidate_pipeline.execute(query).await;
|
|
|
|
let scored_posts: Vec<ScoredPost> = pipeline_result
|
|
.selected_candidates
|
|
.into_iter()
|
|
.map(|candidate| {
|
|
let screen_names = candidate.get_screen_names();
|
|
ScoredPost {
|
|
tweet_id: candidate.tweet_id as u64,
|
|
author_id: candidate.author_id,
|
|
retweeted_tweet_id: candidate.retweeted_tweet_id.unwrap_or(0),
|
|
retweeted_user_id: candidate.retweeted_user_id.unwrap_or(0),
|
|
in_reply_to_tweet_id: candidate.in_reply_to_tweet_id.unwrap_or(0),
|
|
score: candidate.score.unwrap_or(0.0) as f32,
|
|
in_network: candidate.in_network.unwrap_or(false),
|
|
served_type: candidate.served_type.map(|t| t as i32).unwrap_or_default(),
|
|
last_scored_timestamp_ms: candidate.last_scored_at_ms.unwrap_or(0),
|
|
prediction_request_id: candidate.prediction_request_id.unwrap_or(0),
|
|
ancestors: candidate.ancestors,
|
|
screen_names,
|
|
visibility_reason: candidate.visibility_reason.map(|r| r.into()),
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
info!(
|
|
"Scored Posts response - request_id {} - {} posts ({} ms)",
|
|
pipeline_result.query.request_id,
|
|
scored_posts.len(),
|
|
start.elapsed().as_millis()
|
|
);
|
|
Ok(Response::new(ScoredPostsResponse { scored_posts }))
|
|
}
|
|
}
|