disk_blob/blob/
http.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

//! HTTP blob implementation based on [`hyper`], [`tokio`], and
//! [`hyper_tls`].
//!
//! In the future, it may better to use `pal_async` instead. This will require a
//! new, unreleased version of `hyper`, and a bunch of infrastructure to support
//! initiating TCP connections the way `hyper` expects.

use super::Blob;
use anyhow::Context as _;
use async_trait::async_trait;
use http::uri::Scheme;
use http_body_util::BodyExt;
use http_body_util::Empty;
use hyper::Request;
use hyper::StatusCode;
use hyper::Uri;
use hyper_tls::HttpsConnector;
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
use inspect::Inspect;
use once_cell::sync::OnceCell;
use std::fmt::Debug;
use std::io;

/// A blob backed by an HTTP/HTTPS connection.
#[derive(Debug, Inspect)]
pub struct HttpBlob {
    #[inspect(skip)]
    client: Client<HttpsConnector<HttpConnector>, Empty<&'static [u8]>>,
    #[inspect(debug)]
    version: http::Version,
    #[inspect(display)]
    uri: Uri,
    len: u64,
    #[inspect(skip)]
    tokio_handle: tokio::runtime::Handle,
}

static TOKIO_RUNTIME: OnceCell<tokio::runtime::Runtime> = OnceCell::new();

impl HttpBlob {
    /// Connects to `url` and returns an object to access it as a blob.
    pub async fn new(url: &str) -> anyhow::Result<Self> {
        let mut uri: Uri = url.parse()?;

        let connector = HttpsConnector::new();
        let builder = Client::builder(TokioExecutor::new());
        let client = builder.build(connector);

        let handle = TOKIO_RUNTIME
            .get_or_try_init(tokio::runtime::Runtime::new)
            .context("failed to initialize tokio")?
            .handle()
            .clone();

        let mut redirect_count = 0;
        let response = loop {
            if redirect_count > 5 {
                anyhow::bail!("too many redirects");
            }

            let response = handle
                .spawn(
                    client.request(
                        Request::builder()
                            .uri(&uri)
                            .method("HEAD")
                            .body(Empty::new())
                            .unwrap(),
                    ),
                )
                .await
                .unwrap()
                .context("failed to query blob size")?;

            let next_uri: Uri = match response.status() {
                StatusCode::OK => break response,
                StatusCode::MOVED_PERMANENTLY
                | StatusCode::FOUND
                | StatusCode::TEMPORARY_REDIRECT
                | StatusCode::PERMANENT_REDIRECT => response
                    .headers()
                    .get("Location")
                    .context("missing redirect URL")?
                    .to_str()
                    .context("couldn't parse redirect URL")?
                    .parse()
                    .context("couldn't parse redirect URL")?,
                status => {
                    anyhow::bail!("failed to query blob size: {status}");
                }
            };

            if uri.scheme() == Some(&Scheme::HTTPS) && next_uri.scheme() != Some(&Scheme::HTTPS) {
                anyhow::bail!("https redirected to http");
            }

            uri = next_uri;
            redirect_count += 1;
        };

        let len = response
            .headers()
            .get("Content-Length")
            .context("missing blob length")?
            .to_str()
            .context("couldn't parse blob length")?
            .parse()
            .context("couldn't parse blob length")?;

        let version = response.version();

        Ok(Self {
            client,
            version,
            uri,
            len,
            tokio_handle: handle,
        })
    }
}

#[async_trait]
impl Blob for HttpBlob {
    async fn read(&self, mut buf: &mut [u8], offset: u64) -> io::Result<()> {
        let mut response = self
            .tokio_handle
            .spawn(
                self.client.request(
                    Request::builder()
                        .uri(&self.uri)
                        .header(
                            hyper::header::RANGE,
                            format!("bytes={}-{}", offset, offset + buf.len() as u64 - 1,),
                        )
                        .body(Empty::new())
                        .unwrap(),
                ),
            )
            .await
            .unwrap()
            .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;

        if !response.status().is_success() {
            return Err(io::Error::new(
                io::ErrorKind::Other,
                response.status().to_string(),
            ));
        }

        while let Some(frame) = response.body_mut().frame().await {
            let frame = frame.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
            if let Some(data) = frame.data_ref() {
                let len = data.len();
                if len > buf.len() {
                    return Err(io::Error::new(
                        io::ErrorKind::Other,
                        "server did not respect range query",
                    ));
                }
                let (this, rest) = buf.split_at_mut(len);
                this.copy_from_slice(data);
                buf = rest;
            }
        }

        if !buf.is_empty() {
            return Err(io::ErrorKind::UnexpectedEof.into());
        }

        Ok(())
    }

    fn len(&self) -> u64 {
        self.len
    }
}