diff options
Diffstat (limited to 'virt/kvm/vfio.c')
| -rw-r--r-- | virt/kvm/vfio.c | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/virt/kvm/vfio.c b/virt/kvm/vfio.c index ca4260e3503..ba1a93f935c 100644 --- a/virt/kvm/vfio.c +++ b/virt/kvm/vfio.c @@ -59,6 +59,22 @@ static void kvm_vfio_group_put_external_user(struct vfio_group *vfio_group) symbol_put(vfio_group_put_external_user); } +static bool kvm_vfio_group_is_coherent(struct vfio_group *vfio_group) +{ + long (*fn)(struct vfio_group *, unsigned long); + long ret; + + fn = symbol_get(vfio_external_check_extension); + if (!fn) + return false; + + ret = fn(vfio_group, VFIO_DMA_CC_IOMMU); + + symbol_put(vfio_external_check_extension); + + return ret > 0; +} + /* * Groups can use the same or different IOMMU domains. If the same then * adding a new group may change the coherency of groups we've previously @@ -75,13 +91,10 @@ static void kvm_vfio_update_coherency(struct kvm_device *dev) mutex_lock(&kv->lock); list_for_each_entry(kvg, &kv->group_list, node) { - /* - * TODO: We need an interface to check the coherency of - * the IOMMU domain this group is using. For now, assume - * it's always noncoherent. - */ - noncoherent = true; - break; + if (!kvm_vfio_group_is_coherent(kvg->vfio_group)) { + noncoherent = true; + break; + } } if (noncoherent != kv->noncoherent) { @@ -101,14 +114,14 @@ static int kvm_vfio_set_group(struct kvm_device *dev, long attr, u64 arg) struct kvm_vfio *kv = dev->private; struct vfio_group *vfio_group; struct kvm_vfio_group *kvg; - void __user *argp = (void __user *)arg; + int32_t __user *argp = (int32_t __user *)(unsigned long)arg; struct fd f; int32_t fd; int ret; switch (attr) { case KVM_DEV_VFIO_GROUP_ADD: - if (get_user(fd, (int32_t __user *)argp)) + if (get_user(fd, argp)) return -EFAULT; f = fdget(fd); @@ -148,7 +161,7 @@ static int kvm_vfio_set_group(struct kvm_device *dev, long attr, u64 arg) return 0; case KVM_DEV_VFIO_GROUP_DEL: - if (get_user(fd, (int32_t __user *)argp)) + if (get_user(fd, argp)) return -EFAULT; f = fdget(fd); |
