mesh_rpc/
message.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! TTRPC message handling.
5
6use crate::service::Status;
7use anyhow::Context;
8use futures::AsyncRead;
9use futures::AsyncReadExt;
10use futures::AsyncWrite;
11use futures::AsyncWriteExt;
12use mesh::payload::Protobuf;
13use std::io::ErrorKind;
14use thiserror::Error;
15use zerocopy::BigEndian;
16use zerocopy::FromBytes;
17use zerocopy::FromZeros;
18use zerocopy::Immutable;
19use zerocopy::IntoBytes;
20use zerocopy::KnownLayout;
21use zerocopy::U32;
22
23/// The wire format header for a message.
24#[repr(C, packed)]
25#[derive(Debug, Copy, Clone, IntoBytes, Immutable, KnownLayout, FromBytes)]
26struct MessageHeader {
27    length: U32<BigEndian>,
28    stream_id: U32<BigEndian>,
29    message_type: u8,
30    flags: u8,
31}
32
33pub const MESSAGE_TYPE_REQUEST: u8 = 1;
34pub const MESSAGE_TYPE_RESPONSE: u8 = 2;
35
36/// The maximum ttrpc message size.
37///
38/// The spec specifies 4MB as the maximum, but it's not quite large enough for
39/// our use cases.
40///
41/// This only affects the receiving side. The reference implementation only
42/// enforces this on the receiving side, so receivers already have to cope with
43/// messages that are too large (by rejecting them).
44///
45/// So allow a larger size here, which should be a compatible relaxation of the
46/// spec.
47///
48/// Note however, that 16MB - 1 is a hard maximum, because the spec specifies
49/// that the top byte may be reused for something else in the future. (I am
50/// still skeptical that this is possible because existing senders do not
51/// validate this at all. But let's not take a dependency on messages bigger
52/// than this.)
53const MAX_MESSAGE_SIZE: usize = 0xffffff;
54
55#[derive(Debug, Error)]
56#[error("message length {0} exceeds maximum allowed size {MAX_MESSAGE_SIZE}")]
57pub struct TooLongError(usize);
58
59pub struct ReadResult {
60    pub stream_id: u32,
61    pub message_type: u8,
62    pub payload: Result<Vec<u8>, TooLongError>,
63}
64
65pub async fn read_message(
66    reader: &mut (impl AsyncRead + Unpin),
67) -> std::io::Result<Option<ReadResult>> {
68    let mut header = MessageHeader::new_zeroed();
69    match reader.read_exact(header.as_mut_bytes()).await {
70        Ok(_) => (),
71        Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
72            return Ok(None);
73        }
74        Err(err) => return Err(err),
75    }
76
77    let stream_id = header.stream_id.get();
78    let length = header.length.get() as usize;
79    let payload = if length <= MAX_MESSAGE_SIZE {
80        let mut buf = vec![0; length];
81        reader.read_exact(&mut buf).await?;
82        Ok(buf)
83    } else {
84        // Discard the message that was too long.
85        futures::io::copy(reader.take(length as u64), &mut futures::io::sink()).await?;
86        Err(TooLongError(length))
87    };
88
89    Ok(Some(ReadResult {
90        stream_id,
91        message_type: header.message_type,
92        payload,
93    }))
94}
95
96pub async fn write_message(
97    writer: &mut (impl AsyncWrite + Unpin),
98    stream_id: u32,
99    message_type: u8,
100    payload: &[u8],
101) -> anyhow::Result<()> {
102    let header = MessageHeader {
103        stream_id: stream_id.into(),
104        message_type,
105        length: (payload.len() as u32).into(),
106        flags: 0,
107    };
108
109    writer
110        .write_all(header.as_bytes())
111        .await
112        .context("failed writing message header")?;
113
114    writer
115        .write_all(payload)
116        .await
117        .context("failed writing message payload")?;
118
119    Ok(())
120}
121
122/// A request message payload.
123#[derive(Protobuf)]
124pub struct Request {
125    pub service: String,
126    pub method: String,
127    pub payload: Vec<u8>,
128    pub timeout_nano: u64,
129    pub metadata: Vec<(String, String)>,
130}
131
132/// A response message payload.
133#[derive(Protobuf)]
134pub enum Response {
135    #[mesh(1, transparent)]
136    Status(Status),
137    #[mesh(2, transparent)]
138    Payload(Vec<u8>),
139}