aboutsummaryrefslogtreecommitdiff
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/Kconfig27
-rw-r--r--drivers/vhost/Makefile8
-rw-r--r--drivers/vhost/net.c784
-rw-r--r--drivers/vhost/scsi.c2367
-rw-r--r--drivers/vhost/test.c338
-rw-r--r--drivers/vhost/test.h7
-rw-r--r--drivers/vhost/vhost.c665
-rw-r--r--drivers/vhost/vhost.h83
-rw-r--r--drivers/vhost/vringh.c1010
9 files changed, 4732 insertions, 557 deletions
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig
index e4e2fd1b510..017a1e8a8f6 100644
--- a/drivers/vhost/Kconfig
+++ b/drivers/vhost/Kconfig
@@ -1,6 +1,8 @@
config VHOST_NET
- tristate "Host kernel accelerator for virtio net (EXPERIMENTAL)"
- depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP) && EXPERIMENTAL
+ tristate "Host kernel accelerator for virtio net"
+ depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP)
+ select VHOST
+ select VHOST_RING
---help---
This kernel module can be loaded in host kernel to accelerate
guest networking with virtio_net. Not to be confused with virtio_net
@@ -9,3 +11,24 @@ config VHOST_NET
To compile this driver as a module, choose M here: the module will
be called vhost_net.
+config VHOST_SCSI
+ tristate "VHOST_SCSI TCM fabric driver"
+ depends on TARGET_CORE && EVENTFD && m
+ select VHOST
+ select VHOST_RING
+ default n
+ ---help---
+ Say M here to enable the vhost_scsi TCM fabric module
+ for use with virtio-scsi guests
+
+config VHOST_RING
+ tristate
+ ---help---
+ This option is selected by any driver which needs to access
+ the host side of a virtio ring.
+
+config VHOST
+ tristate
+ ---help---
+ This option is selected by any driver which needs to access
+ the core of vhost.
diff --git a/drivers/vhost/Makefile b/drivers/vhost/Makefile
index 72dd02050bb..e0441c34db1 100644
--- a/drivers/vhost/Makefile
+++ b/drivers/vhost/Makefile
@@ -1,2 +1,8 @@
obj-$(CONFIG_VHOST_NET) += vhost_net.o
-vhost_net-y := vhost.o net.o
+vhost_net-y := net.o
+
+obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o
+vhost_scsi-y := scsi.o
+
+obj-$(CONFIG_VHOST_RING) += vringh.o
+obj-$(CONFIG_VHOST) += vhost.o
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 4b4da5b86ff..8dae2f724a3 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -10,57 +10,237 @@
#include <linux/eventfd.h>
#include <linux/vhost.h>
#include <linux/virtio_net.h>
-#include <linux/mmu_context.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
+#include <linux/moduleparam.h>
#include <linux/mutex.h>
#include <linux/workqueue.h>
-#include <linux/rcupdate.h>
#include <linux/file.h>
#include <linux/slab.h>
+#include <linux/vmalloc.h>
#include <linux/net.h>
#include <linux/if_packet.h>
#include <linux/if_arp.h>
#include <linux/if_tun.h>
#include <linux/if_macvlan.h>
+#include <linux/if_vlan.h>
#include <net/sock.h>
#include "vhost.h"
+static int experimental_zcopytx = 1;
+module_param(experimental_zcopytx, int, 0444);
+MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
+ " 1 -Enable; 0 - Disable");
+
/* Max number of bytes transferred before requeueing the job.
* Using this limit prevents one virtqueue from starving others. */
#define VHOST_NET_WEIGHT 0x80000
+/* MAX number of TX used buffers for outstanding zerocopy */
+#define VHOST_MAX_PEND 128
+#define VHOST_GOODCOPY_LEN 256
+
+/*
+ * For transmit, used buffer len is unused; we override it to track buffer
+ * status internally; used for zerocopy tx only.
+ */
+/* Lower device DMA failed */
+#define VHOST_DMA_FAILED_LEN 3
+/* Lower device DMA done */
+#define VHOST_DMA_DONE_LEN 2
+/* Lower device DMA in progress */
+#define VHOST_DMA_IN_PROGRESS 1
+/* Buffer unused */
+#define VHOST_DMA_CLEAR_LEN 0
+
+#define VHOST_DMA_IS_DONE(len) ((len) >= VHOST_DMA_DONE_LEN)
+
+enum {
+ VHOST_NET_FEATURES = VHOST_FEATURES |
+ (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
+ (1ULL << VIRTIO_NET_F_MRG_RXBUF),
+};
+
enum {
VHOST_NET_VQ_RX = 0,
VHOST_NET_VQ_TX = 1,
VHOST_NET_VQ_MAX = 2,
};
-enum vhost_net_poll_state {
- VHOST_NET_POLL_DISABLED = 0,
- VHOST_NET_POLL_STARTED = 1,
- VHOST_NET_POLL_STOPPED = 2,
+struct vhost_net_ubuf_ref {
+ /* refcount follows semantics similar to kref:
+ * 0: object is released
+ * 1: no outstanding ubufs
+ * >1: outstanding ubufs
+ */
+ atomic_t refcount;
+ wait_queue_head_t wait;
+ struct vhost_virtqueue *vq;
+};
+
+struct vhost_net_virtqueue {
+ struct vhost_virtqueue vq;
+ /* hdr is used to store the virtio header.
+ * Since each iovec has >= 1 byte length, we never need more than
+ * header length entries to store the header. */
+ struct iovec hdr[sizeof(struct virtio_net_hdr_mrg_rxbuf)];
+ size_t vhost_hlen;
+ size_t sock_hlen;
+ /* vhost zerocopy support fields below: */
+ /* last used idx for outstanding DMA zerocopy buffers */
+ int upend_idx;
+ /* first used idx for DMA done zerocopy buffers */
+ int done_idx;
+ /* an array of userspace buffers info */
+ struct ubuf_info *ubuf_info;
+ /* Reference counting for outstanding ubufs.
+ * Protected by vq mutex. Writers must also take device mutex. */
+ struct vhost_net_ubuf_ref *ubufs;
};
struct vhost_net {
struct vhost_dev dev;
- struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
+ struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
struct vhost_poll poll[VHOST_NET_VQ_MAX];
- /* Tells us whether we are polling a socket for TX.
- * We only do this when socket buffer fills up.
+ /* Number of TX recently submitted.
* Protected by tx vq lock. */
- enum vhost_net_poll_state tx_poll_state;
+ unsigned tx_packets;
+ /* Number of times zerocopy TX recently failed.
+ * Protected by tx vq lock. */
+ unsigned tx_zcopy_err;
+ /* Flush in progress. Protected by tx vq lock. */
+ bool tx_flush;
};
+static unsigned vhost_net_zcopy_mask __read_mostly;
+
+static void vhost_net_enable_zcopy(int vq)
+{
+ vhost_net_zcopy_mask |= 0x1 << vq;
+}
+
+static struct vhost_net_ubuf_ref *
+vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
+{
+ struct vhost_net_ubuf_ref *ubufs;
+ /* No zero copy backend? Nothing to count. */
+ if (!zcopy)
+ return NULL;
+ ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
+ if (!ubufs)
+ return ERR_PTR(-ENOMEM);
+ atomic_set(&ubufs->refcount, 1);
+ init_waitqueue_head(&ubufs->wait);
+ ubufs->vq = vq;
+ return ubufs;
+}
+
+static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
+{
+ int r = atomic_sub_return(1, &ubufs->refcount);
+ if (unlikely(!r))
+ wake_up(&ubufs->wait);
+ return r;
+}
+
+static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
+{
+ vhost_net_ubuf_put(ubufs);
+ wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
+}
+
+static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
+{
+ vhost_net_ubuf_put_and_wait(ubufs);
+ kfree(ubufs);
+}
+
+static void vhost_net_clear_ubuf_info(struct vhost_net *n)
+{
+ int i;
+
+ for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
+ kfree(n->vqs[i].ubuf_info);
+ n->vqs[i].ubuf_info = NULL;
+ }
+}
+
+static int vhost_net_set_ubuf_info(struct vhost_net *n)
+{
+ bool zcopy;
+ int i;
+
+ for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
+ zcopy = vhost_net_zcopy_mask & (0x1 << i);
+ if (!zcopy)
+ continue;
+ n->vqs[i].ubuf_info = kmalloc(sizeof(*n->vqs[i].ubuf_info) *
+ UIO_MAXIOV, GFP_KERNEL);
+ if (!n->vqs[i].ubuf_info)
+ goto err;
+ }
+ return 0;
+
+err:
+ vhost_net_clear_ubuf_info(n);
+ return -ENOMEM;
+}
+
+static void vhost_net_vq_reset(struct vhost_net *n)
+{
+ int i;
+
+ vhost_net_clear_ubuf_info(n);
+
+ for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
+ n->vqs[i].done_idx = 0;
+ n->vqs[i].upend_idx = 0;
+ n->vqs[i].ubufs = NULL;
+ n->vqs[i].vhost_hlen = 0;
+ n->vqs[i].sock_hlen = 0;
+ }
+
+}
+
+static void vhost_net_tx_packet(struct vhost_net *net)
+{
+ ++net->tx_packets;
+ if (net->tx_packets < 1024)
+ return;
+ net->tx_packets = 0;
+ net->tx_zcopy_err = 0;
+}
+
+static void vhost_net_tx_err(struct vhost_net *net)
+{
+ ++net->tx_zcopy_err;
+}
+
+static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
+{
+ /* TX flush waits for outstanding DMAs to be done.
+ * Don't start new DMAs.
+ */
+ return !net->tx_flush &&
+ net->tx_packets / 64 >= net->tx_zcopy_err;
+}
+
+static bool vhost_sock_zcopy(struct socket *sock)
+{
+ return unlikely(experimental_zcopytx) &&
+ sock_flag(sock->sk, SOCK_ZEROCOPY);
+}
+
/* Pop first len bytes from iovec. Return number of segments used. */
static int move_iovec_hdr(struct iovec *from, struct iovec *to,
size_t len, int iov_count)
{
int seg = 0;
size_t size;
+
while (len && seg < iov_count) {
size = min(from->iov_len, len);
to->iov_base = from->iov_base;
@@ -80,6 +260,7 @@ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
{
int seg = 0;
size_t size;
+
while (len && seg < iovcount) {
size = min(from->iov_len, len);
to->iov_base = from->iov_base;
@@ -91,29 +272,69 @@ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
}
}
-/* Caller must have TX VQ lock */
-static void tx_poll_stop(struct vhost_net *net)
+/* In case of DMA done not in order in lower device driver for some reason.
+ * upend_idx is used to track end of used idx, done_idx is used to track head
+ * of used idx. Once lower device DMA done contiguously, we will signal KVM
+ * guest used idx.
+ */
+static void vhost_zerocopy_signal_used(struct vhost_net *net,
+ struct vhost_virtqueue *vq)
{
- if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
- return;
- vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
- net->tx_poll_state = VHOST_NET_POLL_STOPPED;
+ struct vhost_net_virtqueue *nvq =
+ container_of(vq, struct vhost_net_virtqueue, vq);
+ int i, add;
+ int j = 0;
+
+ for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
+ if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
+ vhost_net_tx_err(net);
+ if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
+ vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
+ ++j;
+ } else
+ break;
+ }
+ while (j) {
+ add = min(UIO_MAXIOV - nvq->done_idx, j);
+ vhost_add_used_and_signal_n(vq->dev, vq,
+ &vq->heads[nvq->done_idx], add);
+ nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV;
+ j -= add;
+ }
}
-/* Caller must have TX VQ lock */
-static void tx_poll_start(struct vhost_net *net, struct socket *sock)
+static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
{
- if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
- return;
- vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
- net->tx_poll_state = VHOST_NET_POLL_STARTED;
+ struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
+ struct vhost_virtqueue *vq = ubufs->vq;
+ int cnt;
+
+ rcu_read_lock_bh();
+
+ /* set len to mark this desc buffers done DMA */
+ vq->heads[ubuf->desc].len = success ?
+ VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
+ cnt = vhost_net_ubuf_put(ubufs);
+
+ /*
+ * Trigger polling thread if guest stopped submitting new buffers:
+ * in this case, the refcount after decrement will eventually reach 1.
+ * We also trigger polling periodically after each 16 packets
+ * (the value 16 here is more or less arbitrary, it's tuned to trigger
+ * less than 10% of times).
+ */
+ if (cnt <= 1 || !(cnt % 16))
+ vhost_poll_queue(&vq->poll);
+
+ rcu_read_unlock_bh();
}
/* Expects to be always run from workqueue - which acts as
* read-size critical section for our kind of RCU. */
static void handle_tx(struct vhost_net *net)
{
- struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
+ struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
+ struct vhost_virtqueue *vq = &nvq->vq;
unsigned out, in, s;
int head;
struct msghdr msg = {
@@ -125,33 +346,35 @@ static void handle_tx(struct vhost_net *net)
.msg_flags = MSG_DONTWAIT,
};
size_t len, total_len = 0;
- int err, wmem;
+ int err;
size_t hdr_size;
struct socket *sock;
+ struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
+ bool zcopy, zcopy_used;
- sock = rcu_dereference_check(vq->private_data,
- lockdep_is_held(&vq->mutex));
+ mutex_lock(&vq->mutex);
+ sock = vq->private_data;
if (!sock)
- return;
-
- wmem = atomic_read(&sock->sk->sk_wmem_alloc);
- if (wmem >= sock->sk->sk_sndbuf) {
- mutex_lock(&vq->mutex);
- tx_poll_start(net, sock);
- mutex_unlock(&vq->mutex);
- return;
- }
+ goto out;
- use_mm(net->dev.mm);
- mutex_lock(&vq->mutex);
- vhost_disable_notify(vq);
+ vhost_disable_notify(&net->dev, vq);
- if (wmem < sock->sk->sk_sndbuf / 2)
- tx_poll_stop(net);
- hdr_size = vq->vhost_hlen;
+ hdr_size = nvq->vhost_hlen;
+ zcopy = nvq->ubufs;
for (;;) {
- head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
+ /* Release DMAs done buffers first */
+ if (zcopy)
+ vhost_zerocopy_signal_used(net, vq);
+
+ /* If more outstanding DMAs, queue the work.
+ * Handle upend_idx wrap around
+ */
+ if (unlikely((nvq->upend_idx + vq->num - VHOST_MAX_PEND)
+ % UIO_MAXIOV == nvq->done_idx))
+ break;
+
+ head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov),
&out, &in,
NULL, NULL);
@@ -160,14 +383,8 @@ static void handle_tx(struct vhost_net *net)
break;
/* Nothing new? Wait for eventfd to tell us they refilled. */
if (head == vq->num) {
- wmem = atomic_read(&sock->sk->sk_wmem_alloc);
- if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
- tx_poll_start(net, sock);
- set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
- break;
- }
- if (unlikely(vhost_enable_notify(vq))) {
- vhost_disable_notify(vq);
+ if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+ vhost_disable_notify(&net->dev, vq);
continue;
}
break;
@@ -178,48 +395,85 @@ static void handle_tx(struct vhost_net *net)
break;
}
/* Skip header. TODO: support TSO. */
- s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
+ s = move_iovec_hdr(vq->iov, nvq->hdr, hdr_size, out);
msg.msg_iovlen = out;
len = iov_length(vq->iov, out);
/* Sanity check */
if (!len) {
vq_err(vq, "Unexpected header len for TX: "
"%zd expected %zd\n",
- iov_length(vq->hdr, s), hdr_size);
+ iov_length(nvq->hdr, s), hdr_size);
break;
}
+
+ zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
+ && (nvq->upend_idx + 1) % UIO_MAXIOV !=
+ nvq->done_idx
+ && vhost_net_tx_select_zcopy(net);
+
+ /* use msg_control to pass vhost zerocopy ubuf info to skb */
+ if (zcopy_used) {
+ struct ubuf_info *ubuf;
+ ubuf = nvq->ubuf_info + nvq->upend_idx;
+
+ vq->heads[nvq->upend_idx].id = head;
+ vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
+ ubuf->callback = vhost_zerocopy_callback;
+ ubuf->ctx = nvq->ubufs;
+ ubuf->desc = nvq->upend_idx;
+ msg.msg_control = ubuf;
+ msg.msg_controllen = sizeof(ubuf);
+ ubufs = nvq->ubufs;
+ atomic_inc(&ubufs->refcount);
+ nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
+ } else {
+ msg.msg_control = NULL;
+ ubufs = NULL;
+ }
/* TODO: Check specific error and bomb out unless ENOBUFS? */
err = sock->ops->sendmsg(NULL, sock, &msg, len);
if (unlikely(err < 0)) {
+ if (zcopy_used) {
+ vhost_net_ubuf_put(ubufs);
+ nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
+ % UIO_MAXIOV;
+ }
vhost_discard_vq_desc(vq, 1);
- tx_poll_start(net, sock);
break;
}
if (err != len)
pr_debug("Truncated TX packet: "
" len %d != %zd\n", err, len);
- vhost_add_used_and_signal(&net->dev, vq, head, 0);
+ if (!zcopy_used)
+ vhost_add_used_and_signal(&net->dev, vq, head, 0);
+ else
+ vhost_zerocopy_signal_used(net, vq);
total_len += len;
+ vhost_net_tx_packet(net);
if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
vhost_poll_queue(&vq->poll);
break;
}
}
-
+out:
mutex_unlock(&vq->mutex);
- unuse_mm(net->dev.mm);
}
static int peek_head_len(struct sock *sk)
{
struct sk_buff *head;
int len = 0;
+ unsigned long flags;
- lock_sock(sk);
+ spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
head = skb_peek(&sk->sk_receive_queue);
- if (head)
+ if (likely(head)) {
len = head->len;
- release_sock(sk);
+ if (vlan_tx_tag_present(head))
+ len += VLAN_HLEN;
+ }
+
+ spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
return len;
}
@@ -230,6 +484,7 @@ static int peek_head_len(struct sock *sk)
* @iovcount - returned count of io vectors we fill
* @log - vhost log
* @log_num - log offset
+ * @quota - headcount quota, 1 for big buffer
* returns number of buffer heads allocated, negative on error
*/
static int get_rx_bufs(struct vhost_virtqueue *vq,
@@ -237,7 +492,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
int datalen,
unsigned *iovcount,
struct vhost_log *log,
- unsigned *log_num)
+ unsigned *log_num,
+ unsigned int quota)
{
unsigned int out, in;
int seg = 0;
@@ -245,14 +501,18 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
unsigned d;
int r, nlogs = 0;
- while (datalen > 0) {
+ while (datalen > 0 && headcount < quota) {
if (unlikely(seg >= UIO_MAXIOV)) {
r = -ENOBUFS;
goto err;
}
- d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
+ r = vhost_get_vq_desc(vq, vq->iov + seg,
ARRAY_SIZE(vq->iov) - seg, &out,
&in, log, log_num);
+ if (unlikely(r < 0))
+ goto err;
+
+ d = r;
if (d == vq->num) {
r = 0;
goto err;
@@ -277,6 +537,12 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
*iovcount = seg;
if (unlikely(log))
*log_num = nlogs;
+
+ /* Detect overrun */
+ if (unlikely(datalen > 0)) {
+ r = UIO_MAXIOV + 1;
+ goto err;
+ }
return headcount;
err:
vhost_discard_vq_desc(vq, headcount);
@@ -285,120 +551,10 @@ err:
/* Expects to be always run from workqueue - which acts as
* read-size critical section for our kind of RCU. */
-static void handle_rx_big(struct vhost_net *net)
-{
- struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
- unsigned out, in, log, s;
- int head;
- struct vhost_log *vq_log;
- struct msghdr msg = {
- .msg_name = NULL,
- .msg_namelen = 0,
- .msg_control = NULL, /* FIXME: get and handle RX aux data. */
- .msg_controllen = 0,
- .msg_iov = vq->iov,
- .msg_flags = MSG_DONTWAIT,
- };
-
- struct virtio_net_hdr hdr = {
- .flags = 0,
- .gso_type = VIRTIO_NET_HDR_GSO_NONE
- };
-
- size_t len, total_len = 0;
- int err;
- size_t hdr_size;
- struct socket *sock = rcu_dereference(vq->private_data);
- if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue))
- return;
-
- use_mm(net->dev.mm);
- mutex_lock(&vq->mutex);
- vhost_disable_notify(vq);
- hdr_size = vq->vhost_hlen;
-
- vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
- vq->log : NULL;
-
- for (;;) {
- head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
- ARRAY_SIZE(vq->iov),
- &out, &in,
- vq_log, &log);
- /* On error, stop handling until the next kick. */
- if (unlikely(head < 0))
- break;
- /* OK, now we need to know about added descriptors. */
- if (head == vq->num) {
- if (unlikely(vhost_enable_notify(vq))) {
- /* They have slipped one in as we were
- * doing that: check again. */
- vhost_disable_notify(vq);
- continue;
- }
- /* Nothing new? Wait for eventfd to tell us
- * they refilled. */
- break;
- }
- /* We don't need to be notified again. */
- if (out) {
- vq_err(vq, "Unexpected descriptor format for RX: "
- "out %d, int %d\n",
- out, in);
- break;
- }
- /* Skip header. TODO: support TSO/mergeable rx buffers. */
- s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, in);
- msg.msg_iovlen = in;
- len = iov_length(vq->iov, in);
- /* Sanity check */
- if (!len) {
- vq_err(vq, "Unexpected header len for RX: "
- "%zd expected %zd\n",
- iov_length(vq->hdr, s), hdr_size);
- break;
- }
- err = sock->ops->recvmsg(NULL, sock, &msg,
- len, MSG_DONTWAIT | MSG_TRUNC);
- /* TODO: Check specific error and bomb out unless EAGAIN? */
- if (err < 0) {
- vhost_discard_vq_desc(vq, 1);
- break;
- }
- /* TODO: Should check and handle checksum. */
- if (err > len) {
- pr_debug("Discarded truncated rx packet: "
- " len %d > %zd\n", err, len);
- vhost_discard_vq_desc(vq, 1);
- continue;
- }
- len = err;
- err = memcpy_toiovec(vq->hdr, (unsigned char *)&hdr, hdr_size);
- if (err) {
- vq_err(vq, "Unable to write vnet_hdr at addr %p: %d\n",
- vq->iov->iov_base, err);
- break;
- }
- len += hdr_size;
- vhost_add_used_and_signal(&net->dev, vq, head, len);
- if (unlikely(vq_log))
- vhost_log_write(vq, vq_log, log, len);
- total_len += len;
- if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
- vhost_poll_queue(&vq->poll);
- break;
- }
- }
-
- mutex_unlock(&vq->mutex);
- unuse_mm(net->dev.mm);
-}
-
-/* Expects to be always run from workqueue - which acts as
- * read-size critical section for our kind of RCU. */
-static void handle_rx_mergeable(struct vhost_net *net)
+static void handle_rx(struct vhost_net *net)
{
- struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
+ struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
+ struct vhost_virtqueue *vq = &nvq->vq;
unsigned uninitialized_var(in), log;
struct vhost_log *vq_log;
struct msghdr msg = {
@@ -409,43 +565,53 @@ static void handle_rx_mergeable(struct vhost_net *net)
.msg_iov = vq->iov,
.msg_flags = MSG_DONTWAIT,
};
-
struct virtio_net_hdr_mrg_rxbuf hdr = {
.hdr.flags = 0,
.hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
};
-
size_t total_len = 0;
- int err, headcount;
+ int err, mergeable;
+ s16 headcount;
size_t vhost_hlen, sock_hlen;
size_t vhost_len, sock_len;
- struct socket *sock = rcu_dereference(vq->private_data);
- if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue))
- return;
+ struct socket *sock;
- use_mm(net->dev.mm);
mutex_lock(&vq->mutex);
- vhost_disable_notify(vq);
- vhost_hlen = vq->vhost_hlen;
- sock_hlen = vq->sock_hlen;
+ sock = vq->private_data;
+ if (!sock)
+ goto out;
+ vhost_disable_notify(&net->dev, vq);
- vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
+ vhost_hlen = nvq->vhost_hlen;
+ sock_hlen = nvq->sock_hlen;
+
+ vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
vq->log : NULL;
+ mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
while ((sock_len = peek_head_len(sock->sk))) {
sock_len += sock_hlen;
vhost_len = sock_len + vhost_hlen;
headcount = get_rx_bufs(vq, vq->heads, vhost_len,
- &in, vq_log, &log);
+ &in, vq_log, &log,
+ likely(mergeable) ? UIO_MAXIOV : 1);
/* On error, stop handling until the next kick. */
if (unlikely(headcount < 0))
break;
+ /* On overrun, truncate and discard */
+ if (unlikely(headcount > UIO_MAXIOV)) {
+ msg.msg_iovlen = 1;
+ err = sock->ops->recvmsg(NULL, sock, &msg,
+ 1, MSG_DONTWAIT | MSG_TRUNC);
+ pr_debug("Discarded rx packet: len %zd\n", sock_len);
+ continue;
+ }
/* OK, now we need to know about added descriptors. */
if (!headcount) {
- if (unlikely(vhost_enable_notify(vq))) {
+ if (unlikely(vhost_enable_notify(&net->dev, vq))) {
/* They have slipped one in as we were
* doing that: check again. */
- vhost_disable_notify(vq);
+ vhost_disable_notify(&net->dev, vq);
continue;
}
/* Nothing new? Wait for eventfd to tell us
@@ -455,11 +621,11 @@ static void handle_rx_mergeable(struct vhost_net *net)
/* We don't need to be notified again. */
if (unlikely((vhost_hlen)))
/* Skip header. TODO: support TSO. */
- move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
+ move_iovec_hdr(vq->iov, nvq->hdr, vhost_hlen, in);
else
/* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
- * needed because sendmsg can modify msg_iov. */
- copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
+ * needed because recvmsg can modify msg_iov. */
+ copy_iovec_hdr(vq->iov, nvq->hdr, sock_hlen, in);
msg.msg_iovlen = in;
err = sock->ops->recvmsg(NULL, sock, &msg,
sock_len, MSG_DONTWAIT | MSG_TRUNC);
@@ -473,15 +639,15 @@ static void handle_rx_mergeable(struct vhost_net *net)
continue;
}
if (unlikely(vhost_hlen) &&
- memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
+ memcpy_toiovecend(nvq->hdr, (unsigned char *)&hdr, 0,
vhost_hlen)) {
vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
vq->iov->iov_base);
break;
}
/* TODO: Should check and handle checksum. */
- if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF) &&
- memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
+ if (likely(mergeable) &&
+ memcpy_toiovecend(nvq->hdr, (unsigned char *)&headcount,
offsetof(typeof(hdr), num_buffers),
sizeof hdr.num_buffers)) {
vq_err(vq, "Failed num_buffers write");
@@ -498,17 +664,8 @@ static void handle_rx_mergeable(struct vhost_net *net)
break;
}
}
-
+out:
mutex_unlock(&vq->mutex);
- unuse_mm(net->dev.mm);
-}
-
-static void handle_rx(struct vhost_net *net)
-{
- if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF))
- handle_rx_mergeable(net);
- else
- handle_rx_big(net);
}
static void handle_tx_kick(struct vhost_work *work)
@@ -545,25 +702,40 @@ static void handle_rx_net(struct vhost_work *work)
static int vhost_net_open(struct inode *inode, struct file *f)
{
- struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
+ struct vhost_net *n;
struct vhost_dev *dev;
- int r;
+ struct vhost_virtqueue **vqs;
+ int i;
- if (!n)
+ n = kmalloc(sizeof *n, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
+ if (!n) {
+ n = vmalloc(sizeof *n);
+ if (!n)
+ return -ENOMEM;
+ }
+ vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
+ if (!vqs) {
+ kvfree(n);
return -ENOMEM;
+ }
dev = &n->dev;
- n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
- n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
- r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX);
- if (r < 0) {
- kfree(n);
- return r;
+ vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
+ vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
+ n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
+ n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
+ for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
+ n->vqs[i].ubufs = NULL;
+ n->vqs[i].ubuf_info = NULL;
+ n->vqs[i].upend_idx = 0;
+ n->vqs[i].done_idx = 0;
+ n->vqs[i].vhost_hlen = 0;
+ n->vqs[i].sock_hlen = 0;
}
+ vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
- n->tx_poll_state = VHOST_NET_POLL_DISABLED;
f->private_data = n;
@@ -573,29 +745,27 @@ static int vhost_net_open(struct inode *inode, struct file *f)
static void vhost_net_disable_vq(struct vhost_net *n,
struct vhost_virtqueue *vq)
{
+ struct vhost_net_virtqueue *nvq =
+ container_of(vq, struct vhost_net_virtqueue, vq);
+ struct vhost_poll *poll = n->poll + (nvq - n->vqs);
if (!vq->private_data)
return;
- if (vq == n->vqs + VHOST_NET_VQ_TX) {
- tx_poll_stop(n);
- n->tx_poll_state = VHOST_NET_POLL_DISABLED;
- } else
- vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
+ vhost_poll_stop(poll);
}
-static void vhost_net_enable_vq(struct vhost_net *n,
+static int vhost_net_enable_vq(struct vhost_net *n,
struct vhost_virtqueue *vq)
{
+ struct vhost_net_virtqueue *nvq =
+ container_of(vq, struct vhost_net_virtqueue, vq);
+ struct vhost_poll *poll = n->poll + (nvq - n->vqs);
struct socket *sock;
- sock = rcu_dereference_protected(vq->private_data,
- lockdep_is_held(&vq->mutex));
+ sock = vq->private_data;
if (!sock)
- return;
- if (vq == n->vqs + VHOST_NET_VQ_TX) {
- n->tx_poll_state = VHOST_NET_POLL_STOPPED;
- tx_poll_start(n, sock);
- } else
- vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
+ return 0;
+
+ return vhost_poll_start(poll, sock->file);
}
static struct socket *vhost_net_stop_vq(struct vhost_net *n,
@@ -604,10 +774,9 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
struct socket *sock;
mutex_lock(&vq->mutex);
- sock = rcu_dereference_protected(vq->private_data,
- lockdep_is_held(&vq->mutex));
+ sock = vq->private_data;
vhost_net_disable_vq(n, vq);
- rcu_assign_pointer(vq->private_data, NULL);
+ vq->private_data = NULL;
mutex_unlock(&vq->mutex);
return sock;
}
@@ -615,20 +784,31 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
struct socket **rx_sock)
{
- *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
- *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
+ *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
+ *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
}
static void vhost_net_flush_vq(struct vhost_net *n, int index)
{
vhost_poll_flush(n->poll + index);
- vhost_poll_flush(&n->dev.vqs[index].poll);
+ vhost_poll_flush(&n->vqs[index].vq.poll);
}
static void vhost_net_flush(struct vhost_net *n)
{
vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
+ if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
+ mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
+ n->tx_flush = true;
+ mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
+ /* Wait for all lower device DMAs done. */
+ vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
+ mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
+ n->tx_flush = false;
+ atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
+ mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
+ }
}
static int vhost_net_release(struct inode *inode, struct file *f)
@@ -639,15 +819,20 @@ static int vhost_net_release(struct inode *inode, struct file *f)
vhost_net_stop(n, &tx_sock, &rx_sock);
vhost_net_flush(n);
- vhost_dev_cleanup(&n->dev);
+ vhost_dev_stop(&n->dev);
+ vhost_dev_cleanup(&n->dev, false);
+ vhost_net_vq_reset(n);
if (tx_sock)
- fput(tx_sock->file);
+ sockfd_put(tx_sock);
if (rx_sock)
- fput(rx_sock->file);
+ sockfd_put(rx_sock);
+ /* Make sure no callbacks are outstanding */
+ synchronize_rcu_bh();
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
vhost_net_flush(n);
- kfree(n);
+ kfree(n->dev.vqs);
+ kvfree(n);
return 0;
}
@@ -659,6 +844,7 @@ static struct socket *get_raw_socket(int fd)
} uaddr;
int uaddr_len = sizeof uaddr, r;
struct socket *sock = sockfd_lookup(fd, &r);
+
if (!sock)
return ERR_PTR(-ENOTSOCK);
@@ -679,7 +865,7 @@ static struct socket *get_raw_socket(int fd)
}
return sock;
err:
- fput(sock->file);
+ sockfd_put(sock);
return ERR_PTR(r);
}
@@ -687,6 +873,7 @@ static struct socket *get_tap_socket(int fd)
{
struct file *file = fget(fd);
struct socket *sock;
+
if (!file)
return ERR_PTR(-EBADF);
sock = tun_get_socket(file);
@@ -701,6 +888,7 @@ static struct socket *get_tap_socket(int fd)
static struct socket *get_socket(int fd)
{
struct socket *sock;
+
/* special case to disable backend */
if (fd == -1)
return NULL;
@@ -717,6 +905,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
{
struct socket *sock, *oldsock;
struct vhost_virtqueue *vq;
+ struct vhost_net_virtqueue *nvq;
+ struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL;
int r;
mutex_lock(&n->dev.mutex);
@@ -728,7 +918,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
r = -ENOBUFS;
goto err;
}
- vq = n->vqs + index;
+ vq = &n->vqs[index].vq;
+ nvq = &n->vqs[index];
mutex_lock(&vq->mutex);
/* Verify that ring has been setup correctly. */
@@ -743,24 +934,56 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
}
/* start polling new socket */
- oldsock = rcu_dereference_protected(vq->private_data,
- lockdep_is_held(&vq->mutex));
+ oldsock = vq->private_data;
if (sock != oldsock) {
- vhost_net_disable_vq(n, vq);
- rcu_assign_pointer(vq->private_data, sock);
- vhost_net_enable_vq(n, vq);
+ ubufs = vhost_net_ubuf_alloc(vq,
+ sock && vhost_sock_zcopy(sock));
+ if (IS_ERR(ubufs)) {
+ r = PTR_ERR(ubufs);
+ goto err_ubufs;
+ }
+
+ vhost_net_disable_vq(n, vq);
+ vq->private_data = sock;
+ r = vhost_init_used(vq);
+ if (r)
+ goto err_used;
+ r = vhost_net_enable_vq(n, vq);
+ if (r)
+ goto err_used;
+
+ oldubufs = nvq->ubufs;
+ nvq->ubufs = ubufs;
+
+ n->tx_packets = 0;
+ n->tx_zcopy_err = 0;
+ n->tx_flush = false;
}
mutex_unlock(&vq->mutex);
+ if (oldubufs) {
+ vhost_net_ubuf_put_wait_and_free(oldubufs);
+ mutex_lock(&vq->mutex);
+ vhost_zerocopy_signal_used(n, vq);
+ mutex_unlock(&vq->mutex);
+ }
+
if (oldsock) {
vhost_net_flush_vq(n, index);
- fput(oldsock->file);
+ sockfd_put(oldsock);
}
mutex_unlock(&n->dev.mutex);
return 0;
+err_used:
+ vq->private_data = oldsock;
+ vhost_net_enable_vq(n, vq);
+ if (ubufs)
+ vhost_net_ubuf_put_wait_and_free(ubufs);
+err_ubufs:
+ sockfd_put(sock);
err_vq:
mutex_unlock(&vq->mutex);
err:
@@ -773,19 +996,27 @@ static long vhost_net_reset_owner(struct vhost_net *n)
struct socket *tx_sock = NULL;
struct socket *rx_sock = NULL;
long err;
+ struct vhost_memory *memory;
+
mutex_lock(&n->dev.mutex);
err = vhost_dev_check_owner(&n->dev);
if (err)
goto done;
+ memory = vhost_dev_reset_owner_prepare();
+ if (!memory) {
+ err = -ENOMEM;
+ goto done;
+ }
vhost_net_stop(n, &tx_sock, &rx_sock);
vhost_net_flush(n);
- err = vhost_dev_reset_owner(&n->dev);
+ vhost_dev_reset_owner(&n->dev, memory);
+ vhost_net_vq_reset(n);
done:
mutex_unlock(&n->dev.mutex);
if (tx_sock)
- fput(tx_sock->file);
+ sockfd_put(tx_sock);
if (rx_sock)
- fput(rx_sock->file);
+ sockfd_put(rx_sock);
return err;
}
@@ -812,19 +1043,38 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
mutex_unlock(&n->dev.mutex);
return -EFAULT;
}
- n->dev.acked_features = features;
- smp_wmb();
for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
- mutex_lock(&n->vqs[i].mutex);
+ mutex_lock(&n->vqs[i].vq.mutex);
+ n->vqs[i].vq.acked_features = features;
n->vqs[i].vhost_hlen = vhost_hlen;
n->vqs[i].sock_hlen = sock_hlen;
- mutex_unlock(&n->vqs[i].mutex);
+ mutex_unlock(&n->vqs[i].vq.mutex);
}
- vhost_net_flush(n);
mutex_unlock(&n->dev.mutex);
return 0;
}
+static long vhost_net_set_owner(struct vhost_net *n)
+{
+ int r;
+
+ mutex_lock(&n->dev.mutex);
+ if (vhost_dev_has_owner(&n->dev)) {
+ r = -EBUSY;
+ goto out;
+ }
+ r = vhost_net_set_ubuf_info(n);
+ if (r)
+ goto out;
+ r = vhost_dev_set_owner(&n->dev);
+ if (r)
+ vhost_net_clear_ubuf_info(n);
+ vhost_net_flush(n);
+out:
+ mutex_unlock(&n->dev.mutex);
+ return r;
+}
+
static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
unsigned long arg)
{
@@ -834,28 +1084,34 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
struct vhost_vring_file backend;
u64 features;
int r;
+
switch (ioctl) {
case VHOST_NET_SET_BACKEND:
if (copy_from_user(&backend, argp, sizeof backend))
return -EFAULT;
return vhost_net_set_backend(n, backend.index, backend.fd);
case VHOST_GET_FEATURES:
- features = VHOST_FEATURES;
+ features = VHOST_NET_FEATURES;
if (copy_to_user(featurep, &features, sizeof features))
return -EFAULT;
return 0;
case VHOST_SET_FEATURES:
if (copy_from_user(&features, featurep, sizeof features))
return -EFAULT;
- if (features & ~VHOST_FEATURES)
+ if (features & ~VHOST_NET_FEATURES)
return -EOPNOTSUPP;
return vhost_net_set_features(n, features);
case VHOST_RESET_OWNER:
return vhost_net_reset_owner(n);
+ case VHOST_SET_OWNER:
+ return vhost_net_set_owner(n);
default:
mutex_lock(&n->dev.mutex);
- r = vhost_dev_ioctl(&n->dev, ioctl, arg);
- vhost_net_flush(n);
+ r = vhost_dev_ioctl(&n->dev, ioctl, argp);
+ if (r == -ENOIOCTLCMD)
+ r = vhost_vring_ioctl(&n->dev, ioctl, argp);
+ else
+ vhost_net_flush(n);
mutex_unlock(&n->dev.mutex);
return r;
}
@@ -881,13 +1137,15 @@ static const struct file_operations vhost_net_fops = {
};
static struct miscdevice vhost_net_misc = {
- MISC_DYNAMIC_MINOR,
- "vhost-net",
- &vhost_net_fops,
+ .minor = VHOST_NET_MINOR,
+ .name = "vhost-net",
+ .fops = &vhost_net_fops,
};
static int vhost_net_init(void)
{
+ if (experimental_zcopytx)
+ vhost_net_enable_zcopy(VHOST_NET_VQ_TX);
return misc_register(&vhost_net_misc);
}
module_init(vhost_net_init);
@@ -902,3 +1160,5 @@ MODULE_VERSION("0.0.1");
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Michael S. Tsirkin");
MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
+MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR);
+MODULE_ALIAS("devname:vhost-net");
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
new file mode 100644
index 00000000000..69906cacd04
--- /dev/null
+++ b/drivers/vhost/scsi.c
@@ -0,0 +1,2367 @@
+/*******************************************************************************
+ * Vhost kernel TCM fabric driver for virtio SCSI initiators
+ *
+ * (C) Copyright 2010-2013 Datera, Inc.
+ * (C) Copyright 2010-2012 IBM Corp.
+ *
+ * Licensed to the Linux Foundation under the General Public License (GPL) version 2.
+ *
+ * Authors: Nicholas A. Bellinger <nab@daterainc.com>
+ * Stefan Hajnoczi <stefanha@linux.vnet.ibm.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ ****************************************************************************/
+
+#include <linux/module.h>
+#include <linux/moduleparam.h>
+#include <generated/utsrelease.h>
+#include <linux/utsname.h>
+#include <linux/init.h>
+#include <linux/slab.h>
+#include <linux/kthread.h>
+#include <linux/types.h>
+#include <linux/string.h>
+#include <linux/configfs.h>
+#include <linux/ctype.h>
+#include <linux/compat.h>
+#include <linux/eventfd.h>
+#include <linux/fs.h>
+#include <linux/miscdevice.h>
+#include <asm/unaligned.h>
+#include <scsi/scsi.h>
+#include <scsi/scsi_tcq.h>
+#include <target/target_core_base.h>
+#include <target/target_core_fabric.h>
+#include <target/target_core_fabric_configfs.h>
+#include <target/target_core_configfs.h>
+#include <target/configfs_macros.h>
+#include <linux/vhost.h>
+#include <linux/virtio_scsi.h>
+#include <linux/llist.h>
+#include <linux/bitmap.h>
+#include <linux/percpu_ida.h>
+
+#include "vhost.h"
+
+#define TCM_VHOST_VERSION "v0.1"
+#define TCM_VHOST_NAMELEN 256
+#define TCM_VHOST_MAX_CDB_SIZE 32
+#define TCM_VHOST_DEFAULT_TAGS 256
+#define TCM_VHOST_PREALLOC_SGLS 2048
+#define TCM_VHOST_PREALLOC_UPAGES 2048
+#define TCM_VHOST_PREALLOC_PROT_SGLS 512
+
+struct vhost_scsi_inflight {
+ /* Wait for the flush operation to finish */
+ struct completion comp;
+ /* Refcount for the inflight reqs */
+ struct kref kref;
+};
+
+struct tcm_vhost_cmd {
+ /* Descriptor from vhost_get_vq_desc() for virt_queue segment */
+ int tvc_vq_desc;
+ /* virtio-scsi initiator task attribute */
+ int tvc_task_attr;
+ /* virtio-scsi initiator data direction */
+ enum dma_data_direction tvc_data_direction;
+ /* Expected data transfer length from virtio-scsi header */
+ u32 tvc_exp_data_len;
+ /* The Tag from include/linux/virtio_scsi.h:struct virtio_scsi_cmd_req */
+ u64 tvc_tag;
+ /* The number of scatterlists associated with this cmd */
+ u32 tvc_sgl_count;
+ u32 tvc_prot_sgl_count;
+ /* Saved unpacked SCSI LUN for tcm_vhost_submission_work() */
+ u32 tvc_lun;
+ /* Pointer to the SGL formatted memory from virtio-scsi */
+ struct scatterlist *tvc_sgl;
+ struct scatterlist *tvc_prot_sgl;
+ struct page **tvc_upages;
+ /* Pointer to response */
+ struct virtio_scsi_cmd_resp __user *tvc_resp;
+ /* Pointer to vhost_scsi for our device */
+ struct vhost_scsi *tvc_vhost;
+ /* Pointer to vhost_virtqueue for the cmd */
+ struct vhost_virtqueue *tvc_vq;
+ /* Pointer to vhost nexus memory */
+ struct tcm_vhost_nexus *tvc_nexus;
+ /* The TCM I/O descriptor that is accessed via container_of() */
+ struct se_cmd tvc_se_cmd;
+ /* work item used for cmwq dispatch to tcm_vhost_submission_work() */
+ struct work_struct work;
+ /* Copy of the incoming SCSI command descriptor block (CDB) */
+ unsigned char tvc_cdb[TCM_VHOST_MAX_CDB_SIZE];
+ /* Sense buffer that will be mapped into outgoing status */
+ unsigned char tvc_sense_buf[TRANSPORT_SENSE_BUFFER];
+ /* Completed commands list, serviced from vhost worker thread */
+ struct llist_node tvc_completion_list;
+ /* Used to track inflight cmd */
+ struct vhost_scsi_inflight *inflight;
+};
+
+struct tcm_vhost_nexus {
+ /* Pointer to TCM session for I_T Nexus */
+ struct se_session *tvn_se_sess;
+};
+
+struct tcm_vhost_nacl {
+ /* Binary World Wide unique Port Name for Vhost Initiator port */
+ u64 iport_wwpn;
+ /* ASCII formatted WWPN for Sas Initiator port */
+ char iport_name[TCM_VHOST_NAMELEN];
+ /* Returned by tcm_vhost_make_nodeacl() */
+ struct se_node_acl se_node_acl;
+};
+
+struct tcm_vhost_tpg {
+ /* Vhost port target portal group tag for TCM */
+ u16 tport_tpgt;
+ /* Used to track number of TPG Port/Lun Links wrt to explict I_T Nexus shutdown */
+ int tv_tpg_port_count;
+ /* Used for vhost_scsi device reference to tpg_nexus, protected by tv_tpg_mutex */
+ int tv_tpg_vhost_count;
+ /* list for tcm_vhost_list */
+ struct list_head tv_tpg_list;
+ /* Used to protect access for tpg_nexus */
+ struct mutex tv_tpg_mutex;
+ /* Pointer to the TCM VHost I_T Nexus for this TPG endpoint */
+ struct tcm_vhost_nexus *tpg_nexus;
+ /* Pointer back to tcm_vhost_tport */
+ struct tcm_vhost_tport *tport;
+ /* Returned by tcm_vhost_make_tpg() */
+ struct se_portal_group se_tpg;
+ /* Pointer back to vhost_scsi, protected by tv_tpg_mutex */
+ struct vhost_scsi *vhost_scsi;
+};
+
+struct tcm_vhost_tport {
+ /* SCSI protocol the tport is providing */
+ u8 tport_proto_id;
+ /* Binary World Wide unique Port Name for Vhost Target port */
+ u64 tport_wwpn;
+ /* ASCII formatted WWPN for Vhost Target port */
+ char tport_name[TCM_VHOST_NAMELEN];
+ /* Returned by tcm_vhost_make_tport() */
+ struct se_wwn tport_wwn;
+};
+
+struct tcm_vhost_evt {
+ /* event to be sent to guest */
+ struct virtio_scsi_event event;
+ /* event list, serviced from vhost worker thread */
+ struct llist_node list;
+};
+
+enum {
+ VHOST_SCSI_VQ_CTL = 0,
+ VHOST_SCSI_VQ_EVT = 1,
+ VHOST_SCSI_VQ_IO = 2,
+};
+
+enum {
+ VHOST_SCSI_FEATURES = VHOST_FEATURES | (1ULL << VIRTIO_SCSI_F_HOTPLUG) |
+ (1ULL << VIRTIO_SCSI_F_T10_PI)
+};
+
+#define VHOST_SCSI_MAX_TARGET 256
+#define VHOST_SCSI_MAX_VQ 128
+#define VHOST_SCSI_MAX_EVENT 128
+
+struct vhost_scsi_virtqueue {
+ struct vhost_virtqueue vq;
+ /*
+ * Reference counting for inflight reqs, used for flush operation. At
+ * each time, one reference tracks new commands submitted, while we
+ * wait for another one to reach 0.
+ */
+ struct vhost_scsi_inflight inflights[2];
+ /*
+ * Indicate current inflight in use, protected by vq->mutex.
+ * Writers must also take dev mutex and flush under it.
+ */
+ int inflight_idx;
+};
+
+struct vhost_scsi {
+ /* Protected by vhost_scsi->dev.mutex */
+ struct tcm_vhost_tpg **vs_tpg;
+ char vs_vhost_wwpn[TRANSPORT_IQN_LEN];
+
+ struct vhost_dev dev;
+ struct vhost_scsi_virtqueue vqs[VHOST_SCSI_MAX_VQ];
+
+ struct vhost_work vs_completion_work; /* cmd completion work item */
+ struct llist_head vs_completion_list; /* cmd completion queue */
+
+ struct vhost_work vs_event_work; /* evt injection work item */
+ struct llist_head vs_event_list; /* evt injection queue */
+
+ bool vs_events_missed; /* any missed events, protected by vq->mutex */
+ int vs_events_nr; /* num of pending events, protected by vq->mutex */
+};
+
+/* Local pointer to allocated TCM configfs fabric module */
+static struct target_fabric_configfs *tcm_vhost_fabric_configfs;
+
+static struct workqueue_struct *tcm_vhost_workqueue;
+
+/* Global spinlock to protect tcm_vhost TPG list for vhost IOCTL access */
+static DEFINE_MUTEX(tcm_vhost_mutex);
+static LIST_HEAD(tcm_vhost_list);
+
+static int iov_num_pages(struct iovec *iov)
+{
+ return (PAGE_ALIGN((unsigned long)iov->iov_base + iov->iov_len) -
+ ((unsigned long)iov->iov_base & PAGE_MASK)) >> PAGE_SHIFT;
+}
+
+static void tcm_vhost_done_inflight(struct kref *kref)
+{
+ struct vhost_scsi_inflight *inflight;
+
+ inflight = container_of(kref, struct vhost_scsi_inflight, kref);
+ complete(&inflight->comp);
+}
+
+static void tcm_vhost_init_inflight(struct vhost_scsi *vs,
+ struct vhost_scsi_inflight *old_inflight[])
+{
+ struct vhost_scsi_inflight *new_inflight;
+ struct vhost_virtqueue *vq;
+ int idx, i;
+
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
+ vq = &vs->vqs[i].vq;
+
+ mutex_lock(&vq->mutex);
+
+ /* store old infight */
+ idx = vs->vqs[i].inflight_idx;
+ if (old_inflight)
+ old_inflight[i] = &vs->vqs[i].inflights[idx];
+
+ /* setup new infight */
+ vs->vqs[i].inflight_idx = idx ^ 1;
+ new_inflight = &vs->vqs[i].inflights[idx ^ 1];
+ kref_init(&new_inflight->kref);
+ init_completion(&new_inflight->comp);
+
+ mutex_unlock(&vq->mutex);
+ }
+}
+
+static struct vhost_scsi_inflight *
+tcm_vhost_get_inflight(struct vhost_virtqueue *vq)
+{
+ struct vhost_scsi_inflight *inflight;
+ struct vhost_scsi_virtqueue *svq;
+
+ svq = container_of(vq, struct vhost_scsi_virtqueue, vq);
+ inflight = &svq->inflights[svq->inflight_idx];
+ kref_get(&inflight->kref);
+
+ return inflight;
+}
+
+static void tcm_vhost_put_inflight(struct vhost_scsi_inflight *inflight)
+{
+ kref_put(&inflight->kref, tcm_vhost_done_inflight);
+}
+
+static int tcm_vhost_check_true(struct se_portal_group *se_tpg)
+{
+ return 1;
+}
+
+static int tcm_vhost_check_false(struct se_portal_group *se_tpg)
+{
+ return 0;
+}
+
+static char *tcm_vhost_get_fabric_name(void)
+{
+ return "vhost";
+}
+
+static u8 tcm_vhost_get_fabric_proto_ident(struct se_portal_group *se_tpg)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ struct tcm_vhost_tport *tport = tpg->tport;
+
+ switch (tport->tport_proto_id) {
+ case SCSI_PROTOCOL_SAS:
+ return sas_get_fabric_proto_ident(se_tpg);
+ case SCSI_PROTOCOL_FCP:
+ return fc_get_fabric_proto_ident(se_tpg);
+ case SCSI_PROTOCOL_ISCSI:
+ return iscsi_get_fabric_proto_ident(se_tpg);
+ default:
+ pr_err("Unknown tport_proto_id: 0x%02x, using"
+ " SAS emulation\n", tport->tport_proto_id);
+ break;
+ }
+
+ return sas_get_fabric_proto_ident(se_tpg);
+}
+
+static char *tcm_vhost_get_fabric_wwn(struct se_portal_group *se_tpg)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ struct tcm_vhost_tport *tport = tpg->tport;
+
+ return &tport->tport_name[0];
+}
+
+static u16 tcm_vhost_get_tag(struct se_portal_group *se_tpg)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ return tpg->tport_tpgt;
+}
+
+static u32 tcm_vhost_get_default_depth(struct se_portal_group *se_tpg)
+{
+ return 1;
+}
+
+static u32
+tcm_vhost_get_pr_transport_id(struct se_portal_group *se_tpg,
+ struct se_node_acl *se_nacl,
+ struct t10_pr_registration *pr_reg,
+ int *format_code,
+ unsigned char *buf)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ struct tcm_vhost_tport *tport = tpg->tport;
+
+ switch (tport->tport_proto_id) {
+ case SCSI_PROTOCOL_SAS:
+ return sas_get_pr_transport_id(se_tpg, se_nacl, pr_reg,
+ format_code, buf);
+ case SCSI_PROTOCOL_FCP:
+ return fc_get_pr_transport_id(se_tpg, se_nacl, pr_reg,
+ format_code, buf);
+ case SCSI_PROTOCOL_ISCSI:
+ return iscsi_get_pr_transport_id(se_tpg, se_nacl, pr_reg,
+ format_code, buf);
+ default:
+ pr_err("Unknown tport_proto_id: 0x%02x, using"
+ " SAS emulation\n", tport->tport_proto_id);
+ break;
+ }
+
+ return sas_get_pr_transport_id(se_tpg, se_nacl, pr_reg,
+ format_code, buf);
+}
+
+static u32
+tcm_vhost_get_pr_transport_id_len(struct se_portal_group *se_tpg,
+ struct se_node_acl *se_nacl,
+ struct t10_pr_registration *pr_reg,
+ int *format_code)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ struct tcm_vhost_tport *tport = tpg->tport;
+
+ switch (tport->tport_proto_id) {
+ case SCSI_PROTOCOL_SAS:
+ return sas_get_pr_transport_id_len(se_tpg, se_nacl, pr_reg,
+ format_code);
+ case SCSI_PROTOCOL_FCP:
+ return fc_get_pr_transport_id_len(se_tpg, se_nacl, pr_reg,
+ format_code);
+ case SCSI_PROTOCOL_ISCSI:
+ return iscsi_get_pr_transport_id_len(se_tpg, se_nacl, pr_reg,
+ format_code);
+ default:
+ pr_err("Unknown tport_proto_id: 0x%02x, using"
+ " SAS emulation\n", tport->tport_proto_id);
+ break;
+ }
+
+ return sas_get_pr_transport_id_len(se_tpg, se_nacl, pr_reg,
+ format_code);
+}
+
+static char *
+tcm_vhost_parse_pr_out_transport_id(struct se_portal_group *se_tpg,
+ const char *buf,
+ u32 *out_tid_len,
+ char **port_nexus_ptr)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ struct tcm_vhost_tport *tport = tpg->tport;
+
+ switch (tport->tport_proto_id) {
+ case SCSI_PROTOCOL_SAS:
+ return sas_parse_pr_out_transport_id(se_tpg, buf, out_tid_len,
+ port_nexus_ptr);
+ case SCSI_PROTOCOL_FCP:
+ return fc_parse_pr_out_transport_id(se_tpg, buf, out_tid_len,
+ port_nexus_ptr);
+ case SCSI_PROTOCOL_ISCSI:
+ return iscsi_parse_pr_out_transport_id(se_tpg, buf, out_tid_len,
+ port_nexus_ptr);
+ default:
+ pr_err("Unknown tport_proto_id: 0x%02x, using"
+ " SAS emulation\n", tport->tport_proto_id);
+ break;
+ }
+
+ return sas_parse_pr_out_transport_id(se_tpg, buf, out_tid_len,
+ port_nexus_ptr);
+}
+
+static struct se_node_acl *
+tcm_vhost_alloc_fabric_acl(struct se_portal_group *se_tpg)
+{
+ struct tcm_vhost_nacl *nacl;
+
+ nacl = kzalloc(sizeof(struct tcm_vhost_nacl), GFP_KERNEL);
+ if (!nacl) {
+ pr_err("Unable to allocate struct tcm_vhost_nacl\n");
+ return NULL;
+ }
+
+ return &nacl->se_node_acl;
+}
+
+static void
+tcm_vhost_release_fabric_acl(struct se_portal_group *se_tpg,
+ struct se_node_acl *se_nacl)
+{
+ struct tcm_vhost_nacl *nacl = container_of(se_nacl,
+ struct tcm_vhost_nacl, se_node_acl);
+ kfree(nacl);
+}
+
+static u32 tcm_vhost_tpg_get_inst_index(struct se_portal_group *se_tpg)
+{
+ return 1;
+}
+
+static void tcm_vhost_release_cmd(struct se_cmd *se_cmd)
+{
+ struct tcm_vhost_cmd *tv_cmd = container_of(se_cmd,
+ struct tcm_vhost_cmd, tvc_se_cmd);
+ struct se_session *se_sess = se_cmd->se_sess;
+ int i;
+
+ if (tv_cmd->tvc_sgl_count) {
+ for (i = 0; i < tv_cmd->tvc_sgl_count; i++)
+ put_page(sg_page(&tv_cmd->tvc_sgl[i]));
+ }
+ if (tv_cmd->tvc_prot_sgl_count) {
+ for (i = 0; i < tv_cmd->tvc_prot_sgl_count; i++)
+ put_page(sg_page(&tv_cmd->tvc_prot_sgl[i]));
+ }
+
+ tcm_vhost_put_inflight(tv_cmd->inflight);
+ percpu_ida_free(&se_sess->sess_tag_pool, se_cmd->map_tag);
+}
+
+static int tcm_vhost_shutdown_session(struct se_session *se_sess)
+{
+ return 0;
+}
+
+static void tcm_vhost_close_session(struct se_session *se_sess)
+{
+ return;
+}
+
+static u32 tcm_vhost_sess_get_index(struct se_session *se_sess)
+{
+ return 0;
+}
+
+static int tcm_vhost_write_pending(struct se_cmd *se_cmd)
+{
+ /* Go ahead and process the write immediately */
+ target_execute_cmd(se_cmd);
+ return 0;
+}
+
+static int tcm_vhost_write_pending_status(struct se_cmd *se_cmd)
+{
+ return 0;
+}
+
+static void tcm_vhost_set_default_node_attrs(struct se_node_acl *nacl)
+{
+ return;
+}
+
+static u32 tcm_vhost_get_task_tag(struct se_cmd *se_cmd)
+{
+ return 0;
+}
+
+static int tcm_vhost_get_cmd_state(struct se_cmd *se_cmd)
+{
+ return 0;
+}
+
+static void vhost_scsi_complete_cmd(struct tcm_vhost_cmd *cmd)
+{
+ struct vhost_scsi *vs = cmd->tvc_vhost;
+
+ llist_add(&cmd->tvc_completion_list, &vs->vs_completion_list);
+
+ vhost_work_queue(&vs->dev, &vs->vs_completion_work);
+}
+
+static int tcm_vhost_queue_data_in(struct se_cmd *se_cmd)
+{
+ struct tcm_vhost_cmd *cmd = container_of(se_cmd,
+ struct tcm_vhost_cmd, tvc_se_cmd);
+ vhost_scsi_complete_cmd(cmd);
+ return 0;
+}
+
+static int tcm_vhost_queue_status(struct se_cmd *se_cmd)
+{
+ struct tcm_vhost_cmd *cmd = container_of(se_cmd,
+ struct tcm_vhost_cmd, tvc_se_cmd);
+ vhost_scsi_complete_cmd(cmd);
+ return 0;
+}
+
+static void tcm_vhost_queue_tm_rsp(struct se_cmd *se_cmd)
+{
+ return;
+}
+
+static void tcm_vhost_aborted_task(struct se_cmd *se_cmd)
+{
+ return;
+}
+
+static void tcm_vhost_free_evt(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)
+{
+ vs->vs_events_nr--;
+ kfree(evt);
+}
+
+static struct tcm_vhost_evt *
+tcm_vhost_allocate_evt(struct vhost_scsi *vs,
+ u32 event, u32 reason)
+{
+ struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
+ struct tcm_vhost_evt *evt;
+
+ if (vs->vs_events_nr > VHOST_SCSI_MAX_EVENT) {
+ vs->vs_events_missed = true;
+ return NULL;
+ }
+
+ evt = kzalloc(sizeof(*evt), GFP_KERNEL);
+ if (!evt) {
+ vq_err(vq, "Failed to allocate tcm_vhost_evt\n");
+ vs->vs_events_missed = true;
+ return NULL;
+ }
+
+ evt->event.event = event;
+ evt->event.reason = reason;
+ vs->vs_events_nr++;
+
+ return evt;
+}
+
+static void vhost_scsi_free_cmd(struct tcm_vhost_cmd *cmd)
+{
+ struct se_cmd *se_cmd = &cmd->tvc_se_cmd;
+
+ /* TODO locking against target/backend threads? */
+ transport_generic_free_cmd(se_cmd, 0);
+
+}
+
+static int vhost_scsi_check_stop_free(struct se_cmd *se_cmd)
+{
+ return target_put_sess_cmd(se_cmd->se_sess, se_cmd);
+}
+
+static void
+tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)
+{
+ struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
+ struct virtio_scsi_event *event = &evt->event;
+ struct virtio_scsi_event __user *eventp;
+ unsigned out, in;
+ int head, ret;
+
+ if (!vq->private_data) {
+ vs->vs_events_missed = true;
+ return;
+ }
+
+again:
+ vhost_disable_notify(&vs->dev, vq);
+ head = vhost_get_vq_desc(vq, vq->iov,
+ ARRAY_SIZE(vq->iov), &out, &in,
+ NULL, NULL);
+ if (head < 0) {
+ vs->vs_events_missed = true;
+ return;
+ }
+ if (head == vq->num) {
+ if (vhost_enable_notify(&vs->dev, vq))
+ goto again;
+ vs->vs_events_missed = true;
+ return;
+ }
+
+ if ((vq->iov[out].iov_len != sizeof(struct virtio_scsi_event))) {
+ vq_err(vq, "Expecting virtio_scsi_event, got %zu bytes\n",
+ vq->iov[out].iov_len);
+ vs->vs_events_missed = true;
+ return;
+ }
+
+ if (vs->vs_events_missed) {
+ event->event |= VIRTIO_SCSI_T_EVENTS_MISSED;
+ vs->vs_events_missed = false;
+ }
+
+ eventp = vq->iov[out].iov_base;
+ ret = __copy_to_user(eventp, event, sizeof(*event));
+ if (!ret)
+ vhost_add_used_and_signal(&vs->dev, vq, head, 0);
+ else
+ vq_err(vq, "Faulted on tcm_vhost_send_event\n");
+}
+
+static void tcm_vhost_evt_work(struct vhost_work *work)
+{
+ struct vhost_scsi *vs = container_of(work, struct vhost_scsi,
+ vs_event_work);
+ struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
+ struct tcm_vhost_evt *evt;
+ struct llist_node *llnode;
+
+ mutex_lock(&vq->mutex);
+ llnode = llist_del_all(&vs->vs_event_list);
+ while (llnode) {
+ evt = llist_entry(llnode, struct tcm_vhost_evt, list);
+ llnode = llist_next(llnode);
+ tcm_vhost_do_evt_work(vs, evt);
+ tcm_vhost_free_evt(vs, evt);
+ }
+ mutex_unlock(&vq->mutex);
+}
+
+/* Fill in status and signal that we are done processing this command
+ *
+ * This is scheduled in the vhost work queue so we are called with the owner
+ * process mm and can access the vring.
+ */
+static void vhost_scsi_complete_cmd_work(struct vhost_work *work)
+{
+ struct vhost_scsi *vs = container_of(work, struct vhost_scsi,
+ vs_completion_work);
+ DECLARE_BITMAP(signal, VHOST_SCSI_MAX_VQ);
+ struct virtio_scsi_cmd_resp v_rsp;
+ struct tcm_vhost_cmd *cmd;
+ struct llist_node *llnode;
+ struct se_cmd *se_cmd;
+ int ret, vq;
+
+ bitmap_zero(signal, VHOST_SCSI_MAX_VQ);
+ llnode = llist_del_all(&vs->vs_completion_list);
+ while (llnode) {
+ cmd = llist_entry(llnode, struct tcm_vhost_cmd,
+ tvc_completion_list);
+ llnode = llist_next(llnode);
+ se_cmd = &cmd->tvc_se_cmd;
+
+ pr_debug("%s tv_cmd %p resid %u status %#02x\n", __func__,
+ cmd, se_cmd->residual_count, se_cmd->scsi_status);
+
+ memset(&v_rsp, 0, sizeof(v_rsp));
+ v_rsp.resid = se_cmd->residual_count;
+ /* TODO is status_qualifier field needed? */
+ v_rsp.status = se_cmd->scsi_status;
+ v_rsp.sense_len = se_cmd->scsi_sense_length;
+ memcpy(v_rsp.sense, cmd->tvc_sense_buf,
+ v_rsp.sense_len);
+ ret = copy_to_user(cmd->tvc_resp, &v_rsp, sizeof(v_rsp));
+ if (likely(ret == 0)) {
+ struct vhost_scsi_virtqueue *q;
+ vhost_add_used(cmd->tvc_vq, cmd->tvc_vq_desc, 0);
+ q = container_of(cmd->tvc_vq, struct vhost_scsi_virtqueue, vq);
+ vq = q - vs->vqs;
+ __set_bit(vq, signal);
+ } else
+ pr_err("Faulted on virtio_scsi_cmd_resp\n");
+
+ vhost_scsi_free_cmd(cmd);
+ }
+
+ vq = -1;
+ while ((vq = find_next_bit(signal, VHOST_SCSI_MAX_VQ, vq + 1))
+ < VHOST_SCSI_MAX_VQ)
+ vhost_signal(&vs->dev, &vs->vqs[vq].vq);
+}
+
+static struct tcm_vhost_cmd *
+vhost_scsi_get_tag(struct vhost_virtqueue *vq, struct tcm_vhost_tpg *tpg,
+ unsigned char *cdb, u64 scsi_tag, u16 lun, u8 task_attr,
+ u32 exp_data_len, int data_direction)
+{
+ struct tcm_vhost_cmd *cmd;
+ struct tcm_vhost_nexus *tv_nexus;
+ struct se_session *se_sess;
+ struct scatterlist *sg, *prot_sg;
+ struct page **pages;
+ int tag;
+
+ tv_nexus = tpg->tpg_nexus;
+ if (!tv_nexus) {
+ pr_err("Unable to locate active struct tcm_vhost_nexus\n");
+ return ERR_PTR(-EIO);
+ }
+ se_sess = tv_nexus->tvn_se_sess;
+
+ tag = percpu_ida_alloc(&se_sess->sess_tag_pool, TASK_RUNNING);
+ if (tag < 0) {
+ pr_err("Unable to obtain tag for tcm_vhost_cmd\n");
+ return ERR_PTR(-ENOMEM);
+ }
+
+ cmd = &((struct tcm_vhost_cmd *)se_sess->sess_cmd_map)[tag];
+ sg = cmd->tvc_sgl;
+ prot_sg = cmd->tvc_prot_sgl;
+ pages = cmd->tvc_upages;
+ memset(cmd, 0, sizeof(struct tcm_vhost_cmd));
+
+ cmd->tvc_sgl = sg;
+ cmd->tvc_prot_sgl = prot_sg;
+ cmd->tvc_upages = pages;
+ cmd->tvc_se_cmd.map_tag = tag;
+ cmd->tvc_tag = scsi_tag;
+ cmd->tvc_lun = lun;
+ cmd->tvc_task_attr = task_attr;
+ cmd->tvc_exp_data_len = exp_data_len;
+ cmd->tvc_data_direction = data_direction;
+ cmd->tvc_nexus = tv_nexus;
+ cmd->inflight = tcm_vhost_get_inflight(vq);
+
+ memcpy(cmd->tvc_cdb, cdb, TCM_VHOST_MAX_CDB_SIZE);
+
+ return cmd;
+}
+
+/*
+ * Map a user memory range into a scatterlist
+ *
+ * Returns the number of scatterlist entries used or -errno on error.
+ */
+static int
+vhost_scsi_map_to_sgl(struct tcm_vhost_cmd *tv_cmd,
+ struct scatterlist *sgl,
+ unsigned int sgl_count,
+ struct iovec *iov,
+ struct page **pages,
+ bool write)
+{
+ unsigned int npages = 0, pages_nr, offset, nbytes;
+ struct scatterlist *sg = sgl;
+ void __user *ptr = iov->iov_base;
+ size_t len = iov->iov_len;
+ int ret, i;
+
+ pages_nr = iov_num_pages(iov);
+ if (pages_nr > sgl_count) {
+ pr_err("vhost_scsi_map_to_sgl() pages_nr: %u greater than"
+ " sgl_count: %u\n", pages_nr, sgl_count);
+ return -ENOBUFS;
+ }
+ if (pages_nr > TCM_VHOST_PREALLOC_UPAGES) {
+ pr_err("vhost_scsi_map_to_sgl() pages_nr: %u greater than"
+ " preallocated TCM_VHOST_PREALLOC_UPAGES: %u\n",
+ pages_nr, TCM_VHOST_PREALLOC_UPAGES);
+ return -ENOBUFS;
+ }
+
+ ret = get_user_pages_fast((unsigned long)ptr, pages_nr, write, pages);
+ /* No pages were pinned */
+ if (ret < 0)
+ goto out;
+ /* Less pages pinned than wanted */
+ if (ret != pages_nr) {
+ for (i = 0; i < ret; i++)
+ put_page(pages[i]);
+ ret = -EFAULT;
+ goto out;
+ }
+
+ while (len > 0) {
+ offset = (uintptr_t)ptr & ~PAGE_MASK;
+ nbytes = min_t(unsigned int, PAGE_SIZE - offset, len);
+ sg_set_page(sg, pages[npages], nbytes, offset);
+ ptr += nbytes;
+ len -= nbytes;
+ sg++;
+ npages++;
+ }
+
+out:
+ return ret;
+}
+
+static int
+vhost_scsi_map_iov_to_sgl(struct tcm_vhost_cmd *cmd,
+ struct iovec *iov,
+ int niov,
+ bool write)
+{
+ struct scatterlist *sg = cmd->tvc_sgl;
+ unsigned int sgl_count = 0;
+ int ret, i;
+
+ for (i = 0; i < niov; i++)
+ sgl_count += iov_num_pages(&iov[i]);
+
+ if (sgl_count > TCM_VHOST_PREALLOC_SGLS) {
+ pr_err("vhost_scsi_map_iov_to_sgl() sgl_count: %u greater than"
+ " preallocated TCM_VHOST_PREALLOC_SGLS: %u\n",
+ sgl_count, TCM_VHOST_PREALLOC_SGLS);
+ return -ENOBUFS;
+ }
+
+ pr_debug("%s sg %p sgl_count %u\n", __func__, sg, sgl_count);
+ sg_init_table(sg, sgl_count);
+ cmd->tvc_sgl_count = sgl_count;
+
+ pr_debug("Mapping iovec %p for %u pages\n", &iov[0], sgl_count);
+
+ for (i = 0; i < niov; i++) {
+ ret = vhost_scsi_map_to_sgl(cmd, sg, sgl_count, &iov[i],
+ cmd->tvc_upages, write);
+ if (ret < 0) {
+ for (i = 0; i < cmd->tvc_sgl_count; i++)
+ put_page(sg_page(&cmd->tvc_sgl[i]));
+
+ cmd->tvc_sgl_count = 0;
+ return ret;
+ }
+ sg += ret;
+ sgl_count -= ret;
+ }
+ return 0;
+}
+
+static int
+vhost_scsi_map_iov_to_prot(struct tcm_vhost_cmd *cmd,
+ struct iovec *iov,
+ int niov,
+ bool write)
+{
+ struct scatterlist *prot_sg = cmd->tvc_prot_sgl;
+ unsigned int prot_sgl_count = 0;
+ int ret, i;
+
+ for (i = 0; i < niov; i++)
+ prot_sgl_count += iov_num_pages(&iov[i]);
+
+ if (prot_sgl_count > TCM_VHOST_PREALLOC_PROT_SGLS) {
+ pr_err("vhost_scsi_map_iov_to_prot() sgl_count: %u greater than"
+ " preallocated TCM_VHOST_PREALLOC_PROT_SGLS: %u\n",
+ prot_sgl_count, TCM_VHOST_PREALLOC_PROT_SGLS);
+ return -ENOBUFS;
+ }
+
+ pr_debug("%s prot_sg %p prot_sgl_count %u\n", __func__,
+ prot_sg, prot_sgl_count);
+ sg_init_table(prot_sg, prot_sgl_count);
+ cmd->tvc_prot_sgl_count = prot_sgl_count;
+
+ for (i = 0; i < niov; i++) {
+ ret = vhost_scsi_map_to_sgl(cmd, prot_sg, prot_sgl_count, &iov[i],
+ cmd->tvc_upages, write);
+ if (ret < 0) {
+ for (i = 0; i < cmd->tvc_prot_sgl_count; i++)
+ put_page(sg_page(&cmd->tvc_prot_sgl[i]));
+
+ cmd->tvc_prot_sgl_count = 0;
+ return ret;
+ }
+ prot_sg += ret;
+ prot_sgl_count -= ret;
+ }
+ return 0;
+}
+
+static void tcm_vhost_submission_work(struct work_struct *work)
+{
+ struct tcm_vhost_cmd *cmd =
+ container_of(work, struct tcm_vhost_cmd, work);
+ struct tcm_vhost_nexus *tv_nexus;
+ struct se_cmd *se_cmd = &cmd->tvc_se_cmd;
+ struct scatterlist *sg_ptr, *sg_prot_ptr = NULL;
+ int rc;
+
+ /* FIXME: BIDI operation */
+ if (cmd->tvc_sgl_count) {
+ sg_ptr = cmd->tvc_sgl;
+
+ if (cmd->tvc_prot_sgl_count)
+ sg_prot_ptr = cmd->tvc_prot_sgl;
+ else
+ se_cmd->prot_pto = true;
+ } else {
+ sg_ptr = NULL;
+ }
+ tv_nexus = cmd->tvc_nexus;
+
+ rc = target_submit_cmd_map_sgls(se_cmd, tv_nexus->tvn_se_sess,
+ cmd->tvc_cdb, &cmd->tvc_sense_buf[0],
+ cmd->tvc_lun, cmd->tvc_exp_data_len,
+ cmd->tvc_task_attr, cmd->tvc_data_direction,
+ TARGET_SCF_ACK_KREF, sg_ptr, cmd->tvc_sgl_count,
+ NULL, 0, sg_prot_ptr, cmd->tvc_prot_sgl_count);
+ if (rc < 0) {
+ transport_send_check_condition_and_sense(se_cmd,
+ TCM_LOGICAL_UNIT_COMMUNICATION_FAILURE, 0);
+ transport_generic_free_cmd(se_cmd, 0);
+ }
+}
+
+static void
+vhost_scsi_send_bad_target(struct vhost_scsi *vs,
+ struct vhost_virtqueue *vq,
+ int head, unsigned out)
+{
+ struct virtio_scsi_cmd_resp __user *resp;
+ struct virtio_scsi_cmd_resp rsp;
+ int ret;
+
+ memset(&rsp, 0, sizeof(rsp));
+ rsp.response = VIRTIO_SCSI_S_BAD_TARGET;
+ resp = vq->iov[out].iov_base;
+ ret = __copy_to_user(resp, &rsp, sizeof(rsp));
+ if (!ret)
+ vhost_add_used_and_signal(&vs->dev, vq, head, 0);
+ else
+ pr_err("Faulted on virtio_scsi_cmd_resp\n");
+}
+
+static void
+vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
+{
+ struct tcm_vhost_tpg **vs_tpg;
+ struct virtio_scsi_cmd_req v_req;
+ struct virtio_scsi_cmd_req_pi v_req_pi;
+ struct tcm_vhost_tpg *tpg;
+ struct tcm_vhost_cmd *cmd;
+ u64 tag;
+ u32 exp_data_len, data_first, data_num, data_direction, prot_first;
+ unsigned out, in, i;
+ int head, ret, data_niov, prot_niov, prot_bytes;
+ size_t req_size;
+ u16 lun;
+ u8 *target, *lunp, task_attr;
+ bool hdr_pi;
+ void *req, *cdb;
+
+ mutex_lock(&vq->mutex);
+ /*
+ * We can handle the vq only after the endpoint is setup by calling the
+ * VHOST_SCSI_SET_ENDPOINT ioctl.
+ */
+ vs_tpg = vq->private_data;
+ if (!vs_tpg)
+ goto out;
+
+ vhost_disable_notify(&vs->dev, vq);
+
+ for (;;) {
+ head = vhost_get_vq_desc(vq, vq->iov,
+ ARRAY_SIZE(vq->iov), &out, &in,
+ NULL, NULL);
+ pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
+ head, out, in);
+ /* On error, stop handling until the next kick. */
+ if (unlikely(head < 0))
+ break;
+ /* Nothing new? Wait for eventfd to tell us they refilled. */
+ if (head == vq->num) {
+ if (unlikely(vhost_enable_notify(&vs->dev, vq))) {
+ vhost_disable_notify(&vs->dev, vq);
+ continue;
+ }
+ break;
+ }
+
+ /* FIXME: BIDI operation */
+ if (out == 1 && in == 1) {
+ data_direction = DMA_NONE;
+ data_first = 0;
+ data_num = 0;
+ } else if (out == 1 && in > 1) {
+ data_direction = DMA_FROM_DEVICE;
+ data_first = out + 1;
+ data_num = in - 1;
+ } else if (out > 1 && in == 1) {
+ data_direction = DMA_TO_DEVICE;
+ data_first = 1;
+ data_num = out - 1;
+ } else {
+ vq_err(vq, "Invalid buffer layout out: %u in: %u\n",
+ out, in);
+ break;
+ }
+
+ /*
+ * Check for a sane resp buffer so we can report errors to
+ * the guest.
+ */
+ if (unlikely(vq->iov[out].iov_len !=
+ sizeof(struct virtio_scsi_cmd_resp))) {
+ vq_err(vq, "Expecting virtio_scsi_cmd_resp, got %zu"
+ " bytes\n", vq->iov[out].iov_len);
+ break;
+ }
+
+ if (vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI)) {
+ req = &v_req_pi;
+ lunp = &v_req_pi.lun[0];
+ target = &v_req_pi.lun[1];
+ req_size = sizeof(v_req_pi);
+ hdr_pi = true;
+ } else {
+ req = &v_req;
+ lunp = &v_req.lun[0];
+ target = &v_req.lun[1];
+ req_size = sizeof(v_req);
+ hdr_pi = false;
+ }
+
+ if (unlikely(vq->iov[0].iov_len < req_size)) {
+ pr_err("Expecting virtio-scsi header: %zu, got %zu\n",
+ req_size, vq->iov[0].iov_len);
+ break;
+ }
+ ret = memcpy_fromiovecend(req, &vq->iov[0], 0, req_size);
+ if (unlikely(ret)) {
+ vq_err(vq, "Faulted on virtio_scsi_cmd_req\n");
+ break;
+ }
+
+ /* virtio-scsi spec requires byte 0 of the lun to be 1 */
+ if (unlikely(*lunp != 1)) {
+ vhost_scsi_send_bad_target(vs, vq, head, out);
+ continue;
+ }
+
+ tpg = ACCESS_ONCE(vs_tpg[*target]);
+
+ /* Target does not exist, fail the request */
+ if (unlikely(!tpg)) {
+ vhost_scsi_send_bad_target(vs, vq, head, out);
+ continue;
+ }
+
+ data_niov = data_num;
+ prot_niov = prot_first = prot_bytes = 0;
+ /*
+ * Determine if any protection information iovecs are preceeding
+ * the actual data payload, and adjust data_first + data_niov
+ * values accordingly for vhost_scsi_map_iov_to_sgl() below.
+ *
+ * Also extract virtio_scsi header bits for vhost_scsi_get_tag()
+ */
+ if (hdr_pi) {
+ if (v_req_pi.pi_bytesout) {
+ if (data_direction != DMA_TO_DEVICE) {
+ vq_err(vq, "Received non zero do_pi_niov"
+ ", but wrong data_direction\n");
+ goto err_cmd;
+ }
+ prot_bytes = v_req_pi.pi_bytesout;
+ } else if (v_req_pi.pi_bytesin) {
+ if (data_direction != DMA_FROM_DEVICE) {
+ vq_err(vq, "Received non zero di_pi_niov"
+ ", but wrong data_direction\n");
+ goto err_cmd;
+ }
+ prot_bytes = v_req_pi.pi_bytesin;
+ }
+ if (prot_bytes) {
+ int tmp = 0;
+
+ for (i = 0; i < data_num; i++) {
+ tmp += vq->iov[data_first + i].iov_len;
+ prot_niov++;
+ if (tmp >= prot_bytes)
+ break;
+ }
+ prot_first = data_first;
+ data_first += prot_niov;
+ data_niov = data_num - prot_niov;
+ }
+ tag = v_req_pi.tag;
+ task_attr = v_req_pi.task_attr;
+ cdb = &v_req_pi.cdb[0];
+ lun = ((v_req_pi.lun[2] << 8) | v_req_pi.lun[3]) & 0x3FFF;
+ } else {
+ tag = v_req.tag;
+ task_attr = v_req.task_attr;
+ cdb = &v_req.cdb[0];
+ lun = ((v_req.lun[2] << 8) | v_req.lun[3]) & 0x3FFF;
+ }
+ exp_data_len = 0;
+ for (i = 0; i < data_niov; i++)
+ exp_data_len += vq->iov[data_first + i].iov_len;
+ /*
+ * Check that the recieved CDB size does not exceeded our
+ * hardcoded max for vhost-scsi
+ *
+ * TODO what if cdb was too small for varlen cdb header?
+ */
+ if (unlikely(scsi_command_size(cdb) > TCM_VHOST_MAX_CDB_SIZE)) {
+ vq_err(vq, "Received SCSI CDB with command_size: %d that"
+ " exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n",
+ scsi_command_size(cdb), TCM_VHOST_MAX_CDB_SIZE);
+ goto err_cmd;
+ }
+
+ cmd = vhost_scsi_get_tag(vq, tpg, cdb, tag, lun, task_attr,
+ exp_data_len + prot_bytes,
+ data_direction);
+ if (IS_ERR(cmd)) {
+ vq_err(vq, "vhost_scsi_get_tag failed %ld\n",
+ PTR_ERR(cmd));
+ goto err_cmd;
+ }
+
+ pr_debug("Allocated tv_cmd: %p exp_data_len: %d, data_direction"
+ ": %d\n", cmd, exp_data_len, data_direction);
+
+ cmd->tvc_vhost = vs;
+ cmd->tvc_vq = vq;
+ cmd->tvc_resp = vq->iov[out].iov_base;
+
+ pr_debug("vhost_scsi got command opcode: %#02x, lun: %d\n",
+ cmd->tvc_cdb[0], cmd->tvc_lun);
+
+ if (prot_niov) {
+ ret = vhost_scsi_map_iov_to_prot(cmd,
+ &vq->iov[prot_first], prot_niov,
+ data_direction == DMA_FROM_DEVICE);
+ if (unlikely(ret)) {
+ vq_err(vq, "Failed to map iov to"
+ " prot_sgl\n");
+ goto err_free;
+ }
+ }
+ if (data_direction != DMA_NONE) {
+ ret = vhost_scsi_map_iov_to_sgl(cmd,
+ &vq->iov[data_first], data_niov,
+ data_direction == DMA_FROM_DEVICE);
+ if (unlikely(ret)) {
+ vq_err(vq, "Failed to map iov to sgl\n");
+ goto err_free;
+ }
+ }
+ /*
+ * Save the descriptor from vhost_get_vq_desc() to be used to
+ * complete the virtio-scsi request in TCM callback context via
+ * tcm_vhost_queue_data_in() and tcm_vhost_queue_status()
+ */
+ cmd->tvc_vq_desc = head;
+ /*
+ * Dispatch tv_cmd descriptor for cmwq execution in process
+ * context provided by tcm_vhost_workqueue. This also ensures
+ * tv_cmd is executed on the same kworker CPU as this vhost
+ * thread to gain positive L2 cache locality effects..
+ */
+ INIT_WORK(&cmd->work, tcm_vhost_submission_work);
+ queue_work(tcm_vhost_workqueue, &cmd->work);
+ }
+
+ mutex_unlock(&vq->mutex);
+ return;
+
+err_free:
+ vhost_scsi_free_cmd(cmd);
+err_cmd:
+ vhost_scsi_send_bad_target(vs, vq, head, out);
+out:
+ mutex_unlock(&vq->mutex);
+}
+
+static void vhost_scsi_ctl_handle_kick(struct vhost_work *work)
+{
+ pr_debug("%s: The handling func for control queue.\n", __func__);
+}
+
+static void
+tcm_vhost_send_evt(struct vhost_scsi *vs,
+ struct tcm_vhost_tpg *tpg,
+ struct se_lun *lun,
+ u32 event,
+ u32 reason)
+{
+ struct tcm_vhost_evt *evt;
+
+ evt = tcm_vhost_allocate_evt(vs, event, reason);
+ if (!evt)
+ return;
+
+ if (tpg && lun) {
+ /* TODO: share lun setup code with virtio-scsi.ko */
+ /*
+ * Note: evt->event is zeroed when we allocate it and
+ * lun[4-7] need to be zero according to virtio-scsi spec.
+ */
+ evt->event.lun[0] = 0x01;
+ evt->event.lun[1] = tpg->tport_tpgt & 0xFF;
+ if (lun->unpacked_lun >= 256)
+ evt->event.lun[2] = lun->unpacked_lun >> 8 | 0x40 ;
+ evt->event.lun[3] = lun->unpacked_lun & 0xFF;
+ }
+
+ llist_add(&evt->list, &vs->vs_event_list);
+ vhost_work_queue(&vs->dev, &vs->vs_event_work);
+}
+
+static void vhost_scsi_evt_handle_kick(struct vhost_work *work)
+{
+ struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
+ poll.work);
+ struct vhost_scsi *vs = container_of(vq->dev, struct vhost_scsi, dev);
+
+ mutex_lock(&vq->mutex);
+ if (!vq->private_data)
+ goto out;
+
+ if (vs->vs_events_missed)
+ tcm_vhost_send_evt(vs, NULL, NULL, VIRTIO_SCSI_T_NO_EVENT, 0);
+out:
+ mutex_unlock(&vq->mutex);
+}
+
+static void vhost_scsi_handle_kick(struct vhost_work *work)
+{
+ struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
+ poll.work);
+ struct vhost_scsi *vs = container_of(vq->dev, struct vhost_scsi, dev);
+
+ vhost_scsi_handle_vq(vs, vq);
+}
+
+static void vhost_scsi_flush_vq(struct vhost_scsi *vs, int index)
+{
+ vhost_poll_flush(&vs->vqs[index].vq.poll);
+}
+
+/* Callers must hold dev mutex */
+static void vhost_scsi_flush(struct vhost_scsi *vs)
+{
+ struct vhost_scsi_inflight *old_inflight[VHOST_SCSI_MAX_VQ];
+ int i;
+
+ /* Init new inflight and remember the old inflight */
+ tcm_vhost_init_inflight(vs, old_inflight);
+
+ /*
+ * The inflight->kref was initialized to 1. We decrement it here to
+ * indicate the start of the flush operation so that it will reach 0
+ * when all the reqs are finished.
+ */
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++)
+ kref_put(&old_inflight[i]->kref, tcm_vhost_done_inflight);
+
+ /* Flush both the vhost poll and vhost work */
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++)
+ vhost_scsi_flush_vq(vs, i);
+ vhost_work_flush(&vs->dev, &vs->vs_completion_work);
+ vhost_work_flush(&vs->dev, &vs->vs_event_work);
+
+ /* Wait for all reqs issued before the flush to be finished */
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++)
+ wait_for_completion(&old_inflight[i]->comp);
+}
+
+/*
+ * Called from vhost_scsi_ioctl() context to walk the list of available
+ * tcm_vhost_tpg with an active struct tcm_vhost_nexus
+ *
+ * The lock nesting rule is:
+ * tcm_vhost_mutex -> vs->dev.mutex -> tpg->tv_tpg_mutex -> vq->mutex
+ */
+static int
+vhost_scsi_set_endpoint(struct vhost_scsi *vs,
+ struct vhost_scsi_target *t)
+{
+ struct tcm_vhost_tport *tv_tport;
+ struct tcm_vhost_tpg *tpg;
+ struct tcm_vhost_tpg **vs_tpg;
+ struct vhost_virtqueue *vq;
+ int index, ret, i, len;
+ bool match = false;
+
+ mutex_lock(&tcm_vhost_mutex);
+ mutex_lock(&vs->dev.mutex);
+
+ /* Verify that ring has been setup correctly. */
+ for (index = 0; index < vs->dev.nvqs; ++index) {
+ /* Verify that ring has been setup correctly. */
+ if (!vhost_vq_access_ok(&vs->vqs[index].vq)) {
+ ret = -EFAULT;
+ goto out;
+ }
+ }
+
+ len = sizeof(vs_tpg[0]) * VHOST_SCSI_MAX_TARGET;
+ vs_tpg = kzalloc(len, GFP_KERNEL);
+ if (!vs_tpg) {
+ ret = -ENOMEM;
+ goto out;
+ }
+ if (vs->vs_tpg)
+ memcpy(vs_tpg, vs->vs_tpg, len);
+
+ list_for_each_entry(tpg, &tcm_vhost_list, tv_tpg_list) {
+ mutex_lock(&tpg->tv_tpg_mutex);
+ if (!tpg->tpg_nexus) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ continue;
+ }
+ if (tpg->tv_tpg_vhost_count != 0) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ continue;
+ }
+ tv_tport = tpg->tport;
+
+ if (!strcmp(tv_tport->tport_name, t->vhost_wwpn)) {
+ if (vs->vs_tpg && vs->vs_tpg[tpg->tport_tpgt]) {
+ kfree(vs_tpg);
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ ret = -EEXIST;
+ goto out;
+ }
+ tpg->tv_tpg_vhost_count++;
+ tpg->vhost_scsi = vs;
+ vs_tpg[tpg->tport_tpgt] = tpg;
+ smp_mb__after_atomic();
+ match = true;
+ }
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ }
+
+ if (match) {
+ memcpy(vs->vs_vhost_wwpn, t->vhost_wwpn,
+ sizeof(vs->vs_vhost_wwpn));
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
+ vq = &vs->vqs[i].vq;
+ mutex_lock(&vq->mutex);
+ vq->private_data = vs_tpg;
+ vhost_init_used(vq);
+ mutex_unlock(&vq->mutex);
+ }
+ ret = 0;
+ } else {
+ ret = -EEXIST;
+ }
+
+ /*
+ * Act as synchronize_rcu to make sure access to
+ * old vs->vs_tpg is finished.
+ */
+ vhost_scsi_flush(vs);
+ kfree(vs->vs_tpg);
+ vs->vs_tpg = vs_tpg;
+
+out:
+ mutex_unlock(&vs->dev.mutex);
+ mutex_unlock(&tcm_vhost_mutex);
+ return ret;
+}
+
+static int
+vhost_scsi_clear_endpoint(struct vhost_scsi *vs,
+ struct vhost_scsi_target *t)
+{
+ struct tcm_vhost_tport *tv_tport;
+ struct tcm_vhost_tpg *tpg;
+ struct vhost_virtqueue *vq;
+ bool match = false;
+ int index, ret, i;
+ u8 target;
+
+ mutex_lock(&tcm_vhost_mutex);
+ mutex_lock(&vs->dev.mutex);
+ /* Verify that ring has been setup correctly. */
+ for (index = 0; index < vs->dev.nvqs; ++index) {
+ if (!vhost_vq_access_ok(&vs->vqs[index].vq)) {
+ ret = -EFAULT;
+ goto err_dev;
+ }
+ }
+
+ if (!vs->vs_tpg) {
+ ret = 0;
+ goto err_dev;
+ }
+
+ for (i = 0; i < VHOST_SCSI_MAX_TARGET; i++) {
+ target = i;
+ tpg = vs->vs_tpg[target];
+ if (!tpg)
+ continue;
+
+ mutex_lock(&tpg->tv_tpg_mutex);
+ tv_tport = tpg->tport;
+ if (!tv_tport) {
+ ret = -ENODEV;
+ goto err_tpg;
+ }
+
+ if (strcmp(tv_tport->tport_name, t->vhost_wwpn)) {
+ pr_warn("tv_tport->tport_name: %s, tpg->tport_tpgt: %hu"
+ " does not match t->vhost_wwpn: %s, t->vhost_tpgt: %hu\n",
+ tv_tport->tport_name, tpg->tport_tpgt,
+ t->vhost_wwpn, t->vhost_tpgt);
+ ret = -EINVAL;
+ goto err_tpg;
+ }
+ tpg->tv_tpg_vhost_count--;
+ tpg->vhost_scsi = NULL;
+ vs->vs_tpg[target] = NULL;
+ match = true;
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ }
+ if (match) {
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
+ vq = &vs->vqs[i].vq;
+ mutex_lock(&vq->mutex);
+ vq->private_data = NULL;
+ mutex_unlock(&vq->mutex);
+ }
+ }
+ /*
+ * Act as synchronize_rcu to make sure access to
+ * old vs->vs_tpg is finished.
+ */
+ vhost_scsi_flush(vs);
+ kfree(vs->vs_tpg);
+ vs->vs_tpg = NULL;
+ WARN_ON(vs->vs_events_nr);
+ mutex_unlock(&vs->dev.mutex);
+ mutex_unlock(&tcm_vhost_mutex);
+ return 0;
+
+err_tpg:
+ mutex_unlock(&tpg->tv_tpg_mutex);
+err_dev:
+ mutex_unlock(&vs->dev.mutex);
+ mutex_unlock(&tcm_vhost_mutex);
+ return ret;
+}
+
+static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)
+{
+ struct vhost_virtqueue *vq;
+ int i;
+
+ if (features & ~VHOST_SCSI_FEATURES)
+ return -EOPNOTSUPP;
+
+ mutex_lock(&vs->dev.mutex);
+ if ((features & (1 << VHOST_F_LOG_ALL)) &&
+ !vhost_log_access_ok(&vs->dev)) {
+ mutex_unlock(&vs->dev.mutex);
+ return -EFAULT;
+ }
+
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
+ vq = &vs->vqs[i].vq;
+ mutex_lock(&vq->mutex);
+ vq->acked_features = features;
+ mutex_unlock(&vq->mutex);
+ }
+ mutex_unlock(&vs->dev.mutex);
+ return 0;
+}
+
+static int vhost_scsi_open(struct inode *inode, struct file *f)
+{
+ struct vhost_scsi *vs;
+ struct vhost_virtqueue **vqs;
+ int r = -ENOMEM, i;
+
+ vs = kzalloc(sizeof(*vs), GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
+ if (!vs) {
+ vs = vzalloc(sizeof(*vs));
+ if (!vs)
+ goto err_vs;
+ }
+
+ vqs = kmalloc(VHOST_SCSI_MAX_VQ * sizeof(*vqs), GFP_KERNEL);
+ if (!vqs)
+ goto err_vqs;
+
+ vhost_work_init(&vs->vs_completion_work, vhost_scsi_complete_cmd_work);
+ vhost_work_init(&vs->vs_event_work, tcm_vhost_evt_work);
+
+ vs->vs_events_nr = 0;
+ vs->vs_events_missed = false;
+
+ vqs[VHOST_SCSI_VQ_CTL] = &vs->vqs[VHOST_SCSI_VQ_CTL].vq;
+ vqs[VHOST_SCSI_VQ_EVT] = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
+ vs->vqs[VHOST_SCSI_VQ_CTL].vq.handle_kick = vhost_scsi_ctl_handle_kick;
+ vs->vqs[VHOST_SCSI_VQ_EVT].vq.handle_kick = vhost_scsi_evt_handle_kick;
+ for (i = VHOST_SCSI_VQ_IO; i < VHOST_SCSI_MAX_VQ; i++) {
+ vqs[i] = &vs->vqs[i].vq;
+ vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
+ }
+ vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ);
+
+ tcm_vhost_init_inflight(vs, NULL);
+
+ f->private_data = vs;
+ return 0;
+
+err_vqs:
+ kvfree(vs);
+err_vs:
+ return r;
+}
+
+static int vhost_scsi_release(struct inode *inode, struct file *f)
+{
+ struct vhost_scsi *vs = f->private_data;
+ struct vhost_scsi_target t;
+
+ mutex_lock(&vs->dev.mutex);
+ memcpy(t.vhost_wwpn, vs->vs_vhost_wwpn, sizeof(t.vhost_wwpn));
+ mutex_unlock(&vs->dev.mutex);
+ vhost_scsi_clear_endpoint(vs, &t);
+ vhost_dev_stop(&vs->dev);
+ vhost_dev_cleanup(&vs->dev, false);
+ /* Jobs can re-queue themselves in evt kick handler. Do extra flush. */
+ vhost_scsi_flush(vs);
+ kfree(vs->dev.vqs);
+ kvfree(vs);
+ return 0;
+}
+
+static long
+vhost_scsi_ioctl(struct file *f,
+ unsigned int ioctl,
+ unsigned long arg)
+{
+ struct vhost_scsi *vs = f->private_data;
+ struct vhost_scsi_target backend;
+ void __user *argp = (void __user *)arg;
+ u64 __user *featurep = argp;
+ u32 __user *eventsp = argp;
+ u32 events_missed;
+ u64 features;
+ int r, abi_version = VHOST_SCSI_ABI_VERSION;
+ struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
+
+ switch (ioctl) {
+ case VHOST_SCSI_SET_ENDPOINT:
+ if (copy_from_user(&backend, argp, sizeof backend))
+ return -EFAULT;
+ if (backend.reserved != 0)
+ return -EOPNOTSUPP;
+
+ return vhost_scsi_set_endpoint(vs, &backend);
+ case VHOST_SCSI_CLEAR_ENDPOINT:
+ if (copy_from_user(&backend, argp, sizeof backend))
+ return -EFAULT;
+ if (backend.reserved != 0)
+ return -EOPNOTSUPP;
+
+ return vhost_scsi_clear_endpoint(vs, &backend);
+ case VHOST_SCSI_GET_ABI_VERSION:
+ if (copy_to_user(argp, &abi_version, sizeof abi_version))
+ return -EFAULT;
+ return 0;
+ case VHOST_SCSI_SET_EVENTS_MISSED:
+ if (get_user(events_missed, eventsp))
+ return -EFAULT;
+ mutex_lock(&vq->mutex);
+ vs->vs_events_missed = events_missed;
+ mutex_unlock(&vq->mutex);
+ return 0;
+ case VHOST_SCSI_GET_EVENTS_MISSED:
+ mutex_lock(&vq->mutex);
+ events_missed = vs->vs_events_missed;
+ mutex_unlock(&vq->mutex);
+ if (put_user(events_missed, eventsp))
+ return -EFAULT;
+ return 0;
+ case VHOST_GET_FEATURES:
+ features = VHOST_SCSI_FEATURES;
+ if (copy_to_user(featurep, &features, sizeof features))
+ return -EFAULT;
+ return 0;
+ case VHOST_SET_FEATURES:
+ if (copy_from_user(&features, featurep, sizeof features))
+ return -EFAULT;
+ return vhost_scsi_set_features(vs, features);
+ default:
+ mutex_lock(&vs->dev.mutex);
+ r = vhost_dev_ioctl(&vs->dev, ioctl, argp);
+ /* TODO: flush backend after dev ioctl. */
+ if (r == -ENOIOCTLCMD)
+ r = vhost_vring_ioctl(&vs->dev, ioctl, argp);
+ mutex_unlock(&vs->dev.mutex);
+ return r;
+ }
+}
+
+#ifdef CONFIG_COMPAT
+static long vhost_scsi_compat_ioctl(struct file *f, unsigned int ioctl,
+ unsigned long arg)
+{
+ return vhost_scsi_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
+}
+#endif
+
+static const struct file_operations vhost_scsi_fops = {
+ .owner = THIS_MODULE,
+ .release = vhost_scsi_release,
+ .unlocked_ioctl = vhost_scsi_ioctl,
+#ifdef CONFIG_COMPAT
+ .compat_ioctl = vhost_scsi_compat_ioctl,
+#endif
+ .open = vhost_scsi_open,
+ .llseek = noop_llseek,
+};
+
+static struct miscdevice vhost_scsi_misc = {
+ MISC_DYNAMIC_MINOR,
+ "vhost-scsi",
+ &vhost_scsi_fops,
+};
+
+static int __init vhost_scsi_register(void)
+{
+ return misc_register(&vhost_scsi_misc);
+}
+
+static int vhost_scsi_deregister(void)
+{
+ return misc_deregister(&vhost_scsi_misc);
+}
+
+static char *tcm_vhost_dump_proto_id(struct tcm_vhost_tport *tport)
+{
+ switch (tport->tport_proto_id) {
+ case SCSI_PROTOCOL_SAS:
+ return "SAS";
+ case SCSI_PROTOCOL_FCP:
+ return "FCP";
+ case SCSI_PROTOCOL_ISCSI:
+ return "iSCSI";
+ default:
+ break;
+ }
+
+ return "Unknown";
+}
+
+static void
+tcm_vhost_do_plug(struct tcm_vhost_tpg *tpg,
+ struct se_lun *lun, bool plug)
+{
+
+ struct vhost_scsi *vs = tpg->vhost_scsi;
+ struct vhost_virtqueue *vq;
+ u32 reason;
+
+ if (!vs)
+ return;
+
+ mutex_lock(&vs->dev.mutex);
+
+ if (plug)
+ reason = VIRTIO_SCSI_EVT_RESET_RESCAN;
+ else
+ reason = VIRTIO_SCSI_EVT_RESET_REMOVED;
+
+ vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
+ mutex_lock(&vq->mutex);
+ if (vhost_has_feature(vq, VIRTIO_SCSI_F_HOTPLUG))
+ tcm_vhost_send_evt(vs, tpg, lun,
+ VIRTIO_SCSI_T_TRANSPORT_RESET, reason);
+ mutex_unlock(&vq->mutex);
+ mutex_unlock(&vs->dev.mutex);
+}
+
+static void tcm_vhost_hotplug(struct tcm_vhost_tpg *tpg, struct se_lun *lun)
+{
+ tcm_vhost_do_plug(tpg, lun, true);
+}
+
+static void tcm_vhost_hotunplug(struct tcm_vhost_tpg *tpg, struct se_lun *lun)
+{
+ tcm_vhost_do_plug(tpg, lun, false);
+}
+
+static int tcm_vhost_port_link(struct se_portal_group *se_tpg,
+ struct se_lun *lun)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+
+ mutex_lock(&tcm_vhost_mutex);
+
+ mutex_lock(&tpg->tv_tpg_mutex);
+ tpg->tv_tpg_port_count++;
+ mutex_unlock(&tpg->tv_tpg_mutex);
+
+ tcm_vhost_hotplug(tpg, lun);
+
+ mutex_unlock(&tcm_vhost_mutex);
+
+ return 0;
+}
+
+static void tcm_vhost_port_unlink(struct se_portal_group *se_tpg,
+ struct se_lun *lun)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+
+ mutex_lock(&tcm_vhost_mutex);
+
+ mutex_lock(&tpg->tv_tpg_mutex);
+ tpg->tv_tpg_port_count--;
+ mutex_unlock(&tpg->tv_tpg_mutex);
+
+ tcm_vhost_hotunplug(tpg, lun);
+
+ mutex_unlock(&tcm_vhost_mutex);
+}
+
+static struct se_node_acl *
+tcm_vhost_make_nodeacl(struct se_portal_group *se_tpg,
+ struct config_group *group,
+ const char *name)
+{
+ struct se_node_acl *se_nacl, *se_nacl_new;
+ struct tcm_vhost_nacl *nacl;
+ u64 wwpn = 0;
+ u32 nexus_depth;
+
+ /* tcm_vhost_parse_wwn(name, &wwpn, 1) < 0)
+ return ERR_PTR(-EINVAL); */
+ se_nacl_new = tcm_vhost_alloc_fabric_acl(se_tpg);
+ if (!se_nacl_new)
+ return ERR_PTR(-ENOMEM);
+
+ nexus_depth = 1;
+ /*
+ * se_nacl_new may be released by core_tpg_add_initiator_node_acl()
+ * when converting a NodeACL from demo mode -> explict
+ */
+ se_nacl = core_tpg_add_initiator_node_acl(se_tpg, se_nacl_new,
+ name, nexus_depth);
+ if (IS_ERR(se_nacl)) {
+ tcm_vhost_release_fabric_acl(se_tpg, se_nacl_new);
+ return se_nacl;
+ }
+ /*
+ * Locate our struct tcm_vhost_nacl and set the FC Nport WWPN
+ */
+ nacl = container_of(se_nacl, struct tcm_vhost_nacl, se_node_acl);
+ nacl->iport_wwpn = wwpn;
+
+ return se_nacl;
+}
+
+static void tcm_vhost_drop_nodeacl(struct se_node_acl *se_acl)
+{
+ struct tcm_vhost_nacl *nacl = container_of(se_acl,
+ struct tcm_vhost_nacl, se_node_acl);
+ core_tpg_del_initiator_node_acl(se_acl->se_tpg, se_acl, 1);
+ kfree(nacl);
+}
+
+static void tcm_vhost_free_cmd_map_res(struct tcm_vhost_nexus *nexus,
+ struct se_session *se_sess)
+{
+ struct tcm_vhost_cmd *tv_cmd;
+ unsigned int i;
+
+ if (!se_sess->sess_cmd_map)
+ return;
+
+ for (i = 0; i < TCM_VHOST_DEFAULT_TAGS; i++) {
+ tv_cmd = &((struct tcm_vhost_cmd *)se_sess->sess_cmd_map)[i];
+
+ kfree(tv_cmd->tvc_sgl);
+ kfree(tv_cmd->tvc_prot_sgl);
+ kfree(tv_cmd->tvc_upages);
+ }
+}
+
+static int tcm_vhost_make_nexus(struct tcm_vhost_tpg *tpg,
+ const char *name)
+{
+ struct se_portal_group *se_tpg;
+ struct se_session *se_sess;
+ struct tcm_vhost_nexus *tv_nexus;
+ struct tcm_vhost_cmd *tv_cmd;
+ unsigned int i;
+
+ mutex_lock(&tpg->tv_tpg_mutex);
+ if (tpg->tpg_nexus) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_debug("tpg->tpg_nexus already exists\n");
+ return -EEXIST;
+ }
+ se_tpg = &tpg->se_tpg;
+
+ tv_nexus = kzalloc(sizeof(struct tcm_vhost_nexus), GFP_KERNEL);
+ if (!tv_nexus) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_err("Unable to allocate struct tcm_vhost_nexus\n");
+ return -ENOMEM;
+ }
+ /*
+ * Initialize the struct se_session pointer and setup tagpool
+ * for struct tcm_vhost_cmd descriptors
+ */
+ tv_nexus->tvn_se_sess = transport_init_session_tags(
+ TCM_VHOST_DEFAULT_TAGS,
+ sizeof(struct tcm_vhost_cmd),
+ TARGET_PROT_DIN_PASS | TARGET_PROT_DOUT_PASS);
+ if (IS_ERR(tv_nexus->tvn_se_sess)) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ kfree(tv_nexus);
+ return -ENOMEM;
+ }
+ se_sess = tv_nexus->tvn_se_sess;
+ for (i = 0; i < TCM_VHOST_DEFAULT_TAGS; i++) {
+ tv_cmd = &((struct tcm_vhost_cmd *)se_sess->sess_cmd_map)[i];
+
+ tv_cmd->tvc_sgl = kzalloc(sizeof(struct scatterlist) *
+ TCM_VHOST_PREALLOC_SGLS, GFP_KERNEL);
+ if (!tv_cmd->tvc_sgl) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_err("Unable to allocate tv_cmd->tvc_sgl\n");
+ goto out;
+ }
+
+ tv_cmd->tvc_upages = kzalloc(sizeof(struct page *) *
+ TCM_VHOST_PREALLOC_UPAGES, GFP_KERNEL);
+ if (!tv_cmd->tvc_upages) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_err("Unable to allocate tv_cmd->tvc_upages\n");
+ goto out;
+ }
+
+ tv_cmd->tvc_prot_sgl = kzalloc(sizeof(struct scatterlist) *
+ TCM_VHOST_PREALLOC_PROT_SGLS, GFP_KERNEL);
+ if (!tv_cmd->tvc_prot_sgl) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_err("Unable to allocate tv_cmd->tvc_prot_sgl\n");
+ goto out;
+ }
+ }
+ /*
+ * Since we are running in 'demo mode' this call with generate a
+ * struct se_node_acl for the tcm_vhost struct se_portal_group with
+ * the SCSI Initiator port name of the passed configfs group 'name'.
+ */
+ tv_nexus->tvn_se_sess->se_node_acl = core_tpg_check_initiator_node_acl(
+ se_tpg, (unsigned char *)name);
+ if (!tv_nexus->tvn_se_sess->se_node_acl) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_debug("core_tpg_check_initiator_node_acl() failed"
+ " for %s\n", name);
+ goto out;
+ }
+ /*
+ * Now register the TCM vhost virtual I_T Nexus as active with the
+ * call to __transport_register_session()
+ */
+ __transport_register_session(se_tpg, tv_nexus->tvn_se_sess->se_node_acl,
+ tv_nexus->tvn_se_sess, tv_nexus);
+ tpg->tpg_nexus = tv_nexus;
+
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ return 0;
+
+out:
+ tcm_vhost_free_cmd_map_res(tv_nexus, se_sess);
+ transport_free_session(se_sess);
+ kfree(tv_nexus);
+ return -ENOMEM;
+}
+
+static int tcm_vhost_drop_nexus(struct tcm_vhost_tpg *tpg)
+{
+ struct se_session *se_sess;
+ struct tcm_vhost_nexus *tv_nexus;
+
+ mutex_lock(&tpg->tv_tpg_mutex);
+ tv_nexus = tpg->tpg_nexus;
+ if (!tv_nexus) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ return -ENODEV;
+ }
+
+ se_sess = tv_nexus->tvn_se_sess;
+ if (!se_sess) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ return -ENODEV;
+ }
+
+ if (tpg->tv_tpg_port_count != 0) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_err("Unable to remove TCM_vhost I_T Nexus with"
+ " active TPG port count: %d\n",
+ tpg->tv_tpg_port_count);
+ return -EBUSY;
+ }
+
+ if (tpg->tv_tpg_vhost_count != 0) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_err("Unable to remove TCM_vhost I_T Nexus with"
+ " active TPG vhost count: %d\n",
+ tpg->tv_tpg_vhost_count);
+ return -EBUSY;
+ }
+
+ pr_debug("TCM_vhost_ConfigFS: Removing I_T Nexus to emulated"
+ " %s Initiator Port: %s\n", tcm_vhost_dump_proto_id(tpg->tport),
+ tv_nexus->tvn_se_sess->se_node_acl->initiatorname);
+
+ tcm_vhost_free_cmd_map_res(tv_nexus, se_sess);
+ /*
+ * Release the SCSI I_T Nexus to the emulated vhost Target Port
+ */
+ transport_deregister_session(tv_nexus->tvn_se_sess);
+ tpg->tpg_nexus = NULL;
+ mutex_unlock(&tpg->tv_tpg_mutex);
+
+ kfree(tv_nexus);
+ return 0;
+}
+
+static ssize_t tcm_vhost_tpg_show_nexus(struct se_portal_group *se_tpg,
+ char *page)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ struct tcm_vhost_nexus *tv_nexus;
+ ssize_t ret;
+
+ mutex_lock(&tpg->tv_tpg_mutex);
+ tv_nexus = tpg->tpg_nexus;
+ if (!tv_nexus) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ return -ENODEV;
+ }
+ ret = snprintf(page, PAGE_SIZE, "%s\n",
+ tv_nexus->tvn_se_sess->se_node_acl->initiatorname);
+ mutex_unlock(&tpg->tv_tpg_mutex);
+
+ return ret;
+}
+
+static ssize_t tcm_vhost_tpg_store_nexus(struct se_portal_group *se_tpg,
+ const char *page,
+ size_t count)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+ struct tcm_vhost_tport *tport_wwn = tpg->tport;
+ unsigned char i_port[TCM_VHOST_NAMELEN], *ptr, *port_ptr;
+ int ret;
+ /*
+ * Shutdown the active I_T nexus if 'NULL' is passed..
+ */
+ if (!strncmp(page, "NULL", 4)) {
+ ret = tcm_vhost_drop_nexus(tpg);
+ return (!ret) ? count : ret;
+ }
+ /*
+ * Otherwise make sure the passed virtual Initiator port WWN matches
+ * the fabric protocol_id set in tcm_vhost_make_tport(), and call
+ * tcm_vhost_make_nexus().
+ */
+ if (strlen(page) >= TCM_VHOST_NAMELEN) {
+ pr_err("Emulated NAA Sas Address: %s, exceeds"
+ " max: %d\n", page, TCM_VHOST_NAMELEN);
+ return -EINVAL;
+ }
+ snprintf(&i_port[0], TCM_VHOST_NAMELEN, "%s", page);
+
+ ptr = strstr(i_port, "naa.");
+ if (ptr) {
+ if (tport_wwn->tport_proto_id != SCSI_PROTOCOL_SAS) {
+ pr_err("Passed SAS Initiator Port %s does not"
+ " match target port protoid: %s\n", i_port,
+ tcm_vhost_dump_proto_id(tport_wwn));
+ return -EINVAL;
+ }
+ port_ptr = &i_port[0];
+ goto check_newline;
+ }
+ ptr = strstr(i_port, "fc.");
+ if (ptr) {
+ if (tport_wwn->tport_proto_id != SCSI_PROTOCOL_FCP) {
+ pr_err("Passed FCP Initiator Port %s does not"
+ " match target port protoid: %s\n", i_port,
+ tcm_vhost_dump_proto_id(tport_wwn));
+ return -EINVAL;
+ }
+ port_ptr = &i_port[3]; /* Skip over "fc." */
+ goto check_newline;
+ }
+ ptr = strstr(i_port, "iqn.");
+ if (ptr) {
+ if (tport_wwn->tport_proto_id != SCSI_PROTOCOL_ISCSI) {
+ pr_err("Passed iSCSI Initiator Port %s does not"
+ " match target port protoid: %s\n", i_port,
+ tcm_vhost_dump_proto_id(tport_wwn));
+ return -EINVAL;
+ }
+ port_ptr = &i_port[0];
+ goto check_newline;
+ }
+ pr_err("Unable to locate prefix for emulated Initiator Port:"
+ " %s\n", i_port);
+ return -EINVAL;
+ /*
+ * Clear any trailing newline for the NAA WWN
+ */
+check_newline:
+ if (i_port[strlen(i_port)-1] == '\n')
+ i_port[strlen(i_port)-1] = '\0';
+
+ ret = tcm_vhost_make_nexus(tpg, port_ptr);
+ if (ret < 0)
+ return ret;
+
+ return count;
+}
+
+TF_TPG_BASE_ATTR(tcm_vhost, nexus, S_IRUGO | S_IWUSR);
+
+static struct configfs_attribute *tcm_vhost_tpg_attrs[] = {
+ &tcm_vhost_tpg_nexus.attr,
+ NULL,
+};
+
+static struct se_portal_group *
+tcm_vhost_make_tpg(struct se_wwn *wwn,
+ struct config_group *group,
+ const char *name)
+{
+ struct tcm_vhost_tport *tport = container_of(wwn,
+ struct tcm_vhost_tport, tport_wwn);
+
+ struct tcm_vhost_tpg *tpg;
+ unsigned long tpgt;
+ int ret;
+
+ if (strstr(name, "tpgt_") != name)
+ return ERR_PTR(-EINVAL);
+ if (kstrtoul(name + 5, 10, &tpgt) || tpgt > UINT_MAX)
+ return ERR_PTR(-EINVAL);
+
+ tpg = kzalloc(sizeof(struct tcm_vhost_tpg), GFP_KERNEL);
+ if (!tpg) {
+ pr_err("Unable to allocate struct tcm_vhost_tpg");
+ return ERR_PTR(-ENOMEM);
+ }
+ mutex_init(&tpg->tv_tpg_mutex);
+ INIT_LIST_HEAD(&tpg->tv_tpg_list);
+ tpg->tport = tport;
+ tpg->tport_tpgt = tpgt;
+
+ ret = core_tpg_register(&tcm_vhost_fabric_configfs->tf_ops, wwn,
+ &tpg->se_tpg, tpg, TRANSPORT_TPG_TYPE_NORMAL);
+ if (ret < 0) {
+ kfree(tpg);
+ return NULL;
+ }
+ mutex_lock(&tcm_vhost_mutex);
+ list_add_tail(&tpg->tv_tpg_list, &tcm_vhost_list);
+ mutex_unlock(&tcm_vhost_mutex);
+
+ return &tpg->se_tpg;
+}
+
+static void tcm_vhost_drop_tpg(struct se_portal_group *se_tpg)
+{
+ struct tcm_vhost_tpg *tpg = container_of(se_tpg,
+ struct tcm_vhost_tpg, se_tpg);
+
+ mutex_lock(&tcm_vhost_mutex);
+ list_del(&tpg->tv_tpg_list);
+ mutex_unlock(&tcm_vhost_mutex);
+ /*
+ * Release the virtual I_T Nexus for this vhost TPG
+ */
+ tcm_vhost_drop_nexus(tpg);
+ /*
+ * Deregister the se_tpg from TCM..
+ */
+ core_tpg_deregister(se_tpg);
+ kfree(tpg);
+}
+
+static struct se_wwn *
+tcm_vhost_make_tport(struct target_fabric_configfs *tf,
+ struct config_group *group,
+ const char *name)
+{
+ struct tcm_vhost_tport *tport;
+ char *ptr;
+ u64 wwpn = 0;
+ int off = 0;
+
+ /* if (tcm_vhost_parse_wwn(name, &wwpn, 1) < 0)
+ return ERR_PTR(-EINVAL); */
+
+ tport = kzalloc(sizeof(struct tcm_vhost_tport), GFP_KERNEL);
+ if (!tport) {
+ pr_err("Unable to allocate struct tcm_vhost_tport");
+ return ERR_PTR(-ENOMEM);
+ }
+ tport->tport_wwpn = wwpn;
+ /*
+ * Determine the emulated Protocol Identifier and Target Port Name
+ * based on the incoming configfs directory name.
+ */
+ ptr = strstr(name, "naa.");
+ if (ptr) {
+ tport->tport_proto_id = SCSI_PROTOCOL_SAS;
+ goto check_len;
+ }
+ ptr = strstr(name, "fc.");
+ if (ptr) {
+ tport->tport_proto_id = SCSI_PROTOCOL_FCP;
+ off = 3; /* Skip over "fc." */
+ goto check_len;
+ }
+ ptr = strstr(name, "iqn.");
+ if (ptr) {
+ tport->tport_proto_id = SCSI_PROTOCOL_ISCSI;
+ goto check_len;
+ }
+
+ pr_err("Unable to locate prefix for emulated Target Port:"
+ " %s\n", name);
+ kfree(tport);
+ return ERR_PTR(-EINVAL);
+
+check_len:
+ if (strlen(name) >= TCM_VHOST_NAMELEN) {
+ pr_err("Emulated %s Address: %s, exceeds"
+ " max: %d\n", name, tcm_vhost_dump_proto_id(tport),
+ TCM_VHOST_NAMELEN);
+ kfree(tport);
+ return ERR_PTR(-EINVAL);
+ }
+ snprintf(&tport->tport_name[0], TCM_VHOST_NAMELEN, "%s", &name[off]);
+
+ pr_debug("TCM_VHost_ConfigFS: Allocated emulated Target"
+ " %s Address: %s\n", tcm_vhost_dump_proto_id(tport), name);
+
+ return &tport->tport_wwn;
+}
+
+static void tcm_vhost_drop_tport(struct se_wwn *wwn)
+{
+ struct tcm_vhost_tport *tport = container_of(wwn,
+ struct tcm_vhost_tport, tport_wwn);
+
+ pr_debug("TCM_VHost_ConfigFS: Deallocating emulated Target"
+ " %s Address: %s\n", tcm_vhost_dump_proto_id(tport),
+ tport->tport_name);
+
+ kfree(tport);
+}
+
+static ssize_t
+tcm_vhost_wwn_show_attr_version(struct target_fabric_configfs *tf,
+ char *page)
+{
+ return sprintf(page, "TCM_VHOST fabric module %s on %s/%s"
+ "on "UTS_RELEASE"\n", TCM_VHOST_VERSION, utsname()->sysname,
+ utsname()->machine);
+}
+
+TF_WWN_ATTR_RO(tcm_vhost, version);
+
+static struct configfs_attribute *tcm_vhost_wwn_attrs[] = {
+ &tcm_vhost_wwn_version.attr,
+ NULL,
+};
+
+static struct target_core_fabric_ops tcm_vhost_ops = {
+ .get_fabric_name = tcm_vhost_get_fabric_name,
+ .get_fabric_proto_ident = tcm_vhost_get_fabric_proto_ident,
+ .tpg_get_wwn = tcm_vhost_get_fabric_wwn,
+ .tpg_get_tag = tcm_vhost_get_tag,
+ .tpg_get_default_depth = tcm_vhost_get_default_depth,
+ .tpg_get_pr_transport_id = tcm_vhost_get_pr_transport_id,
+ .tpg_get_pr_transport_id_len = tcm_vhost_get_pr_transport_id_len,
+ .tpg_parse_pr_out_transport_id = tcm_vhost_parse_pr_out_transport_id,
+ .tpg_check_demo_mode = tcm_vhost_check_true,
+ .tpg_check_demo_mode_cache = tcm_vhost_check_true,
+ .tpg_check_demo_mode_write_protect = tcm_vhost_check_false,
+ .tpg_check_prod_mode_write_protect = tcm_vhost_check_false,
+ .tpg_alloc_fabric_acl = tcm_vhost_alloc_fabric_acl,
+ .tpg_release_fabric_acl = tcm_vhost_release_fabric_acl,
+ .tpg_get_inst_index = tcm_vhost_tpg_get_inst_index,
+ .release_cmd = tcm_vhost_release_cmd,
+ .check_stop_free = vhost_scsi_check_stop_free,
+ .shutdown_session = tcm_vhost_shutdown_session,
+ .close_session = tcm_vhost_close_session,
+ .sess_get_index = tcm_vhost_sess_get_index,
+ .sess_get_initiator_sid = NULL,
+ .write_pending = tcm_vhost_write_pending,
+ .write_pending_status = tcm_vhost_write_pending_status,
+ .set_default_node_attributes = tcm_vhost_set_default_node_attrs,
+ .get_task_tag = tcm_vhost_get_task_tag,
+ .get_cmd_state = tcm_vhost_get_cmd_state,
+ .queue_data_in = tcm_vhost_queue_data_in,
+ .queue_status = tcm_vhost_queue_status,
+ .queue_tm_rsp = tcm_vhost_queue_tm_rsp,
+ .aborted_task = tcm_vhost_aborted_task,
+ /*
+ * Setup callers for generic logic in target_core_fabric_configfs.c
+ */
+ .fabric_make_wwn = tcm_vhost_make_tport,
+ .fabric_drop_wwn = tcm_vhost_drop_tport,
+ .fabric_make_tpg = tcm_vhost_make_tpg,
+ .fabric_drop_tpg = tcm_vhost_drop_tpg,
+ .fabric_post_link = tcm_vhost_port_link,
+ .fabric_pre_unlink = tcm_vhost_port_unlink,
+ .fabric_make_np = NULL,
+ .fabric_drop_np = NULL,
+ .fabric_make_nodeacl = tcm_vhost_make_nodeacl,
+ .fabric_drop_nodeacl = tcm_vhost_drop_nodeacl,
+};
+
+static int tcm_vhost_register_configfs(void)
+{
+ struct target_fabric_configfs *fabric;
+ int ret;
+
+ pr_debug("TCM_VHOST fabric module %s on %s/%s"
+ " on "UTS_RELEASE"\n", TCM_VHOST_VERSION, utsname()->sysname,
+ utsname()->machine);
+ /*
+ * Register the top level struct config_item_type with TCM core
+ */
+ fabric = target_fabric_configfs_init(THIS_MODULE, "vhost");
+ if (IS_ERR(fabric)) {
+ pr_err("target_fabric_configfs_init() failed\n");
+ return PTR_ERR(fabric);
+ }
+ /*
+ * Setup fabric->tf_ops from our local tcm_vhost_ops
+ */
+ fabric->tf_ops = tcm_vhost_ops;
+ /*
+ * Setup default attribute lists for various fabric->tf_cit_tmpl
+ */
+ fabric->tf_cit_tmpl.tfc_wwn_cit.ct_attrs = tcm_vhost_wwn_attrs;
+ fabric->tf_cit_tmpl.tfc_tpg_base_cit.ct_attrs = tcm_vhost_tpg_attrs;
+ fabric->tf_cit_tmpl.tfc_tpg_attrib_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_param_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_np_base_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_base_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_attrib_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_auth_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_param_cit.ct_attrs = NULL;
+ /*
+ * Register the fabric for use within TCM
+ */
+ ret = target_fabric_configfs_register(fabric);
+ if (ret < 0) {
+ pr_err("target_fabric_configfs_register() failed"
+ " for TCM_VHOST\n");
+ return ret;
+ }
+ /*
+ * Setup our local pointer to *fabric
+ */
+ tcm_vhost_fabric_configfs = fabric;
+ pr_debug("TCM_VHOST[0] - Set fabric -> tcm_vhost_fabric_configfs\n");
+ return 0;
+};
+
+static void tcm_vhost_deregister_configfs(void)
+{
+ if (!tcm_vhost_fabric_configfs)
+ return;
+
+ target_fabric_configfs_deregister(tcm_vhost_fabric_configfs);
+ tcm_vhost_fabric_configfs = NULL;
+ pr_debug("TCM_VHOST[0] - Cleared tcm_vhost_fabric_configfs\n");
+};
+
+static int __init tcm_vhost_init(void)
+{
+ int ret = -ENOMEM;
+ /*
+ * Use our own dedicated workqueue for submitting I/O into
+ * target core to avoid contention within system_wq.
+ */
+ tcm_vhost_workqueue = alloc_workqueue("tcm_vhost", 0, 0);
+ if (!tcm_vhost_workqueue)
+ goto out;
+
+ ret = vhost_scsi_register();
+ if (ret < 0)
+ goto out_destroy_workqueue;
+
+ ret = tcm_vhost_register_configfs();
+ if (ret < 0)
+ goto out_vhost_scsi_deregister;
+
+ return 0;
+
+out_vhost_scsi_deregister:
+ vhost_scsi_deregister();
+out_destroy_workqueue:
+ destroy_workqueue(tcm_vhost_workqueue);
+out:
+ return ret;
+};
+
+static void tcm_vhost_exit(void)
+{
+ tcm_vhost_deregister_configfs();
+ vhost_scsi_deregister();
+ destroy_workqueue(tcm_vhost_workqueue);
+};
+
+MODULE_DESCRIPTION("VHOST_SCSI series fabric driver");
+MODULE_ALIAS("tcm_vhost");
+MODULE_LICENSE("GPL");
+module_init(tcm_vhost_init);
+module_exit(tcm_vhost_exit);
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
new file mode 100644
index 00000000000..d9c501eaa6c
--- /dev/null
+++ b/drivers/vhost/test.c
@@ -0,0 +1,338 @@
+/* Copyright (C) 2009 Red Hat, Inc.
+ * Author: Michael S. Tsirkin <mst@redhat.com>
+ *
+ * This work is licensed under the terms of the GNU GPL, version 2.
+ *
+ * test virtio server in host kernel.
+ */
+
+#include <linux/compat.h>
+#include <linux/eventfd.h>
+#include <linux/vhost.h>
+#include <linux/miscdevice.h>
+#include <linux/module.h>
+#include <linux/mutex.h>
+#include <linux/workqueue.h>
+#include <linux/file.h>
+#include <linux/slab.h>
+
+#include "test.h"
+#include "vhost.h"
+
+/* Max number of bytes transferred before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others. */
+#define VHOST_TEST_WEIGHT 0x80000
+
+enum {
+ VHOST_TEST_VQ = 0,
+ VHOST_TEST_VQ_MAX = 1,
+};
+
+struct vhost_test {
+ struct vhost_dev dev;
+ struct vhost_virtqueue vqs[VHOST_TEST_VQ_MAX];
+};
+
+/* Expects to be always run from workqueue - which acts as
+ * read-size critical section for our kind of RCU. */
+static void handle_vq(struct vhost_test *n)
+{
+ struct vhost_virtqueue *vq = &n->vqs[VHOST_TEST_VQ];
+ unsigned out, in;
+ int head;
+ size_t len, total_len = 0;
+ void *private;
+
+ mutex_lock(&vq->mutex);
+ private = vq->private_data;
+ if (!private) {
+ mutex_unlock(&vq->mutex);
+ return;
+ }
+
+ vhost_disable_notify(&n->dev, vq);
+
+ for (;;) {
+ head = vhost_get_vq_desc(vq, vq->iov,
+ ARRAY_SIZE(vq->iov),
+ &out, &in,
+ NULL, NULL);
+ /* On error, stop handling until the next kick. */
+ if (unlikely(head < 0))
+ break;
+ /* Nothing new? Wait for eventfd to tell us they refilled. */
+ if (head == vq->num) {
+ if (unlikely(vhost_enable_notify(&n->dev, vq))) {
+ vhost_disable_notify(&n->dev, vq);
+ continue;
+ }
+ break;
+ }
+ if (in) {
+ vq_err(vq, "Unexpected descriptor format for TX: "
+ "out %d, int %d\n", out, in);
+ break;
+ }
+ len = iov_length(vq->iov, out);
+ /* Sanity check */
+ if (!len) {
+ vq_err(vq, "Unexpected 0 len for TX\n");
+ break;
+ }
+ vhost_add_used_and_signal(&n->dev, vq, head, 0);
+ total_len += len;
+ if (unlikely(total_len >= VHOST_TEST_WEIGHT)) {
+ vhost_poll_queue(&vq->poll);
+ break;
+ }
+ }
+
+ mutex_unlock(&vq->mutex);
+}
+
+static void handle_vq_kick(struct vhost_work *work)
+{
+ struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
+ poll.work);
+ struct vhost_test *n = container_of(vq->dev, struct vhost_test, dev);
+
+ handle_vq(n);
+}
+
+static int vhost_test_open(struct inode *inode, struct file *f)
+{
+ struct vhost_test *n = kmalloc(sizeof *n, GFP_KERNEL);
+ struct vhost_dev *dev;
+ struct vhost_virtqueue **vqs;
+
+ if (!n)
+ return -ENOMEM;
+ vqs = kmalloc(VHOST_TEST_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
+ if (!vqs) {
+ kfree(n);
+ return -ENOMEM;
+ }
+
+ dev = &n->dev;
+ vqs[VHOST_TEST_VQ] = &n->vqs[VHOST_TEST_VQ];
+ n->vqs[VHOST_TEST_VQ].handle_kick = handle_vq_kick;
+ vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX);
+
+ f->private_data = n;
+
+ return 0;
+}
+
+static void *vhost_test_stop_vq(struct vhost_test *n,
+ struct vhost_virtqueue *vq)
+{
+ void *private;
+
+ mutex_lock(&vq->mutex);
+ private = vq->private_data;
+ vq->private_data = NULL;
+ mutex_unlock(&vq->mutex);
+ return private;
+}
+
+static void vhost_test_stop(struct vhost_test *n, void **privatep)
+{
+ *privatep = vhost_test_stop_vq(n, n->vqs + VHOST_TEST_VQ);
+}
+
+static void vhost_test_flush_vq(struct vhost_test *n, int index)
+{
+ vhost_poll_flush(&n->vqs[index].poll);
+}
+
+static void vhost_test_flush(struct vhost_test *n)
+{
+ vhost_test_flush_vq(n, VHOST_TEST_VQ);
+}
+
+static int vhost_test_release(struct inode *inode, struct file *f)
+{
+ struct vhost_test *n = f->private_data;
+ void *private;
+
+ vhost_test_stop(n, &private);
+ vhost_test_flush(n);
+ vhost_dev_cleanup(&n->dev, false);
+ /* We do an extra flush before freeing memory,
+ * since jobs can re-queue themselves. */
+ vhost_test_flush(n);
+ kfree(n);
+ return 0;
+}
+
+static long vhost_test_run(struct vhost_test *n, int test)
+{
+ void *priv, *oldpriv;
+ struct vhost_virtqueue *vq;
+ int r, index;
+
+ if (test < 0 || test > 1)
+ return -EINVAL;
+
+ mutex_lock(&n->dev.mutex);
+ r = vhost_dev_check_owner(&n->dev);
+ if (r)
+ goto err;
+
+ for (index = 0; index < n->dev.nvqs; ++index) {
+ /* Verify that ring has been setup correctly. */
+ if (!vhost_vq_access_ok(&n->vqs[index])) {
+ r = -EFAULT;
+ goto err;
+ }
+ }
+
+ for (index = 0; index < n->dev.nvqs; ++index) {
+ vq = n->vqs + index;
+ mutex_lock(&vq->mutex);
+ priv = test ? n : NULL;
+
+ /* start polling new socket */
+ oldpriv = vq->private_data;
+ vq->private_data = priv;
+
+ r = vhost_init_used(&n->vqs[index]);
+
+ mutex_unlock(&vq->mutex);
+
+ if (r)
+ goto err;
+
+ if (oldpriv) {
+ vhost_test_flush_vq(n, index);
+ }
+ }
+
+ mutex_unlock(&n->dev.mutex);
+ return 0;
+
+err:
+ mutex_unlock(&n->dev.mutex);
+ return r;
+}
+
+static long vhost_test_reset_owner(struct vhost_test *n)
+{
+ void *priv = NULL;
+ long err;
+ struct vhost_memory *memory;
+
+ mutex_lock(&n->dev.mutex);
+ err = vhost_dev_check_owner(&n->dev);
+ if (err)
+ goto done;
+ memory = vhost_dev_reset_owner_prepare();
+ if (!memory) {
+ err = -ENOMEM;
+ goto done;
+ }
+ vhost_test_stop(n, &priv);
+ vhost_test_flush(n);
+ vhost_dev_reset_owner(&n->dev, memory);
+done:
+ mutex_unlock(&n->dev.mutex);
+ return err;
+}
+
+static int vhost_test_set_features(struct vhost_test *n, u64 features)
+{
+ struct vhost_virtqueue *vq;
+
+ mutex_lock(&n->dev.mutex);
+ if ((features & (1 << VHOST_F_LOG_ALL)) &&
+ !vhost_log_access_ok(&n->dev)) {
+ mutex_unlock(&n->dev.mutex);
+ return -EFAULT;
+ }
+ vq = &n->vqs[VHOST_TEST_VQ];
+ mutex_lock(&vq->mutex);
+ vq->acked_features = features;
+ mutex_unlock(&vq->mutex);
+ mutex_unlock(&n->dev.mutex);
+ return 0;
+}
+
+static long vhost_test_ioctl(struct file *f, unsigned int ioctl,
+ unsigned long arg)
+{
+ struct vhost_test *n = f->private_data;
+ void __user *argp = (void __user *)arg;
+ u64 __user *featurep = argp;
+ int test;
+ u64 features;
+ int r;
+ switch (ioctl) {
+ case VHOST_TEST_RUN:
+ if (copy_from_user(&test, argp, sizeof test))
+ return -EFAULT;
+ return vhost_test_run(n, test);
+ case VHOST_GET_FEATURES:
+ features = VHOST_FEATURES;
+ if (copy_to_user(featurep, &features, sizeof features))
+ return -EFAULT;
+ return 0;
+ case VHOST_SET_FEATURES:
+ if (copy_from_user(&features, featurep, sizeof features))
+ return -EFAULT;
+ if (features & ~VHOST_FEATURES)
+ return -EOPNOTSUPP;
+ return vhost_test_set_features(n, features);
+ case VHOST_RESET_OWNER:
+ return vhost_test_reset_owner(n);
+ default:
+ mutex_lock(&n->dev.mutex);
+ r = vhost_dev_ioctl(&n->dev, ioctl, argp);
+ if (r == -ENOIOCTLCMD)
+ r = vhost_vring_ioctl(&n->dev, ioctl, argp);
+ vhost_test_flush(n);
+ mutex_unlock(&n->dev.mutex);
+ return r;
+ }
+}
+
+#ifdef CONFIG_COMPAT
+static long vhost_test_compat_ioctl(struct file *f, unsigned int ioctl,
+ unsigned long arg)
+{
+ return vhost_test_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
+}
+#endif
+
+static const struct file_operations vhost_test_fops = {
+ .owner = THIS_MODULE,
+ .release = vhost_test_release,
+ .unlocked_ioctl = vhost_test_ioctl,
+#ifdef CONFIG_COMPAT
+ .compat_ioctl = vhost_test_compat_ioctl,
+#endif
+ .open = vhost_test_open,
+ .llseek = noop_llseek,
+};
+
+static struct miscdevice vhost_test_misc = {
+ MISC_DYNAMIC_MINOR,
+ "vhost-test",
+ &vhost_test_fops,
+};
+
+static int vhost_test_init(void)
+{
+ return misc_register(&vhost_test_misc);
+}
+module_init(vhost_test_init);
+
+static void vhost_test_exit(void)
+{
+ misc_deregister(&vhost_test_misc);
+}
+module_exit(vhost_test_exit);
+
+MODULE_VERSION("0.0.1");
+MODULE_LICENSE("GPL v2");
+MODULE_AUTHOR("Michael S. Tsirkin");
+MODULE_DESCRIPTION("Host kernel side for virtio simulator");
diff --git a/drivers/vhost/test.h b/drivers/vhost/test.h
new file mode 100644
index 00000000000..1fef5df8215
--- /dev/null
+++ b/drivers/vhost/test.h
@@ -0,0 +1,7 @@
+#ifndef LINUX_VHOST_TEST_H
+#define LINUX_VHOST_TEST_H
+
+/* Start a given test on the virtio null device. 0 stops all tests. */
+#define VHOST_TEST_RUN _IOW(VHOST_VIRTIO, 0x31, int)
+
+#endif
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 94701ff3a23..c90f4374442 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -4,7 +4,7 @@
* Author: Michael S. Tsirkin <mst@redhat.com>
*
* Inspiration, some code, and most witty comments come from
- * Documentation/lguest/lguest.c, by Rusty Russell
+ * Documentation/virtual/lguest/lguest.c, by Rusty Russell
*
* This work is licensed under the terms of the GNU GPL, version 2.
*
@@ -13,23 +13,18 @@
#include <linux/eventfd.h>
#include <linux/vhost.h>
-#include <linux/virtio_net.h>
+#include <linux/uio.h>
#include <linux/mm.h>
+#include <linux/mmu_context.h>
#include <linux/miscdevice.h>
#include <linux/mutex.h>
-#include <linux/rcupdate.h>
#include <linux/poll.h>
#include <linux/file.h>
#include <linux/highmem.h>
#include <linux/slab.h>
#include <linux/kthread.h>
#include <linux/cgroup.h>
-
-#include <linux/net.h>
-#include <linux/if_packet.h>
-#include <linux/if_arp.h>
-
-#include <net/sock.h>
+#include <linux/module.h>
#include "vhost.h"
@@ -38,12 +33,15 @@ enum {
VHOST_MEMORY_F_LOG = 0x1,
};
+#define vhost_used_event(vq) ((u16 __user *)&vq->avail->ring[vq->num])
+#define vhost_avail_event(vq) ((u16 __user *)&vq->used->ring[vq->num])
+
static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
poll_table *pt)
{
struct vhost_poll *poll;
- poll = container_of(pt, struct vhost_poll, table);
+ poll = container_of(pt, struct vhost_poll, table);
poll->wqh = wqh;
add_wait_queue(wqh, &poll->wait);
}
@@ -60,7 +58,7 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,
return 0;
}
-static void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
+void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
{
INIT_LIST_HEAD(&work->node);
work->fn = fn;
@@ -68,6 +66,7 @@ static void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
work->flushing = 0;
work->queue_seq = work->done_seq = 0;
}
+EXPORT_SYMBOL_GPL(vhost_work_init);
/* Init poll structure */
void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
@@ -77,48 +76,73 @@ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
init_poll_funcptr(&poll->table, vhost_poll_func);
poll->mask = mask;
poll->dev = dev;
+ poll->wqh = NULL;
vhost_work_init(&poll->work, fn);
}
+EXPORT_SYMBOL_GPL(vhost_poll_init);
/* Start polling a file. We add ourselves to file's wait queue. The caller must
* keep a reference to a file until after vhost_poll_stop is called. */
-void vhost_poll_start(struct vhost_poll *poll, struct file *file)
+int vhost_poll_start(struct vhost_poll *poll, struct file *file)
{
unsigned long mask;
+ int ret = 0;
+
+ if (poll->wqh)
+ return 0;
+
mask = file->f_op->poll(file, &poll->table);
if (mask)
vhost_poll_wakeup(&poll->wait, 0, 0, (void *)mask);
+ if (mask & POLLERR) {
+ if (poll->wqh)
+ remove_wait_queue(poll->wqh, &poll->wait);
+ ret = -EINVAL;
+ }
+
+ return ret;
}
+EXPORT_SYMBOL_GPL(vhost_poll_start);
/* Stop polling a file. After this function returns, it becomes safe to drop the
* file reference. You must also flush afterwards. */
void vhost_poll_stop(struct vhost_poll *poll)
{
- remove_wait_queue(poll->wqh, &poll->wait);
+ if (poll->wqh) {
+ remove_wait_queue(poll->wqh, &poll->wait);
+ poll->wqh = NULL;
+ }
}
+EXPORT_SYMBOL_GPL(vhost_poll_stop);
-static void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
+static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
+ unsigned seq)
{
- unsigned seq;
int left;
+
+ spin_lock_irq(&dev->work_lock);
+ left = seq - work->done_seq;
+ spin_unlock_irq(&dev->work_lock);
+ return left <= 0;
+}
+
+void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
+{
+ unsigned seq;
int flushing;
spin_lock_irq(&dev->work_lock);
seq = work->queue_seq;
work->flushing++;
spin_unlock_irq(&dev->work_lock);
- wait_event(work->done, ({
- spin_lock_irq(&dev->work_lock);
- left = seq - work->done_seq <= 0;
- spin_unlock_irq(&dev->work_lock);
- left;
- }));
+ wait_event(work->done, vhost_work_seq_done(dev, work, seq));
spin_lock_irq(&dev->work_lock);
flushing = --work->flushing;
spin_unlock_irq(&dev->work_lock);
BUG_ON(flushing < 0);
}
+EXPORT_SYMBOL_GPL(vhost_work_flush);
/* Flush any work that has been scheduled. When calling this, don't hold any
* locks that are also used by the callback. */
@@ -126,9 +150,9 @@ void vhost_poll_flush(struct vhost_poll *poll)
{
vhost_work_flush(poll->dev, &poll->work);
}
+EXPORT_SYMBOL_GPL(vhost_poll_flush);
-static inline void vhost_work_queue(struct vhost_dev *dev,
- struct vhost_work *work)
+void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
{
unsigned long flags;
@@ -136,15 +160,19 @@ static inline void vhost_work_queue(struct vhost_dev *dev,
if (list_empty(&work->node)) {
list_add_tail(&work->node, &dev->work_list);
work->queue_seq++;
+ spin_unlock_irqrestore(&dev->work_lock, flags);
wake_up_process(dev->worker);
+ } else {
+ spin_unlock_irqrestore(&dev->work_lock, flags);
}
- spin_unlock_irqrestore(&dev->work_lock, flags);
}
+EXPORT_SYMBOL_GPL(vhost_work_queue);
void vhost_poll_queue(struct vhost_poll *poll)
{
vhost_work_queue(poll->dev, &poll->work);
}
+EXPORT_SYMBOL_GPL(vhost_poll_queue);
static void vhost_vq_reset(struct vhost_dev *dev,
struct vhost_virtqueue *vq)
@@ -156,13 +184,13 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->last_avail_idx = 0;
vq->avail_idx = 0;
vq->last_used_idx = 0;
- vq->used_flags = 0;
+ vq->signalled_used = 0;
+ vq->signalled_used_valid = false;
vq->used_flags = 0;
vq->log_used = false;
vq->log_addr = -1ull;
- vq->vhost_hlen = 0;
- vq->sock_hlen = 0;
vq->private_data = NULL;
+ vq->acked_features = 0;
vq->log_base = NULL;
vq->error_ctx = NULL;
vq->error = NULL;
@@ -170,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->call_ctx = NULL;
vq->call = NULL;
vq->log_ctx = NULL;
+ vq->memory = NULL;
}
static int vhost_worker(void *data)
@@ -177,6 +206,10 @@ static int vhost_worker(void *data)
struct vhost_dev *dev = data;
struct vhost_work *work = NULL;
unsigned uninitialized_var(seq);
+ mm_segment_t oldfs = get_fs();
+
+ set_fs(USER_DS);
+ use_mm(dev->mm);
for (;;) {
/* mb paired w/ kthread_stop */
@@ -192,7 +225,7 @@ static int vhost_worker(void *data)
if (kthread_should_stop()) {
spin_unlock_irq(&dev->work_lock);
__set_current_state(TASK_RUNNING);
- return 0;
+ break;
}
if (!list_empty(&dev->work_list)) {
work = list_first_entry(&dev->work_list,
@@ -206,54 +239,62 @@ static int vhost_worker(void *data)
if (work) {
__set_current_state(TASK_RUNNING);
work->fn(work);
+ if (need_resched())
+ schedule();
} else
schedule();
}
+ unuse_mm(dev->mm);
+ set_fs(oldfs);
+ return 0;
+}
+
+static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
+{
+ kfree(vq->indirect);
+ vq->indirect = NULL;
+ kfree(vq->log);
+ vq->log = NULL;
+ kfree(vq->heads);
+ vq->heads = NULL;
}
/* Helper to allocate iovec buffers for all vqs. */
static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
{
+ struct vhost_virtqueue *vq;
int i;
+
for (i = 0; i < dev->nvqs; ++i) {
- dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect *
- UIO_MAXIOV, GFP_KERNEL);
- dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV,
- GFP_KERNEL);
- dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads *
- UIO_MAXIOV, GFP_KERNEL);
-
- if (!dev->vqs[i].indirect || !dev->vqs[i].log ||
- !dev->vqs[i].heads)
+ vq = dev->vqs[i];
+ vq->indirect = kmalloc(sizeof *vq->indirect * UIO_MAXIOV,
+ GFP_KERNEL);
+ vq->log = kmalloc(sizeof *vq->log * UIO_MAXIOV, GFP_KERNEL);
+ vq->heads = kmalloc(sizeof *vq->heads * UIO_MAXIOV, GFP_KERNEL);
+ if (!vq->indirect || !vq->log || !vq->heads)
goto err_nomem;
}
return 0;
+
err_nomem:
- for (; i >= 0; --i) {
- kfree(dev->vqs[i].indirect);
- kfree(dev->vqs[i].log);
- kfree(dev->vqs[i].heads);
- }
+ for (; i >= 0; --i)
+ vhost_vq_free_iovecs(dev->vqs[i]);
return -ENOMEM;
}
static void vhost_dev_free_iovecs(struct vhost_dev *dev)
{
int i;
- for (i = 0; i < dev->nvqs; ++i) {
- kfree(dev->vqs[i].indirect);
- dev->vqs[i].indirect = NULL;
- kfree(dev->vqs[i].log);
- dev->vqs[i].log = NULL;
- kfree(dev->vqs[i].heads);
- dev->vqs[i].heads = NULL;
- }
+
+ for (i = 0; i < dev->nvqs; ++i)
+ vhost_vq_free_iovecs(dev->vqs[i]);
}
-long vhost_dev_init(struct vhost_dev *dev,
- struct vhost_virtqueue *vqs, int nvqs)
+void vhost_dev_init(struct vhost_dev *dev,
+ struct vhost_virtqueue **vqs, int nvqs)
{
+ struct vhost_virtqueue *vq;
int i;
dev->vqs = vqs;
@@ -268,19 +309,19 @@ long vhost_dev_init(struct vhost_dev *dev,
dev->worker = NULL;
for (i = 0; i < dev->nvqs; ++i) {
- dev->vqs[i].log = NULL;
- dev->vqs[i].indirect = NULL;
- dev->vqs[i].heads = NULL;
- dev->vqs[i].dev = dev;
- mutex_init(&dev->vqs[i].mutex);
- vhost_vq_reset(dev, dev->vqs + i);
- if (dev->vqs[i].handle_kick)
- vhost_poll_init(&dev->vqs[i].poll,
- dev->vqs[i].handle_kick, POLLIN, dev);
+ vq = dev->vqs[i];
+ vq->log = NULL;
+ vq->indirect = NULL;
+ vq->heads = NULL;
+ vq->dev = dev;
+ mutex_init(&vq->mutex);
+ vhost_vq_reset(dev, vq);
+ if (vq->handle_kick)
+ vhost_poll_init(&vq->poll, vq->handle_kick,
+ POLLIN, dev);
}
-
- return 0;
}
+EXPORT_SYMBOL_GPL(vhost_dev_init);
/* Caller should have device mutex */
long vhost_dev_check_owner(struct vhost_dev *dev)
@@ -288,40 +329,52 @@ long vhost_dev_check_owner(struct vhost_dev *dev)
/* Are you the owner? If not, I don't think you mean to do that */
return dev->mm == current->mm ? 0 : -EPERM;
}
+EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
struct vhost_attach_cgroups_struct {
- struct vhost_work work;
- struct task_struct *owner;
- int ret;
+ struct vhost_work work;
+ struct task_struct *owner;
+ int ret;
};
static void vhost_attach_cgroups_work(struct vhost_work *work)
{
- struct vhost_attach_cgroups_struct *s;
- s = container_of(work, struct vhost_attach_cgroups_struct, work);
- s->ret = cgroup_attach_task_all(s->owner, current);
+ struct vhost_attach_cgroups_struct *s;
+
+ s = container_of(work, struct vhost_attach_cgroups_struct, work);
+ s->ret = cgroup_attach_task_all(s->owner, current);
}
static int vhost_attach_cgroups(struct vhost_dev *dev)
{
- struct vhost_attach_cgroups_struct attach;
- attach.owner = current;
- vhost_work_init(&attach.work, vhost_attach_cgroups_work);
- vhost_work_queue(dev, &attach.work);
- vhost_work_flush(dev, &attach.work);
- return attach.ret;
+ struct vhost_attach_cgroups_struct attach;
+
+ attach.owner = current;
+ vhost_work_init(&attach.work, vhost_attach_cgroups_work);
+ vhost_work_queue(dev, &attach.work);
+ vhost_work_flush(dev, &attach.work);
+ return attach.ret;
+}
+
+/* Caller should have device mutex */
+bool vhost_dev_has_owner(struct vhost_dev *dev)
+{
+ return dev->mm;
}
+EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
/* Caller should have device mutex */
-static long vhost_dev_set_owner(struct vhost_dev *dev)
+long vhost_dev_set_owner(struct vhost_dev *dev)
{
struct task_struct *worker;
int err;
+
/* Is there an owner already? */
- if (dev->mm) {
+ if (vhost_dev_has_owner(dev)) {
err = -EBUSY;
goto err_mm;
}
+
/* No owner, become one */
dev->mm = get_task_mm(current);
worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
@@ -352,44 +405,62 @@ err_worker:
err_mm:
return err;
}
+EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
-/* Caller should have device mutex */
-long vhost_dev_reset_owner(struct vhost_dev *dev)
+struct vhost_memory *vhost_dev_reset_owner_prepare(void)
{
- struct vhost_memory *memory;
+ return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL);
+}
+EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
- /* Restore memory to default empty mapping. */
- memory = kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL);
- if (!memory)
- return -ENOMEM;
+/* Caller should have device mutex */
+void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
+{
+ int i;
- vhost_dev_cleanup(dev);
+ vhost_dev_cleanup(dev, true);
+ /* Restore memory to default empty mapping. */
memory->nregions = 0;
- RCU_INIT_POINTER(dev->memory, memory);
- return 0;
+ dev->memory = memory;
+ /* We don't need VQ locks below since vhost_dev_cleanup makes sure
+ * VQs aren't running.
+ */
+ for (i = 0; i < dev->nvqs; ++i)
+ dev->vqs[i]->memory = memory;
}
+EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
-/* Caller should have device mutex */
-void vhost_dev_cleanup(struct vhost_dev *dev)
+void vhost_dev_stop(struct vhost_dev *dev)
{
int i;
+
for (i = 0; i < dev->nvqs; ++i) {
- if (dev->vqs[i].kick && dev->vqs[i].handle_kick) {
- vhost_poll_stop(&dev->vqs[i].poll);
- vhost_poll_flush(&dev->vqs[i].poll);
+ if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) {
+ vhost_poll_stop(&dev->vqs[i]->poll);
+ vhost_poll_flush(&dev->vqs[i]->poll);
}
- if (dev->vqs[i].error_ctx)
- eventfd_ctx_put(dev->vqs[i].error_ctx);
- if (dev->vqs[i].error)
- fput(dev->vqs[i].error);
- if (dev->vqs[i].kick)
- fput(dev->vqs[i].kick);
- if (dev->vqs[i].call_ctx)
- eventfd_ctx_put(dev->vqs[i].call_ctx);
- if (dev->vqs[i].call)
- fput(dev->vqs[i].call);
- vhost_vq_reset(dev, dev->vqs + i);
+ }
+}
+EXPORT_SYMBOL_GPL(vhost_dev_stop);
+
+/* Caller should have device mutex if and only if locked is set */
+void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
+{
+ int i;
+
+ for (i = 0; i < dev->nvqs; ++i) {
+ if (dev->vqs[i]->error_ctx)
+ eventfd_ctx_put(dev->vqs[i]->error_ctx);
+ if (dev->vqs[i]->error)
+ fput(dev->vqs[i]->error);
+ if (dev->vqs[i]->kick)
+ fput(dev->vqs[i]->kick);
+ if (dev->vqs[i]->call_ctx)
+ eventfd_ctx_put(dev->vqs[i]->call_ctx);
+ if (dev->vqs[i]->call)
+ fput(dev->vqs[i]->call);
+ vhost_vq_reset(dev, dev->vqs[i]);
}
vhost_dev_free_iovecs(dev);
if (dev->log_ctx)
@@ -399,23 +470,23 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
fput(dev->log_file);
dev->log_file = NULL;
/* No one will access memory at this point */
- kfree(rcu_dereference_protected(dev->memory,
- lockdep_is_held(&dev->mutex)));
- RCU_INIT_POINTER(dev->memory, NULL);
- if (dev->mm)
- mmput(dev->mm);
- dev->mm = NULL;
-
+ kfree(dev->memory);
+ dev->memory = NULL;
WARN_ON(!list_empty(&dev->work_list));
if (dev->worker) {
kthread_stop(dev->worker);
dev->worker = NULL;
}
+ if (dev->mm)
+ mmput(dev->mm);
+ dev->mm = NULL;
}
+EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
{
u64 a = addr / VHOST_PAGE_SIZE / 8;
+
/* Make sure 64 bit math will not overflow. */
if (a > ULONG_MAX - (unsigned long)log_base ||
a + (unsigned long)log_base > ULONG_MAX)
@@ -456,72 +527,75 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
int log_all)
{
int i;
+
for (i = 0; i < d->nvqs; ++i) {
int ok;
- mutex_lock(&d->vqs[i].mutex);
+ bool log;
+
+ mutex_lock(&d->vqs[i]->mutex);
+ log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
/* If ring is inactive, will check when it's enabled. */
- if (d->vqs[i].private_data)
- ok = vq_memory_access_ok(d->vqs[i].log_base, mem,
- log_all);
+ if (d->vqs[i]->private_data)
+ ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log);
else
ok = 1;
- mutex_unlock(&d->vqs[i].mutex);
+ mutex_unlock(&d->vqs[i]->mutex);
if (!ok)
return 0;
}
return 1;
}
-static int vq_access_ok(unsigned int num,
+static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
struct vring_desc __user *desc,
struct vring_avail __user *avail,
struct vring_used __user *used)
{
+ size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
access_ok(VERIFY_READ, avail,
- sizeof *avail + num * sizeof *avail->ring) &&
+ sizeof *avail + num * sizeof *avail->ring + s) &&
access_ok(VERIFY_WRITE, used,
- sizeof *used + num * sizeof *used->ring);
+ sizeof *used + num * sizeof *used->ring + s);
}
/* Can we log writes? */
/* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev)
{
- struct vhost_memory *mp;
-
- mp = rcu_dereference_protected(dev->memory,
- lockdep_is_held(&dev->mutex));
- return memory_access_ok(dev, mp, 1);
+ return memory_access_ok(dev, dev->memory, 1);
}
+EXPORT_SYMBOL_GPL(vhost_log_access_ok);
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
-static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base)
+static int vq_log_access_ok(struct vhost_virtqueue *vq,
+ void __user *log_base)
{
- struct vhost_memory *mp;
+ size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
- mp = rcu_dereference_protected(vq->dev->memory,
- lockdep_is_held(&vq->mutex));
- return vq_memory_access_ok(log_base, mp,
- vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
+ return vq_memory_access_ok(log_base, vq->memory,
+ vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
sizeof *vq->used +
- vq->num * sizeof *vq->used->ring));
+ vq->num * sizeof *vq->used->ring + s));
}
/* Can we start vq? */
/* Caller should have vq mutex and device mutex */
int vhost_vq_access_ok(struct vhost_virtqueue *vq)
{
- return vq_access_ok(vq->num, vq->desc, vq->avail, vq->used) &&
+ return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
vq_log_access_ok(vq, vq->log_base);
}
+EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
{
struct vhost_memory mem, *newmem, *oldmem;
unsigned long size = offsetof(struct vhost_memory, regions);
+ int i;
+
if (copy_from_user(&mem, m, size))
return -EFAULT;
if (mem.padding)
@@ -539,31 +613,27 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
return -EFAULT;
}
- if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL))) {
+ if (!memory_access_ok(d, newmem, 0)) {
kfree(newmem);
return -EFAULT;
}
- oldmem = rcu_dereference_protected(d->memory,
- lockdep_is_held(&d->mutex));
- rcu_assign_pointer(d->memory, newmem);
- synchronize_rcu();
+ oldmem = d->memory;
+ d->memory = newmem;
+
+ /* All memory accesses are done under some VQ mutex. */
+ for (i = 0; i < d->nvqs; ++i) {
+ mutex_lock(&d->vqs[i]->mutex);
+ d->vqs[i]->memory = newmem;
+ mutex_unlock(&d->vqs[i]->mutex);
+ }
kfree(oldmem);
return 0;
}
-static int init_used(struct vhost_virtqueue *vq,
- struct vring_used __user *used)
+long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
{
- int r = put_user(vq->used_flags, &used->flags);
- if (r)
- return r;
- return get_user(vq->last_used_idx, &used->idx);
-}
-
-static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
-{
- struct file *eventfp, *filep = NULL,
- *pollstart = NULL, *pollstop = NULL;
+ struct file *eventfp, *filep = NULL;
+ bool pollstart = false, pollstop = false;
struct eventfd_ctx *ctx = NULL;
u32 __user *idxp = argp;
struct vhost_virtqueue *vq;
@@ -579,7 +649,7 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
if (idx >= d->nvqs)
return -ENOBUFS;
- vq = d->vqs + idx;
+ vq = d->vqs[idx];
mutex_lock(&vq->mutex);
@@ -654,7 +724,7 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
* If it is not, we don't as size might not have been setup.
* We will verify when backend is configured. */
if (vq->private_data) {
- if (!vq_access_ok(vq->num,
+ if (!vq_access_ok(vq, vq->num,
(void __user *)(unsigned long)a.desc_user_addr,
(void __user *)(unsigned long)a.avail_user_addr,
(void __user *)(unsigned long)a.used_user_addr)) {
@@ -672,10 +742,6 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
}
}
- r = init_used(vq, (struct vring_used __user *)(unsigned long)
- a.used_user_addr);
- if (r)
- break;
vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
@@ -693,8 +759,8 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
break;
}
if (eventfp != vq->kick) {
- pollstop = filep = vq->kick;
- pollstart = vq->kick = eventfp;
+ pollstop = (filep = vq->kick) != NULL;
+ pollstart = (vq->kick = eventfp) != NULL;
} else
filep = eventfp;
break;
@@ -749,7 +815,7 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
fput(filep);
if (pollstart && vq->handle_kick)
- vhost_poll_start(&vq->poll, vq->kick);
+ r = vhost_poll_start(&vq->poll, vq->kick);
mutex_unlock(&vq->mutex);
@@ -757,11 +823,11 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
vhost_poll_flush(&vq->poll);
return r;
}
+EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
/* Caller must have device mutex */
-long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
+long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
{
- void __user *argp = (void __user *)arg;
struct file *eventfp, *filep = NULL;
struct eventfd_ctx *ctx = NULL;
u64 p;
@@ -795,7 +861,7 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
for (i = 0; i < d->nvqs; ++i) {
struct vhost_virtqueue *vq;
void __user *base = (void __user *)(unsigned long)p;
- vq = d->vqs + i;
+ vq = d->vqs[i];
mutex_lock(&vq->mutex);
/* If ring is inactive, will check when it's enabled. */
if (vq->private_data && !vq_log_access_ok(vq, base))
@@ -822,9 +888,9 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
} else
filep = eventfp;
for (i = 0; i < d->nvqs; ++i) {
- mutex_lock(&d->vqs[i].mutex);
- d->vqs[i].log_ctx = d->log_ctx;
- mutex_unlock(&d->vqs[i].mutex);
+ mutex_lock(&d->vqs[i]->mutex);
+ d->vqs[i]->log_ctx = d->log_ctx;
+ mutex_unlock(&d->vqs[i]->mutex);
}
if (ctx)
eventfd_ctx_put(ctx);
@@ -832,18 +898,20 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
fput(filep);
break;
default:
- r = vhost_set_vring(d, ioctl, argp);
+ r = -ENOIOCTLCMD;
break;
}
done:
return r;
}
+EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
__u64 addr, __u32 len)
{
struct vhost_memory_region *reg;
int i;
+
/* linear search is not brilliant, but we really have on the order of 6
* regions in practice */
for (i = 0; i < mem->nregions; ++i) {
@@ -866,13 +934,14 @@ static int set_bit_to_user(int nr, void __user *addr)
void *base;
int bit = nr + (log % PAGE_SIZE) * 8;
int r;
+
r = get_user_pages_fast(log, 1, 1, &page);
if (r < 0)
return r;
BUG_ON(r != 1);
- base = kmap_atomic(page, KM_USER0);
+ base = kmap_atomic(page);
set_bit(bit, base);
- kunmap_atomic(base, KM_USER0);
+ kunmap_atomic(base);
set_page_dirty_lock(page);
put_page(page);
return 0;
@@ -881,14 +950,16 @@ static int set_bit_to_user(int nr, void __user *addr)
static int log_write(void __user *log_base,
u64 write_address, u64 write_length)
{
+ u64 write_page = write_address / VHOST_PAGE_SIZE;
int r;
+
if (!write_length)
return 0;
- write_address /= VHOST_PAGE_SIZE;
+ write_length += write_address % VHOST_PAGE_SIZE;
for (;;) {
u64 base = (u64)(unsigned long)log_base;
- u64 log = base + write_address / 8;
- int bit = write_address % 8;
+ u64 log = base + write_page / 8;
+ int bit = write_page % 8;
if ((u64)(unsigned long)log != log)
return -EFAULT;
r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
@@ -897,7 +968,7 @@ static int log_write(void __user *log_base,
if (write_length <= VHOST_PAGE_SIZE)
break;
write_length -= VHOST_PAGE_SIZE;
- write_address += VHOST_PAGE_SIZE;
+ write_page += 1;
}
return r;
}
@@ -925,8 +996,61 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
BUG();
return 0;
}
+EXPORT_SYMBOL_GPL(vhost_log_write);
+
+static int vhost_update_used_flags(struct vhost_virtqueue *vq)
+{
+ void __user *used;
+ if (__put_user(vq->used_flags, &vq->used->flags) < 0)
+ return -EFAULT;
+ if (unlikely(vq->log_used)) {
+ /* Make sure the flag is seen before log. */
+ smp_wmb();
+ /* Log used flag write. */
+ used = &vq->used->flags;
+ log_write(vq->log_base, vq->log_addr +
+ (used - (void __user *)vq->used),
+ sizeof vq->used->flags);
+ if (vq->log_ctx)
+ eventfd_signal(vq->log_ctx, 1);
+ }
+ return 0;
+}
+
+static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
+{
+ if (__put_user(vq->avail_idx, vhost_avail_event(vq)))
+ return -EFAULT;
+ if (unlikely(vq->log_used)) {
+ void __user *used;
+ /* Make sure the event is seen before log. */
+ smp_wmb();
+ /* Log avail event write */
+ used = vhost_avail_event(vq);
+ log_write(vq->log_base, vq->log_addr +
+ (used - (void __user *)vq->used),
+ sizeof *vhost_avail_event(vq));
+ if (vq->log_ctx)
+ eventfd_signal(vq->log_ctx, 1);
+ }
+ return 0;
+}
-static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
+int vhost_init_used(struct vhost_virtqueue *vq)
+{
+ int r;
+ if (!vq->private_data)
+ return 0;
+
+ r = vhost_update_used_flags(vq);
+ if (r)
+ return r;
+ vq->signalled_used_valid = false;
+ return get_user(vq->last_used_idx, &vq->used->idx);
+}
+EXPORT_SYMBOL_GPL(vhost_init_used);
+
+static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size)
{
const struct vhost_memory_region *reg;
@@ -935,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
u64 s = 0;
int ret = 0;
- rcu_read_lock();
-
- mem = rcu_dereference(dev->memory);
+ mem = vq->memory;
while ((u64)len > s) {
u64 size;
if (unlikely(ret >= iov_size)) {
@@ -951,7 +1073,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
}
_iov = iov + ret;
size = reg->memory_size - addr + reg->guest_phys_addr;
- _iov->iov_len = min((u64)len, size);
+ _iov->iov_len = min((u64)len - s, size);
_iov->iov_base = (void __user *)(unsigned long)
(reg->userspace_addr + addr - reg->guest_phys_addr);
s += size;
@@ -959,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
++ret;
}
- rcu_read_unlock();
return ret;
}
@@ -984,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
return next;
}
-static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
+static int get_indirect(struct vhost_virtqueue *vq,
struct iovec iov[], unsigned int iov_size,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num,
@@ -1003,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EINVAL;
}
- ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
+ ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
UIO_MAXIOV);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d in indirect.\n", ret);
@@ -1031,8 +1152,8 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
i, count);
return -EINVAL;
}
- if (unlikely(memcpy_fromiovec((unsigned char *)&desc, vq->indirect,
- sizeof desc))) {
+ if (unlikely(memcpy_fromiovec((unsigned char *)&desc,
+ vq->indirect, sizeof desc))) {
vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
i, (size_t)indirect->addr + i * sizeof desc);
return -EINVAL;
@@ -1043,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EINVAL;
}
- ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
+ ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
iov_size - iov_count);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d indirect idx %d\n",
@@ -1080,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
* This function returns the descriptor number found, or vq->num (which is
* never a valid descriptor number) if none was found. A negative code is
* returned on error. */
-int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
+int vhost_get_vq_desc(struct vhost_virtqueue *vq,
struct iovec iov[], unsigned int iov_size,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num)
@@ -1092,7 +1213,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
/* Check it isn't doing very strange things with descriptor numbers. */
last_avail_idx = vq->last_avail_idx;
- if (unlikely(get_user(vq->avail_idx, &vq->avail->idx))) {
+ if (unlikely(__get_user(vq->avail_idx, &vq->avail->idx))) {
vq_err(vq, "Failed to access avail idx at %p\n",
&vq->avail->idx);
return -EFAULT;
@@ -1113,8 +1234,8 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
/* Grab the next descriptor number they're advertising, and increment
* the index we've seen. */
- if (unlikely(get_user(head,
- &vq->avail->ring[last_avail_idx % vq->num]))) {
+ if (unlikely(__get_user(head,
+ &vq->avail->ring[last_avail_idx % vq->num]))) {
vq_err(vq, "Failed to read head: idx %d address %p\n",
last_avail_idx,
&vq->avail->ring[last_avail_idx % vq->num]);
@@ -1147,14 +1268,14 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
i, vq->num, head);
return -EINVAL;
}
- ret = copy_from_user(&desc, vq->desc + i, sizeof desc);
+ ret = __copy_from_user(&desc, vq->desc + i, sizeof desc);
if (unlikely(ret)) {
vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
i, vq->desc + i);
return -EFAULT;
}
if (desc.flags & VRING_DESC_F_INDIRECT) {
- ret = get_indirect(dev, vq, iov, iov_size,
+ ret = get_indirect(vq, iov, iov_size,
out_num, in_num,
log, log_num, &desc);
if (unlikely(ret < 0)) {
@@ -1165,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
continue;
}
- ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
+ ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
iov_size - iov_count);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d descriptor idx %d\n",
@@ -1195,67 +1316,51 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
/* On success, increment avail index. */
vq->last_avail_idx++;
+
+ /* Assume notifications from guest are disabled at this point,
+ * if they aren't we would need to update avail_event index. */
+ BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
return head;
}
+EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
{
vq->last_avail_idx -= n;
}
+EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
/* After we've used one of their buffers, we tell them about it. We'll then
* want to notify the guest, using eventfd. */
int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
{
- struct vring_used_elem __user *used;
+ struct vring_used_elem heads = { head, len };
- /* The virtqueue contains a ring of used buffers. Get a pointer to the
- * next entry in that used ring. */
- used = &vq->used->ring[vq->last_used_idx % vq->num];
- if (put_user(head, &used->id)) {
- vq_err(vq, "Failed to write used id");
- return -EFAULT;
- }
- if (put_user(len, &used->len)) {
- vq_err(vq, "Failed to write used len");
- return -EFAULT;
- }
- /* Make sure buffer is written before we update index. */
- smp_wmb();
- if (put_user(vq->last_used_idx + 1, &vq->used->idx)) {
- vq_err(vq, "Failed to increment used idx");
- return -EFAULT;
- }
- if (unlikely(vq->log_used)) {
- /* Make sure data is seen before log. */
- smp_wmb();
- /* Log used ring entry write. */
- log_write(vq->log_base,
- vq->log_addr +
- ((void __user *)used - (void __user *)vq->used),
- sizeof *used);
- /* Log used index update. */
- log_write(vq->log_base,
- vq->log_addr + offsetof(struct vring_used, idx),
- sizeof vq->used->idx);
- if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
- }
- vq->last_used_idx++;
- return 0;
+ return vhost_add_used_n(vq, &heads, 1);
}
+EXPORT_SYMBOL_GPL(vhost_add_used);
static int __vhost_add_used_n(struct vhost_virtqueue *vq,
struct vring_used_elem *heads,
unsigned count)
{
struct vring_used_elem __user *used;
+ u16 old, new;
int start;
start = vq->last_used_idx % vq->num;
used = vq->used->ring + start;
- if (copy_to_user(used, heads, count * sizeof *used)) {
+ if (count == 1) {
+ if (__put_user(heads[0].id, &used->id)) {
+ vq_err(vq, "Failed to write used id");
+ return -EFAULT;
+ }
+ if (__put_user(heads[0].len, &used->len)) {
+ vq_err(vq, "Failed to write used len");
+ return -EFAULT;
+ }
+ } else if (__copy_to_user(used, heads, count * sizeof *used)) {
vq_err(vq, "Failed to write used");
return -EFAULT;
}
@@ -1268,7 +1373,14 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
((void __user *)used - (void __user *)vq->used),
count * sizeof *used);
}
- vq->last_used_idx += count;
+ old = vq->last_used_idx;
+ new = (vq->last_used_idx += count);
+ /* If the driver never bothers to signal in a very long while,
+ * used index might wrap around. If that happens, invalidate
+ * signalled_used index we stored. TODO: make sure driver
+ * signals at least once in 2^16 and remove this. */
+ if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
+ vq->signalled_used_valid = false;
return 0;
}
@@ -1306,31 +1418,52 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
}
return r;
}
+EXPORT_SYMBOL_GPL(vhost_add_used_n);
-/* This actually signals the guest, using eventfd. */
-void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
+static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{
- __u16 flags;
+ __u16 old, new, event;
+ bool v;
/* Flush out used index updates. This is paired
* with the barrier that the Guest executes when enabling
* interrupts. */
smp_mb();
- if (get_user(flags, &vq->avail->flags)) {
- vq_err(vq, "Failed to get flags");
- return;
+ if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
+ unlikely(vq->avail_idx == vq->last_avail_idx))
+ return true;
+
+ if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
+ __u16 flags;
+ if (__get_user(flags, &vq->avail->flags)) {
+ vq_err(vq, "Failed to get flags");
+ return true;
+ }
+ return !(flags & VRING_AVAIL_F_NO_INTERRUPT);
}
+ old = vq->signalled_used;
+ v = vq->signalled_used_valid;
+ new = vq->signalled_used = vq->last_used_idx;
+ vq->signalled_used_valid = true;
- /* If they don't want an interrupt, don't signal, unless empty. */
- if ((flags & VRING_AVAIL_F_NO_INTERRUPT) &&
- (vq->avail_idx != vq->last_avail_idx ||
- !vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY)))
- return;
+ if (unlikely(!v))
+ return true;
+
+ if (get_user(event, vhost_used_event(vq))) {
+ vq_err(vq, "Failed to get used event idx");
+ return true;
+ }
+ return vring_need_event(event, new, old);
+}
+/* This actually signals the guest, using eventfd. */
+void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
+{
/* Signal the Guest tell them we used something up. */
- if (vq->call_ctx)
+ if (vq->call_ctx && vhost_notify(dev, vq))
eventfd_signal(vq->call_ctx, 1);
}
+EXPORT_SYMBOL_GPL(vhost_signal);
/* And here's the combo meal deal. Supersize me! */
void vhost_add_used_and_signal(struct vhost_dev *dev,
@@ -1340,6 +1473,7 @@ void vhost_add_used_and_signal(struct vhost_dev *dev,
vhost_add_used(vq, head, len);
vhost_signal(dev, vq);
}
+EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
/* multi-buffer version of vhost_add_used_and_signal */
void vhost_add_used_and_signal_n(struct vhost_dev *dev,
@@ -1349,25 +1483,36 @@ void vhost_add_used_and_signal_n(struct vhost_dev *dev,
vhost_add_used_n(vq, heads, count);
vhost_signal(dev, vq);
}
+EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
/* OK, now we need to know about added descriptors. */
-bool vhost_enable_notify(struct vhost_virtqueue *vq)
+bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{
u16 avail_idx;
int r;
+
if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
return false;
vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
- r = put_user(vq->used_flags, &vq->used->flags);
- if (r) {
- vq_err(vq, "Failed to enable notification at %p: %d\n",
- &vq->used->flags, r);
- return false;
+ if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
+ r = vhost_update_used_flags(vq);
+ if (r) {
+ vq_err(vq, "Failed to enable notification at %p: %d\n",
+ &vq->used->flags, r);
+ return false;
+ }
+ } else {
+ r = vhost_update_avail_event(vq, vq->avail_idx);
+ if (r) {
+ vq_err(vq, "Failed to update avail event index at %p: %d\n",
+ vhost_avail_event(vq), r);
+ return false;
+ }
}
/* They could have slipped one in as we were doing that: make
* sure it's written, then check again. */
smp_mb();
- r = get_user(avail_idx, &vq->avail->idx);
+ r = __get_user(avail_idx, &vq->avail->idx);
if (r) {
vq_err(vq, "Failed to check avail idx at %p: %d\n",
&vq->avail->idx, r);
@@ -1376,16 +1521,38 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)
return avail_idx != vq->avail_idx;
}
+EXPORT_SYMBOL_GPL(vhost_enable_notify);
/* We don't need to be notified again. */
-void vhost_disable_notify(struct vhost_virtqueue *vq)
+void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{
int r;
+
if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
return;
vq->used_flags |= VRING_USED_F_NO_NOTIFY;
- r = put_user(vq->used_flags, &vq->used->flags);
- if (r)
- vq_err(vq, "Failed to enable notification at %p: %d\n",
- &vq->used->flags, r);
+ if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
+ r = vhost_update_used_flags(vq);
+ if (r)
+ vq_err(vq, "Failed to enable notification at %p: %d\n",
+ &vq->used->flags, r);
+ }
+}
+EXPORT_SYMBOL_GPL(vhost_disable_notify);
+
+static int __init vhost_init(void)
+{
+ return 0;
}
+
+static void __exit vhost_exit(void)
+{
+}
+
+module_init(vhost_init);
+module_exit(vhost_exit);
+
+MODULE_VERSION("0.0.1");
+MODULE_LICENSE("GPL v2");
+MODULE_AUTHOR("Michael S. Tsirkin");
+MODULE_DESCRIPTION("Host kernel accelerator for virtio");
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 073d06ae091..3eda654b8f5 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -7,11 +7,10 @@
#include <linux/mutex.h>
#include <linux/poll.h>
#include <linux/file.h>
-#include <linux/skbuff.h>
#include <linux/uio.h>
#include <linux/virtio_config.h>
#include <linux/virtio_ring.h>
-#include <asm/atomic.h>
+#include <linux/atomic.h>
struct vhost_device;
@@ -38,18 +37,25 @@ struct vhost_poll {
struct vhost_dev *dev;
};
+void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn);
+void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work);
+
void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
unsigned long mask, struct vhost_dev *dev);
-void vhost_poll_start(struct vhost_poll *poll, struct file *file);
+int vhost_poll_start(struct vhost_poll *poll, struct file *file);
void vhost_poll_stop(struct vhost_poll *poll);
void vhost_poll_flush(struct vhost_poll *poll);
void vhost_poll_queue(struct vhost_poll *poll);
+void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work);
+long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);
struct vhost_log {
u64 addr;
u64 len;
};
+struct vhost_virtqueue;
+
/* The virtqueue structure describes a queue attached to a device. */
struct vhost_virtqueue {
struct vhost_dev *dev;
@@ -84,41 +90,33 @@ struct vhost_virtqueue {
/* Used flags */
u16 used_flags;
+ /* Last used index value we have signalled on */
+ u16 signalled_used;
+
+ /* Last used index value we have signalled on */
+ bool signalled_used_valid;
+
/* Log writes to used structure. */
bool log_used;
u64 log_addr;
struct iovec iov[UIO_MAXIOV];
- /* hdr is used to store the virtio header.
- * Since each iovec has >= 1 byte length, we never need more than
- * header length entries to store the header. */
- struct iovec hdr[sizeof(struct virtio_net_hdr_mrg_rxbuf)];
struct iovec *indirect;
- size_t vhost_hlen;
- size_t sock_hlen;
struct vring_used_elem *heads;
- /* We use a kind of RCU to access private pointer.
- * All readers access it from worker, which makes it possible to
- * flush the vhost_work instead of synchronize_rcu. Therefore readers do
- * not need to call rcu_read_lock/rcu_read_unlock: the beginning of
- * vhost_work execution acts instead of rcu_read_lock() and the end of
- * vhost_work execution acts instead of rcu_read_lock().
- * Writers use virtqueue mutex. */
- void __rcu *private_data;
+ /* Protected by virtqueue mutex. */
+ struct vhost_memory *memory;
+ void *private_data;
+ unsigned acked_features;
/* Log write descriptors */
void __user *log_base;
struct vhost_log *log;
};
struct vhost_dev {
- /* Readers use RCU to access memory table pointer
- * log base pointer and features.
- * Writers use mutex below.*/
- struct vhost_memory __rcu *memory;
+ struct vhost_memory *memory;
struct mm_struct *mm;
struct mutex mutex;
- unsigned acked_features;
- struct vhost_virtqueue *vqs;
+ struct vhost_virtqueue **vqs;
int nvqs;
struct file *log_file;
struct eventfd_ctx *log_ctx;
@@ -127,20 +125,26 @@ struct vhost_dev {
struct task_struct *worker;
};
-long vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue *vqs, int nvqs);
+void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
+long vhost_dev_set_owner(struct vhost_dev *dev);
+bool vhost_dev_has_owner(struct vhost_dev *dev);
long vhost_dev_check_owner(struct vhost_dev *);
-long vhost_dev_reset_owner(struct vhost_dev *);
-void vhost_dev_cleanup(struct vhost_dev *);
-long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, unsigned long arg);
+struct vhost_memory *vhost_dev_reset_owner_prepare(void);
+void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *);
+void vhost_dev_cleanup(struct vhost_dev *, bool locked);
+void vhost_dev_stop(struct vhost_dev *);
+long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp);
+long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);
int vhost_vq_access_ok(struct vhost_virtqueue *vq);
int vhost_log_access_ok(struct vhost_dev *);
-int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
+int vhost_get_vq_desc(struct vhost_virtqueue *,
struct iovec iov[], unsigned int iov_count,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num);
void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
+int vhost_init_used(struct vhost_virtqueue *);
int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
unsigned count);
@@ -149,8 +153,8 @@ void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *,
void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *,
struct vring_used_elem *heads, unsigned count);
void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *);
-void vhost_disable_notify(struct vhost_virtqueue *);
-bool vhost_enable_notify(struct vhost_virtqueue *);
+void vhost_disable_notify(struct vhost_dev *, struct vhost_virtqueue *);
+bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
unsigned int log_num, u64 len);
@@ -162,21 +166,14 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
} while (0)
enum {
- VHOST_FEATURES = (1 << VIRTIO_F_NOTIFY_ON_EMPTY) |
- (1 << VIRTIO_RING_F_INDIRECT_DESC) |
- (1 << VHOST_F_LOG_ALL) |
- (1 << VHOST_NET_F_VIRTIO_NET_HDR) |
- (1 << VIRTIO_NET_F_MRG_RXBUF),
+ VHOST_FEATURES = (1ULL << VIRTIO_F_NOTIFY_ON_EMPTY) |
+ (1ULL << VIRTIO_RING_F_INDIRECT_DESC) |
+ (1ULL << VIRTIO_RING_F_EVENT_IDX) |
+ (1ULL << VHOST_F_LOG_ALL),
};
-static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
+static inline int vhost_has_feature(struct vhost_virtqueue *vq, int bit)
{
- unsigned acked_features;
-
- acked_features =
- rcu_dereference_index_check(dev->acked_features,
- lockdep_is_held(&dev->mutex));
- return acked_features & (1 << bit);
+ return vq->acked_features & (1 << bit);
}
-
#endif
diff --git a/drivers/vhost/vringh.c b/drivers/vhost/vringh.c
new file mode 100644
index 00000000000..5174ebac288
--- /dev/null
+++ b/drivers/vhost/vringh.c
@@ -0,0 +1,1010 @@
+/*
+ * Helpers for the host side of a virtio ring.
+ *
+ * Since these may be in userspace, we use (inline) accessors.
+ */
+#include <linux/module.h>
+#include <linux/vringh.h>
+#include <linux/virtio_ring.h>
+#include <linux/kernel.h>
+#include <linux/ratelimit.h>
+#include <linux/uaccess.h>
+#include <linux/slab.h>
+#include <linux/export.h>
+
+static __printf(1,2) __cold void vringh_bad(const char *fmt, ...)
+{
+ static DEFINE_RATELIMIT_STATE(vringh_rs,
+ DEFAULT_RATELIMIT_INTERVAL,
+ DEFAULT_RATELIMIT_BURST);
+ if (__ratelimit(&vringh_rs)) {
+ va_list ap;
+ va_start(ap, fmt);
+ printk(KERN_NOTICE "vringh:");
+ vprintk(fmt, ap);
+ va_end(ap);
+ }
+}
+
+/* Returns vring->num if empty, -ve on error. */
+static inline int __vringh_get_head(const struct vringh *vrh,
+ int (*getu16)(u16 *val, const u16 *p),
+ u16 *last_avail_idx)
+{
+ u16 avail_idx, i, head;
+ int err;
+
+ err = getu16(&avail_idx, &vrh->vring.avail->idx);
+ if (err) {
+ vringh_bad("Failed to access avail idx at %p",
+ &vrh->vring.avail->idx);
+ return err;
+ }
+
+ if (*last_avail_idx == avail_idx)
+ return vrh->vring.num;
+
+ /* Only get avail ring entries after they have been exposed by guest. */
+ virtio_rmb(vrh->weak_barriers);
+
+ i = *last_avail_idx & (vrh->vring.num - 1);
+
+ err = getu16(&head, &vrh->vring.avail->ring[i]);
+ if (err) {
+ vringh_bad("Failed to read head: idx %d address %p",
+ *last_avail_idx, &vrh->vring.avail->ring[i]);
+ return err;
+ }
+
+ if (head >= vrh->vring.num) {
+ vringh_bad("Guest says index %u > %u is available",
+ head, vrh->vring.num);
+ return -EINVAL;
+ }
+
+ (*last_avail_idx)++;
+ return head;
+}
+
+/* Copy some bytes to/from the iovec. Returns num copied. */
+static inline ssize_t vringh_iov_xfer(struct vringh_kiov *iov,
+ void *ptr, size_t len,
+ int (*xfer)(void *addr, void *ptr,
+ size_t len))
+{
+ int err, done = 0;
+
+ while (len && iov->i < iov->used) {
+ size_t partlen;
+
+ partlen = min(iov->iov[iov->i].iov_len, len);
+ err = xfer(iov->iov[iov->i].iov_base, ptr, partlen);
+ if (err)
+ return err;
+ done += partlen;
+ len -= partlen;
+ ptr += partlen;
+ iov->consumed += partlen;
+ iov->iov[iov->i].iov_len -= partlen;
+ iov->iov[iov->i].iov_base += partlen;
+
+ if (!iov->iov[iov->i].iov_len) {
+ /* Fix up old iov element then increment. */
+ iov->iov[iov->i].iov_len = iov->consumed;
+ iov->iov[iov->i].iov_base -= iov->consumed;
+
+ iov->consumed = 0;
+ iov->i++;
+ }
+ }
+ return done;
+}
+
+/* May reduce *len if range is shorter. */
+static inline bool range_check(struct vringh *vrh, u64 addr, size_t *len,
+ struct vringh_range *range,
+ bool (*getrange)(struct vringh *,
+ u64, struct vringh_range *))
+{
+ if (addr < range->start || addr > range->end_incl) {
+ if (!getrange(vrh, addr, range))
+ return false;
+ }
+ BUG_ON(addr < range->start || addr > range->end_incl);
+
+ /* To end of memory? */
+ if (unlikely(addr + *len == 0)) {
+ if (range->end_incl == -1ULL)
+ return true;
+ goto truncate;
+ }
+
+ /* Otherwise, don't wrap. */
+ if (addr + *len < addr) {
+ vringh_bad("Wrapping descriptor %zu@0x%llx",
+ *len, (unsigned long long)addr);
+ return false;
+ }
+
+ if (unlikely(addr + *len - 1 > range->end_incl))
+ goto truncate;
+ return true;
+
+truncate:
+ *len = range->end_incl + 1 - addr;
+ return true;
+}
+
+static inline bool no_range_check(struct vringh *vrh, u64 addr, size_t *len,
+ struct vringh_range *range,
+ bool (*getrange)(struct vringh *,
+ u64, struct vringh_range *))
+{
+ return true;
+}
+
+/* No reason for this code to be inline. */
+static int move_to_indirect(int *up_next, u16 *i, void *addr,
+ const struct vring_desc *desc,
+ struct vring_desc **descs, int *desc_max)
+{
+ /* Indirect tables can't have indirect. */
+ if (*up_next != -1) {
+ vringh_bad("Multilevel indirect %u->%u", *up_next, *i);
+ return -EINVAL;
+ }
+
+ if (unlikely(desc->len % sizeof(struct vring_desc))) {
+ vringh_bad("Strange indirect len %u", desc->len);
+ return -EINVAL;
+ }
+
+ /* We will check this when we follow it! */
+ if (desc->flags & VRING_DESC_F_NEXT)
+ *up_next = desc->next;
+ else
+ *up_next = -2;
+ *descs = addr;
+ *desc_max = desc->len / sizeof(struct vring_desc);
+
+ /* Now, start at the first indirect. */
+ *i = 0;
+ return 0;
+}
+
+static int resize_iovec(struct vringh_kiov *iov, gfp_t gfp)
+{
+ struct kvec *new;
+ unsigned int flag, new_num = (iov->max_num & ~VRINGH_IOV_ALLOCATED) * 2;
+
+ if (new_num < 8)
+ new_num = 8;
+
+ flag = (iov->max_num & VRINGH_IOV_ALLOCATED);
+ if (flag)
+ new = krealloc(iov->iov, new_num * sizeof(struct iovec), gfp);
+ else {
+ new = kmalloc(new_num * sizeof(struct iovec), gfp);
+ if (new) {
+ memcpy(new, iov->iov,
+ iov->max_num * sizeof(struct iovec));
+ flag = VRINGH_IOV_ALLOCATED;
+ }
+ }
+ if (!new)
+ return -ENOMEM;
+ iov->iov = new;
+ iov->max_num = (new_num | flag);
+ return 0;
+}
+
+static u16 __cold return_from_indirect(const struct vringh *vrh, int *up_next,
+ struct vring_desc **descs, int *desc_max)
+{
+ u16 i = *up_next;
+
+ *up_next = -1;
+ *descs = vrh->vring.desc;
+ *desc_max = vrh->vring.num;
+ return i;
+}
+
+static int slow_copy(struct vringh *vrh, void *dst, const void *src,
+ bool (*rcheck)(struct vringh *vrh, u64 addr, size_t *len,
+ struct vringh_range *range,
+ bool (*getrange)(struct vringh *vrh,
+ u64,
+ struct vringh_range *)),
+ bool (*getrange)(struct vringh *vrh,
+ u64 addr,
+ struct vringh_range *r),
+ struct vringh_range *range,
+ int (*copy)(void *dst, const void *src, size_t len))
+{
+ size_t part, len = sizeof(struct vring_desc);
+
+ do {
+ u64 addr;
+ int err;
+
+ part = len;
+ addr = (u64)(unsigned long)src - range->offset;
+
+ if (!rcheck(vrh, addr, &part, range, getrange))
+ return -EINVAL;
+
+ err = copy(dst, src, part);
+ if (err)
+ return err;
+
+ dst += part;
+ src += part;
+ len -= part;
+ } while (len);
+ return 0;
+}
+
+static inline int
+__vringh_iov(struct vringh *vrh, u16 i,
+ struct vringh_kiov *riov,
+ struct vringh_kiov *wiov,
+ bool (*rcheck)(struct vringh *vrh, u64 addr, size_t *len,
+ struct vringh_range *range,
+ bool (*getrange)(struct vringh *, u64,
+ struct vringh_range *)),
+ bool (*getrange)(struct vringh *, u64, struct vringh_range *),
+ gfp_t gfp,
+ int (*copy)(void *dst, const void *src, size_t len))
+{
+ int err, count = 0, up_next, desc_max;
+ struct vring_desc desc, *descs;
+ struct vringh_range range = { -1ULL, 0 }, slowrange;
+ bool slow = false;
+
+ /* We start traversing vring's descriptor table. */
+ descs = vrh->vring.desc;
+ desc_max = vrh->vring.num;
+ up_next = -1;
+
+ if (riov)
+ riov->i = riov->used = 0;
+ else if (wiov)
+ wiov->i = wiov->used = 0;
+ else
+ /* You must want something! */
+ BUG();
+
+ for (;;) {
+ void *addr;
+ struct vringh_kiov *iov;
+ size_t len;
+
+ if (unlikely(slow))
+ err = slow_copy(vrh, &desc, &descs[i], rcheck, getrange,
+ &slowrange, copy);
+ else
+ err = copy(&desc, &descs[i], sizeof(desc));
+ if (unlikely(err))
+ goto fail;
+
+ if (unlikely(desc.flags & VRING_DESC_F_INDIRECT)) {
+ /* Make sure it's OK, and get offset. */
+ len = desc.len;
+ if (!rcheck(vrh, desc.addr, &len, &range, getrange)) {
+ err = -EINVAL;
+ goto fail;
+ }
+
+ if (unlikely(len != desc.len)) {
+ slow = true;
+ /* We need to save this range to use offset */
+ slowrange = range;
+ }
+
+ addr = (void *)(long)(desc.addr + range.offset);
+ err = move_to_indirect(&up_next, &i, addr, &desc,
+ &descs, &desc_max);
+ if (err)
+ goto fail;
+ continue;
+ }
+
+ if (count++ == vrh->vring.num) {
+ vringh_bad("Descriptor loop in %p", descs);
+ err = -ELOOP;
+ goto fail;
+ }
+
+ if (desc.flags & VRING_DESC_F_WRITE)
+ iov = wiov;
+ else {
+ iov = riov;
+ if (unlikely(wiov && wiov->i)) {
+ vringh_bad("Readable desc %p after writable",
+ &descs[i]);
+ err = -EINVAL;
+ goto fail;
+ }
+ }
+
+ if (!iov) {
+ vringh_bad("Unexpected %s desc",
+ !wiov ? "writable" : "readable");
+ err = -EPROTO;
+ goto fail;
+ }
+
+ again:
+ /* Make sure it's OK, and get offset. */
+ len = desc.len;
+ if (!rcheck(vrh, desc.addr, &len, &range, getrange)) {
+ err = -EINVAL;
+ goto fail;
+ }
+ addr = (void *)(unsigned long)(desc.addr + range.offset);
+
+ if (unlikely(iov->used == (iov->max_num & ~VRINGH_IOV_ALLOCATED))) {
+ err = resize_iovec(iov, gfp);
+ if (err)
+ goto fail;
+ }
+
+ iov->iov[iov->used].iov_base = addr;
+ iov->iov[iov->used].iov_len = len;
+ iov->used++;
+
+ if (unlikely(len != desc.len)) {
+ desc.len -= len;
+ desc.addr += len;
+ goto again;
+ }
+
+ if (desc.flags & VRING_DESC_F_NEXT) {
+ i = desc.next;
+ } else {
+ /* Just in case we need to finish traversing above. */
+ if (unlikely(up_next > 0)) {
+ i = return_from_indirect(vrh, &up_next,
+ &descs, &desc_max);
+ slow = false;
+ } else
+ break;
+ }
+
+ if (i >= desc_max) {
+ vringh_bad("Chained index %u > %u", i, desc_max);
+ err = -EINVAL;
+ goto fail;
+ }
+ }
+
+ return 0;
+
+fail:
+ return err;
+}
+
+static inline int __vringh_complete(struct vringh *vrh,
+ const struct vring_used_elem *used,
+ unsigned int num_used,
+ int (*putu16)(u16 *p, u16 val),
+ int (*putused)(struct vring_used_elem *dst,
+ const struct vring_used_elem
+ *src, unsigned num))
+{
+ struct vring_used *used_ring;
+ int err;
+ u16 used_idx, off;
+
+ used_ring = vrh->vring.used;
+ used_idx = vrh->last_used_idx + vrh->completed;
+
+ off = used_idx % vrh->vring.num;
+
+ /* Compiler knows num_used == 1 sometimes, hence extra check */
+ if (num_used > 1 && unlikely(off + num_used >= vrh->vring.num)) {
+ u16 part = vrh->vring.num - off;
+ err = putused(&used_ring->ring[off], used, part);
+ if (!err)
+ err = putused(&used_ring->ring[0], used + part,
+ num_used - part);
+ } else
+ err = putused(&used_ring->ring[off], used, num_used);
+
+ if (err) {
+ vringh_bad("Failed to write %u used entries %u at %p",
+ num_used, off, &used_ring->ring[off]);
+ return err;
+ }
+
+ /* Make sure buffer is written before we update index. */
+ virtio_wmb(vrh->weak_barriers);
+
+ err = putu16(&vrh->vring.used->idx, used_idx + num_used);
+ if (err) {
+ vringh_bad("Failed to update used index at %p",
+ &vrh->vring.used->idx);
+ return err;
+ }
+
+ vrh->completed += num_used;
+ return 0;
+}
+
+
+static inline int __vringh_need_notify(struct vringh *vrh,
+ int (*getu16)(u16 *val, const u16 *p))
+{
+ bool notify;
+ u16 used_event;
+ int err;
+
+ /* Flush out used index update. This is paired with the
+ * barrier that the Guest executes when enabling
+ * interrupts. */
+ virtio_mb(vrh->weak_barriers);
+
+ /* Old-style, without event indices. */
+ if (!vrh->event_indices) {
+ u16 flags;
+ err = getu16(&flags, &vrh->vring.avail->flags);
+ if (err) {
+ vringh_bad("Failed to get flags at %p",
+ &vrh->vring.avail->flags);
+ return err;
+ }
+ return (!(flags & VRING_AVAIL_F_NO_INTERRUPT));
+ }
+
+ /* Modern: we know when other side wants to know. */
+ err = getu16(&used_event, &vring_used_event(&vrh->vring));
+ if (err) {
+ vringh_bad("Failed to get used event idx at %p",
+ &vring_used_event(&vrh->vring));
+ return err;
+ }
+
+ /* Just in case we added so many that we wrap. */
+ if (unlikely(vrh->completed > 0xffff))
+ notify = true;
+ else
+ notify = vring_need_event(used_event,
+ vrh->last_used_idx + vrh->completed,
+ vrh->last_used_idx);
+
+ vrh->last_used_idx += vrh->completed;
+ vrh->completed = 0;
+ return notify;
+}
+
+static inline bool __vringh_notify_enable(struct vringh *vrh,
+ int (*getu16)(u16 *val, const u16 *p),
+ int (*putu16)(u16 *p, u16 val))
+{
+ u16 avail;
+
+ if (!vrh->event_indices) {
+ /* Old-school; update flags. */
+ if (putu16(&vrh->vring.used->flags, 0) != 0) {
+ vringh_bad("Clearing used flags %p",
+ &vrh->vring.used->flags);
+ return true;
+ }
+ } else {
+ if (putu16(&vring_avail_event(&vrh->vring),
+ vrh->last_avail_idx) != 0) {
+ vringh_bad("Updating avail event index %p",
+ &vring_avail_event(&vrh->vring));
+ return true;
+ }
+ }
+
+ /* They could have slipped one in as we were doing that: make
+ * sure it's written, then check again. */
+ virtio_mb(vrh->weak_barriers);
+
+ if (getu16(&avail, &vrh->vring.avail->idx) != 0) {
+ vringh_bad("Failed to check avail idx at %p",
+ &vrh->vring.avail->idx);
+ return true;
+ }
+
+ /* This is unlikely, so we just leave notifications enabled
+ * (if we're using event_indices, we'll only get one
+ * notification anyway). */
+ return avail == vrh->last_avail_idx;
+}
+
+static inline void __vringh_notify_disable(struct vringh *vrh,
+ int (*putu16)(u16 *p, u16 val))
+{
+ if (!vrh->event_indices) {
+ /* Old-school; update flags. */
+ if (putu16(&vrh->vring.used->flags, VRING_USED_F_NO_NOTIFY)) {
+ vringh_bad("Setting used flags %p",
+ &vrh->vring.used->flags);
+ }
+ }
+}
+
+/* Userspace access helpers: in this case, addresses are really userspace. */
+static inline int getu16_user(u16 *val, const u16 *p)
+{
+ return get_user(*val, (__force u16 __user *)p);
+}
+
+static inline int putu16_user(u16 *p, u16 val)
+{
+ return put_user(val, (__force u16 __user *)p);
+}
+
+static inline int copydesc_user(void *dst, const void *src, size_t len)
+{
+ return copy_from_user(dst, (__force void __user *)src, len) ?
+ -EFAULT : 0;
+}
+
+static inline int putused_user(struct vring_used_elem *dst,
+ const struct vring_used_elem *src,
+ unsigned int num)
+{
+ return copy_to_user((__force void __user *)dst, src,
+ sizeof(*dst) * num) ? -EFAULT : 0;
+}
+
+static inline int xfer_from_user(void *src, void *dst, size_t len)
+{
+ return copy_from_user(dst, (__force void __user *)src, len) ?
+ -EFAULT : 0;
+}
+
+static inline int xfer_to_user(void *dst, void *src, size_t len)
+{
+ return copy_to_user((__force void __user *)dst, src, len) ?
+ -EFAULT : 0;
+}
+
+/**
+ * vringh_init_user - initialize a vringh for a userspace vring.
+ * @vrh: the vringh to initialize.
+ * @features: the feature bits for this ring.
+ * @num: the number of elements.
+ * @weak_barriers: true if we only need memory barriers, not I/O.
+ * @desc: the userpace descriptor pointer.
+ * @avail: the userpace avail pointer.
+ * @used: the userpace used pointer.
+ *
+ * Returns an error if num is invalid: you should check pointers
+ * yourself!
+ */
+int vringh_init_user(struct vringh *vrh, u32 features,
+ unsigned int num, bool weak_barriers,
+ struct vring_desc __user *desc,
+ struct vring_avail __user *avail,
+ struct vring_used __user *used)
+{
+ /* Sane power of 2 please! */
+ if (!num || num > 0xffff || (num & (num - 1))) {
+ vringh_bad("Bad ring size %u", num);
+ return -EINVAL;
+ }
+
+ vrh->event_indices = (features & (1 << VIRTIO_RING_F_EVENT_IDX));
+ vrh->weak_barriers = weak_barriers;
+ vrh->completed = 0;
+ vrh->last_avail_idx = 0;
+ vrh->last_used_idx = 0;
+ vrh->vring.num = num;
+ /* vring expects kernel addresses, but only used via accessors. */
+ vrh->vring.desc = (__force struct vring_desc *)desc;
+ vrh->vring.avail = (__force struct vring_avail *)avail;
+ vrh->vring.used = (__force struct vring_used *)used;
+ return 0;
+}
+EXPORT_SYMBOL(vringh_init_user);
+
+/**
+ * vringh_getdesc_user - get next available descriptor from userspace ring.
+ * @vrh: the userspace vring.
+ * @riov: where to put the readable descriptors (or NULL)
+ * @wiov: where to put the writable descriptors (or NULL)
+ * @getrange: function to call to check ranges.
+ * @head: head index we received, for passing to vringh_complete_user().
+ *
+ * Returns 0 if there was no descriptor, 1 if there was, or -errno.
+ *
+ * Note that on error return, you can tell the difference between an
+ * invalid ring and a single invalid descriptor: in the former case,
+ * *head will be vrh->vring.num. You may be able to ignore an invalid
+ * descriptor, but there's not much you can do with an invalid ring.
+ *
+ * Note that you may need to clean up riov and wiov, even on error!
+ */
+int vringh_getdesc_user(struct vringh *vrh,
+ struct vringh_iov *riov,
+ struct vringh_iov *wiov,
+ bool (*getrange)(struct vringh *vrh,
+ u64 addr, struct vringh_range *r),
+ u16 *head)
+{
+ int err;
+
+ *head = vrh->vring.num;
+ err = __vringh_get_head(vrh, getu16_user, &vrh->last_avail_idx);
+ if (err < 0)
+ return err;
+
+ /* Empty... */
+ if (err == vrh->vring.num)
+ return 0;
+
+ /* We need the layouts to be the identical for this to work */
+ BUILD_BUG_ON(sizeof(struct vringh_kiov) != sizeof(struct vringh_iov));
+ BUILD_BUG_ON(offsetof(struct vringh_kiov, iov) !=
+ offsetof(struct vringh_iov, iov));
+ BUILD_BUG_ON(offsetof(struct vringh_kiov, i) !=
+ offsetof(struct vringh_iov, i));
+ BUILD_BUG_ON(offsetof(struct vringh_kiov, used) !=
+ offsetof(struct vringh_iov, used));
+ BUILD_BUG_ON(offsetof(struct vringh_kiov, max_num) !=
+ offsetof(struct vringh_iov, max_num));
+ BUILD_BUG_ON(sizeof(struct iovec) != sizeof(struct kvec));
+ BUILD_BUG_ON(offsetof(struct iovec, iov_base) !=
+ offsetof(struct kvec, iov_base));
+ BUILD_BUG_ON(offsetof(struct iovec, iov_len) !=
+ offsetof(struct kvec, iov_len));
+ BUILD_BUG_ON(sizeof(((struct iovec *)NULL)->iov_base)
+ != sizeof(((struct kvec *)NULL)->iov_base));
+ BUILD_BUG_ON(sizeof(((struct iovec *)NULL)->iov_len)
+ != sizeof(((struct kvec *)NULL)->iov_len));
+
+ *head = err;
+ err = __vringh_iov(vrh, *head, (struct vringh_kiov *)riov,
+ (struct vringh_kiov *)wiov,
+ range_check, getrange, GFP_KERNEL, copydesc_user);
+ if (err)
+ return err;
+
+ return 1;
+}
+EXPORT_SYMBOL(vringh_getdesc_user);
+
+/**
+ * vringh_iov_pull_user - copy bytes from vring_iov.
+ * @riov: the riov as passed to vringh_getdesc_user() (updated as we consume)
+ * @dst: the place to copy.
+ * @len: the maximum length to copy.
+ *
+ * Returns the bytes copied <= len or a negative errno.
+ */
+ssize_t vringh_iov_pull_user(struct vringh_iov *riov, void *dst, size_t len)
+{
+ return vringh_iov_xfer((struct vringh_kiov *)riov,
+ dst, len, xfer_from_user);
+}
+EXPORT_SYMBOL(vringh_iov_pull_user);
+
+/**
+ * vringh_iov_push_user - copy bytes into vring_iov.
+ * @wiov: the wiov as passed to vringh_getdesc_user() (updated as we consume)
+ * @dst: the place to copy.
+ * @len: the maximum length to copy.
+ *
+ * Returns the bytes copied <= len or a negative errno.
+ */
+ssize_t vringh_iov_push_user(struct vringh_iov *wiov,
+ const void *src, size_t len)
+{
+ return vringh_iov_xfer((struct vringh_kiov *)wiov,
+ (void *)src, len, xfer_to_user);
+}
+EXPORT_SYMBOL(vringh_iov_push_user);
+
+/**
+ * vringh_abandon_user - we've decided not to handle the descriptor(s).
+ * @vrh: the vring.
+ * @num: the number of descriptors to put back (ie. num
+ * vringh_get_user() to undo).
+ *
+ * The next vringh_get_user() will return the old descriptor(s) again.
+ */
+void vringh_abandon_user(struct vringh *vrh, unsigned int num)
+{
+ /* We only update vring_avail_event(vr) when we want to be notified,
+ * so we haven't changed that yet. */
+ vrh->last_avail_idx -= num;
+}
+EXPORT_SYMBOL(vringh_abandon_user);
+
+/**
+ * vringh_complete_user - we've finished with descriptor, publish it.
+ * @vrh: the vring.
+ * @head: the head as filled in by vringh_getdesc_user.
+ * @len: the length of data we have written.
+ *
+ * You should check vringh_need_notify_user() after one or more calls
+ * to this function.
+ */
+int vringh_complete_user(struct vringh *vrh, u16 head, u32 len)
+{
+ struct vring_used_elem used;
+
+ used.id = head;
+ used.len = len;
+ return __vringh_complete(vrh, &used, 1, putu16_user, putused_user);
+}
+EXPORT_SYMBOL(vringh_complete_user);
+
+/**
+ * vringh_complete_multi_user - we've finished with many descriptors.
+ * @vrh: the vring.
+ * @used: the head, length pairs.
+ * @num_used: the number of used elements.
+ *
+ * You should check vringh_need_notify_user() after one or more calls
+ * to this function.
+ */
+int vringh_complete_multi_user(struct vringh *vrh,
+ const struct vring_used_elem used[],
+ unsigned num_used)
+{
+ return __vringh_complete(vrh, used, num_used,
+ putu16_user, putused_user);
+}
+EXPORT_SYMBOL(vringh_complete_multi_user);
+
+/**
+ * vringh_notify_enable_user - we want to know if something changes.
+ * @vrh: the vring.
+ *
+ * This always enables notifications, but returns false if there are
+ * now more buffers available in the vring.
+ */
+bool vringh_notify_enable_user(struct vringh *vrh)
+{
+ return __vringh_notify_enable(vrh, getu16_user, putu16_user);
+}
+EXPORT_SYMBOL(vringh_notify_enable_user);
+
+/**
+ * vringh_notify_disable_user - don't tell us if something changes.
+ * @vrh: the vring.
+ *
+ * This is our normal running state: we disable and then only enable when
+ * we're going to sleep.
+ */
+void vringh_notify_disable_user(struct vringh *vrh)
+{
+ __vringh_notify_disable(vrh, putu16_user);
+}
+EXPORT_SYMBOL(vringh_notify_disable_user);
+
+/**
+ * vringh_need_notify_user - must we tell the other side about used buffers?
+ * @vrh: the vring we've called vringh_complete_user() on.
+ *
+ * Returns -errno or 0 if we don't need to tell the other side, 1 if we do.
+ */
+int vringh_need_notify_user(struct vringh *vrh)
+{
+ return __vringh_need_notify(vrh, getu16_user);
+}
+EXPORT_SYMBOL(vringh_need_notify_user);
+
+/* Kernelspace access helpers. */
+static inline int getu16_kern(u16 *val, const u16 *p)
+{
+ *val = ACCESS_ONCE(*p);
+ return 0;
+}
+
+static inline int putu16_kern(u16 *p, u16 val)
+{
+ ACCESS_ONCE(*p) = val;
+ return 0;
+}
+
+static inline int copydesc_kern(void *dst, const void *src, size_t len)
+{
+ memcpy(dst, src, len);
+ return 0;
+}
+
+static inline int putused_kern(struct vring_used_elem *dst,
+ const struct vring_used_elem *src,
+ unsigned int num)
+{
+ memcpy(dst, src, num * sizeof(*dst));
+ return 0;
+}
+
+static inline int xfer_kern(void *src, void *dst, size_t len)
+{
+ memcpy(dst, src, len);
+ return 0;
+}
+
+/**
+ * vringh_init_kern - initialize a vringh for a kernelspace vring.
+ * @vrh: the vringh to initialize.
+ * @features: the feature bits for this ring.
+ * @num: the number of elements.
+ * @weak_barriers: true if we only need memory barriers, not I/O.
+ * @desc: the userpace descriptor pointer.
+ * @avail: the userpace avail pointer.
+ * @used: the userpace used pointer.
+ *
+ * Returns an error if num is invalid.
+ */
+int vringh_init_kern(struct vringh *vrh, u32 features,
+ unsigned int num, bool weak_barriers,
+ struct vring_desc *desc,
+ struct vring_avail *avail,
+ struct vring_used *used)
+{
+ /* Sane power of 2 please! */
+ if (!num || num > 0xffff || (num & (num - 1))) {
+ vringh_bad("Bad ring size %u", num);
+ return -EINVAL;
+ }
+
+ vrh->event_indices = (features & (1 << VIRTIO_RING_F_EVENT_IDX));
+ vrh->weak_barriers = weak_barriers;
+ vrh->completed = 0;
+ vrh->last_avail_idx = 0;
+ vrh->last_used_idx = 0;
+ vrh->vring.num = num;
+ vrh->vring.desc = desc;
+ vrh->vring.avail = avail;
+ vrh->vring.used = used;
+ return 0;
+}
+EXPORT_SYMBOL(vringh_init_kern);
+
+/**
+ * vringh_getdesc_kern - get next available descriptor from kernelspace ring.
+ * @vrh: the kernelspace vring.
+ * @riov: where to put the readable descriptors (or NULL)
+ * @wiov: where to put the writable descriptors (or NULL)
+ * @head: head index we received, for passing to vringh_complete_kern().
+ * @gfp: flags for allocating larger riov/wiov.
+ *
+ * Returns 0 if there was no descriptor, 1 if there was, or -errno.
+ *
+ * Note that on error return, you can tell the difference between an
+ * invalid ring and a single invalid descriptor: in the former case,
+ * *head will be vrh->vring.num. You may be able to ignore an invalid
+ * descriptor, but there's not much you can do with an invalid ring.
+ *
+ * Note that you may need to clean up riov and wiov, even on error!
+ */
+int vringh_getdesc_kern(struct vringh *vrh,
+ struct vringh_kiov *riov,
+ struct vringh_kiov *wiov,
+ u16 *head,
+ gfp_t gfp)
+{
+ int err;
+
+ err = __vringh_get_head(vrh, getu16_kern, &vrh->last_avail_idx);
+ if (err < 0)
+ return err;
+
+ /* Empty... */
+ if (err == vrh->vring.num)
+ return 0;
+
+ *head = err;
+ err = __vringh_iov(vrh, *head, riov, wiov, no_range_check, NULL,
+ gfp, copydesc_kern);
+ if (err)
+ return err;
+
+ return 1;
+}
+EXPORT_SYMBOL(vringh_getdesc_kern);
+
+/**
+ * vringh_iov_pull_kern - copy bytes from vring_iov.
+ * @riov: the riov as passed to vringh_getdesc_kern() (updated as we consume)
+ * @dst: the place to copy.
+ * @len: the maximum length to copy.
+ *
+ * Returns the bytes copied <= len or a negative errno.
+ */
+ssize_t vringh_iov_pull_kern(struct vringh_kiov *riov, void *dst, size_t len)
+{
+ return vringh_iov_xfer(riov, dst, len, xfer_kern);
+}
+EXPORT_SYMBOL(vringh_iov_pull_kern);
+
+/**
+ * vringh_iov_push_kern - copy bytes into vring_iov.
+ * @wiov: the wiov as passed to vringh_getdesc_kern() (updated as we consume)
+ * @dst: the place to copy.
+ * @len: the maximum length to copy.
+ *
+ * Returns the bytes copied <= len or a negative errno.
+ */
+ssize_t vringh_iov_push_kern(struct vringh_kiov *wiov,
+ const void *src, size_t len)
+{
+ return vringh_iov_xfer(wiov, (void *)src, len, xfer_kern);
+}
+EXPORT_SYMBOL(vringh_iov_push_kern);
+
+/**
+ * vringh_abandon_kern - we've decided not to handle the descriptor(s).
+ * @vrh: the vring.
+ * @num: the number of descriptors to put back (ie. num
+ * vringh_get_kern() to undo).
+ *
+ * The next vringh_get_kern() will return the old descriptor(s) again.
+ */
+void vringh_abandon_kern(struct vringh *vrh, unsigned int num)
+{
+ /* We only update vring_avail_event(vr) when we want to be notified,
+ * so we haven't changed that yet. */
+ vrh->last_avail_idx -= num;
+}
+EXPORT_SYMBOL(vringh_abandon_kern);
+
+/**
+ * vringh_complete_kern - we've finished with descriptor, publish it.
+ * @vrh: the vring.
+ * @head: the head as filled in by vringh_getdesc_kern.
+ * @len: the length of data we have written.
+ *
+ * You should check vringh_need_notify_kern() after one or more calls
+ * to this function.
+ */
+int vringh_complete_kern(struct vringh *vrh, u16 head, u32 len)
+{
+ struct vring_used_elem used;
+
+ used.id = head;
+ used.len = len;
+
+ return __vringh_complete(vrh, &used, 1, putu16_kern, putused_kern);
+}
+EXPORT_SYMBOL(vringh_complete_kern);
+
+/**
+ * vringh_notify_enable_kern - we want to know if something changes.
+ * @vrh: the vring.
+ *
+ * This always enables notifications, but returns false if there are
+ * now more buffers available in the vring.
+ */
+bool vringh_notify_enable_kern(struct vringh *vrh)
+{
+ return __vringh_notify_enable(vrh, getu16_kern, putu16_kern);
+}
+EXPORT_SYMBOL(vringh_notify_enable_kern);
+
+/**
+ * vringh_notify_disable_kern - don't tell us if something changes.
+ * @vrh: the vring.
+ *
+ * This is our normal running state: we disable and then only enable when
+ * we're going to sleep.
+ */
+void vringh_notify_disable_kern(struct vringh *vrh)
+{
+ __vringh_notify_disable(vrh, putu16_kern);
+}
+EXPORT_SYMBOL(vringh_notify_disable_kern);
+
+/**
+ * vringh_need_notify_kern - must we tell the other side about used buffers?
+ * @vrh: the vring we've called vringh_complete_kern() on.
+ *
+ * Returns -errno or 0 if we don't need to tell the other side, 1 if we do.
+ */
+int vringh_need_notify_kern(struct vringh *vrh)
+{
+ return __vringh_need_notify(vrh, getu16_kern);
+}
+EXPORT_SYMBOL(vringh_need_notify_kern);
+
+MODULE_LICENSE("GPL");