aboutsummaryrefslogtreecommitdiff
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c115
-rw-r--r--drivers/vhost/scsi.c393
-rw-r--r--drivers/vhost/test.c19
-rw-r--r--drivers/vhost/vhost.c101
-rw-r--r--drivers/vhost/vhost.h21
5 files changed, 393 insertions, 256 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 831eb4fd197..8dae2f724a3 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -17,6 +17,7 @@
#include <linux/workqueue.h>
#include <linux/file.h>
#include <linux/slab.h>
+#include <linux/vmalloc.h>
#include <linux/net.h>
#include <linux/if_packet.h>
@@ -70,7 +71,12 @@ enum {
};
struct vhost_net_ubuf_ref {
- struct kref kref;
+ /* refcount follows semantics similar to kref:
+ * 0: object is released
+ * 1: no outstanding ubufs
+ * >1: outstanding ubufs
+ */
+ atomic_t refcount;
wait_queue_head_t wait;
struct vhost_virtqueue *vq;
};
@@ -116,14 +122,6 @@ static void vhost_net_enable_zcopy(int vq)
vhost_net_zcopy_mask |= 0x1 << vq;
}
-static void vhost_net_zerocopy_done_signal(struct kref *kref)
-{
- struct vhost_net_ubuf_ref *ubufs;
-
- ubufs = container_of(kref, struct vhost_net_ubuf_ref, kref);
- wake_up(&ubufs->wait);
-}
-
static struct vhost_net_ubuf_ref *
vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
{
@@ -134,21 +132,24 @@ vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
if (!ubufs)
return ERR_PTR(-ENOMEM);
- kref_init(&ubufs->kref);
+ atomic_set(&ubufs->refcount, 1);
init_waitqueue_head(&ubufs->wait);
ubufs->vq = vq;
return ubufs;
}
-static void vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
+static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
{
- kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal);
+ int r = atomic_sub_return(1, &ubufs->refcount);
+ if (unlikely(!r))
+ wake_up(&ubufs->wait);
+ return r;
}
static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
{
- kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal);
- wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount));
+ vhost_net_ubuf_put(ubufs);
+ wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
}
static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
@@ -306,23 +307,26 @@ static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
{
struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
struct vhost_virtqueue *vq = ubufs->vq;
- int cnt = atomic_read(&ubufs->kref.refcount);
+ int cnt;
+
+ rcu_read_lock_bh();
/* set len to mark this desc buffers done DMA */
vq->heads[ubuf->desc].len = success ?
VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
- vhost_net_ubuf_put(ubufs);
+ cnt = vhost_net_ubuf_put(ubufs);
/*
* Trigger polling thread if guest stopped submitting new buffers:
- * in this case, the refcount after decrement will eventually reach 1
- * so here it is 2.
+ * in this case, the refcount after decrement will eventually reach 1.
* We also trigger polling periodically after each 16 packets
* (the value 16 here is more or less arbitrary, it's tuned to trigger
* less than 10% of times).
*/
- if (cnt <= 2 || !(cnt % 16))
+ if (cnt <= 1 || !(cnt % 16))
vhost_poll_queue(&vq->poll);
+
+ rcu_read_unlock_bh();
}
/* Expects to be always run from workqueue - which acts as
@@ -370,7 +374,7 @@ static void handle_tx(struct vhost_net *net)
% UIO_MAXIOV == nvq->done_idx))
break;
- head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
+ head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov),
&out, &in,
NULL, NULL);
@@ -420,7 +424,7 @@ static void handle_tx(struct vhost_net *net)
msg.msg_control = ubuf;
msg.msg_controllen = sizeof(ubuf);
ubufs = nvq->ubufs;
- kref_get(&ubufs->kref);
+ atomic_inc(&ubufs->refcount);
nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
} else {
msg.msg_control = NULL;
@@ -502,9 +506,13 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
r = -ENOBUFS;
goto err;
}
- d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
+ r = vhost_get_vq_desc(vq, vq->iov + seg,
ARRAY_SIZE(vq->iov) - seg, &out,
&in, log, log_num);
+ if (unlikely(r < 0))
+ goto err;
+
+ d = r;
if (d == vq->num) {
r = 0;
goto err;
@@ -529,6 +537,12 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
*iovcount = seg;
if (unlikely(log))
*log_num = nlogs;
+
+ /* Detect overrun */
+ if (unlikely(datalen > 0)) {
+ r = UIO_MAXIOV + 1;
+ goto err;
+ }
return headcount;
err:
vhost_discard_vq_desc(vq, headcount);
@@ -571,9 +585,9 @@ static void handle_rx(struct vhost_net *net)
vhost_hlen = nvq->vhost_hlen;
sock_hlen = nvq->sock_hlen;
- vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
+ vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
vq->log : NULL;
- mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
+ mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
while ((sock_len = peek_head_len(sock->sk))) {
sock_len += sock_hlen;
@@ -584,6 +598,14 @@ static void handle_rx(struct vhost_net *net)
/* On error, stop handling until the next kick. */
if (unlikely(headcount < 0))
break;
+ /* On overrun, truncate and discard */
+ if (unlikely(headcount > UIO_MAXIOV)) {
+ msg.msg_iovlen = 1;
+ err = sock->ops->recvmsg(NULL, sock, &msg,
+ 1, MSG_DONTWAIT | MSG_TRUNC);
+ pr_debug("Discarded rx packet: len %zd\n", sock_len);
+ continue;
+ }
/* OK, now we need to know about added descriptors. */
if (!headcount) {
if (unlikely(vhost_enable_notify(&net->dev, vq))) {
@@ -680,16 +702,20 @@ static void handle_rx_net(struct vhost_work *work)
static int vhost_net_open(struct inode *inode, struct file *f)
{
- struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
+ struct vhost_net *n;
struct vhost_dev *dev;
struct vhost_virtqueue **vqs;
- int r, i;
+ int i;
- if (!n)
- return -ENOMEM;
+ n = kmalloc(sizeof *n, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
+ if (!n) {
+ n = vmalloc(sizeof *n);
+ if (!n)
+ return -ENOMEM;
+ }
vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
if (!vqs) {
- kfree(n);
+ kvfree(n);
return -ENOMEM;
}
@@ -706,12 +732,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)
n->vqs[i].vhost_hlen = 0;
n->vqs[i].sock_hlen = 0;
}
- r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
- if (r < 0) {
- kfree(n);
- kfree(vqs);
- return r;
- }
+ vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
@@ -785,7 +806,7 @@ static void vhost_net_flush(struct vhost_net *n)
vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
n->tx_flush = false;
- kref_init(&n->vqs[VHOST_NET_VQ_TX].ubufs->kref);
+ atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
}
}
@@ -802,14 +823,16 @@ static int vhost_net_release(struct inode *inode, struct file *f)
vhost_dev_cleanup(&n->dev, false);
vhost_net_vq_reset(n);
if (tx_sock)
- fput(tx_sock->file);
+ sockfd_put(tx_sock);
if (rx_sock)
- fput(rx_sock->file);
+ sockfd_put(rx_sock);
+ /* Make sure no callbacks are outstanding */
+ synchronize_rcu_bh();
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
vhost_net_flush(n);
kfree(n->dev.vqs);
- kfree(n);
+ kvfree(n);
return 0;
}
@@ -842,7 +865,7 @@ static struct socket *get_raw_socket(int fd)
}
return sock;
err:
- fput(sock->file);
+ sockfd_put(sock);
return ERR_PTR(r);
}
@@ -948,7 +971,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
if (oldsock) {
vhost_net_flush_vq(n, index);
- fput(oldsock->file);
+ sockfd_put(oldsock);
}
mutex_unlock(&n->dev.mutex);
@@ -960,7 +983,7 @@ err_used:
if (ubufs)
vhost_net_ubuf_put_wait_and_free(ubufs);
err_ubufs:
- fput(sock->file);
+ sockfd_put(sock);
err_vq:
mutex_unlock(&vq->mutex);
err:
@@ -991,9 +1014,9 @@ static long vhost_net_reset_owner(struct vhost_net *n)
done:
mutex_unlock(&n->dev.mutex);
if (tx_sock)
- fput(tx_sock->file);
+ sockfd_put(tx_sock);
if (rx_sock)
- fput(rx_sock->file);
+ sockfd_put(rx_sock);
return err;
}
@@ -1020,15 +1043,13 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
mutex_unlock(&n->dev.mutex);
return -EFAULT;
}
- n->dev.acked_features = features;
- smp_wmb();
for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
mutex_lock(&n->vqs[i].vq.mutex);
+ n->vqs[i].vq.acked_features = features;
n->vqs[i].vhost_hlen = vhost_hlen;
n->vqs[i].sock_hlen = sock_hlen;
mutex_unlock(&n->vqs[i].vq.mutex);
}
- vhost_net_flush(n);
mutex_unlock(&n->dev.mutex);
return 0;
}
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 592b31698fc..69906cacd04 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -57,7 +57,8 @@
#define TCM_VHOST_MAX_CDB_SIZE 32
#define TCM_VHOST_DEFAULT_TAGS 256
#define TCM_VHOST_PREALLOC_SGLS 2048
-#define TCM_VHOST_PREALLOC_PAGES 2048
+#define TCM_VHOST_PREALLOC_UPAGES 2048
+#define TCM_VHOST_PREALLOC_PROT_SGLS 512
struct vhost_scsi_inflight {
/* Wait for the flush operation to finish */
@@ -79,10 +80,12 @@ struct tcm_vhost_cmd {
u64 tvc_tag;
/* The number of scatterlists associated with this cmd */
u32 tvc_sgl_count;
+ u32 tvc_prot_sgl_count;
/* Saved unpacked SCSI LUN for tcm_vhost_submission_work() */
u32 tvc_lun;
/* Pointer to the SGL formatted memory from virtio-scsi */
struct scatterlist *tvc_sgl;
+ struct scatterlist *tvc_prot_sgl;
struct page **tvc_upages;
/* Pointer to response */
struct virtio_scsi_cmd_resp __user *tvc_resp;
@@ -166,7 +169,8 @@ enum {
};
enum {
- VHOST_SCSI_FEATURES = VHOST_FEATURES | (1ULL << VIRTIO_SCSI_F_HOTPLUG)
+ VHOST_SCSI_FEATURES = VHOST_FEATURES | (1ULL << VIRTIO_SCSI_F_HOTPLUG) |
+ (1ULL << VIRTIO_SCSI_F_T10_PI)
};
#define VHOST_SCSI_MAX_TARGET 256
@@ -456,12 +460,16 @@ static void tcm_vhost_release_cmd(struct se_cmd *se_cmd)
struct tcm_vhost_cmd *tv_cmd = container_of(se_cmd,
struct tcm_vhost_cmd, tvc_se_cmd);
struct se_session *se_sess = se_cmd->se_sess;
+ int i;
if (tv_cmd->tvc_sgl_count) {
- u32 i;
for (i = 0; i < tv_cmd->tvc_sgl_count; i++)
put_page(sg_page(&tv_cmd->tvc_sgl[i]));
}
+ if (tv_cmd->tvc_prot_sgl_count) {
+ for (i = 0; i < tv_cmd->tvc_prot_sgl_count; i++)
+ put_page(sg_page(&tv_cmd->tvc_prot_sgl[i]));
+ }
tcm_vhost_put_inflight(tv_cmd->inflight);
percpu_ida_free(&se_sess->sess_tag_pool, se_cmd->map_tag);
@@ -539,6 +547,11 @@ static void tcm_vhost_queue_tm_rsp(struct se_cmd *se_cmd)
return;
}
+static void tcm_vhost_aborted_task(struct se_cmd *se_cmd)
+{
+ return;
+}
+
static void tcm_vhost_free_evt(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)
{
vs->vs_events_nr--;
@@ -601,7 +614,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)
again:
vhost_disable_notify(&vs->dev, vq);
- head = vhost_get_vq_desc(&vs->dev, vq, vq->iov,
+ head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov), &out, &in,
NULL, NULL);
if (head < 0) {
@@ -708,16 +721,14 @@ static void vhost_scsi_complete_cmd_work(struct vhost_work *work)
}
static struct tcm_vhost_cmd *
-vhost_scsi_get_tag(struct vhost_virtqueue *vq,
- struct tcm_vhost_tpg *tpg,
- struct virtio_scsi_cmd_req *v_req,
- u32 exp_data_len,
- int data_direction)
+vhost_scsi_get_tag(struct vhost_virtqueue *vq, struct tcm_vhost_tpg *tpg,
+ unsigned char *cdb, u64 scsi_tag, u16 lun, u8 task_attr,
+ u32 exp_data_len, int data_direction)
{
struct tcm_vhost_cmd *cmd;
struct tcm_vhost_nexus *tv_nexus;
struct se_session *se_sess;
- struct scatterlist *sg;
+ struct scatterlist *sg, *prot_sg;
struct page **pages;
int tag;
@@ -728,22 +739,32 @@ vhost_scsi_get_tag(struct vhost_virtqueue *vq,
}
se_sess = tv_nexus->tvn_se_sess;
- tag = percpu_ida_alloc(&se_sess->sess_tag_pool, GFP_KERNEL);
+ tag = percpu_ida_alloc(&se_sess->sess_tag_pool, TASK_RUNNING);
+ if (tag < 0) {
+ pr_err("Unable to obtain tag for tcm_vhost_cmd\n");
+ return ERR_PTR(-ENOMEM);
+ }
+
cmd = &((struct tcm_vhost_cmd *)se_sess->sess_cmd_map)[tag];
sg = cmd->tvc_sgl;
+ prot_sg = cmd->tvc_prot_sgl;
pages = cmd->tvc_upages;
memset(cmd, 0, sizeof(struct tcm_vhost_cmd));
cmd->tvc_sgl = sg;
+ cmd->tvc_prot_sgl = prot_sg;
cmd->tvc_upages = pages;
cmd->tvc_se_cmd.map_tag = tag;
- cmd->tvc_tag = v_req->tag;
- cmd->tvc_task_attr = v_req->task_attr;
+ cmd->tvc_tag = scsi_tag;
+ cmd->tvc_lun = lun;
+ cmd->tvc_task_attr = task_attr;
cmd->tvc_exp_data_len = exp_data_len;
cmd->tvc_data_direction = data_direction;
cmd->tvc_nexus = tv_nexus;
cmd->inflight = tcm_vhost_get_inflight(vq);
+ memcpy(cmd->tvc_cdb, cdb, TCM_VHOST_MAX_CDB_SIZE);
+
return cmd;
}
@@ -757,35 +778,28 @@ vhost_scsi_map_to_sgl(struct tcm_vhost_cmd *tv_cmd,
struct scatterlist *sgl,
unsigned int sgl_count,
struct iovec *iov,
- int write)
+ struct page **pages,
+ bool write)
{
unsigned int npages = 0, pages_nr, offset, nbytes;
struct scatterlist *sg = sgl;
void __user *ptr = iov->iov_base;
size_t len = iov->iov_len;
- struct page **pages;
int ret, i;
- if (sgl_count > TCM_VHOST_PREALLOC_SGLS) {
- pr_err("vhost_scsi_map_to_sgl() psgl_count: %u greater than"
- " preallocated TCM_VHOST_PREALLOC_SGLS: %u\n",
- sgl_count, TCM_VHOST_PREALLOC_SGLS);
- return -ENOBUFS;
- }
-
pages_nr = iov_num_pages(iov);
- if (pages_nr > sgl_count)
+ if (pages_nr > sgl_count) {
+ pr_err("vhost_scsi_map_to_sgl() pages_nr: %u greater than"
+ " sgl_count: %u\n", pages_nr, sgl_count);
return -ENOBUFS;
-
- if (pages_nr > TCM_VHOST_PREALLOC_PAGES) {
+ }
+ if (pages_nr > TCM_VHOST_PREALLOC_UPAGES) {
pr_err("vhost_scsi_map_to_sgl() pages_nr: %u greater than"
- " preallocated TCM_VHOST_PREALLOC_PAGES: %u\n",
- pages_nr, TCM_VHOST_PREALLOC_PAGES);
+ " preallocated TCM_VHOST_PREALLOC_UPAGES: %u\n",
+ pages_nr, TCM_VHOST_PREALLOC_UPAGES);
return -ENOBUFS;
}
- pages = tv_cmd->tvc_upages;
-
ret = get_user_pages_fast((unsigned long)ptr, pages_nr, write, pages);
/* No pages were pinned */
if (ret < 0)
@@ -815,33 +829,32 @@ out:
static int
vhost_scsi_map_iov_to_sgl(struct tcm_vhost_cmd *cmd,
struct iovec *iov,
- unsigned int niov,
- int write)
+ int niov,
+ bool write)
{
- int ret;
- unsigned int i;
- u32 sgl_count;
- struct scatterlist *sg;
+ struct scatterlist *sg = cmd->tvc_sgl;
+ unsigned int sgl_count = 0;
+ int ret, i;
- /*
- * Find out how long sglist needs to be
- */
- sgl_count = 0;
for (i = 0; i < niov; i++)
sgl_count += iov_num_pages(&iov[i]);
- /* TODO overflow checking */
+ if (sgl_count > TCM_VHOST_PREALLOC_SGLS) {
+ pr_err("vhost_scsi_map_iov_to_sgl() sgl_count: %u greater than"
+ " preallocated TCM_VHOST_PREALLOC_SGLS: %u\n",
+ sgl_count, TCM_VHOST_PREALLOC_SGLS);
+ return -ENOBUFS;
+ }
- sg = cmd->tvc_sgl;
pr_debug("%s sg %p sgl_count %u\n", __func__, sg, sgl_count);
sg_init_table(sg, sgl_count);
-
cmd->tvc_sgl_count = sgl_count;
- pr_debug("Mapping %u iovecs for %u pages\n", niov, sgl_count);
+ pr_debug("Mapping iovec %p for %u pages\n", &iov[0], sgl_count);
+
for (i = 0; i < niov; i++) {
ret = vhost_scsi_map_to_sgl(cmd, sg, sgl_count, &iov[i],
- write);
+ cmd->tvc_upages, write);
if (ret < 0) {
for (i = 0; i < cmd->tvc_sgl_count; i++)
put_page(sg_page(&cmd->tvc_sgl[i]));
@@ -849,31 +862,70 @@ vhost_scsi_map_iov_to_sgl(struct tcm_vhost_cmd *cmd,
cmd->tvc_sgl_count = 0;
return ret;
}
-
sg += ret;
sgl_count -= ret;
}
return 0;
}
+static int
+vhost_scsi_map_iov_to_prot(struct tcm_vhost_cmd *cmd,
+ struct iovec *iov,
+ int niov,
+ bool write)
+{
+ struct scatterlist *prot_sg = cmd->tvc_prot_sgl;
+ unsigned int prot_sgl_count = 0;
+ int ret, i;
+
+ for (i = 0; i < niov; i++)
+ prot_sgl_count += iov_num_pages(&iov[i]);
+
+ if (prot_sgl_count > TCM_VHOST_PREALLOC_PROT_SGLS) {
+ pr_err("vhost_scsi_map_iov_to_prot() sgl_count: %u greater than"
+ " preallocated TCM_VHOST_PREALLOC_PROT_SGLS: %u\n",
+ prot_sgl_count, TCM_VHOST_PREALLOC_PROT_SGLS);
+ return -ENOBUFS;
+ }
+
+ pr_debug("%s prot_sg %p prot_sgl_count %u\n", __func__,
+ prot_sg, prot_sgl_count);
+ sg_init_table(prot_sg, prot_sgl_count);
+ cmd->tvc_prot_sgl_count = prot_sgl_count;
+
+ for (i = 0; i < niov; i++) {
+ ret = vhost_scsi_map_to_sgl(cmd, prot_sg, prot_sgl_count, &iov[i],
+ cmd->tvc_upages, write);
+ if (ret < 0) {
+ for (i = 0; i < cmd->tvc_prot_sgl_count; i++)
+ put_page(sg_page(&cmd->tvc_prot_sgl[i]));
+
+ cmd->tvc_prot_sgl_count = 0;
+ return ret;
+ }
+ prot_sg += ret;
+ prot_sgl_count -= ret;
+ }
+ return 0;
+}
+
static void tcm_vhost_submission_work(struct work_struct *work)
{
struct tcm_vhost_cmd *cmd =
container_of(work, struct tcm_vhost_cmd, work);
struct tcm_vhost_nexus *tv_nexus;
struct se_cmd *se_cmd = &cmd->tvc_se_cmd;
- struct scatterlist *sg_ptr, *sg_bidi_ptr = NULL;
- int rc, sg_no_bidi = 0;
+ struct scatterlist *sg_ptr, *sg_prot_ptr = NULL;
+ int rc;
+ /* FIXME: BIDI operation */
if (cmd->tvc_sgl_count) {
sg_ptr = cmd->tvc_sgl;
-/* FIXME: Fix BIDI operation in tcm_vhost_submission_work() */
-#if 0
- if (se_cmd->se_cmd_flags & SCF_BIDI) {
- sg_bidi_ptr = NULL;
- sg_no_bidi = 0;
- }
-#endif
+
+ if (cmd->tvc_prot_sgl_count)
+ sg_prot_ptr = cmd->tvc_prot_sgl;
+ else
+ se_cmd->prot_pto = true;
} else {
sg_ptr = NULL;
}
@@ -884,7 +936,7 @@ static void tcm_vhost_submission_work(struct work_struct *work)
cmd->tvc_lun, cmd->tvc_exp_data_len,
cmd->tvc_task_attr, cmd->tvc_data_direction,
TARGET_SCF_ACK_KREF, sg_ptr, cmd->tvc_sgl_count,
- sg_bidi_ptr, sg_no_bidi);
+ NULL, 0, sg_prot_ptr, cmd->tvc_prot_sgl_count);
if (rc < 0) {
transport_send_check_condition_and_sense(se_cmd,
TCM_LOGICAL_UNIT_COMMUNICATION_FAILURE, 0);
@@ -916,12 +968,18 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
{
struct tcm_vhost_tpg **vs_tpg;
struct virtio_scsi_cmd_req v_req;
+ struct virtio_scsi_cmd_req_pi v_req_pi;
struct tcm_vhost_tpg *tpg;
struct tcm_vhost_cmd *cmd;
- u32 exp_data_len, data_first, data_num, data_direction;
+ u64 tag;
+ u32 exp_data_len, data_first, data_num, data_direction, prot_first;
unsigned out, in, i;
- int head, ret;
- u8 target;
+ int head, ret, data_niov, prot_niov, prot_bytes;
+ size_t req_size;
+ u16 lun;
+ u8 *target, *lunp, task_attr;
+ bool hdr_pi;
+ void *req, *cdb;
mutex_lock(&vq->mutex);
/*
@@ -935,7 +993,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
vhost_disable_notify(&vs->dev, vq);
for (;;) {
- head = vhost_get_vq_desc(&vs->dev, vq, vq->iov,
+ head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov), &out, &in,
NULL, NULL);
pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
@@ -952,7 +1010,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
break;
}
-/* FIXME: BIDI operation */
+ /* FIXME: BIDI operation */
if (out == 1 && in == 1) {
data_direction = DMA_NONE;
data_first = 0;
@@ -982,23 +1040,38 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
break;
}
- if (unlikely(vq->iov[0].iov_len != sizeof(v_req))) {
- vq_err(vq, "Expecting virtio_scsi_cmd_req, got %zu"
- " bytes\n", vq->iov[0].iov_len);
+ if (vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI)) {
+ req = &v_req_pi;
+ lunp = &v_req_pi.lun[0];
+ target = &v_req_pi.lun[1];
+ req_size = sizeof(v_req_pi);
+ hdr_pi = true;
+ } else {
+ req = &v_req;
+ lunp = &v_req.lun[0];
+ target = &v_req.lun[1];
+ req_size = sizeof(v_req);
+ hdr_pi = false;
+ }
+
+ if (unlikely(vq->iov[0].iov_len < req_size)) {
+ pr_err("Expecting virtio-scsi header: %zu, got %zu\n",
+ req_size, vq->iov[0].iov_len);
break;
}
- pr_debug("Calling __copy_from_user: vq->iov[0].iov_base: %p,"
- " len: %zu\n", vq->iov[0].iov_base, sizeof(v_req));
- ret = __copy_from_user(&v_req, vq->iov[0].iov_base,
- sizeof(v_req));
+ ret = memcpy_fromiovecend(req, &vq->iov[0], 0, req_size);
if (unlikely(ret)) {
vq_err(vq, "Faulted on virtio_scsi_cmd_req\n");
break;
}
- /* Extract the tpgt */
- target = v_req.lun[1];
- tpg = ACCESS_ONCE(vs_tpg[target]);
+ /* virtio-scsi spec requires byte 0 of the lun to be 1 */
+ if (unlikely(*lunp != 1)) {
+ vhost_scsi_send_bad_target(vs, vq, head, out);
+ continue;
+ }
+
+ tpg = ACCESS_ONCE(vs_tpg[*target]);
/* Target does not exist, fail the request */
if (unlikely(!tpg)) {
@@ -1006,17 +1079,79 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
continue;
}
+ data_niov = data_num;
+ prot_niov = prot_first = prot_bytes = 0;
+ /*
+ * Determine if any protection information iovecs are preceeding
+ * the actual data payload, and adjust data_first + data_niov
+ * values accordingly for vhost_scsi_map_iov_to_sgl() below.
+ *
+ * Also extract virtio_scsi header bits for vhost_scsi_get_tag()
+ */
+ if (hdr_pi) {
+ if (v_req_pi.pi_bytesout) {
+ if (data_direction != DMA_TO_DEVICE) {
+ vq_err(vq, "Received non zero do_pi_niov"
+ ", but wrong data_direction\n");
+ goto err_cmd;
+ }
+ prot_bytes = v_req_pi.pi_bytesout;
+ } else if (v_req_pi.pi_bytesin) {
+ if (data_direction != DMA_FROM_DEVICE) {
+ vq_err(vq, "Received non zero di_pi_niov"
+ ", but wrong data_direction\n");
+ goto err_cmd;
+ }
+ prot_bytes = v_req_pi.pi_bytesin;
+ }
+ if (prot_bytes) {
+ int tmp = 0;
+
+ for (i = 0; i < data_num; i++) {
+ tmp += vq->iov[data_first + i].iov_len;
+ prot_niov++;
+ if (tmp >= prot_bytes)
+ break;
+ }
+ prot_first = data_first;
+ data_first += prot_niov;
+ data_niov = data_num - prot_niov;
+ }
+ tag = v_req_pi.tag;
+ task_attr = v_req_pi.task_attr;
+ cdb = &v_req_pi.cdb[0];
+ lun = ((v_req_pi.lun[2] << 8) | v_req_pi.lun[3]) & 0x3FFF;
+ } else {
+ tag = v_req.tag;
+ task_attr = v_req.task_attr;
+ cdb = &v_req.cdb[0];
+ lun = ((v_req.lun[2] << 8) | v_req.lun[3]) & 0x3FFF;
+ }
exp_data_len = 0;
- for (i = 0; i < data_num; i++)
+ for (i = 0; i < data_niov; i++)
exp_data_len += vq->iov[data_first + i].iov_len;
+ /*
+ * Check that the recieved CDB size does not exceeded our
+ * hardcoded max for vhost-scsi
+ *
+ * TODO what if cdb was too small for varlen cdb header?
+ */
+ if (unlikely(scsi_command_size(cdb) > TCM_VHOST_MAX_CDB_SIZE)) {
+ vq_err(vq, "Received SCSI CDB with command_size: %d that"
+ " exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n",
+ scsi_command_size(cdb), TCM_VHOST_MAX_CDB_SIZE);
+ goto err_cmd;
+ }
- cmd = vhost_scsi_get_tag(vq, tpg, &v_req,
- exp_data_len, data_direction);
+ cmd = vhost_scsi_get_tag(vq, tpg, cdb, tag, lun, task_attr,
+ exp_data_len + prot_bytes,
+ data_direction);
if (IS_ERR(cmd)) {
vq_err(vq, "vhost_scsi_get_tag failed %ld\n",
PTR_ERR(cmd));
goto err_cmd;
}
+
pr_debug("Allocated tv_cmd: %p exp_data_len: %d, data_direction"
": %d\n", cmd, exp_data_len, data_direction);
@@ -1024,40 +1159,28 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
cmd->tvc_vq = vq;
cmd->tvc_resp = vq->iov[out].iov_base;
- /*
- * Copy in the recieved CDB descriptor into cmd->tvc_cdb
- * that will be used by tcm_vhost_new_cmd_map() and down into
- * target_setup_cmd_from_cdb()
- */
- memcpy(cmd->tvc_cdb, v_req.cdb, TCM_VHOST_MAX_CDB_SIZE);
- /*
- * Check that the recieved CDB size does not exceeded our
- * hardcoded max for tcm_vhost
- */
- /* TODO what if cdb was too small for varlen cdb header? */
- if (unlikely(scsi_command_size(cmd->tvc_cdb) >
- TCM_VHOST_MAX_CDB_SIZE)) {
- vq_err(vq, "Received SCSI CDB with command_size: %d that"
- " exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n",
- scsi_command_size(cmd->tvc_cdb),
- TCM_VHOST_MAX_CDB_SIZE);
- goto err_free;
- }
- cmd->tvc_lun = ((v_req.lun[2] << 8) | v_req.lun[3]) & 0x3FFF;
-
pr_debug("vhost_scsi got command opcode: %#02x, lun: %d\n",
cmd->tvc_cdb[0], cmd->tvc_lun);
+ if (prot_niov) {
+ ret = vhost_scsi_map_iov_to_prot(cmd,
+ &vq->iov[prot_first], prot_niov,
+ data_direction == DMA_FROM_DEVICE);
+ if (unlikely(ret)) {
+ vq_err(vq, "Failed to map iov to"
+ " prot_sgl\n");
+ goto err_free;
+ }
+ }
if (data_direction != DMA_NONE) {
ret = vhost_scsi_map_iov_to_sgl(cmd,
- &vq->iov[data_first], data_num,
- data_direction == DMA_TO_DEVICE);
+ &vq->iov[data_first], data_niov,
+ data_direction == DMA_FROM_DEVICE);
if (unlikely(ret)) {
vq_err(vq, "Failed to map iov to sgl\n");
goto err_free;
}
}
-
/*
* Save the descriptor from vhost_get_vq_desc() to be used to
* complete the virtio-scsi request in TCM callback context via
@@ -1239,7 +1362,7 @@ vhost_scsi_set_endpoint(struct vhost_scsi *vs,
tpg->tv_tpg_vhost_count++;
tpg->vhost_scsi = vs;
vs_tpg[tpg->tport_tpgt] = tpg;
- smp_mb__after_atomic_inc();
+ smp_mb__after_atomic();
match = true;
}
mutex_unlock(&tpg->tv_tpg_mutex);
@@ -1357,6 +1480,9 @@ err_dev:
static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)
{
+ struct vhost_virtqueue *vq;
+ int i;
+
if (features & ~VHOST_SCSI_FEATURES)
return -EOPNOTSUPP;
@@ -1366,21 +1492,17 @@ static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)
mutex_unlock(&vs->dev.mutex);
return -EFAULT;
}
- vs->dev.acked_features = features;
- smp_wmb();
- vhost_scsi_flush(vs);
+
+ for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
+ vq = &vs->vqs[i].vq;
+ mutex_lock(&vq->mutex);
+ vq->acked_features = features;
+ mutex_unlock(&vq->mutex);
+ }
mutex_unlock(&vs->dev.mutex);
return 0;
}
-static void vhost_scsi_free(struct vhost_scsi *vs)
-{
- if (is_vmalloc_addr(vs))
- vfree(vs);
- else
- kfree(vs);
-}
-
static int vhost_scsi_open(struct inode *inode, struct file *f)
{
struct vhost_scsi *vs;
@@ -1412,20 +1534,15 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
vqs[i] = &vs->vqs[i].vq;
vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
}
- r = vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ);
+ vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ);
tcm_vhost_init_inflight(vs, NULL);
- if (r < 0)
- goto err_init;
-
f->private_data = vs;
return 0;
-err_init:
- kfree(vqs);
err_vqs:
- vhost_scsi_free(vs);
+ kvfree(vs);
err_vs:
return r;
}
@@ -1444,7 +1561,7 @@ static int vhost_scsi_release(struct inode *inode, struct file *f)
/* Jobs can re-queue themselves in evt kick handler. Do extra flush. */
vhost_scsi_flush(vs);
kfree(vs->dev.vqs);
- vhost_scsi_free(vs);
+ kvfree(vs);
return 0;
}
@@ -1580,10 +1697,6 @@ tcm_vhost_do_plug(struct tcm_vhost_tpg *tpg,
return;
mutex_lock(&vs->dev.mutex);
- if (!vhost_has_feature(&vs->dev, VIRTIO_SCSI_F_HOTPLUG)) {
- mutex_unlock(&vs->dev.mutex);
- return;
- }
if (plug)
reason = VIRTIO_SCSI_EVT_RESET_RESCAN;
@@ -1592,8 +1705,9 @@ tcm_vhost_do_plug(struct tcm_vhost_tpg *tpg,
vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
mutex_lock(&vq->mutex);
- tcm_vhost_send_evt(vs, tpg, lun,
- VIRTIO_SCSI_T_TRANSPORT_RESET, reason);
+ if (vhost_has_feature(vq, VIRTIO_SCSI_F_HOTPLUG))
+ tcm_vhost_send_evt(vs, tpg, lun,
+ VIRTIO_SCSI_T_TRANSPORT_RESET, reason);
mutex_unlock(&vq->mutex);
mutex_unlock(&vs->dev.mutex);
}
@@ -1701,6 +1815,7 @@ static void tcm_vhost_free_cmd_map_res(struct tcm_vhost_nexus *nexus,
tv_cmd = &((struct tcm_vhost_cmd *)se_sess->sess_cmd_map)[i];
kfree(tv_cmd->tvc_sgl);
+ kfree(tv_cmd->tvc_prot_sgl);
kfree(tv_cmd->tvc_upages);
}
}
@@ -1734,7 +1849,8 @@ static int tcm_vhost_make_nexus(struct tcm_vhost_tpg *tpg,
*/
tv_nexus->tvn_se_sess = transport_init_session_tags(
TCM_VHOST_DEFAULT_TAGS,
- sizeof(struct tcm_vhost_cmd));
+ sizeof(struct tcm_vhost_cmd),
+ TARGET_PROT_DIN_PASS | TARGET_PROT_DOUT_PASS);
if (IS_ERR(tv_nexus->tvn_se_sess)) {
mutex_unlock(&tpg->tv_tpg_mutex);
kfree(tv_nexus);
@@ -1753,12 +1869,20 @@ static int tcm_vhost_make_nexus(struct tcm_vhost_tpg *tpg,
}
tv_cmd->tvc_upages = kzalloc(sizeof(struct page *) *
- TCM_VHOST_PREALLOC_PAGES, GFP_KERNEL);
+ TCM_VHOST_PREALLOC_UPAGES, GFP_KERNEL);
if (!tv_cmd->tvc_upages) {
mutex_unlock(&tpg->tv_tpg_mutex);
pr_err("Unable to allocate tv_cmd->tvc_upages\n");
goto out;
}
+
+ tv_cmd->tvc_prot_sgl = kzalloc(sizeof(struct scatterlist) *
+ TCM_VHOST_PREALLOC_PROT_SGLS, GFP_KERNEL);
+ if (!tv_cmd->tvc_prot_sgl) {
+ mutex_unlock(&tpg->tv_tpg_mutex);
+ pr_err("Unable to allocate tv_cmd->tvc_prot_sgl\n");
+ goto out;
+ }
}
/*
* Since we are running in 'demo mode' this call with generate a
@@ -2125,6 +2249,7 @@ static struct target_core_fabric_ops tcm_vhost_ops = {
.queue_data_in = tcm_vhost_queue_data_in,
.queue_status = tcm_vhost_queue_status,
.queue_tm_rsp = tcm_vhost_queue_tm_rsp,
+ .aborted_task = tcm_vhost_aborted_task,
/*
* Setup callers for generic logic in target_core_fabric_configfs.c
*/
@@ -2163,15 +2288,15 @@ static int tcm_vhost_register_configfs(void)
/*
* Setup default attribute lists for various fabric->tf_cit_tmpl
*/
- TF_CIT_TMPL(fabric)->tfc_wwn_cit.ct_attrs = tcm_vhost_wwn_attrs;
- TF_CIT_TMPL(fabric)->tfc_tpg_base_cit.ct_attrs = tcm_vhost_tpg_attrs;
- TF_CIT_TMPL(fabric)->tfc_tpg_attrib_cit.ct_attrs = NULL;
- TF_CIT_TMPL(fabric)->tfc_tpg_param_cit.ct_attrs = NULL;
- TF_CIT_TMPL(fabric)->tfc_tpg_np_base_cit.ct_attrs = NULL;
- TF_CIT_TMPL(fabric)->tfc_tpg_nacl_base_cit.ct_attrs = NULL;
- TF_CIT_TMPL(fabric)->tfc_tpg_nacl_attrib_cit.ct_attrs = NULL;
- TF_CIT_TMPL(fabric)->tfc_tpg_nacl_auth_cit.ct_attrs = NULL;
- TF_CIT_TMPL(fabric)->tfc_tpg_nacl_param_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_wwn_cit.ct_attrs = tcm_vhost_wwn_attrs;
+ fabric->tf_cit_tmpl.tfc_tpg_base_cit.ct_attrs = tcm_vhost_tpg_attrs;
+ fabric->tf_cit_tmpl.tfc_tpg_attrib_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_param_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_np_base_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_base_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_attrib_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_auth_cit.ct_attrs = NULL;
+ fabric->tf_cit_tmpl.tfc_tpg_nacl_param_cit.ct_attrs = NULL;
/*
* Register the fabric for use within TCM
*/
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 339eae85859..d9c501eaa6c 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n)
vhost_disable_notify(&n->dev, vq);
for (;;) {
- head = vhost_get_vq_desc(&n->dev, vq, vq->iov,
+ head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov),
&out, &in,
NULL, NULL);
@@ -104,7 +104,6 @@ static int vhost_test_open(struct inode *inode, struct file *f)
struct vhost_test *n = kmalloc(sizeof *n, GFP_KERNEL);
struct vhost_dev *dev;
struct vhost_virtqueue **vqs;
- int r;
if (!n)
return -ENOMEM;
@@ -117,12 +116,7 @@ static int vhost_test_open(struct inode *inode, struct file *f)
dev = &n->dev;
vqs[VHOST_TEST_VQ] = &n->vqs[VHOST_TEST_VQ];
n->vqs[VHOST_TEST_VQ].handle_kick = handle_vq_kick;
- r = vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX);
- if (r < 0) {
- kfree(vqs);
- kfree(n);
- return r;
- }
+ vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX);
f->private_data = n;
@@ -247,15 +241,18 @@ done:
static int vhost_test_set_features(struct vhost_test *n, u64 features)
{
+ struct vhost_virtqueue *vq;
+
mutex_lock(&n->dev.mutex);
if ((features & (1 << VHOST_F_LOG_ALL)) &&
!vhost_log_access_ok(&n->dev)) {
mutex_unlock(&n->dev.mutex);
return -EFAULT;
}
- n->dev.acked_features = features;
- smp_wmb();
- vhost_test_flush(n);
+ vq = &n->vqs[VHOST_TEST_VQ];
+ mutex_lock(&vq->mutex);
+ vq->acked_features = features;
+ mutex_unlock(&vq->mutex);
mutex_unlock(&n->dev.mutex);
return 0;
}
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 69068e0d8f3..c90f4374442 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -18,7 +18,6 @@
#include <linux/mmu_context.h>
#include <linux/miscdevice.h>
#include <linux/mutex.h>
-#include <linux/rcupdate.h>
#include <linux/poll.h>
#include <linux/file.h>
#include <linux/highmem.h>
@@ -191,6 +190,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->log_used = false;
vq->log_addr = -1ull;
vq->private_data = NULL;
+ vq->acked_features = 0;
vq->log_base = NULL;
vq->error_ctx = NULL;
vq->error = NULL;
@@ -198,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->call_ctx = NULL;
vq->call = NULL;
vq->log_ctx = NULL;
+ vq->memory = NULL;
}
static int vhost_worker(void *data)
@@ -290,7 +291,7 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev)
vhost_vq_free_iovecs(dev->vqs[i]);
}
-long vhost_dev_init(struct vhost_dev *dev,
+void vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue **vqs, int nvqs)
{
struct vhost_virtqueue *vq;
@@ -319,8 +320,6 @@ long vhost_dev_init(struct vhost_dev *dev,
vhost_poll_init(&vq->poll, vq->handle_kick,
POLLIN, dev);
}
-
- return 0;
}
EXPORT_SYMBOL_GPL(vhost_dev_init);
@@ -417,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
/* Caller should have device mutex */
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
{
+ int i;
+
vhost_dev_cleanup(dev, true);
/* Restore memory to default empty mapping. */
memory->nregions = 0;
- RCU_INIT_POINTER(dev->memory, memory);
+ dev->memory = memory;
+ /* We don't need VQ locks below since vhost_dev_cleanup makes sure
+ * VQs aren't running.
+ */
+ for (i = 0; i < dev->nvqs; ++i)
+ dev->vqs[i]->memory = memory;
}
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
@@ -464,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
fput(dev->log_file);
dev->log_file = NULL;
/* No one will access memory at this point */
- kfree(rcu_dereference_protected(dev->memory,
- locked ==
- lockdep_is_held(&dev->mutex)));
- RCU_INIT_POINTER(dev->memory, NULL);
+ kfree(dev->memory);
+ dev->memory = NULL;
WARN_ON(!list_empty(&dev->work_list));
if (dev->worker) {
kthread_stop(dev->worker);
@@ -526,11 +530,13 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
for (i = 0; i < d->nvqs; ++i) {
int ok;
+ bool log;
+
mutex_lock(&d->vqs[i]->mutex);
+ log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
/* If ring is inactive, will check when it's enabled. */
if (d->vqs[i]->private_data)
- ok = vq_memory_access_ok(d->vqs[i]->log_base, mem,
- log_all);
+ ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log);
else
ok = 1;
mutex_unlock(&d->vqs[i]->mutex);
@@ -540,12 +546,12 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
return 1;
}
-static int vq_access_ok(struct vhost_dev *d, unsigned int num,
+static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
struct vring_desc __user *desc,
struct vring_avail __user *avail,
struct vring_used __user *used)
{
- size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+ size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
access_ok(VERIFY_READ, avail,
sizeof *avail + num * sizeof *avail->ring + s) &&
@@ -557,26 +563,19 @@ static int vq_access_ok(struct vhost_dev *d, unsigned int num,
/* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev)
{
- struct vhost_memory *mp;
-
- mp = rcu_dereference_protected(dev->memory,
- lockdep_is_held(&dev->mutex));
- return memory_access_ok(dev, mp, 1);
+ return memory_access_ok(dev, dev->memory, 1);
}
EXPORT_SYMBOL_GPL(vhost_log_access_ok);
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
-static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq,
+static int vq_log_access_ok(struct vhost_virtqueue *vq,
void __user *log_base)
{
- struct vhost_memory *mp;
- size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+ size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
- mp = rcu_dereference_protected(vq->dev->memory,
- lockdep_is_held(&vq->mutex));
- return vq_memory_access_ok(log_base, mp,
- vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
+ return vq_memory_access_ok(log_base, vq->memory,
+ vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
sizeof *vq->used +
vq->num * sizeof *vq->used->ring + s));
@@ -586,8 +585,8 @@ static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq,
/* Caller should have vq mutex and device mutex */
int vhost_vq_access_ok(struct vhost_virtqueue *vq)
{
- return vq_access_ok(vq->dev, vq->num, vq->desc, vq->avail, vq->used) &&
- vq_log_access_ok(vq->dev, vq, vq->log_base);
+ return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
+ vq_log_access_ok(vq, vq->log_base);
}
EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
@@ -595,6 +594,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
{
struct vhost_memory mem, *newmem, *oldmem;
unsigned long size = offsetof(struct vhost_memory, regions);
+ int i;
if (copy_from_user(&mem, m, size))
return -EFAULT;
@@ -613,15 +613,19 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
return -EFAULT;
}
- if (!memory_access_ok(d, newmem,
- vhost_has_feature(d, VHOST_F_LOG_ALL))) {
+ if (!memory_access_ok(d, newmem, 0)) {
kfree(newmem);
return -EFAULT;
}
- oldmem = rcu_dereference_protected(d->memory,
- lockdep_is_held(&d->mutex));
- rcu_assign_pointer(d->memory, newmem);
- synchronize_rcu();
+ oldmem = d->memory;
+ d->memory = newmem;
+
+ /* All memory accesses are done under some VQ mutex. */
+ for (i = 0; i < d->nvqs; ++i) {
+ mutex_lock(&d->vqs[i]->mutex);
+ d->vqs[i]->memory = newmem;
+ mutex_unlock(&d->vqs[i]->mutex);
+ }
kfree(oldmem);
return 0;
}
@@ -720,7 +724,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
* If it is not, we don't as size might not have been setup.
* We will verify when backend is configured. */
if (vq->private_data) {
- if (!vq_access_ok(d, vq->num,
+ if (!vq_access_ok(vq, vq->num,
(void __user *)(unsigned long)a.desc_user_addr,
(void __user *)(unsigned long)a.avail_user_addr,
(void __user *)(unsigned long)a.used_user_addr)) {
@@ -860,7 +864,7 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
vq = d->vqs[i];
mutex_lock(&vq->mutex);
/* If ring is inactive, will check when it's enabled. */
- if (vq->private_data && !vq_log_access_ok(d, vq, base))
+ if (vq->private_data && !vq_log_access_ok(vq, base))
r = -EFAULT;
else
vq->log_base = base;
@@ -1046,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
}
EXPORT_SYMBOL_GPL(vhost_init_used);
-static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
+static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size)
{
const struct vhost_memory_region *reg;
@@ -1055,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
u64 s = 0;
int ret = 0;
- rcu_read_lock();
-
- mem = rcu_dereference(dev->memory);
+ mem = vq->memory;
while ((u64)len > s) {
u64 size;
if (unlikely(ret >= iov_size)) {
@@ -1079,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
++ret;
}
- rcu_read_unlock();
return ret;
}
@@ -1104,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
return next;
}
-static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
+static int get_indirect(struct vhost_virtqueue *vq,
struct iovec iov[], unsigned int iov_size,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num,
@@ -1123,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EINVAL;
}
- ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
+ ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
UIO_MAXIOV);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d in indirect.\n", ret);
@@ -1163,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EINVAL;
}
- ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
+ ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
iov_size - iov_count);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d indirect idx %d\n",
@@ -1200,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
* This function returns the descriptor number found, or vq->num (which is
* never a valid descriptor number) if none was found. A negative code is
* returned on error. */
-int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
+int vhost_get_vq_desc(struct vhost_virtqueue *vq,
struct iovec iov[], unsigned int iov_size,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num)
@@ -1274,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EFAULT;
}
if (desc.flags & VRING_DESC_F_INDIRECT) {
- ret = get_indirect(dev, vq, iov, iov_size,
+ ret = get_indirect(vq, iov, iov_size,
out_num, in_num,
log, log_num, &desc);
if (unlikely(ret < 0)) {
@@ -1285,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
continue;
}
- ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
+ ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
iov_size - iov_count);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d descriptor idx %d\n",
@@ -1428,11 +1429,11 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
* interrupts. */
smp_mb();
- if (vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
+ if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
unlikely(vq->avail_idx == vq->last_avail_idx))
return true;
- if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
+ if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
__u16 flags;
if (__get_user(flags, &vq->avail->flags)) {
vq_err(vq, "Failed to get flags");
@@ -1493,7 +1494,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
return false;
vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
- if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
+ if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
r = vhost_update_used_flags(vq);
if (r) {
vq_err(vq, "Failed to enable notification at %p: %d\n",
@@ -1530,7 +1531,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
return;
vq->used_flags |= VRING_USED_F_NO_NOTIFY;
- if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
+ if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
r = vhost_update_used_flags(vq);
if (r)
vq_err(vq, "Failed to enable notification at %p: %d\n",
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 4465ed5f316..3eda654b8f5 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -104,20 +104,18 @@ struct vhost_virtqueue {
struct iovec *indirect;
struct vring_used_elem *heads;
/* Protected by virtqueue mutex. */
+ struct vhost_memory *memory;
void *private_data;
+ unsigned acked_features;
/* Log write descriptors */
void __user *log_base;
struct vhost_log *log;
};
struct vhost_dev {
- /* Readers use RCU to access memory table pointer
- * log base pointer and features.
- * Writers use mutex below.*/
- struct vhost_memory __rcu *memory;
+ struct vhost_memory *memory;
struct mm_struct *mm;
struct mutex mutex;
- unsigned acked_features;
struct vhost_virtqueue **vqs;
int nvqs;
struct file *log_file;
@@ -127,7 +125,7 @@ struct vhost_dev {
struct task_struct *worker;
};
-long vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
+void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
long vhost_dev_set_owner(struct vhost_dev *dev);
bool vhost_dev_has_owner(struct vhost_dev *dev);
long vhost_dev_check_owner(struct vhost_dev *);
@@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);
int vhost_vq_access_ok(struct vhost_virtqueue *vq);
int vhost_log_access_ok(struct vhost_dev *);
-int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
+int vhost_get_vq_desc(struct vhost_virtqueue *,
struct iovec iov[], unsigned int iov_count,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num);
@@ -174,13 +172,8 @@ enum {
(1ULL << VHOST_F_LOG_ALL),
};
-static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
+static inline int vhost_has_feature(struct vhost_virtqueue *vq, int bit)
{
- unsigned acked_features;
-
- /* TODO: check that we are running from vhost_worker or dev mutex is
- * held? */
- acked_features = rcu_dereference_index_check(dev->acked_features, 1);
- return acked_features & (1 << bit);
+ return vq->acked_features & (1 << bit);
}
#endif