diff --git a/tentacle/src/quic/mod.rs b/tentacle/src/quic/mod.rs index a4212070..04e4c213 100644 --- a/tentacle/src/quic/mod.rs +++ b/tentacle/src/quic/mod.rs @@ -1,3 +1,5 @@ +#![allow(missing_docs)] + #[allow(missing_docs)] pub mod identity_mol; @@ -13,3 +15,5 @@ pub mod error; #[allow(missing_docs)] /// Verifier for rustls pub mod verifier; + +pub mod stream; diff --git a/tentacle/src/quic/stream.rs b/tentacle/src/quic/stream.rs new file mode 100644 index 00000000..366e3a04 --- /dev/null +++ b/tentacle/src/quic/stream.rs @@ -0,0 +1,43 @@ +use std::pin::Pin; + +use tokio::io::{AsyncRead, AsyncWrite}; + +#[derive(Debug)] +pub struct QuicBiStream { + pub(crate) send: quinn::SendStream, + pub(crate) recv: quinn::RecvStream, +} + +impl AsyncRead for QuicBiStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.recv).poll_read(cx, buf) + } +} + +impl AsyncWrite for QuicBiStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.send).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.send).poll_shutdown(cx) + } +} diff --git a/tentacle/src/session.rs b/tentacle/src/session.rs index 5d275f17..63697f72 100644 --- a/tentacle/src/session.rs +++ b/tentacle/src/session.rs @@ -16,7 +16,7 @@ use yamux::{Control, Session as YamuxSession, StreamHandle}; use crate::{ ProtocolId, SessionId, StreamId, SubstreamReadPart, buffer::{Buffer, PriorityBuffer, SendResult}, - channel::{QuickSinkExt, mpsc as priority_mpsc, mpsc::Priority}, + channel::{QuickSinkExt, mpsc::{self as priority_mpsc, Priority}}, context::SessionContext, error::{HandshakeErrorKind, ProtocolHandleErrorKind, TransportErrorKind}, multiaddr::Multiaddr, @@ -28,7 +28,7 @@ use crate::{ config::{Meta, SessionConfig}, future_task::BoxedFutureTask, }, - substream::{ProtocolEvent, SubstreamBuilder, SubstreamWritePartBuilder}, + substream::{ProtocolEvent, SubstreamBuilder, SubstreamInner, SubstreamWritePartBuilder}, transports::MultiIncoming, }; @@ -260,7 +260,7 @@ impl Session { procedure: impl Future< Output = Result< ( - Framed, + Framed, String, Option, ), @@ -325,7 +325,7 @@ impl Session { let task = async move { let handle = match control.open_stream().await { - Ok(handle) => handle, + Ok(handle) => SubstreamInner::Yamux(handle), Err(e) => { debug!("session {} open stream error: {}", id, e); return Err(io::ErrorKind::BrokenPipe.into()); @@ -398,7 +398,7 @@ impl Session { }) .collect(); - let task = server_select(substream, proto_metas); + let task = server_select(SubstreamInner::Yamux(substream), proto_metas); self.select_procedure(task); } @@ -407,7 +407,7 @@ impl Session { cx: &mut Context, name: String, version: String, - substream: Box>, + substream: Box>, ) { let proto = match self.protocol_configs_by_name.get(&name) { Some(proto) => proto, diff --git a/tentacle/src/substream.rs b/tentacle/src/substream.rs index 391e366e..c85c6e8e 100644 --- a/tentacle/src/substream.rs +++ b/tentacle/src/substream.rs @@ -10,7 +10,7 @@ use std::{ sync::{Arc, atomic::Ordering}, task::{Context, Poll}, }; -use tokio::io::AsyncWrite; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_util::codec::{Framed, length_delimited::LengthDelimitedCodec}; use crate::{ @@ -22,9 +22,59 @@ use crate::{ protocol_handle_stream::{ServiceProtocolEvent, SessionProtocolEvent}, service::config::SessionConfig, traits::Codec, - yamux::StreamHandle, }; +#[derive(Debug)] +pub(crate) enum SubstreamInner { + Yamux(yamux::StreamHandle), + #[cfg(feature = "quic")] + Quic(crate::quic::stream::QuicBiStream), +} + +impl AsyncRead for SubstreamInner { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Yamux(s) => Pin::new(s).poll_read(cx, buf), + #[cfg(feature = "quic")] + Self::Quic(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for SubstreamInner { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Yamux(s) => Pin::new(s).poll_write(cx, buf), + #[cfg(feature = "quic")] + Self::Quic(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Yamux(s) => Pin::new(s).poll_flush(cx), + #[cfg(feature = "quic")] + Self::Quic(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Yamux(s) => Pin::new(s).poll_shutdown(cx), + #[cfg(feature = "quic")] + Self::Quic(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + /// Event generated/received by the protocol stream #[derive(Debug)] pub(crate) enum ProtocolEvent { @@ -33,7 +83,7 @@ pub(crate) enum ProtocolEvent { /// Protocol name proto_name: String, /// Yamux sub stream handle handshake framed - substream: Box>, + substream: Box>, /// Protocol version version: String, }, @@ -65,7 +115,7 @@ pub(crate) enum ProtocolEvent { /// Each custom protocol in a session corresponds to a sub stream /// Can be seen as the route of each protocol pub(crate) struct Substream { - substream: Framed, + substream: Framed, id: StreamId, proto_id: ProtocolId, @@ -562,7 +612,7 @@ impl SubstreamBuilder { self } - pub fn build(self, substream: Framed) -> Substream + pub fn build(self, substream: Framed) -> Substream where U: Codec, { @@ -592,7 +642,7 @@ impl SubstreamBuilder { /* Code organization under read-write separation */ pub(crate) struct SubstreamWritePart { - substream: SplitSink, bytes::Bytes>, + substream: SplitSink, bytes::Bytes>, id: StreamId, proto_id: ProtocolId, @@ -863,7 +913,7 @@ where /// Protocol Stream read part pub struct SubstreamReadPart { - pub(crate) substream: SplitStream>>, + pub(crate) substream: SplitStream>>, pub(crate) before_receive: Option, pub(crate) proto_id: ProtocolId, pub(crate) stream_id: StreamId, @@ -966,7 +1016,7 @@ impl SubstreamWritePartBuilder { pub fn build( self, - substream: SplitSink, bytes::Bytes>, + substream: SplitSink, bytes::Bytes>, ) -> SubstreamWritePart where U: Codec,