aboutsummaryrefslogtreecommitdiff
path: root/drivers/vhost/net.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/net.c')
-rw-r--r--drivers/vhost/net.c253
1 files changed, 133 insertions, 120 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index f80d3dd41d8..8dae2f724a3 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -15,9 +15,9 @@
#include <linux/moduleparam.h>
#include <linux/mutex.h>
#include <linux/workqueue.h>
-#include <linux/rcupdate.h>
#include <linux/file.h>
#include <linux/slab.h>
+#include <linux/vmalloc.h>
#include <linux/net.h>
#include <linux/if_packet.h>
@@ -71,7 +71,12 @@ enum {
};
struct vhost_net_ubuf_ref {
- struct kref kref;
+ /* refcount follows semantics similar to kref:
+ * 0: object is released
+ * 1: no outstanding ubufs
+ * >1: outstanding ubufs
+ */
+ atomic_t refcount;
wait_queue_head_t wait;
struct vhost_virtqueue *vq;
};
@@ -117,14 +122,6 @@ static void vhost_net_enable_zcopy(int vq)
vhost_net_zcopy_mask |= 0x1 << vq;
}
-static void vhost_net_zerocopy_done_signal(struct kref *kref)
-{
- struct vhost_net_ubuf_ref *ubufs;
-
- ubufs = container_of(kref, struct vhost_net_ubuf_ref, kref);
- wake_up(&ubufs->wait);
-}
-
static struct vhost_net_ubuf_ref *
vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
{
@@ -135,21 +132,29 @@ vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
if (!ubufs)
return ERR_PTR(-ENOMEM);
- kref_init(&ubufs->kref);
+ atomic_set(&ubufs->refcount, 1);
init_waitqueue_head(&ubufs->wait);
ubufs->vq = vq;
return ubufs;
}
-static void vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
+static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
{
- kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal);
+ int r = atomic_sub_return(1, &ubufs->refcount);
+ if (unlikely(!r))
+ wake_up(&ubufs->wait);
+ return r;
}
static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
{
- kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal);
- wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount));
+ vhost_net_ubuf_put(ubufs);
+ wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
+}
+
+static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
+{
+ vhost_net_ubuf_put_and_wait(ubufs);
kfree(ubufs);
}
@@ -163,7 +168,7 @@ static void vhost_net_clear_ubuf_info(struct vhost_net *n)
}
}
-int vhost_net_set_ubuf_info(struct vhost_net *n)
+static int vhost_net_set_ubuf_info(struct vhost_net *n)
{
bool zcopy;
int i;
@@ -184,7 +189,7 @@ err:
return -ENOMEM;
}
-void vhost_net_vq_reset(struct vhost_net *n)
+static void vhost_net_vq_reset(struct vhost_net *n)
{
int i;
@@ -272,12 +277,12 @@ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
* of used idx. Once lower device DMA done contiguously, we will signal KVM
* guest used idx.
*/
-static int vhost_zerocopy_signal_used(struct vhost_net *net,
- struct vhost_virtqueue *vq)
+static void vhost_zerocopy_signal_used(struct vhost_net *net,
+ struct vhost_virtqueue *vq)
{
struct vhost_net_virtqueue *nvq =
container_of(vq, struct vhost_net_virtqueue, vq);
- int i;
+ int i, add;
int j = 0;
for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
@@ -285,37 +290,43 @@ static int vhost_zerocopy_signal_used(struct vhost_net *net,
vhost_net_tx_err(net);
if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
- vhost_add_used_and_signal(vq->dev, vq,
- vq->heads[i].id, 0);
++j;
} else
break;
}
- if (j)
- nvq->done_idx = i;
- return j;
+ while (j) {
+ add = min(UIO_MAXIOV - nvq->done_idx, j);
+ vhost_add_used_and_signal_n(vq->dev, vq,
+ &vq->heads[nvq->done_idx], add);
+ nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV;
+ j -= add;
+ }
}
static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
{
struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
struct vhost_virtqueue *vq = ubufs->vq;
- int cnt = atomic_read(&ubufs->kref.refcount);
+ int cnt;
+
+ rcu_read_lock_bh();
+
+ /* set len to mark this desc buffers done DMA */
+ vq->heads[ubuf->desc].len = success ?
+ VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
+ cnt = vhost_net_ubuf_put(ubufs);
/*
* Trigger polling thread if guest stopped submitting new buffers:
- * in this case, the refcount after decrement will eventually reach 1
- * so here it is 2.
+ * in this case, the refcount after decrement will eventually reach 1.
* We also trigger polling periodically after each 16 packets
* (the value 16 here is more or less arbitrary, it's tuned to trigger
* less than 10% of times).
*/
- if (cnt <= 2 || !(cnt % 16))
+ if (cnt <= 1 || !(cnt % 16))
vhost_poll_queue(&vq->poll);
- /* set len to mark this desc buffers done DMA */
- vq->heads[ubuf->desc].len = success ?
- VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
- vhost_net_ubuf_put(ubufs);
+
+ rcu_read_unlock_bh();
}
/* Expects to be always run from workqueue - which acts as
@@ -341,12 +352,11 @@ static void handle_tx(struct vhost_net *net)
struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
bool zcopy, zcopy_used;
- /* TODO: check that we are running from vhost_worker? */
- sock = rcu_dereference_check(vq->private_data, 1);
+ mutex_lock(&vq->mutex);
+ sock = vq->private_data;
if (!sock)
- return;
+ goto out;
- mutex_lock(&vq->mutex);
vhost_disable_notify(&net->dev, vq);
hdr_size = nvq->vhost_hlen;
@@ -357,7 +367,14 @@ static void handle_tx(struct vhost_net *net)
if (zcopy)
vhost_zerocopy_signal_used(net, vq);
- head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
+ /* If more outstanding DMAs, queue the work.
+ * Handle upend_idx wrap around
+ */
+ if (unlikely((nvq->upend_idx + vq->num - VHOST_MAX_PEND)
+ % UIO_MAXIOV == nvq->done_idx))
+ break;
+
+ head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov),
&out, &in,
NULL, NULL);
@@ -366,17 +383,6 @@ static void handle_tx(struct vhost_net *net)
break;
/* Nothing new? Wait for eventfd to tell us they refilled. */
if (head == vq->num) {
- int num_pends;
-
- /* If more outstanding DMAs, queue the work.
- * Handle upend_idx wrap around
- */
- num_pends = likely(nvq->upend_idx >= nvq->done_idx) ?
- (nvq->upend_idx - nvq->done_idx) :
- (nvq->upend_idx + UIO_MAXIOV -
- nvq->done_idx);
- if (unlikely(num_pends > VHOST_MAX_PEND))
- break;
if (unlikely(vhost_enable_notify(&net->dev, vq))) {
vhost_disable_notify(&net->dev, vq);
continue;
@@ -399,43 +405,36 @@ static void handle_tx(struct vhost_net *net)
iov_length(nvq->hdr, s), hdr_size);
break;
}
- zcopy_used = zcopy && (len >= VHOST_GOODCOPY_LEN ||
- nvq->upend_idx != nvq->done_idx);
+
+ zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
+ && (nvq->upend_idx + 1) % UIO_MAXIOV !=
+ nvq->done_idx
+ && vhost_net_tx_select_zcopy(net);
/* use msg_control to pass vhost zerocopy ubuf info to skb */
if (zcopy_used) {
+ struct ubuf_info *ubuf;
+ ubuf = nvq->ubuf_info + nvq->upend_idx;
+
vq->heads[nvq->upend_idx].id = head;
- if (!vhost_net_tx_select_zcopy(net) ||
- len < VHOST_GOODCOPY_LEN) {
- /* copy don't need to wait for DMA done */
- vq->heads[nvq->upend_idx].len =
- VHOST_DMA_DONE_LEN;
- msg.msg_control = NULL;
- msg.msg_controllen = 0;
- ubufs = NULL;
- } else {
- struct ubuf_info *ubuf;
- ubuf = nvq->ubuf_info + nvq->upend_idx;
-
- vq->heads[nvq->upend_idx].len =
- VHOST_DMA_IN_PROGRESS;
- ubuf->callback = vhost_zerocopy_callback;
- ubuf->ctx = nvq->ubufs;
- ubuf->desc = nvq->upend_idx;
- msg.msg_control = ubuf;
- msg.msg_controllen = sizeof(ubuf);
- ubufs = nvq->ubufs;
- kref_get(&ubufs->kref);
- }
+ vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
+ ubuf->callback = vhost_zerocopy_callback;
+ ubuf->ctx = nvq->ubufs;
+ ubuf->desc = nvq->upend_idx;
+ msg.msg_control = ubuf;
+ msg.msg_controllen = sizeof(ubuf);
+ ubufs = nvq->ubufs;
+ atomic_inc(&ubufs->refcount);
nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
- } else
+ } else {
msg.msg_control = NULL;
+ ubufs = NULL;
+ }
/* TODO: Check specific error and bomb out unless ENOBUFS? */
err = sock->ops->sendmsg(NULL, sock, &msg, len);
if (unlikely(err < 0)) {
if (zcopy_used) {
- if (ubufs)
- vhost_net_ubuf_put(ubufs);
+ vhost_net_ubuf_put(ubufs);
nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
% UIO_MAXIOV;
}
@@ -456,7 +455,7 @@ static void handle_tx(struct vhost_net *net)
break;
}
}
-
+out:
mutex_unlock(&vq->mutex);
}
@@ -507,9 +506,13 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
r = -ENOBUFS;
goto err;
}
- d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
+ r = vhost_get_vq_desc(vq, vq->iov + seg,
ARRAY_SIZE(vq->iov) - seg, &out,
&in, log, log_num);
+ if (unlikely(r < 0))
+ goto err;
+
+ d = r;
if (d == vq->num) {
r = 0;
goto err;
@@ -534,6 +537,12 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
*iovcount = seg;
if (unlikely(log))
*log_num = nlogs;
+
+ /* Detect overrun */
+ if (unlikely(datalen > 0)) {
+ r = UIO_MAXIOV + 1;
+ goto err;
+ }
return headcount;
err:
vhost_discard_vq_desc(vq, headcount);
@@ -565,20 +574,20 @@ static void handle_rx(struct vhost_net *net)
s16 headcount;
size_t vhost_hlen, sock_hlen;
size_t vhost_len, sock_len;
- /* TODO: check that we are running from vhost_worker? */
- struct socket *sock = rcu_dereference_check(vq->private_data, 1);
-
- if (!sock)
- return;
+ struct socket *sock;
mutex_lock(&vq->mutex);
+ sock = vq->private_data;
+ if (!sock)
+ goto out;
vhost_disable_notify(&net->dev, vq);
+
vhost_hlen = nvq->vhost_hlen;
sock_hlen = nvq->sock_hlen;
- vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
+ vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
vq->log : NULL;
- mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
+ mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
while ((sock_len = peek_head_len(sock->sk))) {
sock_len += sock_hlen;
@@ -589,6 +598,14 @@ static void handle_rx(struct vhost_net *net)
/* On error, stop handling until the next kick. */
if (unlikely(headcount < 0))
break;
+ /* On overrun, truncate and discard */
+ if (unlikely(headcount > UIO_MAXIOV)) {
+ msg.msg_iovlen = 1;
+ err = sock->ops->recvmsg(NULL, sock, &msg,
+ 1, MSG_DONTWAIT | MSG_TRUNC);
+ pr_debug("Discarded rx packet: len %zd\n", sock_len);
+ continue;
+ }
/* OK, now we need to know about added descriptors. */
if (!headcount) {
if (unlikely(vhost_enable_notify(&net->dev, vq))) {
@@ -647,7 +664,7 @@ static void handle_rx(struct vhost_net *net)
break;
}
}
-
+out:
mutex_unlock(&vq->mutex);
}
@@ -685,16 +702,20 @@ static void handle_rx_net(struct vhost_work *work)
static int vhost_net_open(struct inode *inode, struct file *f)
{
- struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
+ struct vhost_net *n;
struct vhost_dev *dev;
struct vhost_virtqueue **vqs;
- int r, i;
+ int i;
- if (!n)
- return -ENOMEM;
+ n = kmalloc(sizeof *n, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
+ if (!n) {
+ n = vmalloc(sizeof *n);
+ if (!n)
+ return -ENOMEM;
+ }
vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
if (!vqs) {
- kfree(n);
+ kvfree(n);
return -ENOMEM;
}
@@ -711,12 +732,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)
n->vqs[i].vhost_hlen = 0;
n->vqs[i].sock_hlen = 0;
}
- r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
- if (r < 0) {
- kfree(n);
- kfree(vqs);
- return r;
- }
+ vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
@@ -745,8 +761,7 @@ static int vhost_net_enable_vq(struct vhost_net *n,
struct vhost_poll *poll = n->poll + (nvq - n->vqs);
struct socket *sock;
- sock = rcu_dereference_protected(vq->private_data,
- lockdep_is_held(&vq->mutex));
+ sock = vq->private_data;
if (!sock)
return 0;
@@ -759,10 +774,9 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
struct socket *sock;
mutex_lock(&vq->mutex);
- sock = rcu_dereference_protected(vq->private_data,
- lockdep_is_held(&vq->mutex));
+ sock = vq->private_data;
vhost_net_disable_vq(n, vq);
- rcu_assign_pointer(vq->private_data, NULL);
+ vq->private_data = NULL;
mutex_unlock(&vq->mutex);
return sock;
}
@@ -792,7 +806,7 @@ static void vhost_net_flush(struct vhost_net *n)
vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
n->tx_flush = false;
- kref_init(&n->vqs[VHOST_NET_VQ_TX].ubufs->kref);
+ atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
}
}
@@ -809,14 +823,16 @@ static int vhost_net_release(struct inode *inode, struct file *f)
vhost_dev_cleanup(&n->dev, false);
vhost_net_vq_reset(n);
if (tx_sock)
- fput(tx_sock->file);
+ sockfd_put(tx_sock);
if (rx_sock)
- fput(rx_sock->file);
+ sockfd_put(rx_sock);
+ /* Make sure no callbacks are outstanding */
+ synchronize_rcu_bh();
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
vhost_net_flush(n);
kfree(n->dev.vqs);
- kfree(n);
+ kvfree(n);
return 0;
}
@@ -849,7 +865,7 @@ static struct socket *get_raw_socket(int fd)
}
return sock;
err:
- fput(sock->file);
+ sockfd_put(sock);
return ERR_PTR(r);
}
@@ -918,8 +934,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
}
/* start polling new socket */
- oldsock = rcu_dereference_protected(vq->private_data,
- lockdep_is_held(&vq->mutex));
+ oldsock = vq->private_data;
if (sock != oldsock) {
ubufs = vhost_net_ubuf_alloc(vq,
sock && vhost_sock_zcopy(sock));
@@ -929,7 +944,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
}
vhost_net_disable_vq(n, vq);
- rcu_assign_pointer(vq->private_data, sock);
+ vq->private_data = sock;
r = vhost_init_used(vq);
if (r)
goto err_used;
@@ -948,7 +963,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
mutex_unlock(&vq->mutex);
if (oldubufs) {
- vhost_net_ubuf_put_and_wait(oldubufs);
+ vhost_net_ubuf_put_wait_and_free(oldubufs);
mutex_lock(&vq->mutex);
vhost_zerocopy_signal_used(n, vq);
mutex_unlock(&vq->mutex);
@@ -956,19 +971,19 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
if (oldsock) {
vhost_net_flush_vq(n, index);
- fput(oldsock->file);
+ sockfd_put(oldsock);
}
mutex_unlock(&n->dev.mutex);
return 0;
err_used:
- rcu_assign_pointer(vq->private_data, oldsock);
+ vq->private_data = oldsock;
vhost_net_enable_vq(n, vq);
if (ubufs)
- vhost_net_ubuf_put_and_wait(ubufs);
+ vhost_net_ubuf_put_wait_and_free(ubufs);
err_ubufs:
- fput(sock->file);
+ sockfd_put(sock);
err_vq:
mutex_unlock(&vq->mutex);
err:
@@ -999,9 +1014,9 @@ static long vhost_net_reset_owner(struct vhost_net *n)
done:
mutex_unlock(&n->dev.mutex);
if (tx_sock)
- fput(tx_sock->file);
+ sockfd_put(tx_sock);
if (rx_sock)
- fput(rx_sock->file);
+ sockfd_put(rx_sock);
return err;
}
@@ -1028,15 +1043,13 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
mutex_unlock(&n->dev.mutex);
return -EFAULT;
}
- n->dev.acked_features = features;
- smp_wmb();
for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
mutex_lock(&n->vqs[i].vq.mutex);
+ n->vqs[i].vq.acked_features = features;
n->vqs[i].vhost_hlen = vhost_hlen;
n->vqs[i].sock_hlen = sock_hlen;
mutex_unlock(&n->vqs[i].vq.mutex);
}
- vhost_net_flush(n);
mutex_unlock(&n->dev.mutex);
return 0;
}