aboutsummaryrefslogtreecommitdiff
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c91
-rw-r--r--drivers/vhost/test.c5
-rw-r--r--drivers/vhost/vhost.c213
-rw-r--r--drivers/vhost/vhost.h34
4 files changed, 294 insertions, 49 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index e224a92baa1..882a51fe7b3 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -12,6 +12,7 @@
#include <linux/virtio_net.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
+#include <linux/moduleparam.h>
#include <linux/mutex.h>
#include <linux/workqueue.h>
#include <linux/rcupdate.h>
@@ -28,10 +29,18 @@
#include "vhost.h"
+static int experimental_zcopytx;
+module_param(experimental_zcopytx, int, 0444);
+MODULE_PARM_DESC(experimental_zcopytx, "Enable Experimental Zero Copy TX");
+
/* Max number of bytes transferred before requeueing the job.
* Using this limit prevents one virtqueue from starving others. */
#define VHOST_NET_WEIGHT 0x80000
+/* MAX number of TX used buffers for outstanding zerocopy */
+#define VHOST_MAX_PEND 128
+#define VHOST_GOODCOPY_LEN 256
+
enum {
VHOST_NET_VQ_RX = 0,
VHOST_NET_VQ_TX = 1,
@@ -54,6 +63,12 @@ struct vhost_net {
enum vhost_net_poll_state tx_poll_state;
};
+static bool vhost_sock_zcopy(struct socket *sock)
+{
+ return unlikely(experimental_zcopytx) &&
+ sock_flag(sock->sk, SOCK_ZEROCOPY);
+}
+
/* Pop first len bytes from iovec. Return number of segments used. */
static int move_iovec_hdr(struct iovec *from, struct iovec *to,
size_t len, int iov_count)
@@ -129,6 +144,8 @@ static void handle_tx(struct vhost_net *net)
int err, wmem;
size_t hdr_size;
struct socket *sock;
+ struct vhost_ubuf_ref *uninitialized_var(ubufs);
+ bool zcopy;
/* TODO: check that we are running from vhost_worker? */
sock = rcu_dereference_check(vq->private_data, 1);
@@ -149,8 +166,13 @@ static void handle_tx(struct vhost_net *net)
if (wmem < sock->sk->sk_sndbuf / 2)
tx_poll_stop(net);
hdr_size = vq->vhost_hlen;
+ zcopy = vhost_sock_zcopy(sock);
for (;;) {
+ /* Release DMAs done buffers first */
+ if (zcopy)
+ vhost_zerocopy_signal_used(vq);
+
head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
ARRAY_SIZE(vq->iov),
&out, &in,
@@ -160,12 +182,25 @@ static void handle_tx(struct vhost_net *net)
break;
/* Nothing new? Wait for eventfd to tell us they refilled. */
if (head == vq->num) {
+ int num_pends;
+
wmem = atomic_read(&sock->sk->sk_wmem_alloc);
if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
tx_poll_start(net, sock);
set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
break;
}
+ /* If more outstanding DMAs, queue the work.
+ * Handle upend_idx wrap around
+ */
+ num_pends = likely(vq->upend_idx >= vq->done_idx) ?
+ (vq->upend_idx - vq->done_idx) :
+ (vq->upend_idx + UIO_MAXIOV - vq->done_idx);
+ if (unlikely(num_pends > VHOST_MAX_PEND)) {
+ tx_poll_start(net, sock);
+ set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
+ break;
+ }
if (unlikely(vhost_enable_notify(&net->dev, vq))) {
vhost_disable_notify(&net->dev, vq);
continue;
@@ -188,9 +223,39 @@ static void handle_tx(struct vhost_net *net)
iov_length(vq->hdr, s), hdr_size);
break;
}
+ /* use msg_control to pass vhost zerocopy ubuf info to skb */
+ if (zcopy) {
+ vq->heads[vq->upend_idx].id = head;
+ if (len < VHOST_GOODCOPY_LEN) {
+ /* copy don't need to wait for DMA done */
+ vq->heads[vq->upend_idx].len =
+ VHOST_DMA_DONE_LEN;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ ubufs = NULL;
+ } else {
+ struct ubuf_info *ubuf = &vq->ubuf_info[head];
+
+ vq->heads[vq->upend_idx].len = len;
+ ubuf->callback = vhost_zerocopy_callback;
+ ubuf->arg = vq->ubufs;
+ ubuf->desc = vq->upend_idx;
+ msg.msg_control = ubuf;
+ msg.msg_controllen = sizeof(ubuf);
+ ubufs = vq->ubufs;
+ kref_get(&ubufs->kref);
+ }
+ vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV;
+ }
/* TODO: Check specific error and bomb out unless ENOBUFS? */
err = sock->ops->sendmsg(NULL, sock, &msg, len);
if (unlikely(err < 0)) {
+ if (zcopy) {
+ if (ubufs)
+ vhost_ubuf_put(ubufs);
+ vq->upend_idx = ((unsigned)vq->upend_idx - 1) %
+ UIO_MAXIOV;
+ }
vhost_discard_vq_desc(vq, 1);
tx_poll_start(net, sock);
break;
@@ -198,7 +263,8 @@ static void handle_tx(struct vhost_net *net)
if (err != len)
pr_debug("Truncated TX packet: "
" len %d != %zd\n", err, len);
- vhost_add_used_and_signal(&net->dev, vq, head, 0);
+ if (!zcopy)
+ vhost_add_used_and_signal(&net->dev, vq, head, 0);
total_len += len;
if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
vhost_poll_queue(&vq->poll);
@@ -603,6 +669,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
{
struct socket *sock, *oldsock;
struct vhost_virtqueue *vq;
+ struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
int r;
mutex_lock(&n->dev.mutex);
@@ -632,13 +699,31 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
oldsock = rcu_dereference_protected(vq->private_data,
lockdep_is_held(&vq->mutex));
if (sock != oldsock) {
+ ubufs = vhost_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock));
+ if (IS_ERR(ubufs)) {
+ r = PTR_ERR(ubufs);
+ goto err_ubufs;
+ }
+ oldubufs = vq->ubufs;
+ vq->ubufs = ubufs;
vhost_net_disable_vq(n, vq);
rcu_assign_pointer(vq->private_data, sock);
vhost_net_enable_vq(n, vq);
+
+ r = vhost_init_used(vq);
+ if (r)
+ goto err_vq;
}
mutex_unlock(&vq->mutex);
+ if (oldubufs) {
+ vhost_ubuf_put_and_wait(oldubufs);
+ mutex_lock(&vq->mutex);
+ vhost_zerocopy_signal_used(vq);
+ mutex_unlock(&vq->mutex);
+ }
+
if (oldsock) {
vhost_net_flush_vq(n, index);
fput(oldsock->file);
@@ -647,6 +732,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
mutex_unlock(&n->dev.mutex);
return 0;
+err_ubufs:
+ fput(sock->file);
err_vq:
mutex_unlock(&vq->mutex);
err:
@@ -776,6 +863,8 @@ static struct miscdevice vhost_net_misc = {
static int vhost_net_init(void)
{
+ if (experimental_zcopytx)
+ vhost_enable_zcopy(VHOST_NET_VQ_TX);
return misc_register(&vhost_net_misc);
}
module_init(vhost_net_init);
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 734e1d74ad8..fc9a1d75281 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -195,8 +195,13 @@ static long vhost_test_run(struct vhost_test *n, int test)
lockdep_is_held(&vq->mutex));
rcu_assign_pointer(vq->private_data, priv);
+ r = vhost_init_used(&n->vqs[index]);
+
mutex_unlock(&vq->mutex);
+ if (r)
+ goto err;
+
if (oldpriv) {
vhost_test_flush_vq(n, index);
}
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index ea966b35635..c14c42b95ab 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -37,6 +37,8 @@ enum {
VHOST_MEMORY_F_LOG = 0x1,
};
+static unsigned vhost_zcopy_mask __read_mostly;
+
#define vhost_used_event(vq) ((u16 __user *)&vq->avail->ring[vq->num])
#define vhost_avail_event(vq) ((u16 __user *)&vq->used->ring[vq->num])
@@ -179,6 +181,9 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->call_ctx = NULL;
vq->call = NULL;
vq->log_ctx = NULL;
+ vq->upend_idx = 0;
+ vq->done_idx = 0;
+ vq->ubufs = NULL;
}
static int vhost_worker(void *data)
@@ -225,10 +230,28 @@ static int vhost_worker(void *data)
return 0;
}
+static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
+{
+ kfree(vq->indirect);
+ vq->indirect = NULL;
+ kfree(vq->log);
+ vq->log = NULL;
+ kfree(vq->heads);
+ vq->heads = NULL;
+ kfree(vq->ubuf_info);
+ vq->ubuf_info = NULL;
+}
+
+void vhost_enable_zcopy(int vq)
+{
+ vhost_zcopy_mask |= 0x1 << vq;
+}
+
/* Helper to allocate iovec buffers for all vqs. */
static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
{
int i;
+ bool zcopy;
for (i = 0; i < dev->nvqs; ++i) {
dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect *
@@ -237,19 +260,21 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
GFP_KERNEL);
dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads *
UIO_MAXIOV, GFP_KERNEL);
-
+ zcopy = vhost_zcopy_mask & (0x1 << i);
+ if (zcopy)
+ dev->vqs[i].ubuf_info =
+ kmalloc(sizeof *dev->vqs[i].ubuf_info *
+ UIO_MAXIOV, GFP_KERNEL);
if (!dev->vqs[i].indirect || !dev->vqs[i].log ||
- !dev->vqs[i].heads)
+ !dev->vqs[i].heads ||
+ (zcopy && !dev->vqs[i].ubuf_info))
goto err_nomem;
}
return 0;
err_nomem:
- for (; i >= 0; --i) {
- kfree(dev->vqs[i].indirect);
- kfree(dev->vqs[i].log);
- kfree(dev->vqs[i].heads);
- }
+ for (; i >= 0; --i)
+ vhost_vq_free_iovecs(&dev->vqs[i]);
return -ENOMEM;
}
@@ -257,14 +282,8 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev)
{
int i;
- for (i = 0; i < dev->nvqs; ++i) {
- kfree(dev->vqs[i].indirect);
- dev->vqs[i].indirect = NULL;
- kfree(dev->vqs[i].log);
- dev->vqs[i].log = NULL;
- kfree(dev->vqs[i].heads);
- dev->vqs[i].heads = NULL;
- }
+ for (i = 0; i < dev->nvqs; ++i)
+ vhost_vq_free_iovecs(&dev->vqs[i]);
}
long vhost_dev_init(struct vhost_dev *dev,
@@ -287,6 +306,7 @@ long vhost_dev_init(struct vhost_dev *dev,
dev->vqs[i].log = NULL;
dev->vqs[i].indirect = NULL;
dev->vqs[i].heads = NULL;
+ dev->vqs[i].ubuf_info = NULL;
dev->vqs[i].dev = dev;
mutex_init(&dev->vqs[i].mutex);
vhost_vq_reset(dev, dev->vqs + i);
@@ -390,6 +410,30 @@ long vhost_dev_reset_owner(struct vhost_dev *dev)
return 0;
}
+/* In case of DMA done not in order in lower device driver for some reason.
+ * upend_idx is used to track end of used idx, done_idx is used to track head
+ * of used idx. Once lower device DMA done contiguously, we will signal KVM
+ * guest used idx.
+ */
+int vhost_zerocopy_signal_used(struct vhost_virtqueue *vq)
+{
+ int i;
+ int j = 0;
+
+ for (i = vq->done_idx; i != vq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
+ if ((vq->heads[i].len == VHOST_DMA_DONE_LEN)) {
+ vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
+ vhost_add_used_and_signal(vq->dev, vq,
+ vq->heads[i].id, 0);
+ ++j;
+ } else
+ break;
+ }
+ if (j)
+ vq->done_idx = i;
+ return j;
+}
+
/* Caller should have device mutex */
void vhost_dev_cleanup(struct vhost_dev *dev)
{
@@ -400,6 +444,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
vhost_poll_stop(&dev->vqs[i].poll);
vhost_poll_flush(&dev->vqs[i].poll);
}
+ /* Wait for all lower device DMAs done. */
+ if (dev->vqs[i].ubufs)
+ vhost_ubuf_put_and_wait(dev->vqs[i].ubufs);
+
+ /* Signal guest as appropriate. */
+ vhost_zerocopy_signal_used(&dev->vqs[i]);
+
if (dev->vqs[i].error_ctx)
eventfd_ctx_put(dev->vqs[i].error_ctx);
if (dev->vqs[i].error)
@@ -578,17 +629,6 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
return 0;
}
-static int init_used(struct vhost_virtqueue *vq,
- struct vring_used __user *used)
-{
- int r = put_user(vq->used_flags, &used->flags);
-
- if (r)
- return r;
- vq->signalled_used_valid = false;
- return get_user(vq->last_used_idx, &used->idx);
-}
-
static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
{
struct file *eventfp, *filep = NULL,
@@ -701,10 +741,6 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
}
}
- r = init_used(vq, (struct vring_used __user *)(unsigned long)
- a.used_user_addr);
- if (r)
- break;
vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
@@ -959,6 +995,57 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
return 0;
}
+static int vhost_update_used_flags(struct vhost_virtqueue *vq)
+{
+ void __user *used;
+ if (__put_user(vq->used_flags, &vq->used->flags) < 0)
+ return -EFAULT;
+ if (unlikely(vq->log_used)) {
+ /* Make sure the flag is seen before log. */
+ smp_wmb();
+ /* Log used flag write. */
+ used = &vq->used->flags;
+ log_write(vq->log_base, vq->log_addr +
+ (used - (void __user *)vq->used),
+ sizeof vq->used->flags);
+ if (vq->log_ctx)
+ eventfd_signal(vq->log_ctx, 1);
+ }
+ return 0;
+}
+
+static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
+{
+ if (__put_user(vq->avail_idx, vhost_avail_event(vq)))
+ return -EFAULT;
+ if (unlikely(vq->log_used)) {
+ void __user *used;
+ /* Make sure the event is seen before log. */
+ smp_wmb();
+ /* Log avail event write */
+ used = vhost_avail_event(vq);
+ log_write(vq->log_base, vq->log_addr +
+ (used - (void __user *)vq->used),
+ sizeof *vhost_avail_event(vq));
+ if (vq->log_ctx)
+ eventfd_signal(vq->log_ctx, 1);
+ }
+ return 0;
+}
+
+int vhost_init_used(struct vhost_virtqueue *vq)
+{
+ int r;
+ if (!vq->private_data)
+ return 0;
+
+ r = vhost_update_used_flags(vq);
+ if (r)
+ return r;
+ vq->signalled_used_valid = false;
+ return get_user(vq->last_used_idx, &vq->used->idx);
+}
+
static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
struct iovec iov[], int iov_size)
{
@@ -1430,34 +1517,20 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
return false;
vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
- r = put_user(vq->used_flags, &vq->used->flags);
+ r = vhost_update_used_flags(vq);
if (r) {
vq_err(vq, "Failed to enable notification at %p: %d\n",
&vq->used->flags, r);
return false;
}
} else {
- r = put_user(vq->avail_idx, vhost_avail_event(vq));
+ r = vhost_update_avail_event(vq, vq->avail_idx);
if (r) {
vq_err(vq, "Failed to update avail event index at %p: %d\n",
vhost_avail_event(vq), r);
return false;
}
}
- if (unlikely(vq->log_used)) {
- void __user *used;
- /* Make sure data is seen before log. */
- smp_wmb();
- used = vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX) ?
- &vq->used->flags : vhost_avail_event(vq);
- /* Log used flags or event index entry write. Both are 16 bit
- * fields. */
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof(u16));
- if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
- }
/* They could have slipped one in as we were doing that: make
* sure it's written, then check again. */
smp_mb();
@@ -1480,9 +1553,55 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
return;
vq->used_flags |= VRING_USED_F_NO_NOTIFY;
if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
- r = put_user(vq->used_flags, &vq->used->flags);
+ r = vhost_update_used_flags(vq);
if (r)
vq_err(vq, "Failed to enable notification at %p: %d\n",
&vq->used->flags, r);
}
}
+
+static void vhost_zerocopy_done_signal(struct kref *kref)
+{
+ struct vhost_ubuf_ref *ubufs = container_of(kref, struct vhost_ubuf_ref,
+ kref);
+ wake_up(&ubufs->wait);
+}
+
+struct vhost_ubuf_ref *vhost_ubuf_alloc(struct vhost_virtqueue *vq,
+ bool zcopy)
+{
+ struct vhost_ubuf_ref *ubufs;
+ /* No zero copy backend? Nothing to count. */
+ if (!zcopy)
+ return NULL;
+ ubufs = kmalloc(sizeof *ubufs, GFP_KERNEL);
+ if (!ubufs)
+ return ERR_PTR(-ENOMEM);
+ kref_init(&ubufs->kref);
+ init_waitqueue_head(&ubufs->wait);
+ ubufs->vq = vq;
+ return ubufs;
+}
+
+void vhost_ubuf_put(struct vhost_ubuf_ref *ubufs)
+{
+ kref_put(&ubufs->kref, vhost_zerocopy_done_signal);
+}
+
+void vhost_ubuf_put_and_wait(struct vhost_ubuf_ref *ubufs)
+{
+ kref_put(&ubufs->kref, vhost_zerocopy_done_signal);
+ wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount));
+ kfree(ubufs);
+}
+
+void vhost_zerocopy_callback(void *arg)
+{
+ struct ubuf_info *ubuf = arg;
+ struct vhost_ubuf_ref *ubufs = ubuf->arg;
+ struct vhost_virtqueue *vq = ubufs->vq;
+
+ /* set len = 1 to mark this desc buffers done DMA */
+ vq->heads[ubuf->desc].len = VHOST_DMA_DONE_LEN;
+ kref_put(&ubufs->kref, vhost_zerocopy_done_signal);
+}
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 8e03379dd30..a801e2821d0 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -11,7 +11,12 @@
#include <linux/uio.h>
#include <linux/virtio_config.h>
#include <linux/virtio_ring.h>
-#include <asm/atomic.h>
+#include <linux/atomic.h>
+
+/* This is for zerocopy, used buffer len is set to 1 when lower device DMA
+ * done */
+#define VHOST_DMA_DONE_LEN 1
+#define VHOST_DMA_CLEAR_LEN 0
struct vhost_device;
@@ -50,6 +55,18 @@ struct vhost_log {
u64 len;
};
+struct vhost_virtqueue;
+
+struct vhost_ubuf_ref {
+ struct kref kref;
+ wait_queue_head_t wait;
+ struct vhost_virtqueue *vq;
+};
+
+struct vhost_ubuf_ref *vhost_ubuf_alloc(struct vhost_virtqueue *, bool zcopy);
+void vhost_ubuf_put(struct vhost_ubuf_ref *);
+void vhost_ubuf_put_and_wait(struct vhost_ubuf_ref *);
+
/* The virtqueue structure describes a queue attached to a device. */
struct vhost_virtqueue {
struct vhost_dev *dev;
@@ -114,6 +131,16 @@ struct vhost_virtqueue {
/* Log write descriptors */
void __user *log_base;
struct vhost_log *log;
+ /* vhost zerocopy support fields below: */
+ /* last used idx for outstanding DMA zerocopy buffers */
+ int upend_idx;
+ /* first used idx for DMA done zerocopy buffers */
+ int done_idx;
+ /* an array of userspace buffers info */
+ struct ubuf_info *ubuf_info;
+ /* Reference counting for outstanding ubufs.
+ * Protected by vq mutex. Writers must also take device mutex. */
+ struct vhost_ubuf_ref *ubufs;
};
struct vhost_dev {
@@ -147,6 +174,7 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
struct vhost_log *log, unsigned int *log_num);
void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
+int vhost_init_used(struct vhost_virtqueue *);
int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
unsigned count);
@@ -160,6 +188,8 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
unsigned int log_num, u64 len);
+void vhost_zerocopy_callback(void *arg);
+int vhost_zerocopy_signal_used(struct vhost_virtqueue *vq);
#define vq_err(vq, fmt, ...) do { \
pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \
@@ -186,4 +216,6 @@ static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
return acked_features & (1 << bit);
}
+void vhost_enable_zcopy(int vq);
+
#endif