Commit 96868b69 authored by Rémi Denis-Courmont's avatar Rémi Denis-Courmont

tls: use I/O vector for sending

parent 95ba2e83
...@@ -43,7 +43,7 @@ struct vlc_tls ...@@ -43,7 +43,7 @@ struct vlc_tls
int fd; int fd;
ssize_t (*recv)(struct vlc_tls *, void *, size_t); ssize_t (*recv)(struct vlc_tls *, void *, size_t);
ssize_t (*send)(struct vlc_tls *, const void *, size_t); ssize_t (*writev)(struct vlc_tls *, const struct iovec *, unsigned);
int (*shutdown)(struct vlc_tls *, bool duplex); int (*shutdown)(struct vlc_tls *, bool duplex);
void (*close)(vlc_tls_t *); void (*close)(vlc_tls_t *);
}; };
......
...@@ -190,20 +190,25 @@ static void vlc_h2_output_flush_unlocked(struct vlc_h2_output *out) ...@@ -190,20 +190,25 @@ static void vlc_h2_output_flush_unlocked(struct vlc_h2_output *out)
static ssize_t vlc_https_send(vlc_tls_t *tls, const void *buf, size_t len) static ssize_t vlc_https_send(vlc_tls_t *tls, const void *buf, size_t len)
{ {
struct pollfd ufd; struct pollfd ufd;
struct iovec iov;
size_t count = 0; size_t count = 0;
ufd.fd = tls->fd; ufd.fd = tls->fd;
ufd.events = POLLOUT; ufd.events = POLLOUT;
iov.iov_base = (void *)buf;
iov.iov_len = len;
while (count < len) while (count < len)
{ {
int canc = vlc_savecancel(); int canc = vlc_savecancel();
ssize_t val = tls->send(tls, (char *)buf + count, len - count); ssize_t val = tls->writev(tls, &iov, 1);
vlc_restorecancel(canc); vlc_restorecancel(canc);
if (val > 0) if (val > 0)
{ {
iov.iov_base = (char *)iov.iov_base + val;
iov.iov_len -= val;
count += val; count += val;
continue; continue;
} }
......
...@@ -40,11 +40,14 @@ static bool send_failure = false; ...@@ -40,11 +40,14 @@ static bool send_failure = false;
static bool expect_hello = true; static bool expect_hello = true;
static vlc_sem_t rx; static vlc_sem_t rx;
static ssize_t send_callback(vlc_tls_t *tls, const void *buf, size_t len) static ssize_t send_callback(vlc_tls_t *tls, const struct iovec *iov,
unsigned count)
{ {
const uint8_t *p = buf; assert(count == 1);
assert(tls->writev == send_callback);
assert(tls->send == send_callback); const uint8_t *p = iov->iov_base;
size_t len = iov->iov_len;
if (expect_hello) if (expect_hello)
{ {
...@@ -71,7 +74,7 @@ static ssize_t send_callback(vlc_tls_t *tls, const void *buf, size_t len) ...@@ -71,7 +74,7 @@ static ssize_t send_callback(vlc_tls_t *tls, const void *buf, size_t len)
static vlc_tls_t fake_tls = static vlc_tls_t fake_tls =
{ {
.send = send_callback, .writev = send_callback,
}; };
static struct vlc_h2_frame *frame(unsigned char c) static struct vlc_h2_frame *frame(unsigned char c)
......
...@@ -159,11 +159,28 @@ static ssize_t vlc_gnutls_writev (gnutls_transport_ptr_t ptr, ...@@ -159,11 +159,28 @@ static ssize_t vlc_gnutls_writev (gnutls_transport_ptr_t ptr,
} }
#endif #endif
static ssize_t gnutls_Send (vlc_tls_t *tls, const void *buf, size_t length) static ssize_t gnutls_Send (vlc_tls_t *tls, const struct iovec *iov,
unsigned count)
{ {
gnutls_session_t session = tls->sys; gnutls_session_t session = tls->sys;
ssize_t val = gnutls_record_send (session, buf, length); ssize_t val;
if (!gnutls_record_check_corked(session))
{
gnutls_record_cork(session);
while (count > 0)
{
val = gnutls_record_send(session, iov->iov_base, iov->iov_len);
if (val < (ssize_t)iov->iov_len)
break;
iov++;
count--;
}
}
val = gnutls_record_uncork(session, 0);
return (val < 0) ? gnutls_Error (tls, val) : val; return (val < 0) ? gnutls_Error (tls, val) : val;
} }
...@@ -178,9 +195,18 @@ static ssize_t gnutls_Recv (vlc_tls_t *tls, void *buf, size_t length) ...@@ -178,9 +195,18 @@ static ssize_t gnutls_Recv (vlc_tls_t *tls, void *buf, size_t length)
static int gnutls_Shutdown(vlc_tls_t *tls, bool duplex) static int gnutls_Shutdown(vlc_tls_t *tls, bool duplex)
{ {
gnutls_session_t session = tls->sys; gnutls_session_t session = tls->sys;
int val = gnutls_bye(session, duplex ? GNUTLS_SHUT_RDWR : GNUTLS_SHUT_WR); ssize_t val;
/* Flush any pending data */
val = gnutls_record_uncork(session, 0);
if (val < 0)
return gnutls_Error(tls, val);
return (val < 0) ? gnutls_Error(tls, val) : 0; val = gnutls_bye(session, duplex ? GNUTLS_SHUT_RDWR : GNUTLS_SHUT_WR);
if (val < 0)
return gnutls_Error(tls, val);
return 0;
} }
static void gnutls_Close (vlc_tls_t *tls) static void gnutls_Close (vlc_tls_t *tls)
...@@ -256,7 +282,7 @@ static int gnutls_SessionOpen(vlc_tls_creds_t *creds, vlc_tls_t *tls, int type, ...@@ -256,7 +282,7 @@ static int gnutls_SessionOpen(vlc_tls_creds_t *creds, vlc_tls_t *tls, int type,
gnutls_transport_set_vec_push_function (session, vlc_gnutls_writev); gnutls_transport_set_vec_push_function (session, vlc_gnutls_writev);
#endif #endif
tls->sys = session; tls->sys = session;
tls->send = gnutls_Send; tls->writev = gnutls_Send;
tls->recv = gnutls_Recv; tls->recv = gnutls_Recv;
tls->shutdown = gnutls_Shutdown; tls->shutdown = gnutls_Shutdown;
tls->close = gnutls_Close; tls->close = gnutls_Close;
......
...@@ -427,12 +427,16 @@ static int st_Handshake (vlc_tls_creds_t *crd, vlc_tls_t *session, ...@@ -427,12 +427,16 @@ static int st_Handshake (vlc_tls_creds_t *crd, vlc_tls_t *session,
/** /**
* Sends data through a TLS session. * Sends data through a TLS session.
*/ */
static ssize_t st_Send (vlc_tls_t *session, const void *buf, size_t length) static ssize_t st_Send (vlc_tls_t *session, const struct iovec *iov,
unsigned count)
{ {
vlc_tls_sys_t *sys = session->sys; vlc_tls_sys_t *sys = session->sys;
assert(sys); assert(sys);
OSStatus ret = noErr; OSStatus ret = noErr;
if (unlikely(count == 0))
return 0;
/* /*
* SSLWrite does not return the number of bytes actually written to * SSLWrite does not return the number of bytes actually written to
* the socket, but the number of bytes written to the internal cache. * the socket, but the number of bytes written to the internal cache.
...@@ -466,7 +470,8 @@ static ssize_t st_Send (vlc_tls_t *session, const void *buf, size_t length) ...@@ -466,7 +470,8 @@ static ssize_t st_Send (vlc_tls_t *session, const void *buf, size_t length)
} }
} else { } else {
ret = SSLWrite(sys->p_context, buf, length, &actualSize); ret = SSLWrite(sys->p_context, iov->iov_base, iov->iov_len,
&actualSize);
if (ret == errSSLWouldBlock) { if (ret == errSSLWouldBlock) {
sys->i_send_buffered_bytes = length; sys->i_send_buffered_bytes = length;
...@@ -560,7 +565,7 @@ static int st_SessionOpenCommon (vlc_tls_creds_t *crd, vlc_tls_t *session, ...@@ -560,7 +565,7 @@ static int st_SessionOpenCommon (vlc_tls_creds_t *crd, vlc_tls_t *session,
sys->p_context = NULL; sys->p_context = NULL;
session->sys = sys; session->sys = sys;
session->send = st_Send; session->writev = st_Send;
session->recv = st_Recv; session->recv = st_Recv;
session->shutdown = st_SessionShutdown; session->shutdown = st_SessionShutdown;
session->close = st_SessionClose; session->close = st_SessionClose;
......
...@@ -253,9 +253,12 @@ ssize_t vlc_tls_Read(vlc_tls_t *session, void *buf, size_t len, bool waitall) ...@@ -253,9 +253,12 @@ ssize_t vlc_tls_Read(vlc_tls_t *session, void *buf, size_t len, bool waitall)
ssize_t vlc_tls_Write(vlc_tls_t *session, const void *buf, size_t len) ssize_t vlc_tls_Write(vlc_tls_t *session, const void *buf, size_t len)
{ {
struct pollfd ufd; struct pollfd ufd;
struct iovec iov;
ufd.fd = session->fd; ufd.fd = session->fd;
ufd.events = POLLOUT; ufd.events = POLLOUT;
iov.iov_base = (void *)buf;
iov.iov_len = len;
for (size_t sent = 0;;) for (size_t sent = 0;;)
{ {
...@@ -265,14 +268,14 @@ ssize_t vlc_tls_Write(vlc_tls_t *session, const void *buf, size_t len) ...@@ -265,14 +268,14 @@ ssize_t vlc_tls_Write(vlc_tls_t *session, const void *buf, size_t len)
return -1; return -1;
} }
ssize_t val = session->send(session, buf, len); ssize_t val = session->writev(session, &iov, 1);
if (val > 0) if (val > 0)
{ {
buf = ((const char *)buf) + val; iov.iov_base = ((char *)iov.iov_base) + val;
len -= val; iov.iov_len -= val;
sent += val; sent += val;
} }
if (len == 0 || val == 0) if (iov.iov_len == 0 || val == 0)
return sent; return sent;
if (val == -1 && errno != EINTR && errno != EAGAIN) if (val == -1 && errno != EINTR && errno != EAGAIN)
return sent ? (ssize_t)sent : -1; return sent ? (ssize_t)sent : -1;
...@@ -317,9 +320,15 @@ static ssize_t vlc_tls_DummyReceive(vlc_tls_t *tls, void *buf, size_t len) ...@@ -317,9 +320,15 @@ static ssize_t vlc_tls_DummyReceive(vlc_tls_t *tls, void *buf, size_t len)
return recv(tls->fd, buf, len, 0); return recv(tls->fd, buf, len, 0);
} }
static ssize_t vlc_tls_DummySend(vlc_tls_t *tls, const void *buf, size_t len) static ssize_t vlc_tls_DummySend(vlc_tls_t *tls, const struct iovec *iov,
unsigned count)
{ {
return send(tls->fd, buf, len, MSG_NOSIGNAL); const struct msghdr msg =
{
.msg_iov = (struct iovec *)iov,
.msg_iovlen = count,
};
return sendmsg(tls->fd, &msg, MSG_NOSIGNAL);
} }
static int vlc_tls_DummyShutdown(vlc_tls_t *tls, bool duplex) static int vlc_tls_DummyShutdown(vlc_tls_t *tls, bool duplex)
...@@ -341,7 +350,7 @@ vlc_tls_t *vlc_tls_DummyCreate(vlc_object_t *obj, int fd) ...@@ -341,7 +350,7 @@ vlc_tls_t *vlc_tls_DummyCreate(vlc_object_t *obj, int fd)
session->obj = obj; session->obj = obj;
session->fd = fd; session->fd = fd;
session->recv = vlc_tls_DummyReceive; session->recv = vlc_tls_DummyReceive;
session->send = vlc_tls_DummySend; session->writev = vlc_tls_DummySend;
session->shutdown = vlc_tls_DummyShutdown; session->shutdown = vlc_tls_DummyShutdown;
session->close = vlc_tls_DummyClose; session->close = vlc_tls_DummyClose;
return session; return session;
......
...@@ -197,6 +197,7 @@ int main(void) ...@@ -197,6 +197,7 @@ int main(void)
/* Do some I/O */ /* Do some I/O */
char buf[12]; char buf[12];
struct iovec iov;
val = tls->recv(tls, buf, sizeof (buf)); val = tls->recv(tls, buf, sizeof (buf));
assert(val == -1 && errno == EAGAIN); assert(val == -1 && errno == EAGAIN);
...@@ -228,13 +229,16 @@ int main(void) ...@@ -228,13 +229,16 @@ int main(void)
size_t bytes = 0; size_t bytes = 0;
unsigned seed = 0; unsigned seed = 0;
iov.iov_base = data;
iov.iov_len = sizeof (data);
do do
{ {
for (size_t i = 0; i < sizeof (data); i++) for (size_t i = 0; i < sizeof (data); i++)
data[i] = rand_r(&seed); data[i] = rand_r(&seed);
bytes += sizeof (data); bytes += sizeof (data);
} }
while ((val = tls->send(tls, data, sizeof (data))) == sizeof (data)); while ((val = tls->writev(tls, &iov, 1)) == sizeof (data));
bytes -= sizeof (data); bytes -= sizeof (data);
if (val > 0) if (val > 0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment