Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tentacle/src/quic/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(missing_docs)]

#[allow(missing_docs)]
pub mod identity_mol;

Expand All @@ -13,3 +15,5 @@ pub mod error;
#[allow(missing_docs)]
/// Verifier for rustls
pub mod verifier;

pub mod stream;
43 changes: 43 additions & 0 deletions tentacle/src/quic/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::pin::Pin;

use tokio::io::{AsyncRead, AsyncWrite};

#[derive(Debug)]
pub struct QuicBiStream {
pub(crate) send: quinn::SendStream,
pub(crate) recv: quinn::RecvStream,
}

impl AsyncRead for QuicBiStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.recv).poll_read(cx, buf)
}
}

impl AsyncWrite for QuicBiStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.send).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.send).poll_shutdown(cx)
}
}
12 changes: 6 additions & 6 deletions tentacle/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use yamux::{Control, Session as YamuxSession, StreamHandle};
use crate::{
ProtocolId, SessionId, StreamId, SubstreamReadPart,
buffer::{Buffer, PriorityBuffer, SendResult},
channel::{QuickSinkExt, mpsc as priority_mpsc, mpsc::Priority},
channel::{QuickSinkExt, mpsc::{self as priority_mpsc, Priority}},
context::SessionContext,
error::{HandshakeErrorKind, ProtocolHandleErrorKind, TransportErrorKind},
multiaddr::Multiaddr,
Expand All @@ -28,7 +28,7 @@ use crate::{
config::{Meta, SessionConfig},
future_task::BoxedFutureTask,
},
substream::{ProtocolEvent, SubstreamBuilder, SubstreamWritePartBuilder},
substream::{ProtocolEvent, SubstreamBuilder, SubstreamInner, SubstreamWritePartBuilder},
transports::MultiIncoming,
};

Expand Down Expand Up @@ -260,7 +260,7 @@ impl Session {
procedure: impl Future<
Output = Result<
(
Framed<StreamHandle, LengthDelimitedCodec>,
Framed<SubstreamInner, LengthDelimitedCodec>,
String,
Option<String>,
),
Expand Down Expand Up @@ -325,7 +325,7 @@ impl Session {

let task = async move {
let handle = match control.open_stream().await {
Ok(handle) => handle,
Ok(handle) => SubstreamInner::Yamux(handle),
Err(e) => {
debug!("session {} open stream error: {}", id, e);
return Err(io::ErrorKind::BrokenPipe.into());
Expand Down Expand Up @@ -398,7 +398,7 @@ impl Session {
})
.collect();

let task = server_select(substream, proto_metas);
let task = server_select(SubstreamInner::Yamux(substream), proto_metas);
self.select_procedure(task);
}

Expand All @@ -407,7 +407,7 @@ impl Session {
cx: &mut Context,
name: String,
version: String,
substream: Box<Framed<StreamHandle, LengthDelimitedCodec>>,
substream: Box<Framed<SubstreamInner, LengthDelimitedCodec>>,
) {
let proto = match self.protocol_configs_by_name.get(&name) {
Some(proto) => proto,
Expand Down
66 changes: 58 additions & 8 deletions tentacle/src/substream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
sync::{Arc, atomic::Ordering},
task::{Context, Poll},
};
use tokio::io::AsyncWrite;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::codec::{Framed, length_delimited::LengthDelimitedCodec};

use crate::{
Expand All @@ -22,9 +22,59 @@ use crate::{
protocol_handle_stream::{ServiceProtocolEvent, SessionProtocolEvent},
service::config::SessionConfig,
traits::Codec,
yamux::StreamHandle,
};

#[derive(Debug)]
pub(crate) enum SubstreamInner {
Yamux(yamux::StreamHandle),
#[cfg(feature = "quic")]
Quic(crate::quic::stream::QuicBiStream),
}

impl AsyncRead for SubstreamInner {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Yamux(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "quic")]
Self::Quic(s) => Pin::new(s).poll_read(cx, buf),
}
}
}

impl AsyncWrite for SubstreamInner {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Yamux(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "quic")]
Self::Quic(s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Yamux(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "quic")]
Self::Quic(s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Yamux(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "quic")]
Self::Quic(s) => Pin::new(s).poll_shutdown(cx),
}
}
}

/// Event generated/received by the protocol stream
#[derive(Debug)]
pub(crate) enum ProtocolEvent {
Expand All @@ -33,7 +83,7 @@ pub(crate) enum ProtocolEvent {
/// Protocol name
proto_name: String,
/// Yamux sub stream handle handshake framed
substream: Box<Framed<StreamHandle, LengthDelimitedCodec>>,
substream: Box<Framed<SubstreamInner, LengthDelimitedCodec>>,
/// Protocol version
version: String,
},
Expand Down Expand Up @@ -65,7 +115,7 @@ pub(crate) enum ProtocolEvent {
/// Each custom protocol in a session corresponds to a sub stream
/// Can be seen as the route of each protocol
pub(crate) struct Substream<U> {
substream: Framed<StreamHandle, U>,
substream: Framed<SubstreamInner, U>,
id: StreamId,
proto_id: ProtocolId,

Expand Down Expand Up @@ -562,7 +612,7 @@ impl SubstreamBuilder {
self
}

pub fn build<U>(self, substream: Framed<StreamHandle, U>) -> Substream<U>
pub fn build<U>(self, substream: Framed<SubstreamInner, U>) -> Substream<U>
where
U: Codec,
{
Expand Down Expand Up @@ -592,7 +642,7 @@ impl SubstreamBuilder {
/* Code organization under read-write separation */

pub(crate) struct SubstreamWritePart<U> {
substream: SplitSink<Framed<StreamHandle, U>, bytes::Bytes>,
substream: SplitSink<Framed<SubstreamInner, U>, bytes::Bytes>,
id: StreamId,
proto_id: ProtocolId,

Expand Down Expand Up @@ -863,7 +913,7 @@ where

/// Protocol Stream read part
pub struct SubstreamReadPart {
pub(crate) substream: SplitStream<Framed<StreamHandle, Box<dyn Codec + Send + 'static>>>,
pub(crate) substream: SplitStream<Framed<SubstreamInner, Box<dyn Codec + Send + 'static>>>,
pub(crate) before_receive: Option<BeforeReceive>,
pub(crate) proto_id: ProtocolId,
pub(crate) stream_id: StreamId,
Expand Down Expand Up @@ -966,7 +1016,7 @@ impl SubstreamWritePartBuilder {

pub fn build<U>(
self,
substream: SplitSink<Framed<StreamHandle, U>, bytes::Bytes>,
substream: SplitSink<Framed<SubstreamInner, U>, bytes::Bytes>,
) -> SubstreamWritePart<U>
where
U: Codec,
Expand Down
Loading