diff options
Diffstat (limited to 'drivers/vhost')
| -rw-r--r-- | drivers/vhost/net.c | 115 | ||||
| -rw-r--r-- | drivers/vhost/scsi.c | 393 | ||||
| -rw-r--r-- | drivers/vhost/test.c | 19 | ||||
| -rw-r--r-- | drivers/vhost/vhost.c | 101 | ||||
| -rw-r--r-- | drivers/vhost/vhost.h | 21 | 
5 files changed, 393 insertions, 256 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 831eb4fd197..8dae2f724a3 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -17,6 +17,7 @@  #include <linux/workqueue.h>  #include <linux/file.h>  #include <linux/slab.h> +#include <linux/vmalloc.h>  #include <linux/net.h>  #include <linux/if_packet.h> @@ -70,7 +71,12 @@ enum {  };  struct vhost_net_ubuf_ref { -	struct kref kref; +	/* refcount follows semantics similar to kref: +	 *  0: object is released +	 *  1: no outstanding ubufs +	 * >1: outstanding ubufs +	 */ +	atomic_t refcount;  	wait_queue_head_t wait;  	struct vhost_virtqueue *vq;  }; @@ -116,14 +122,6 @@ static void vhost_net_enable_zcopy(int vq)  	vhost_net_zcopy_mask |= 0x1 << vq;  } -static void vhost_net_zerocopy_done_signal(struct kref *kref) -{ -	struct vhost_net_ubuf_ref *ubufs; - -	ubufs = container_of(kref, struct vhost_net_ubuf_ref, kref); -	wake_up(&ubufs->wait); -} -  static struct vhost_net_ubuf_ref *  vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)  { @@ -134,21 +132,24 @@ vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)  	ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);  	if (!ubufs)  		return ERR_PTR(-ENOMEM); -	kref_init(&ubufs->kref); +	atomic_set(&ubufs->refcount, 1);  	init_waitqueue_head(&ubufs->wait);  	ubufs->vq = vq;  	return ubufs;  } -static void vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs) +static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)  { -	kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal); +	int r = atomic_sub_return(1, &ubufs->refcount); +	if (unlikely(!r)) +		wake_up(&ubufs->wait); +	return r;  }  static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)  { -	kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal); -	wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount)); +	vhost_net_ubuf_put(ubufs); +	wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));  }  static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs) @@ -306,23 +307,26 @@ static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)  {  	struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;  	struct vhost_virtqueue *vq = ubufs->vq; -	int cnt = atomic_read(&ubufs->kref.refcount); +	int cnt; + +	rcu_read_lock_bh();  	/* set len to mark this desc buffers done DMA */  	vq->heads[ubuf->desc].len = success ?  		VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN; -	vhost_net_ubuf_put(ubufs); +	cnt = vhost_net_ubuf_put(ubufs);  	/*  	 * Trigger polling thread if guest stopped submitting new buffers: -	 * in this case, the refcount after decrement will eventually reach 1 -	 * so here it is 2. +	 * in this case, the refcount after decrement will eventually reach 1.  	 * We also trigger polling periodically after each 16 packets  	 * (the value 16 here is more or less arbitrary, it's tuned to trigger  	 * less than 10% of times).  	 */ -	if (cnt <= 2 || !(cnt % 16)) +	if (cnt <= 1 || !(cnt % 16))  		vhost_poll_queue(&vq->poll); + +	rcu_read_unlock_bh();  }  /* Expects to be always run from workqueue - which acts as @@ -370,7 +374,7 @@ static void handle_tx(struct vhost_net *net)  			      % UIO_MAXIOV == nvq->done_idx))  			break; -		head = vhost_get_vq_desc(&net->dev, vq, vq->iov, +		head = vhost_get_vq_desc(vq, vq->iov,  					 ARRAY_SIZE(vq->iov),  					 &out, &in,  					 NULL, NULL); @@ -420,7 +424,7 @@ static void handle_tx(struct vhost_net *net)  			msg.msg_control = ubuf;  			msg.msg_controllen = sizeof(ubuf);  			ubufs = nvq->ubufs; -			kref_get(&ubufs->kref); +			atomic_inc(&ubufs->refcount);  			nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;  		} else {  			msg.msg_control = NULL; @@ -502,9 +506,13 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,  			r = -ENOBUFS;  			goto err;  		} -		d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg, +		r = vhost_get_vq_desc(vq, vq->iov + seg,  				      ARRAY_SIZE(vq->iov) - seg, &out,  				      &in, log, log_num); +		if (unlikely(r < 0)) +			goto err; + +		d = r;  		if (d == vq->num) {  			r = 0;  			goto err; @@ -529,6 +537,12 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,  	*iovcount = seg;  	if (unlikely(log))  		*log_num = nlogs; + +	/* Detect overrun */ +	if (unlikely(datalen > 0)) { +		r = UIO_MAXIOV + 1; +		goto err; +	}  	return headcount;  err:  	vhost_discard_vq_desc(vq, headcount); @@ -571,9 +585,9 @@ static void handle_rx(struct vhost_net *net)  	vhost_hlen = nvq->vhost_hlen;  	sock_hlen = nvq->sock_hlen; -	vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? +	vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?  		vq->log : NULL; -	mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF); +	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);  	while ((sock_len = peek_head_len(sock->sk))) {  		sock_len += sock_hlen; @@ -584,6 +598,14 @@ static void handle_rx(struct vhost_net *net)  		/* On error, stop handling until the next kick. */  		if (unlikely(headcount < 0))  			break; +		/* On overrun, truncate and discard */ +		if (unlikely(headcount > UIO_MAXIOV)) { +			msg.msg_iovlen = 1; +			err = sock->ops->recvmsg(NULL, sock, &msg, +						 1, MSG_DONTWAIT | MSG_TRUNC); +			pr_debug("Discarded rx packet: len %zd\n", sock_len); +			continue; +		}  		/* OK, now we need to know about added descriptors. */  		if (!headcount) {  			if (unlikely(vhost_enable_notify(&net->dev, vq))) { @@ -680,16 +702,20 @@ static void handle_rx_net(struct vhost_work *work)  static int vhost_net_open(struct inode *inode, struct file *f)  { -	struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL); +	struct vhost_net *n;  	struct vhost_dev *dev;  	struct vhost_virtqueue **vqs; -	int r, i; +	int i; -	if (!n) -		return -ENOMEM; +	n = kmalloc(sizeof *n, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); +	if (!n) { +		n = vmalloc(sizeof *n); +		if (!n) +			return -ENOMEM; +	}  	vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);  	if (!vqs) { -		kfree(n); +		kvfree(n);  		return -ENOMEM;  	} @@ -706,12 +732,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)  		n->vqs[i].vhost_hlen = 0;  		n->vqs[i].sock_hlen = 0;  	} -	r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX); -	if (r < 0) { -		kfree(n); -		kfree(vqs); -		return r; -	} +	vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);  	vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);  	vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev); @@ -785,7 +806,7 @@ static void vhost_net_flush(struct vhost_net *n)  		vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);  		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);  		n->tx_flush = false; -		kref_init(&n->vqs[VHOST_NET_VQ_TX].ubufs->kref); +		atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);  		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);  	}  } @@ -802,14 +823,16 @@ static int vhost_net_release(struct inode *inode, struct file *f)  	vhost_dev_cleanup(&n->dev, false);  	vhost_net_vq_reset(n);  	if (tx_sock) -		fput(tx_sock->file); +		sockfd_put(tx_sock);  	if (rx_sock) -		fput(rx_sock->file); +		sockfd_put(rx_sock); +	/* Make sure no callbacks are outstanding */ +	synchronize_rcu_bh();  	/* We do an extra flush before freeing memory,  	 * since jobs can re-queue themselves. */  	vhost_net_flush(n);  	kfree(n->dev.vqs); -	kfree(n); +	kvfree(n);  	return 0;  } @@ -842,7 +865,7 @@ static struct socket *get_raw_socket(int fd)  	}  	return sock;  err: -	fput(sock->file); +	sockfd_put(sock);  	return ERR_PTR(r);  } @@ -948,7 +971,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)  	if (oldsock) {  		vhost_net_flush_vq(n, index); -		fput(oldsock->file); +		sockfd_put(oldsock);  	}  	mutex_unlock(&n->dev.mutex); @@ -960,7 +983,7 @@ err_used:  	if (ubufs)  		vhost_net_ubuf_put_wait_and_free(ubufs);  err_ubufs: -	fput(sock->file); +	sockfd_put(sock);  err_vq:  	mutex_unlock(&vq->mutex);  err: @@ -991,9 +1014,9 @@ static long vhost_net_reset_owner(struct vhost_net *n)  done:  	mutex_unlock(&n->dev.mutex);  	if (tx_sock) -		fput(tx_sock->file); +		sockfd_put(tx_sock);  	if (rx_sock) -		fput(rx_sock->file); +		sockfd_put(rx_sock);  	return err;  } @@ -1020,15 +1043,13 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)  		mutex_unlock(&n->dev.mutex);  		return -EFAULT;  	} -	n->dev.acked_features = features; -	smp_wmb();  	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {  		mutex_lock(&n->vqs[i].vq.mutex); +		n->vqs[i].vq.acked_features = features;  		n->vqs[i].vhost_hlen = vhost_hlen;  		n->vqs[i].sock_hlen = sock_hlen;  		mutex_unlock(&n->vqs[i].vq.mutex);  	} -	vhost_net_flush(n);  	mutex_unlock(&n->dev.mutex);  	return 0;  } diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c index 592b31698fc..69906cacd04 100644 --- a/drivers/vhost/scsi.c +++ b/drivers/vhost/scsi.c @@ -57,7 +57,8 @@  #define TCM_VHOST_MAX_CDB_SIZE 32  #define TCM_VHOST_DEFAULT_TAGS 256  #define TCM_VHOST_PREALLOC_SGLS 2048 -#define TCM_VHOST_PREALLOC_PAGES 2048 +#define TCM_VHOST_PREALLOC_UPAGES 2048 +#define TCM_VHOST_PREALLOC_PROT_SGLS 512  struct vhost_scsi_inflight {  	/* Wait for the flush operation to finish */ @@ -79,10 +80,12 @@ struct tcm_vhost_cmd {  	u64 tvc_tag;  	/* The number of scatterlists associated with this cmd */  	u32 tvc_sgl_count; +	u32 tvc_prot_sgl_count;  	/* Saved unpacked SCSI LUN for tcm_vhost_submission_work() */  	u32 tvc_lun;  	/* Pointer to the SGL formatted memory from virtio-scsi */  	struct scatterlist *tvc_sgl; +	struct scatterlist *tvc_prot_sgl;  	struct page **tvc_upages;  	/* Pointer to response */  	struct virtio_scsi_cmd_resp __user *tvc_resp; @@ -166,7 +169,8 @@ enum {  };  enum { -	VHOST_SCSI_FEATURES = VHOST_FEATURES | (1ULL << VIRTIO_SCSI_F_HOTPLUG) +	VHOST_SCSI_FEATURES = VHOST_FEATURES | (1ULL << VIRTIO_SCSI_F_HOTPLUG) | +					       (1ULL << VIRTIO_SCSI_F_T10_PI)  };  #define VHOST_SCSI_MAX_TARGET	256 @@ -456,12 +460,16 @@ static void tcm_vhost_release_cmd(struct se_cmd *se_cmd)  	struct tcm_vhost_cmd *tv_cmd = container_of(se_cmd,  				struct tcm_vhost_cmd, tvc_se_cmd);  	struct se_session *se_sess = se_cmd->se_sess; +	int i;  	if (tv_cmd->tvc_sgl_count) { -		u32 i;  		for (i = 0; i < tv_cmd->tvc_sgl_count; i++)  			put_page(sg_page(&tv_cmd->tvc_sgl[i]));  	} +	if (tv_cmd->tvc_prot_sgl_count) { +		for (i = 0; i < tv_cmd->tvc_prot_sgl_count; i++) +			put_page(sg_page(&tv_cmd->tvc_prot_sgl[i])); +	}  	tcm_vhost_put_inflight(tv_cmd->inflight);  	percpu_ida_free(&se_sess->sess_tag_pool, se_cmd->map_tag); @@ -539,6 +547,11 @@ static void tcm_vhost_queue_tm_rsp(struct se_cmd *se_cmd)  	return;  } +static void tcm_vhost_aborted_task(struct se_cmd *se_cmd) +{ +	return; +} +  static void tcm_vhost_free_evt(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)  {  	vs->vs_events_nr--; @@ -601,7 +614,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)  again:  	vhost_disable_notify(&vs->dev, vq); -	head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, +	head = vhost_get_vq_desc(vq, vq->iov,  			ARRAY_SIZE(vq->iov), &out, &in,  			NULL, NULL);  	if (head < 0) { @@ -708,16 +721,14 @@ static void vhost_scsi_complete_cmd_work(struct vhost_work *work)  }  static struct tcm_vhost_cmd * -vhost_scsi_get_tag(struct vhost_virtqueue *vq, -			struct tcm_vhost_tpg *tpg, -			struct virtio_scsi_cmd_req *v_req, -			u32 exp_data_len, -			int data_direction) +vhost_scsi_get_tag(struct vhost_virtqueue *vq, struct tcm_vhost_tpg *tpg, +		   unsigned char *cdb, u64 scsi_tag, u16 lun, u8 task_attr, +		   u32 exp_data_len, int data_direction)  {  	struct tcm_vhost_cmd *cmd;  	struct tcm_vhost_nexus *tv_nexus;  	struct se_session *se_sess; -	struct scatterlist *sg; +	struct scatterlist *sg, *prot_sg;  	struct page **pages;  	int tag; @@ -728,22 +739,32 @@ vhost_scsi_get_tag(struct vhost_virtqueue *vq,  	}  	se_sess = tv_nexus->tvn_se_sess; -	tag = percpu_ida_alloc(&se_sess->sess_tag_pool, GFP_KERNEL); +	tag = percpu_ida_alloc(&se_sess->sess_tag_pool, TASK_RUNNING); +	if (tag < 0) { +		pr_err("Unable to obtain tag for tcm_vhost_cmd\n"); +		return ERR_PTR(-ENOMEM); +	} +  	cmd = &((struct tcm_vhost_cmd *)se_sess->sess_cmd_map)[tag];  	sg = cmd->tvc_sgl; +	prot_sg = cmd->tvc_prot_sgl;  	pages = cmd->tvc_upages;  	memset(cmd, 0, sizeof(struct tcm_vhost_cmd));  	cmd->tvc_sgl = sg; +	cmd->tvc_prot_sgl = prot_sg;  	cmd->tvc_upages = pages;  	cmd->tvc_se_cmd.map_tag = tag; -	cmd->tvc_tag = v_req->tag; -	cmd->tvc_task_attr = v_req->task_attr; +	cmd->tvc_tag = scsi_tag; +	cmd->tvc_lun = lun; +	cmd->tvc_task_attr = task_attr;  	cmd->tvc_exp_data_len = exp_data_len;  	cmd->tvc_data_direction = data_direction;  	cmd->tvc_nexus = tv_nexus;  	cmd->inflight = tcm_vhost_get_inflight(vq); +	memcpy(cmd->tvc_cdb, cdb, TCM_VHOST_MAX_CDB_SIZE); +  	return cmd;  } @@ -757,35 +778,28 @@ vhost_scsi_map_to_sgl(struct tcm_vhost_cmd *tv_cmd,  		      struct scatterlist *sgl,  		      unsigned int sgl_count,  		      struct iovec *iov, -		      int write) +		      struct page **pages, +		      bool write)  {  	unsigned int npages = 0, pages_nr, offset, nbytes;  	struct scatterlist *sg = sgl;  	void __user *ptr = iov->iov_base;  	size_t len = iov->iov_len; -	struct page **pages;  	int ret, i; -	if (sgl_count > TCM_VHOST_PREALLOC_SGLS) { -		pr_err("vhost_scsi_map_to_sgl() psgl_count: %u greater than" -		       " preallocated TCM_VHOST_PREALLOC_SGLS: %u\n", -			sgl_count, TCM_VHOST_PREALLOC_SGLS); -		return -ENOBUFS; -	} -  	pages_nr = iov_num_pages(iov); -	if (pages_nr > sgl_count) +	if (pages_nr > sgl_count) { +		pr_err("vhost_scsi_map_to_sgl() pages_nr: %u greater than" +		       " sgl_count: %u\n", pages_nr, sgl_count);  		return -ENOBUFS; - -	if (pages_nr > TCM_VHOST_PREALLOC_PAGES) { +	} +	if (pages_nr > TCM_VHOST_PREALLOC_UPAGES) {  		pr_err("vhost_scsi_map_to_sgl() pages_nr: %u greater than" -		       " preallocated TCM_VHOST_PREALLOC_PAGES: %u\n", -			pages_nr, TCM_VHOST_PREALLOC_PAGES); +		       " preallocated TCM_VHOST_PREALLOC_UPAGES: %u\n", +			pages_nr, TCM_VHOST_PREALLOC_UPAGES);  		return -ENOBUFS;  	} -	pages = tv_cmd->tvc_upages; -  	ret = get_user_pages_fast((unsigned long)ptr, pages_nr, write, pages);  	/* No pages were pinned */  	if (ret < 0) @@ -815,33 +829,32 @@ out:  static int  vhost_scsi_map_iov_to_sgl(struct tcm_vhost_cmd *cmd,  			  struct iovec *iov, -			  unsigned int niov, -			  int write) +			  int niov, +			  bool write)  { -	int ret; -	unsigned int i; -	u32 sgl_count; -	struct scatterlist *sg; +	struct scatterlist *sg = cmd->tvc_sgl; +	unsigned int sgl_count = 0; +	int ret, i; -	/* -	 * Find out how long sglist needs to be -	 */ -	sgl_count = 0;  	for (i = 0; i < niov; i++)  		sgl_count += iov_num_pages(&iov[i]); -	/* TODO overflow checking */ +	if (sgl_count > TCM_VHOST_PREALLOC_SGLS) { +		pr_err("vhost_scsi_map_iov_to_sgl() sgl_count: %u greater than" +			" preallocated TCM_VHOST_PREALLOC_SGLS: %u\n", +			sgl_count, TCM_VHOST_PREALLOC_SGLS); +		return -ENOBUFS; +	} -	sg = cmd->tvc_sgl;  	pr_debug("%s sg %p sgl_count %u\n", __func__, sg, sgl_count);  	sg_init_table(sg, sgl_count); -  	cmd->tvc_sgl_count = sgl_count; -	pr_debug("Mapping %u iovecs for %u pages\n", niov, sgl_count); +	pr_debug("Mapping iovec %p for %u pages\n", &iov[0], sgl_count); +  	for (i = 0; i < niov; i++) {  		ret = vhost_scsi_map_to_sgl(cmd, sg, sgl_count, &iov[i], -					    write); +					    cmd->tvc_upages, write);  		if (ret < 0) {  			for (i = 0; i < cmd->tvc_sgl_count; i++)  				put_page(sg_page(&cmd->tvc_sgl[i])); @@ -849,31 +862,70 @@ vhost_scsi_map_iov_to_sgl(struct tcm_vhost_cmd *cmd,  			cmd->tvc_sgl_count = 0;  			return ret;  		} -  		sg += ret;  		sgl_count -= ret;  	}  	return 0;  } +static int +vhost_scsi_map_iov_to_prot(struct tcm_vhost_cmd *cmd, +			   struct iovec *iov, +			   int niov, +			   bool write) +{ +	struct scatterlist *prot_sg = cmd->tvc_prot_sgl; +	unsigned int prot_sgl_count = 0; +	int ret, i; + +	for (i = 0; i < niov; i++) +		prot_sgl_count += iov_num_pages(&iov[i]); + +	if (prot_sgl_count > TCM_VHOST_PREALLOC_PROT_SGLS) { +		pr_err("vhost_scsi_map_iov_to_prot() sgl_count: %u greater than" +			" preallocated TCM_VHOST_PREALLOC_PROT_SGLS: %u\n", +			prot_sgl_count, TCM_VHOST_PREALLOC_PROT_SGLS); +		return -ENOBUFS; +	} + +	pr_debug("%s prot_sg %p prot_sgl_count %u\n", __func__, +		 prot_sg, prot_sgl_count); +	sg_init_table(prot_sg, prot_sgl_count); +	cmd->tvc_prot_sgl_count = prot_sgl_count; + +	for (i = 0; i < niov; i++) { +		ret = vhost_scsi_map_to_sgl(cmd, prot_sg, prot_sgl_count, &iov[i], +					    cmd->tvc_upages, write); +		if (ret < 0) { +			for (i = 0; i < cmd->tvc_prot_sgl_count; i++) +				put_page(sg_page(&cmd->tvc_prot_sgl[i])); + +			cmd->tvc_prot_sgl_count = 0; +			return ret; +		} +		prot_sg += ret; +		prot_sgl_count -= ret; +	} +	return 0; +} +  static void tcm_vhost_submission_work(struct work_struct *work)  {  	struct tcm_vhost_cmd *cmd =  		container_of(work, struct tcm_vhost_cmd, work);  	struct tcm_vhost_nexus *tv_nexus;  	struct se_cmd *se_cmd = &cmd->tvc_se_cmd; -	struct scatterlist *sg_ptr, *sg_bidi_ptr = NULL; -	int rc, sg_no_bidi = 0; +	struct scatterlist *sg_ptr, *sg_prot_ptr = NULL; +	int rc; +	/* FIXME: BIDI operation */  	if (cmd->tvc_sgl_count) {  		sg_ptr = cmd->tvc_sgl; -/* FIXME: Fix BIDI operation in tcm_vhost_submission_work() */ -#if 0 -		if (se_cmd->se_cmd_flags & SCF_BIDI) { -			sg_bidi_ptr = NULL; -			sg_no_bidi = 0; -		} -#endif + +		if (cmd->tvc_prot_sgl_count) +			sg_prot_ptr = cmd->tvc_prot_sgl; +		else +			se_cmd->prot_pto = true;  	} else {  		sg_ptr = NULL;  	} @@ -884,7 +936,7 @@ static void tcm_vhost_submission_work(struct work_struct *work)  			cmd->tvc_lun, cmd->tvc_exp_data_len,  			cmd->tvc_task_attr, cmd->tvc_data_direction,  			TARGET_SCF_ACK_KREF, sg_ptr, cmd->tvc_sgl_count, -			sg_bidi_ptr, sg_no_bidi); +			NULL, 0, sg_prot_ptr, cmd->tvc_prot_sgl_count);  	if (rc < 0) {  		transport_send_check_condition_and_sense(se_cmd,  				TCM_LOGICAL_UNIT_COMMUNICATION_FAILURE, 0); @@ -916,12 +968,18 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)  {  	struct tcm_vhost_tpg **vs_tpg;  	struct virtio_scsi_cmd_req v_req; +	struct virtio_scsi_cmd_req_pi v_req_pi;  	struct tcm_vhost_tpg *tpg;  	struct tcm_vhost_cmd *cmd; -	u32 exp_data_len, data_first, data_num, data_direction; +	u64 tag; +	u32 exp_data_len, data_first, data_num, data_direction, prot_first;  	unsigned out, in, i; -	int head, ret; -	u8 target; +	int head, ret, data_niov, prot_niov, prot_bytes; +	size_t req_size; +	u16 lun; +	u8 *target, *lunp, task_attr; +	bool hdr_pi; +	void *req, *cdb;  	mutex_lock(&vq->mutex);  	/* @@ -935,7 +993,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)  	vhost_disable_notify(&vs->dev, vq);  	for (;;) { -		head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, +		head = vhost_get_vq_desc(vq, vq->iov,  					ARRAY_SIZE(vq->iov), &out, &in,  					NULL, NULL);  		pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n", @@ -952,7 +1010,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)  			break;  		} -/* FIXME: BIDI operation */ +		/* FIXME: BIDI operation */  		if (out == 1 && in == 1) {  			data_direction = DMA_NONE;  			data_first = 0; @@ -982,23 +1040,38 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)  			break;  		} -		if (unlikely(vq->iov[0].iov_len != sizeof(v_req))) { -			vq_err(vq, "Expecting virtio_scsi_cmd_req, got %zu" -				" bytes\n", vq->iov[0].iov_len); +		if (vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI)) { +			req = &v_req_pi; +			lunp = &v_req_pi.lun[0]; +			target = &v_req_pi.lun[1]; +			req_size = sizeof(v_req_pi); +			hdr_pi = true; +		} else { +			req = &v_req; +			lunp = &v_req.lun[0]; +			target = &v_req.lun[1]; +			req_size = sizeof(v_req); +			hdr_pi = false; +		} + +		if (unlikely(vq->iov[0].iov_len < req_size)) { +			pr_err("Expecting virtio-scsi header: %zu, got %zu\n", +			       req_size, vq->iov[0].iov_len);  			break;  		} -		pr_debug("Calling __copy_from_user: vq->iov[0].iov_base: %p," -			" len: %zu\n", vq->iov[0].iov_base, sizeof(v_req)); -		ret = __copy_from_user(&v_req, vq->iov[0].iov_base, -				sizeof(v_req)); +		ret = memcpy_fromiovecend(req, &vq->iov[0], 0, req_size);  		if (unlikely(ret)) {  			vq_err(vq, "Faulted on virtio_scsi_cmd_req\n");  			break;  		} -		/* Extract the tpgt */ -		target = v_req.lun[1]; -		tpg = ACCESS_ONCE(vs_tpg[target]); +		/* virtio-scsi spec requires byte 0 of the lun to be 1 */ +		if (unlikely(*lunp != 1)) { +			vhost_scsi_send_bad_target(vs, vq, head, out); +			continue; +		} + +		tpg = ACCESS_ONCE(vs_tpg[*target]);  		/* Target does not exist, fail the request */  		if (unlikely(!tpg)) { @@ -1006,17 +1079,79 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)  			continue;  		} +		data_niov = data_num; +		prot_niov = prot_first = prot_bytes = 0; +		/* +		 * Determine if any protection information iovecs are preceeding +		 * the actual data payload, and adjust data_first + data_niov +		 * values accordingly for vhost_scsi_map_iov_to_sgl() below. +		 * +		 * Also extract virtio_scsi header bits for vhost_scsi_get_tag() +		 */ +		if (hdr_pi) { +			if (v_req_pi.pi_bytesout) { +				if (data_direction != DMA_TO_DEVICE) { +					vq_err(vq, "Received non zero do_pi_niov" +						", but wrong data_direction\n"); +					goto err_cmd; +				} +				prot_bytes = v_req_pi.pi_bytesout; +			} else if (v_req_pi.pi_bytesin) { +				if (data_direction != DMA_FROM_DEVICE) { +					vq_err(vq, "Received non zero di_pi_niov" +						", but wrong data_direction\n"); +					goto err_cmd; +				} +				prot_bytes = v_req_pi.pi_bytesin; +			} +			if (prot_bytes) { +				int tmp = 0; + +				for (i = 0; i < data_num; i++) { +					tmp += vq->iov[data_first + i].iov_len; +					prot_niov++; +					if (tmp >= prot_bytes) +						break; +				} +				prot_first = data_first; +				data_first += prot_niov; +				data_niov = data_num - prot_niov; +			} +			tag = v_req_pi.tag; +			task_attr = v_req_pi.task_attr; +			cdb = &v_req_pi.cdb[0]; +			lun = ((v_req_pi.lun[2] << 8) | v_req_pi.lun[3]) & 0x3FFF; +		} else { +			tag = v_req.tag; +			task_attr = v_req.task_attr; +			cdb = &v_req.cdb[0]; +			lun = ((v_req.lun[2] << 8) | v_req.lun[3]) & 0x3FFF; +		}  		exp_data_len = 0; -		for (i = 0; i < data_num; i++) +		for (i = 0; i < data_niov; i++)  			exp_data_len += vq->iov[data_first + i].iov_len; +		/* +		 * Check that the recieved CDB size does not exceeded our +		 * hardcoded max for vhost-scsi +		 * +		 * TODO what if cdb was too small for varlen cdb header? +		 */ +		if (unlikely(scsi_command_size(cdb) > TCM_VHOST_MAX_CDB_SIZE)) { +			vq_err(vq, "Received SCSI CDB with command_size: %d that" +				" exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n", +				scsi_command_size(cdb), TCM_VHOST_MAX_CDB_SIZE); +			goto err_cmd; +		} -		cmd = vhost_scsi_get_tag(vq, tpg, &v_req, -					 exp_data_len, data_direction); +		cmd = vhost_scsi_get_tag(vq, tpg, cdb, tag, lun, task_attr, +					 exp_data_len + prot_bytes, +					 data_direction);  		if (IS_ERR(cmd)) {  			vq_err(vq, "vhost_scsi_get_tag failed %ld\n",  					PTR_ERR(cmd));  			goto err_cmd;  		} +  		pr_debug("Allocated tv_cmd: %p exp_data_len: %d, data_direction"  			": %d\n", cmd, exp_data_len, data_direction); @@ -1024,40 +1159,28 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)  		cmd->tvc_vq = vq;  		cmd->tvc_resp = vq->iov[out].iov_base; -		/* -		 * Copy in the recieved CDB descriptor into cmd->tvc_cdb -		 * that will be used by tcm_vhost_new_cmd_map() and down into -		 * target_setup_cmd_from_cdb() -		 */ -		memcpy(cmd->tvc_cdb, v_req.cdb, TCM_VHOST_MAX_CDB_SIZE); -		/* -		 * Check that the recieved CDB size does not exceeded our -		 * hardcoded max for tcm_vhost -		 */ -		/* TODO what if cdb was too small for varlen cdb header? */ -		if (unlikely(scsi_command_size(cmd->tvc_cdb) > -					TCM_VHOST_MAX_CDB_SIZE)) { -			vq_err(vq, "Received SCSI CDB with command_size: %d that" -				" exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n", -				scsi_command_size(cmd->tvc_cdb), -				TCM_VHOST_MAX_CDB_SIZE); -			goto err_free; -		} -		cmd->tvc_lun = ((v_req.lun[2] << 8) | v_req.lun[3]) & 0x3FFF; -  		pr_debug("vhost_scsi got command opcode: %#02x, lun: %d\n",  			cmd->tvc_cdb[0], cmd->tvc_lun); +		if (prot_niov) { +			ret = vhost_scsi_map_iov_to_prot(cmd, +					&vq->iov[prot_first], prot_niov, +					data_direction == DMA_FROM_DEVICE); +			if (unlikely(ret)) { +				vq_err(vq, "Failed to map iov to" +					" prot_sgl\n"); +				goto err_free; +			} +		}  		if (data_direction != DMA_NONE) {  			ret = vhost_scsi_map_iov_to_sgl(cmd, -					&vq->iov[data_first], data_num, -					data_direction == DMA_TO_DEVICE); +					&vq->iov[data_first], data_niov, +					data_direction == DMA_FROM_DEVICE);  			if (unlikely(ret)) {  				vq_err(vq, "Failed to map iov to sgl\n");  				goto err_free;  			}  		} -  		/*  		 * Save the descriptor from vhost_get_vq_desc() to be used to  		 * complete the virtio-scsi request in TCM callback context via @@ -1239,7 +1362,7 @@ vhost_scsi_set_endpoint(struct vhost_scsi *vs,  			tpg->tv_tpg_vhost_count++;  			tpg->vhost_scsi = vs;  			vs_tpg[tpg->tport_tpgt] = tpg; -			smp_mb__after_atomic_inc(); +			smp_mb__after_atomic();  			match = true;  		}  		mutex_unlock(&tpg->tv_tpg_mutex); @@ -1357,6 +1480,9 @@ err_dev:  static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)  { +	struct vhost_virtqueue *vq; +	int i; +  	if (features & ~VHOST_SCSI_FEATURES)  		return -EOPNOTSUPP; @@ -1366,21 +1492,17 @@ static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)  		mutex_unlock(&vs->dev.mutex);  		return -EFAULT;  	} -	vs->dev.acked_features = features; -	smp_wmb(); -	vhost_scsi_flush(vs); + +	for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) { +		vq = &vs->vqs[i].vq; +		mutex_lock(&vq->mutex); +		vq->acked_features = features; +		mutex_unlock(&vq->mutex); +	}  	mutex_unlock(&vs->dev.mutex);  	return 0;  } -static void vhost_scsi_free(struct vhost_scsi *vs) -{ -	if (is_vmalloc_addr(vs)) -		vfree(vs); -	else -		kfree(vs); -} -  static int vhost_scsi_open(struct inode *inode, struct file *f)  {  	struct vhost_scsi *vs; @@ -1412,20 +1534,15 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)  		vqs[i] = &vs->vqs[i].vq;  		vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;  	} -	r = vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ); +	vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ);  	tcm_vhost_init_inflight(vs, NULL); -	if (r < 0) -		goto err_init; -  	f->private_data = vs;  	return 0; -err_init: -	kfree(vqs);  err_vqs: -	vhost_scsi_free(vs); +	kvfree(vs);  err_vs:  	return r;  } @@ -1444,7 +1561,7 @@ static int vhost_scsi_release(struct inode *inode, struct file *f)  	/* Jobs can re-queue themselves in evt kick handler. Do extra flush. */  	vhost_scsi_flush(vs);  	kfree(vs->dev.vqs); -	vhost_scsi_free(vs); +	kvfree(vs);  	return 0;  } @@ -1580,10 +1697,6 @@ tcm_vhost_do_plug(struct tcm_vhost_tpg *tpg,  		return;  	mutex_lock(&vs->dev.mutex); -	if (!vhost_has_feature(&vs->dev, VIRTIO_SCSI_F_HOTPLUG)) { -		mutex_unlock(&vs->dev.mutex); -		return; -	}  	if (plug)  		reason = VIRTIO_SCSI_EVT_RESET_RESCAN; @@ -1592,8 +1705,9 @@ tcm_vhost_do_plug(struct tcm_vhost_tpg *tpg,  	vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;  	mutex_lock(&vq->mutex); -	tcm_vhost_send_evt(vs, tpg, lun, -			VIRTIO_SCSI_T_TRANSPORT_RESET, reason); +	if (vhost_has_feature(vq, VIRTIO_SCSI_F_HOTPLUG)) +		tcm_vhost_send_evt(vs, tpg, lun, +				   VIRTIO_SCSI_T_TRANSPORT_RESET, reason);  	mutex_unlock(&vq->mutex);  	mutex_unlock(&vs->dev.mutex);  } @@ -1701,6 +1815,7 @@ static void tcm_vhost_free_cmd_map_res(struct tcm_vhost_nexus *nexus,  		tv_cmd = &((struct tcm_vhost_cmd *)se_sess->sess_cmd_map)[i];  		kfree(tv_cmd->tvc_sgl); +		kfree(tv_cmd->tvc_prot_sgl);  		kfree(tv_cmd->tvc_upages);  	}  } @@ -1734,7 +1849,8 @@ static int tcm_vhost_make_nexus(struct tcm_vhost_tpg *tpg,  	 */  	tv_nexus->tvn_se_sess = transport_init_session_tags(  					TCM_VHOST_DEFAULT_TAGS, -					sizeof(struct tcm_vhost_cmd)); +					sizeof(struct tcm_vhost_cmd), +					TARGET_PROT_DIN_PASS | TARGET_PROT_DOUT_PASS);  	if (IS_ERR(tv_nexus->tvn_se_sess)) {  		mutex_unlock(&tpg->tv_tpg_mutex);  		kfree(tv_nexus); @@ -1753,12 +1869,20 @@ static int tcm_vhost_make_nexus(struct tcm_vhost_tpg *tpg,  		}  		tv_cmd->tvc_upages = kzalloc(sizeof(struct page *) * -					TCM_VHOST_PREALLOC_PAGES, GFP_KERNEL); +					TCM_VHOST_PREALLOC_UPAGES, GFP_KERNEL);  		if (!tv_cmd->tvc_upages) {  			mutex_unlock(&tpg->tv_tpg_mutex);  			pr_err("Unable to allocate tv_cmd->tvc_upages\n");  			goto out;  		} + +		tv_cmd->tvc_prot_sgl = kzalloc(sizeof(struct scatterlist) * +					TCM_VHOST_PREALLOC_PROT_SGLS, GFP_KERNEL); +		if (!tv_cmd->tvc_prot_sgl) { +			mutex_unlock(&tpg->tv_tpg_mutex); +			pr_err("Unable to allocate tv_cmd->tvc_prot_sgl\n"); +			goto out; +		}  	}  	/*  	 * Since we are running in 'demo mode' this call with generate a @@ -2125,6 +2249,7 @@ static struct target_core_fabric_ops tcm_vhost_ops = {  	.queue_data_in			= tcm_vhost_queue_data_in,  	.queue_status			= tcm_vhost_queue_status,  	.queue_tm_rsp			= tcm_vhost_queue_tm_rsp, +	.aborted_task			= tcm_vhost_aborted_task,  	/*  	 * Setup callers for generic logic in target_core_fabric_configfs.c  	 */ @@ -2163,15 +2288,15 @@ static int tcm_vhost_register_configfs(void)  	/*  	 * Setup default attribute lists for various fabric->tf_cit_tmpl  	 */ -	TF_CIT_TMPL(fabric)->tfc_wwn_cit.ct_attrs = tcm_vhost_wwn_attrs; -	TF_CIT_TMPL(fabric)->tfc_tpg_base_cit.ct_attrs = tcm_vhost_tpg_attrs; -	TF_CIT_TMPL(fabric)->tfc_tpg_attrib_cit.ct_attrs = NULL; -	TF_CIT_TMPL(fabric)->tfc_tpg_param_cit.ct_attrs = NULL; -	TF_CIT_TMPL(fabric)->tfc_tpg_np_base_cit.ct_attrs = NULL; -	TF_CIT_TMPL(fabric)->tfc_tpg_nacl_base_cit.ct_attrs = NULL; -	TF_CIT_TMPL(fabric)->tfc_tpg_nacl_attrib_cit.ct_attrs = NULL; -	TF_CIT_TMPL(fabric)->tfc_tpg_nacl_auth_cit.ct_attrs = NULL; -	TF_CIT_TMPL(fabric)->tfc_tpg_nacl_param_cit.ct_attrs = NULL; +	fabric->tf_cit_tmpl.tfc_wwn_cit.ct_attrs = tcm_vhost_wwn_attrs; +	fabric->tf_cit_tmpl.tfc_tpg_base_cit.ct_attrs = tcm_vhost_tpg_attrs; +	fabric->tf_cit_tmpl.tfc_tpg_attrib_cit.ct_attrs = NULL; +	fabric->tf_cit_tmpl.tfc_tpg_param_cit.ct_attrs = NULL; +	fabric->tf_cit_tmpl.tfc_tpg_np_base_cit.ct_attrs = NULL; +	fabric->tf_cit_tmpl.tfc_tpg_nacl_base_cit.ct_attrs = NULL; +	fabric->tf_cit_tmpl.tfc_tpg_nacl_attrib_cit.ct_attrs = NULL; +	fabric->tf_cit_tmpl.tfc_tpg_nacl_auth_cit.ct_attrs = NULL; +	fabric->tf_cit_tmpl.tfc_tpg_nacl_param_cit.ct_attrs = NULL;  	/*  	 * Register the fabric for use within TCM  	 */ diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c index 339eae85859..d9c501eaa6c 100644 --- a/drivers/vhost/test.c +++ b/drivers/vhost/test.c @@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n)  	vhost_disable_notify(&n->dev, vq);  	for (;;) { -		head = vhost_get_vq_desc(&n->dev, vq, vq->iov, +		head = vhost_get_vq_desc(vq, vq->iov,  					 ARRAY_SIZE(vq->iov),  					 &out, &in,  					 NULL, NULL); @@ -104,7 +104,6 @@ static int vhost_test_open(struct inode *inode, struct file *f)  	struct vhost_test *n = kmalloc(sizeof *n, GFP_KERNEL);  	struct vhost_dev *dev;  	struct vhost_virtqueue **vqs; -	int r;  	if (!n)  		return -ENOMEM; @@ -117,12 +116,7 @@ static int vhost_test_open(struct inode *inode, struct file *f)  	dev = &n->dev;  	vqs[VHOST_TEST_VQ] = &n->vqs[VHOST_TEST_VQ];  	n->vqs[VHOST_TEST_VQ].handle_kick = handle_vq_kick; -	r = vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX); -	if (r < 0) { -		kfree(vqs); -		kfree(n); -		return r; -	} +	vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX);  	f->private_data = n; @@ -247,15 +241,18 @@ done:  static int vhost_test_set_features(struct vhost_test *n, u64 features)  { +	struct vhost_virtqueue *vq; +  	mutex_lock(&n->dev.mutex);  	if ((features & (1 << VHOST_F_LOG_ALL)) &&  	    !vhost_log_access_ok(&n->dev)) {  		mutex_unlock(&n->dev.mutex);  		return -EFAULT;  	} -	n->dev.acked_features = features; -	smp_wmb(); -	vhost_test_flush(n); +	vq = &n->vqs[VHOST_TEST_VQ]; +	mutex_lock(&vq->mutex); +	vq->acked_features = features; +	mutex_unlock(&vq->mutex);  	mutex_unlock(&n->dev.mutex);  	return 0;  } diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 69068e0d8f3..c90f4374442 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -18,7 +18,6 @@  #include <linux/mmu_context.h>  #include <linux/miscdevice.h>  #include <linux/mutex.h> -#include <linux/rcupdate.h>  #include <linux/poll.h>  #include <linux/file.h>  #include <linux/highmem.h> @@ -191,6 +190,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,  	vq->log_used = false;  	vq->log_addr = -1ull;  	vq->private_data = NULL; +	vq->acked_features = 0;  	vq->log_base = NULL;  	vq->error_ctx = NULL;  	vq->error = NULL; @@ -198,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,  	vq->call_ctx = NULL;  	vq->call = NULL;  	vq->log_ctx = NULL; +	vq->memory = NULL;  }  static int vhost_worker(void *data) @@ -290,7 +291,7 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev)  		vhost_vq_free_iovecs(dev->vqs[i]);  } -long vhost_dev_init(struct vhost_dev *dev, +void vhost_dev_init(struct vhost_dev *dev,  		    struct vhost_virtqueue **vqs, int nvqs)  {  	struct vhost_virtqueue *vq; @@ -319,8 +320,6 @@ long vhost_dev_init(struct vhost_dev *dev,  			vhost_poll_init(&vq->poll, vq->handle_kick,  					POLLIN, dev);  	} - -	return 0;  }  EXPORT_SYMBOL_GPL(vhost_dev_init); @@ -417,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);  /* Caller should have device mutex */  void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)  { +	int i; +  	vhost_dev_cleanup(dev, true);  	/* Restore memory to default empty mapping. */  	memory->nregions = 0; -	RCU_INIT_POINTER(dev->memory, memory); +	dev->memory = memory; +	/* We don't need VQ locks below since vhost_dev_cleanup makes sure +	 * VQs aren't running. +	 */ +	for (i = 0; i < dev->nvqs; ++i) +		dev->vqs[i]->memory = memory;  }  EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); @@ -464,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)  		fput(dev->log_file);  	dev->log_file = NULL;  	/* No one will access memory at this point */ -	kfree(rcu_dereference_protected(dev->memory, -					locked == -						lockdep_is_held(&dev->mutex))); -	RCU_INIT_POINTER(dev->memory, NULL); +	kfree(dev->memory); +	dev->memory = NULL;  	WARN_ON(!list_empty(&dev->work_list));  	if (dev->worker) {  		kthread_stop(dev->worker); @@ -526,11 +530,13 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,  	for (i = 0; i < d->nvqs; ++i) {  		int ok; +		bool log; +  		mutex_lock(&d->vqs[i]->mutex); +		log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);  		/* If ring is inactive, will check when it's enabled. */  		if (d->vqs[i]->private_data) -			ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, -						 log_all); +			ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log);  		else  			ok = 1;  		mutex_unlock(&d->vqs[i]->mutex); @@ -540,12 +546,12 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,  	return 1;  } -static int vq_access_ok(struct vhost_dev *d, unsigned int num, +static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,  			struct vring_desc __user *desc,  			struct vring_avail __user *avail,  			struct vring_used __user *used)  { -	size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; +	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;  	return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&  	       access_ok(VERIFY_READ, avail,  			 sizeof *avail + num * sizeof *avail->ring + s) && @@ -557,26 +563,19 @@ static int vq_access_ok(struct vhost_dev *d, unsigned int num,  /* Caller should have device mutex but not vq mutex */  int vhost_log_access_ok(struct vhost_dev *dev)  { -	struct vhost_memory *mp; - -	mp = rcu_dereference_protected(dev->memory, -				       lockdep_is_held(&dev->mutex)); -	return memory_access_ok(dev, mp, 1); +	return memory_access_ok(dev, dev->memory, 1);  }  EXPORT_SYMBOL_GPL(vhost_log_access_ok);  /* Verify access for write logging. */  /* Caller should have vq mutex and device mutex */ -static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq, +static int vq_log_access_ok(struct vhost_virtqueue *vq,  			    void __user *log_base)  { -	struct vhost_memory *mp; -	size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; +	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; -	mp = rcu_dereference_protected(vq->dev->memory, -				       lockdep_is_held(&vq->mutex)); -	return vq_memory_access_ok(log_base, mp, -			    vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) && +	return vq_memory_access_ok(log_base, vq->memory, +				   vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&  		(!vq->log_used || log_access_ok(log_base, vq->log_addr,  					sizeof *vq->used +  					vq->num * sizeof *vq->used->ring + s)); @@ -586,8 +585,8 @@ static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq,  /* Caller should have vq mutex and device mutex */  int vhost_vq_access_ok(struct vhost_virtqueue *vq)  { -	return vq_access_ok(vq->dev, vq->num, vq->desc, vq->avail, vq->used) && -		vq_log_access_ok(vq->dev, vq, vq->log_base); +	return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && +		vq_log_access_ok(vq, vq->log_base);  }  EXPORT_SYMBOL_GPL(vhost_vq_access_ok); @@ -595,6 +594,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)  {  	struct vhost_memory mem, *newmem, *oldmem;  	unsigned long size = offsetof(struct vhost_memory, regions); +	int i;  	if (copy_from_user(&mem, m, size))  		return -EFAULT; @@ -613,15 +613,19 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)  		return -EFAULT;  	} -	if (!memory_access_ok(d, newmem, -			      vhost_has_feature(d, VHOST_F_LOG_ALL))) { +	if (!memory_access_ok(d, newmem, 0)) {  		kfree(newmem);  		return -EFAULT;  	} -	oldmem = rcu_dereference_protected(d->memory, -					   lockdep_is_held(&d->mutex)); -	rcu_assign_pointer(d->memory, newmem); -	synchronize_rcu(); +	oldmem = d->memory; +	d->memory = newmem; + +	/* All memory accesses are done under some VQ mutex. */ +	for (i = 0; i < d->nvqs; ++i) { +		mutex_lock(&d->vqs[i]->mutex); +		d->vqs[i]->memory = newmem; +		mutex_unlock(&d->vqs[i]->mutex); +	}  	kfree(oldmem);  	return 0;  } @@ -720,7 +724,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)  		 * If it is not, we don't as size might not have been setup.  		 * We will verify when backend is configured. */  		if (vq->private_data) { -			if (!vq_access_ok(d, vq->num, +			if (!vq_access_ok(vq, vq->num,  				(void __user *)(unsigned long)a.desc_user_addr,  				(void __user *)(unsigned long)a.avail_user_addr,  				(void __user *)(unsigned long)a.used_user_addr)) { @@ -860,7 +864,7 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)  			vq = d->vqs[i];  			mutex_lock(&vq->mutex);  			/* If ring is inactive, will check when it's enabled. */ -			if (vq->private_data && !vq_log_access_ok(d, vq, base)) +			if (vq->private_data && !vq_log_access_ok(vq, base))  				r = -EFAULT;  			else  				vq->log_base = base; @@ -1046,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)  }  EXPORT_SYMBOL_GPL(vhost_init_used); -static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, +static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,  			  struct iovec iov[], int iov_size)  {  	const struct vhost_memory_region *reg; @@ -1055,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,  	u64 s = 0;  	int ret = 0; -	rcu_read_lock(); - -	mem = rcu_dereference(dev->memory); +	mem = vq->memory;  	while ((u64)len > s) {  		u64 size;  		if (unlikely(ret >= iov_size)) { @@ -1079,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,  		++ret;  	} -	rcu_read_unlock();  	return ret;  } @@ -1104,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)  	return next;  } -static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, +static int get_indirect(struct vhost_virtqueue *vq,  			struct iovec iov[], unsigned int iov_size,  			unsigned int *out_num, unsigned int *in_num,  			struct vhost_log *log, unsigned int *log_num, @@ -1123,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,  		return -EINVAL;  	} -	ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, +	ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,  			     UIO_MAXIOV);  	if (unlikely(ret < 0)) {  		vq_err(vq, "Translation failure %d in indirect.\n", ret); @@ -1163,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,  			return -EINVAL;  		} -		ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, +		ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,  				     iov_size - iov_count);  		if (unlikely(ret < 0)) {  			vq_err(vq, "Translation failure %d indirect idx %d\n", @@ -1200,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,   * This function returns the descriptor number found, or vq->num (which is   * never a valid descriptor number) if none was found.  A negative code is   * returned on error. */ -int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, +int vhost_get_vq_desc(struct vhost_virtqueue *vq,  		      struct iovec iov[], unsigned int iov_size,  		      unsigned int *out_num, unsigned int *in_num,  		      struct vhost_log *log, unsigned int *log_num) @@ -1274,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,  			return -EFAULT;  		}  		if (desc.flags & VRING_DESC_F_INDIRECT) { -			ret = get_indirect(dev, vq, iov, iov_size, +			ret = get_indirect(vq, iov, iov_size,  					   out_num, in_num,  					   log, log_num, &desc);  			if (unlikely(ret < 0)) { @@ -1285,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,  			continue;  		} -		ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, +		ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,  				     iov_size - iov_count);  		if (unlikely(ret < 0)) {  			vq_err(vq, "Translation failure %d descriptor idx %d\n", @@ -1428,11 +1429,11 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)  	 * interrupts. */  	smp_mb(); -	if (vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) && +	if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&  	    unlikely(vq->avail_idx == vq->last_avail_idx))  		return true; -	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { +	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {  		__u16 flags;  		if (__get_user(flags, &vq->avail->flags)) {  			vq_err(vq, "Failed to get flags"); @@ -1493,7 +1494,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)  	if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))  		return false;  	vq->used_flags &= ~VRING_USED_F_NO_NOTIFY; -	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { +	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {  		r = vhost_update_used_flags(vq);  		if (r) {  			vq_err(vq, "Failed to enable notification at %p: %d\n", @@ -1530,7 +1531,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)  	if (vq->used_flags & VRING_USED_F_NO_NOTIFY)  		return;  	vq->used_flags |= VRING_USED_F_NO_NOTIFY; -	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { +	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {  		r = vhost_update_used_flags(vq);  		if (r)  			vq_err(vq, "Failed to enable notification at %p: %d\n", diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 4465ed5f316..3eda654b8f5 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -104,20 +104,18 @@ struct vhost_virtqueue {  	struct iovec *indirect;  	struct vring_used_elem *heads;  	/* Protected by virtqueue mutex. */ +	struct vhost_memory *memory;  	void *private_data; +	unsigned acked_features;  	/* Log write descriptors */  	void __user *log_base;  	struct vhost_log *log;  };  struct vhost_dev { -	/* Readers use RCU to access memory table pointer -	 * log base pointer and features. -	 * Writers use mutex below.*/ -	struct vhost_memory __rcu *memory; +	struct vhost_memory *memory;  	struct mm_struct *mm;  	struct mutex mutex; -	unsigned acked_features;  	struct vhost_virtqueue **vqs;  	int nvqs;  	struct file *log_file; @@ -127,7 +125,7 @@ struct vhost_dev {  	struct task_struct *worker;  }; -long vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); +void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);  long vhost_dev_set_owner(struct vhost_dev *dev);  bool vhost_dev_has_owner(struct vhost_dev *dev);  long vhost_dev_check_owner(struct vhost_dev *); @@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);  int vhost_vq_access_ok(struct vhost_virtqueue *vq);  int vhost_log_access_ok(struct vhost_dev *); -int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, +int vhost_get_vq_desc(struct vhost_virtqueue *,  		      struct iovec iov[], unsigned int iov_count,  		      unsigned int *out_num, unsigned int *in_num,  		      struct vhost_log *log, unsigned int *log_num); @@ -174,13 +172,8 @@ enum {  			 (1ULL << VHOST_F_LOG_ALL),  }; -static inline int vhost_has_feature(struct vhost_dev *dev, int bit) +static inline int vhost_has_feature(struct vhost_virtqueue *vq, int bit)  { -	unsigned acked_features; - -	/* TODO: check that we are running from vhost_worker or dev mutex is -	 * held? */ -	acked_features = rcu_dereference_index_check(dev->acked_features, 1); -	return acked_features & (1 << bit); +	return vq->acked_features & (1 << bit);  }  #endif  | 
