aboutsummaryrefslogtreecommitdiff
path: root/tools/lguest/lguest.c
diff options
context:
space:
mode:
Diffstat (limited to 'tools/lguest/lguest.c')
-rw-r--r--tools/lguest/lguest.c123
1 files changed, 65 insertions, 58 deletions
diff --git a/tools/lguest/lguest.c b/tools/lguest/lguest.c
index f759f4f097c..32cf2ce15d6 100644
--- a/tools/lguest/lguest.c
+++ b/tools/lguest/lguest.c
@@ -42,14 +42,10 @@
#include <pwd.h>
#include <grp.h>
-#include <linux/virtio_config.h>
-#include <linux/virtio_net.h>
-#include <linux/virtio_blk.h>
-#include <linux/virtio_console.h>
-#include <linux/virtio_rng.h>
-#include <linux/virtio_ring.h>
-#include <asm/bootparam.h>
-#include "../../include/linux/lguest_launcher.h"
+#ifndef VIRTIO_F_ANY_LAYOUT
+#define VIRTIO_F_ANY_LAYOUT 27
+#endif
+
/*L:110
* We can ignore the 43 include files we need for this program, but I do want
* to draw attention to the use of kernel-style types.
@@ -65,6 +61,15 @@ typedef uint16_t u16;
typedef uint8_t u8;
/*:*/
+#include <linux/virtio_config.h>
+#include <linux/virtio_net.h>
+#include <linux/virtio_blk.h>
+#include <linux/virtio_console.h>
+#include <linux/virtio_rng.h>
+#include <linux/virtio_ring.h>
+#include <asm/bootparam.h>
+#include "../../include/linux/lguest_launcher.h"
+
#define BRIDGE_PFX "bridge:"
#ifndef SIOCBRADDIF
#define SIOCBRADDIF 0x89a2 /* add interface to bridge */
@@ -177,30 +182,8 @@ static struct termios orig_term;
* in precise order.
*/
#define wmb() __asm__ __volatile__("" : : : "memory")
-#define mb() __asm__ __volatile__("" : : : "memory")
-
-/*
- * Convert an iovec element to the given type.
- *
- * This is a fairly ugly trick: we need to know the size of the type and
- * alignment requirement to check the pointer is kosher. It's also nice to
- * have the name of the type in case we report failure.
- *
- * Typing those three things all the time is cumbersome and error prone, so we
- * have a macro which sets them all up and passes to the real function.
- */
-#define convert(iov, type) \
- ((type *)_convert((iov), sizeof(type), __alignof__(type), #type))
-
-static void *_convert(struct iovec *iov, size_t size, size_t align,
- const char *name)
-{
- if (iov->iov_len != size)
- errx(1, "Bad iovec size %zu for %s", iov->iov_len, name);
- if ((unsigned long)iov->iov_base % align != 0)
- errx(1, "Bad alignment %p for %s", iov->iov_base, name);
- return iov->iov_base;
-}
+#define rmb() __asm__ __volatile__("lock; addl $0,0(%%esp)" : : : "memory")
+#define mb() __asm__ __volatile__("lock; addl $0,0(%%esp)" : : : "memory")
/* Wrapper for the last available index. Makes it easier to change. */
#define lg_last_avail(vq) ((vq)->last_avail_idx)
@@ -228,7 +211,8 @@ static bool iov_empty(const struct iovec iov[], unsigned int num_iov)
}
/* Take len bytes from the front of this iovec. */
-static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len)
+static void iov_consume(struct iovec iov[], unsigned num_iov,
+ void *dest, unsigned len)
{
unsigned int i;
@@ -236,11 +220,16 @@ static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len)
unsigned int used;
used = iov[i].iov_len < len ? iov[i].iov_len : len;
+ if (dest) {
+ memcpy(dest, iov[i].iov_base, used);
+ dest += used;
+ }
iov[i].iov_base += used;
iov[i].iov_len -= used;
len -= used;
}
- assert(len == 0);
+ if (len != 0)
+ errx(1, "iovec too short!");
}
/* The device virtqueue descriptors are followed by feature bitmasks. */
@@ -693,6 +682,12 @@ static unsigned wait_for_vq_desc(struct virtqueue *vq,
errx(1, "Guest moved used index from %u to %u",
last_avail, vq->vring.avail->idx);
+ /*
+ * Make sure we read the descriptor number *after* we read the ring
+ * update; don't let the cpu or compiler change the order.
+ */
+ rmb();
+
/*
* Grab the next descriptor number they're advertising, and increment
* the index we've seen.
@@ -712,6 +707,12 @@ static unsigned wait_for_vq_desc(struct virtqueue *vq,
i = head;
/*
+ * We have to read the descriptor after we read the descriptor number,
+ * but there's a data dependency there so the CPU shouldn't reorder
+ * that: no rmb() required.
+ */
+
+ /*
* If this is an indirect entry, then this buffer contains a descriptor
* table which we handle as if it's any normal descriptor chain.
*/
@@ -864,7 +865,7 @@ static void console_output(struct virtqueue *vq)
warn("Write to stdout gave %i (%d)", len, errno);
break;
}
- iov_consume(iov, out, len);
+ iov_consume(iov, out, NULL, len);
}
/*
@@ -1299,6 +1300,7 @@ static struct device *new_device(const char *name, u16 type)
dev->feature_len = 0;
dev->num_vq = 0;
dev->running = false;
+ dev->next = NULL;
/*
* Append to device list. Prepending to a single-linked list is
@@ -1546,6 +1548,8 @@ static void setup_tun_net(char *arg)
add_feature(dev, VIRTIO_NET_F_HOST_ECN);
/* We handle indirect ring entries */
add_feature(dev, VIRTIO_RING_F_INDIRECT_DESC);
+ /* We're compliant with the damn spec. */
+ add_feature(dev, VIRTIO_F_ANY_LAYOUT);
set_config(dev, sizeof(conf), &conf);
/* We don't need the socket any more; setup is done. */
@@ -1590,9 +1594,9 @@ static void blk_request(struct virtqueue *vq)
{
struct vblk_info *vblk = vq->dev->priv;
unsigned int head, out_num, in_num, wlen;
- int ret;
+ int ret, i;
u8 *in;
- struct virtio_blk_outhdr *out;
+ struct virtio_blk_outhdr out;
struct iovec iov[vq->vring.num];
off64_t off;
@@ -1602,32 +1606,36 @@ static void blk_request(struct virtqueue *vq)
*/
head = wait_for_vq_desc(vq, iov, &out_num, &in_num);
- /*
- * Every block request should contain at least one output buffer
- * (detailing the location on disk and the type of request) and one
- * input buffer (to hold the result).
- */
- if (out_num == 0 || in_num == 0)
- errx(1, "Bad virtblk cmd %u out=%u in=%u",
- head, out_num, in_num);
+ /* Copy the output header from the front of the iov (adjusts iov) */
+ iov_consume(iov, out_num, &out, sizeof(out));
+
+ /* Find and trim end of iov input array, for our status byte. */
+ in = NULL;
+ for (i = out_num + in_num - 1; i >= out_num; i--) {
+ if (iov[i].iov_len > 0) {
+ in = iov[i].iov_base + iov[i].iov_len - 1;
+ iov[i].iov_len--;
+ break;
+ }
+ }
+ if (!in)
+ errx(1, "Bad virtblk cmd with no room for status");
- out = convert(&iov[0], struct virtio_blk_outhdr);
- in = convert(&iov[out_num+in_num-1], u8);
/*
* For historical reasons, block operations are expressed in 512 byte
* "sectors".
*/
- off = out->sector * 512;
+ off = out.sector * 512;
/*
* In general the virtio block driver is allowed to try SCSI commands.
* It'd be nice if we supported eject, for example, but we don't.
*/
- if (out->type & VIRTIO_BLK_T_SCSI_CMD) {
+ if (out.type & VIRTIO_BLK_T_SCSI_CMD) {
fprintf(stderr, "Scsi commands unsupported\n");
*in = VIRTIO_BLK_S_UNSUPP;
wlen = sizeof(*in);
- } else if (out->type & VIRTIO_BLK_T_OUT) {
+ } else if (out.type & VIRTIO_BLK_T_OUT) {
/*
* Write
*
@@ -1635,10 +1643,10 @@ static void blk_request(struct virtqueue *vq)
* if they try to write past end.
*/
if (lseek64(vblk->fd, off, SEEK_SET) != off)
- err(1, "Bad seek to sector %llu", out->sector);
+ err(1, "Bad seek to sector %llu", out.sector);
- ret = writev(vblk->fd, iov+1, out_num-1);
- verbose("WRITE to sector %llu: %i\n", out->sector, ret);
+ ret = writev(vblk->fd, iov, out_num);
+ verbose("WRITE to sector %llu: %i\n", out.sector, ret);
/*
* Grr... Now we know how long the descriptor they sent was, we
@@ -1654,7 +1662,7 @@ static void blk_request(struct virtqueue *vq)
wlen = sizeof(*in);
*in = (ret >= 0 ? VIRTIO_BLK_S_OK : VIRTIO_BLK_S_IOERR);
- } else if (out->type & VIRTIO_BLK_T_FLUSH) {
+ } else if (out.type & VIRTIO_BLK_T_FLUSH) {
/* Flush */
ret = fdatasync(vblk->fd);
verbose("FLUSH fdatasync: %i\n", ret);
@@ -1668,10 +1676,9 @@ static void blk_request(struct virtqueue *vq)
* if they try to read past end.
*/
if (lseek64(vblk->fd, off, SEEK_SET) != off)
- err(1, "Bad seek to sector %llu", out->sector);
+ err(1, "Bad seek to sector %llu", out.sector);
- ret = readv(vblk->fd, iov+1, in_num-1);
- verbose("READ from sector %llu: %i\n", out->sector, ret);
+ ret = readv(vblk->fd, iov + out_num, in_num);
if (ret >= 0) {
wlen = sizeof(*in) + ret;
*in = VIRTIO_BLK_S_OK;
@@ -1757,7 +1764,7 @@ static void rng_input(struct virtqueue *vq)
len = readv(rng_info->rfd, iov, in_num);
if (len <= 0)
err(1, "Read from /dev/random gave %i", len);
- iov_consume(iov, in_num, len);
+ iov_consume(iov, in_num, NULL, len);
totlen += len;
}