xref: /unit/src/wasm-wasi-component/src/lib.rs (revision 2757:563f8f63b9b2)
1 use anyhow::{bail, Context, Result};
2 use bytes::{Bytes, BytesMut};
3 use http_body_util::combinators::BoxBody;
4 use http_body_util::{BodyExt, Full};
5 use std::ffi::{CStr, CString};
6 use std::mem::MaybeUninit;
7 use std::ptr;
8 use std::sync::OnceLock;
9 use tokio::sync::mpsc;
10 use wasmtime::component::{Component, InstancePre, Linker, ResourceTable};
11 use wasmtime::{Config, Engine, Store};
12 use wasmtime_wasi::preview2::{
13     DirPerms, FilePerms, WasiCtx, WasiCtxBuilder, WasiView,
14 };
15 use wasmtime_wasi::{ambient_authority, Dir};
16 use wasmtime_wasi_http::bindings::http::types::ErrorCode;
17 use wasmtime_wasi_http::{WasiHttpCtx, WasiHttpView};
18 
19 #[allow(
20     non_camel_case_types,
21     non_upper_case_globals,
22     non_snake_case,
23     dead_code
24 )]
25 mod bindings {
26     include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
27 
nxt_string(s: &'static str) -> nxt_str_t28     pub const fn nxt_string(s: &'static str) -> nxt_str_t {
29         nxt_str_t {
30             start: s.as_ptr().cast_mut(),
31             length: s.len(),
32         }
33     }
34 
nxt_unit_sptr_get(sptr: &nxt_unit_sptr_t) -> *const u835     pub unsafe fn nxt_unit_sptr_get(sptr: &nxt_unit_sptr_t) -> *const u8 {
36         sptr.base.as_ptr().offset(sptr.offset as isize)
37     }
38 }
39 
40 #[no_mangle]
41 pub static mut nxt_app_module: bindings::nxt_app_module_t = {
42     const COMPAT: [u32; 2] = [bindings::NXT_VERNUM, bindings::NXT_DEBUG];
43     let version = "0.1\0";
44     bindings::nxt_app_module_t {
45         compat: COMPAT.as_ptr().cast_mut(),
46         compat_length: COMPAT.len() * 4,
47         mounts: ptr::null(),
48         nmounts: 0,
49         type_: bindings::nxt_string("wasm-wasi-component"),
50         version: version.as_ptr().cast(),
51         setup: Some(setup),
52         start: Some(start),
53     }
54 };
55 
56 static GLOBAL_CONFIG: OnceLock<GlobalConfig> = OnceLock::new();
57 static GLOBAL_STATE: OnceLock<GlobalState> = OnceLock::new();
58 
setup( task: *mut bindings::nxt_task_t, _process: *mut bindings::nxt_process_t, conf: *mut bindings::nxt_common_app_conf_t, ) -> bindings::nxt_int_t59 unsafe extern "C" fn setup(
60     task: *mut bindings::nxt_task_t,
61     // TODO: should this get used?
62     _process: *mut bindings::nxt_process_t,
63     conf: *mut bindings::nxt_common_app_conf_t,
64 ) -> bindings::nxt_int_t {
65     handle_result(task, || {
66         let wasm_conf = &(*conf).u.wasm_wc;
67         let component = CStr::from_ptr(wasm_conf.component).to_str()?;
68         let mut dirs = Vec::new();
69         if !wasm_conf.access.is_null() {
70             let dirs_ptr = bindings::nxt_conf_get_object_member(
71                 wasm_conf.access,
72                 &mut bindings::nxt_string("filesystem"),
73                 ptr::null_mut(),
74             );
75             for i in 0..bindings::nxt_conf_object_members_count(dirs_ptr) {
76                 let value = bindings::nxt_conf_get_array_element(
77                     dirs_ptr,
78                     i.try_into().unwrap(),
79                 );
80                 let mut s = bindings::nxt_string("");
81                 bindings::nxt_conf_get_string(value, &mut s);
82                 dirs.push(
83                     std::str::from_utf8(std::slice::from_raw_parts(
84                         s.start, s.length,
85                     ))?
86                     .to_string(),
87                 );
88             }
89         }
90 
91         let result = GLOBAL_CONFIG.set(GlobalConfig {
92             component: component.to_string(),
93             dirs,
94         });
95         assert!(result.is_ok());
96         Ok(())
97     })
98 }
99 
start( task: *mut bindings::nxt_task_t, data: *mut bindings::nxt_process_data_t, ) -> bindings::nxt_int_t100 unsafe extern "C" fn start(
101     task: *mut bindings::nxt_task_t,
102     data: *mut bindings::nxt_process_data_t,
103 ) -> bindings::nxt_int_t {
104     handle_result(task, || {
105         let config = GLOBAL_CONFIG.get().unwrap();
106         let state = GlobalState::new(&config)
107             .context("failed to create initial state")?;
108         let res = GLOBAL_STATE.set(state);
109         assert!(res.is_ok());
110 
111         let conf = (*data).app;
112         let mut wasm_init = MaybeUninit::uninit();
113         let ret =
114             bindings::nxt_unit_default_init(task, wasm_init.as_mut_ptr(), conf);
115         if ret != bindings::NXT_OK as bindings::nxt_int_t {
116             bail!("nxt_unit_default_init() failed");
117         }
118         let mut wasm_init = wasm_init.assume_init();
119         wasm_init.callbacks.request_handler = Some(request_handler);
120 
121         let unit_ctx = bindings::nxt_unit_init(&mut wasm_init);
122         if unit_ctx.is_null() {
123             bail!("nxt_unit_init() failed");
124         }
125 
126         bindings::nxt_unit_run(unit_ctx);
127         bindings::nxt_unit_done(unit_ctx);
128 
129         Ok(())
130     })
131 }
132 
handle_result( task: *mut bindings::nxt_task_t, func: impl FnOnce() -> Result<()>, ) -> bindings::nxt_int_t133 unsafe fn handle_result(
134     task: *mut bindings::nxt_task_t,
135     func: impl FnOnce() -> Result<()>,
136 ) -> bindings::nxt_int_t {
137     let rc = match func() {
138         Ok(()) => bindings::NXT_OK as bindings::nxt_int_t,
139         Err(e) => {
140             alert(task, &format!("{e:?}"));
141             bindings::NXT_ERROR as bindings::nxt_int_t
142         }
143     };
144     return rc;
145 
146     unsafe fn alert(task: *mut bindings::nxt_task_t, msg: &str) {
147         let log = (*task).log;
148         let msg = CString::new(msg).unwrap();
149         ((*log).handler).unwrap()(
150             bindings::NXT_LOG_ALERT as bindings::nxt_uint_t,
151             log,
152             "%s\0".as_ptr().cast(),
153             msg.as_ptr(),
154         );
155     }
156 }
157 
request_handler( info: *mut bindings::nxt_unit_request_info_t, )158 unsafe extern "C" fn request_handler(
159     info: *mut bindings::nxt_unit_request_info_t,
160 ) {
161     // Enqueue this request to get processed by the Tokio event loop, and
162     // otherwise immediately return.
163     let state = GLOBAL_STATE.get().unwrap();
164     state.sender.blocking_send(NxtRequestInfo { info }).unwrap();
165 }
166 
167 struct GlobalConfig {
168     component: String,
169     dirs: Vec<String>,
170 }
171 
172 struct GlobalState {
173     engine: Engine,
174     component: InstancePre<StoreState>,
175     global_config: &'static GlobalConfig,
176     sender: mpsc::Sender<NxtRequestInfo>,
177 }
178 
179 impl GlobalState {
new(global_config: &'static GlobalConfig) -> Result<GlobalState>180     fn new(global_config: &'static GlobalConfig) -> Result<GlobalState> {
181         // Configure Wasmtime, e.g. the component model and async support are
182         // enabled here. Other configuration can include:
183         //
184         // * Epochs/fuel - enables async yielding to prevent any one request
185         //   starving others.
186         // * Pooling allocator - accelerates instantiation at the cost of a
187         //   large virtual memory reservation.
188         // * Memory limits/etc.
189         let mut config = Config::new();
190         config.wasm_component_model(true);
191         config.async_support(true);
192         let engine = Engine::new(&config)?;
193 
194         // Compile the binary component on disk in Wasmtime. This is then
195         // pre-instantiated with host APIs defined by WASI. The result of
196         // this is a "pre-instantiated instance" which can be used to
197         // repeatedly instantiate later on. This will frontload
198         // compilation/linking/type-checking/etc to happen once rather than on
199         // each request.
200         let component = Component::from_file(&engine, &global_config.component)
201             .context("failed to compile component")?;
202         let mut linker = Linker::<StoreState>::new(&engine);
203         wasmtime_wasi::preview2::command::add_to_linker(&mut linker)?;
204         wasmtime_wasi_http::proxy::add_only_http_to_linker(&mut linker)?;
205         let component = linker
206             .instantiate_pre(&component)
207             .context("failed to pre-instantiate the provided component")?;
208 
209         // Spin up the Tokio async runtime in a separate thread with a
210         // communication channel into it. This thread will send requests to
211         // Tokio and the results will be calculated there.
212         let (sender, receiver) = mpsc::channel(10);
213         std::thread::spawn(|| GlobalState::run(receiver));
214 
215         Ok(GlobalState {
216             engine,
217             component,
218             sender,
219             global_config,
220         })
221     }
222 
223     /// Worker thread that executes the Tokio runtime, infinitely receiving
224     /// messages from the provided `receiver` and handling those requests.
225     ///
226     /// Each request is handled in a separate subtask so processing can all
227     /// happen concurrently.
run(mut receiver: mpsc::Receiver<NxtRequestInfo>)228     fn run(mut receiver: mpsc::Receiver<NxtRequestInfo>) {
229         let rt = tokio::runtime::Runtime::new().unwrap();
230         rt.block_on(async {
231             while let Some(msg) = receiver.recv().await {
232                 let state = GLOBAL_STATE.get().unwrap();
233                 tokio::task::spawn(async move {
234                     state.handle(msg).await.expect("failed to handle request")
235                 });
236             }
237         });
238     }
239 
handle(&'static self, mut info: NxtRequestInfo) -> Result<()>240     async fn handle(&'static self, mut info: NxtRequestInfo) -> Result<()> {
241         // Create a "Store" which is the unit of per-request isolation in
242         // Wasmtime.
243         let data = StoreState {
244             ctx: {
245                 let mut cx = WasiCtxBuilder::new();
246                 // NB: while useful for debugging untrusted code probably
247                 // shouldn't get raw access to stdout/stderr.
248                 cx.inherit_stdout();
249                 cx.inherit_stderr();
250                 for dir in self.global_config.dirs.iter() {
251                     let fd = Dir::open_ambient_dir(dir, ambient_authority())
252                         .with_context(|| {
253                             format!("failed to open directory '{dir}'")
254                         })?;
255                     cx.preopened_dir(
256                         fd,
257                         DirPerms::all(),
258                         FilePerms::all(),
259                         dir,
260                     );
261                 }
262                 cx.build()
263             },
264             table: ResourceTable::default(),
265             http: WasiHttpCtx,
266         };
267         let mut store = Store::new(&self.engine, data);
268 
269         // Convert the `nxt_*` representation into the representation required
270         // by Wasmtime's `wasi-http` implementation using the Rust `http`
271         // crate.
272         let request = self.to_request_builder(&info)?;
273         let body = self.to_request_body(&mut info);
274         let request = request.body(body)?;
275 
276         let (sender, receiver) = tokio::sync::oneshot::channel();
277 
278         // Instantiate the WebAssembly component and invoke its `handle`
279         // function which receives a request and where to put a response.
280         //
281         // Note that this is done in a sub-task to work concurrently with
282         // writing the response when it's available. This enables wasm to
283         // generate headers, write those below, and then compute the body
284         // afterwards.
285         let task = tokio::spawn(async move {
286             let (proxy, _) = wasmtime_wasi_http::proxy::Proxy::instantiate_pre(
287                 &mut store,
288                 &self.component,
289             )
290             .await
291             .context("failed to instantiate")?;
292             let req = store.data_mut().new_incoming_request(request)?;
293             let out = store.data_mut().new_response_outparam(sender)?;
294             proxy
295                 .wasi_http_incoming_handler()
296                 .call_handle(&mut store, req, out)
297                 .await
298                 .context("failed to invoke wasm `handle`")?;
299             Ok::<_, anyhow::Error>(())
300         });
301 
302         // Wait for the wasm to produce the initial response. If this succeeds
303         // then propagate that failure. If this fails then wait for the above
304         // task to complete to see if it failed, otherwise panic since that's
305         // unexpected.
306         let response = match receiver.await {
307             Ok(response) => response.context("response generation failed")?,
308             Err(_) => {
309                 task.await.unwrap()?;
310                 panic!("sender of response disappeared");
311             }
312         };
313 
314         // Send the headers/status which will extract the body for the next
315         // phase.
316         let body = self.send_response(&mut info, response);
317 
318         // Send the body, a blocking operation, over time as it becomes
319         // available.
320         self.send_response_body(&mut info, body)
321             .await
322             .context("failed to write response body")?;
323 
324         // Join on completion of the wasm task which should be done by this
325         // point.
326         task.await.unwrap()?;
327 
328         // And finally signal that we're done.
329         info.request_done();
330 
331         Ok(())
332     }
333 
to_request_builder( &self, info: &NxtRequestInfo, ) -> Result<http::request::Builder>334     fn to_request_builder(
335         &self,
336         info: &NxtRequestInfo,
337     ) -> Result<http::request::Builder> {
338         let mut request = http::Request::builder();
339 
340         request = request.method(info.method());
341         request = match info.version() {
342             "HTTP/0.9" => request.version(http::Version::HTTP_09),
343             "HTTP/1.0" => request.version(http::Version::HTTP_10),
344             "HTTP/1.1" => request.version(http::Version::HTTP_11),
345             "HTTP/2.0" => request.version(http::Version::HTTP_2),
346             "HTTP/3.0" => request.version(http::Version::HTTP_3),
347             version => {
348                 println!("unknown version: {version}");
349                 request
350             }
351         };
352 
353         let uri = http::Uri::builder()
354             .scheme(if info.tls() { "https" } else { "http" })
355             .authority(info.server_name())
356             .path_and_query(info.target())
357             .build()
358             .context("failed to build URI")?;
359         request = request.uri(uri);
360 
361         for (name, value) in info.fields() {
362             request = request.header(name, value);
363         }
364         Ok(request)
365     }
366 
to_request_body( &self, info: &mut NxtRequestInfo, ) -> BoxBody<Bytes, ErrorCode>367     fn to_request_body(
368         &self,
369         info: &mut NxtRequestInfo,
370     ) -> BoxBody<Bytes, ErrorCode> {
371         // TODO: should convert the body into a form of `Stream` to become an
372         // async stream of frames. The return value can represent that here
373         // but for now this slurps up the entire body into memory and puts it
374         // all in a single `BytesMut` which is then converted to `Bytes`.
375         let mut body =
376             BytesMut::with_capacity(info.content_length().try_into().unwrap());
377 
378         // TODO: can this perform a partial read?
379         // TODO: how to make this async at the nxt level?
380         info.request_read(&mut body);
381 
382         Full::new(body.freeze()).map_err(|e| match e {}).boxed()
383     }
384 
send_response<T>( &self, info: &mut NxtRequestInfo, response: http::Response<T>, ) -> T385     fn send_response<T>(
386         &self,
387         info: &mut NxtRequestInfo,
388         response: http::Response<T>,
389     ) -> T {
390         info.init_response(
391             response.status().as_u16(),
392             response.headers().len().try_into().unwrap(),
393             response
394                 .headers()
395                 .iter()
396                 .map(|(k, v)| k.as_str().len() + v.len())
397                 .sum::<usize>()
398                 .try_into()
399                 .unwrap(),
400         );
401         for (k, v) in response.headers() {
402             info.add_field(k.as_str().as_bytes(), v.as_bytes());
403         }
404         info.send_response();
405 
406         response.into_body()
407     }
408 
send_response_body( &self, info: &mut NxtRequestInfo, mut body: BoxBody<Bytes, ErrorCode>, ) -> Result<()>409     async fn send_response_body(
410         &self,
411         info: &mut NxtRequestInfo,
412         mut body: BoxBody<Bytes, ErrorCode>,
413     ) -> Result<()> {
414         loop {
415             // Acquire the next frame, and because nothing is actually async
416             // at the moment this should never block meaning that the
417             // `Pending` case should not happen.
418             let frame = match body.frame().await {
419                 Some(Ok(frame)) => frame,
420                 Some(Err(e)) => break Err(e.into()),
421                 None => break Ok(()),
422             };
423             match frame.data_ref() {
424                 Some(data) => {
425                     info.response_write(&data);
426                 }
427                 None => {
428                     // TODO: what to do with trailers?
429                 }
430             }
431         }
432     }
433 }
434 
435 struct NxtRequestInfo {
436     info: *mut bindings::nxt_unit_request_info_t,
437 }
438 
439 // TODO: is this actually safe?
440 unsafe impl Send for NxtRequestInfo {}
441 unsafe impl Sync for NxtRequestInfo {}
442 
443 impl NxtRequestInfo {
method(&self) -> &str444     fn method(&self) -> &str {
445         unsafe {
446             let raw = (*self.info).request;
447             self.get_str(&(*raw).method, (*raw).method_length.into())
448         }
449     }
450 
tls(&self) -> bool451     fn tls(&self) -> bool {
452         unsafe { (*(*self.info).request).tls != 0 }
453     }
454 
version(&self) -> &str455     fn version(&self) -> &str {
456         unsafe {
457             let raw = (*self.info).request;
458             self.get_str(&(*raw).version, (*raw).version_length.into())
459         }
460     }
461 
server_name(&self) -> &str462     fn server_name(&self) -> &str {
463         unsafe {
464             let raw = (*self.info).request;
465             self.get_str(&(*raw).server_name, (*raw).server_name_length.into())
466         }
467     }
468 
target(&self) -> &str469     fn target(&self) -> &str {
470         unsafe {
471             let raw = (*self.info).request;
472             self.get_str(&(*raw).target, (*raw).target_length.into())
473         }
474     }
475 
content_length(&self) -> u64476     fn content_length(&self) -> u64 {
477         unsafe {
478             let raw_request = (*self.info).request;
479             (*raw_request).content_length
480         }
481     }
482 
fields(&self) -> impl Iterator<Item = (&str, &str)>483     fn fields(&self) -> impl Iterator<Item = (&str, &str)> {
484         unsafe {
485             let raw = (*self.info).request;
486             (0..(*raw).fields_count).map(move |i| {
487                 let field = (*raw).fields.as_ptr().add(i as usize);
488                 let name =
489                     self.get_str(&(*field).name, (*field).name_length.into());
490                 let value =
491                     self.get_str(&(*field).value, (*field).value_length.into());
492                 (name, value)
493             })
494         }
495     }
496 
request_read(&mut self, dst: &mut BytesMut)497     fn request_read(&mut self, dst: &mut BytesMut) {
498         unsafe {
499             let rest = dst.spare_capacity_mut();
500             let mut total_bytes_read = 0;
501             loop {
502                 let amt = bindings::nxt_unit_request_read(
503                     self.info,
504                     rest.as_mut_ptr().wrapping_add(total_bytes_read).cast(),
505                     32 * 1024 * 1024,
506                 );
507                 total_bytes_read += amt as usize;
508                 if total_bytes_read >= rest.len() {
509                     break;
510                 }
511             }
512             // TODO: handle failure when `amt` is negative
513             let total_bytes_read: usize = total_bytes_read.try_into().unwrap();
514             dst.set_len(dst.len() + total_bytes_read);
515         }
516     }
517 
response_write(&mut self, data: &[u8])518     fn response_write(&mut self, data: &[u8]) {
519         unsafe {
520             let rc = bindings::nxt_unit_response_write(
521                 self.info,
522                 data.as_ptr().cast(),
523                 data.len(),
524             );
525             assert_eq!(rc, 0);
526         }
527     }
528 
init_response(&mut self, status: u16, headers: u32, headers_size: u32)529     fn init_response(&mut self, status: u16, headers: u32, headers_size: u32) {
530         unsafe {
531             let rc = bindings::nxt_unit_response_init(
532                 self.info,
533                 status,
534                 headers,
535                 headers_size,
536             );
537             assert_eq!(rc, 0);
538         }
539     }
540 
add_field(&mut self, key: &[u8], val: &[u8])541     fn add_field(&mut self, key: &[u8], val: &[u8]) {
542         unsafe {
543             let rc = bindings::nxt_unit_response_add_field(
544                 self.info,
545                 key.as_ptr().cast(),
546                 key.len().try_into().unwrap(),
547                 val.as_ptr().cast(),
548                 val.len().try_into().unwrap(),
549             );
550             assert_eq!(rc, 0);
551         }
552     }
553 
send_response(&mut self)554     fn send_response(&mut self) {
555         unsafe {
556             let rc = bindings::nxt_unit_response_send(self.info);
557             assert_eq!(rc, 0);
558         }
559     }
560 
request_done(self)561     fn request_done(self) {
562         unsafe {
563             bindings::nxt_unit_request_done(
564                 self.info,
565                 bindings::NXT_UNIT_OK as i32,
566             );
567         }
568     }
569 
get_str( &self, ptr: &bindings::nxt_unit_sptr_t, len: u32, ) -> &str570     unsafe fn get_str(
571         &self,
572         ptr: &bindings::nxt_unit_sptr_t,
573         len: u32,
574     ) -> &str {
575         let ptr = bindings::nxt_unit_sptr_get(ptr);
576         let slice = std::slice::from_raw_parts(ptr, len.try_into().unwrap());
577         std::str::from_utf8(slice).unwrap()
578     }
579 }
580 
581 struct StoreState {
582     ctx: WasiCtx,
583     http: WasiHttpCtx,
584     table: ResourceTable,
585 }
586 
587 impl WasiView for StoreState {
table(&self) -> &ResourceTable588     fn table(&self) -> &ResourceTable {
589         &self.table
590     }
table_mut(&mut self) -> &mut ResourceTable591     fn table_mut(&mut self) -> &mut ResourceTable {
592         &mut self.table
593     }
ctx(&self) -> &WasiCtx594     fn ctx(&self) -> &WasiCtx {
595         &self.ctx
596     }
ctx_mut(&mut self) -> &mut WasiCtx597     fn ctx_mut(&mut self) -> &mut WasiCtx {
598         &mut self.ctx
599     }
600 }
601 
602 impl WasiHttpView for StoreState {
ctx(&mut self) -> &mut WasiHttpCtx603     fn ctx(&mut self) -> &mut WasiHttpCtx {
604         &mut self.http
605     }
table(&mut self) -> &mut ResourceTable606     fn table(&mut self) -> &mut ResourceTable {
607         &mut self.table
608     }
609 }
610 
611 impl StoreState {}
612