diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 233 |
1 files changed, 79 insertions, 154 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index f44340b41494..d450e16c5c25 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -50,10 +50,6 @@ enum { #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) -INTERVAL_TREE_DEFINE(struct vhost_umem_node, - rb, __u64, __subtree_last, - START, LAST, static inline, vhost_umem_interval_tree); - #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) { @@ -457,7 +453,9 @@ static size_t vhost_get_desc_size(struct vhost_virtqueue *vq, void vhost_dev_init(struct vhost_dev *dev, struct vhost_virtqueue **vqs, int nvqs, - int iov_limit, int weight, int byte_weight) + int iov_limit, int weight, int byte_weight, + int (*msg_handler)(struct vhost_dev *dev, + struct vhost_iotlb_msg *msg)) { struct vhost_virtqueue *vq; int i; @@ -473,6 +471,7 @@ void vhost_dev_init(struct vhost_dev *dev, dev->iov_limit = iov_limit; dev->weight = weight; dev->byte_weight = byte_weight; + dev->msg_handler = msg_handler; init_llist_head(&dev->work_list); init_waitqueue_head(&dev->wait); INIT_LIST_HEAD(&dev->read_list); @@ -581,21 +580,25 @@ err_mm: } EXPORT_SYMBOL_GPL(vhost_dev_set_owner); -struct vhost_umem *vhost_dev_reset_owner_prepare(void) +static struct vhost_iotlb *iotlb_alloc(void) +{ + return vhost_iotlb_alloc(max_iotlb_entries, + VHOST_IOTLB_FLAG_RETIRE); +} + +struct vhost_iotlb *vhost_dev_reset_owner_prepare(void) { - return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL); + return iotlb_alloc(); } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); /* Caller should have device mutex */ -void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem) +void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem) { int i; vhost_dev_cleanup(dev); - /* Restore memory to default empty mapping. */ - INIT_LIST_HEAD(&umem->umem_list); dev->umem = umem; /* We don't need VQ locks below since vhost_dev_cleanup makes sure * VQs aren't running. @@ -618,28 +621,6 @@ void vhost_dev_stop(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_dev_stop); -static void vhost_umem_free(struct vhost_umem *umem, - struct vhost_umem_node *node) -{ - vhost_umem_interval_tree_remove(node, &umem->umem_tree); - list_del(&node->link); - kfree(node); - umem->numem--; -} - -static void vhost_umem_clean(struct vhost_umem *umem) -{ - struct vhost_umem_node *node, *tmp; - - if (!umem) - return; - - list_for_each_entry_safe(node, tmp, &umem->umem_list, link) - vhost_umem_free(umem, node); - - kvfree(umem); -} - static void vhost_clear_msg(struct vhost_dev *dev) { struct vhost_msg_node *node, *n; @@ -677,9 +658,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev) eventfd_ctx_put(dev->log_ctx); dev->log_ctx = NULL; /* No one will access memory at this point */ - vhost_umem_clean(dev->umem); + vhost_iotlb_free(dev->umem); dev->umem = NULL; - vhost_umem_clean(dev->iotlb); + vhost_iotlb_free(dev->iotlb); dev->iotlb = NULL; vhost_clear_msg(dev); wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); @@ -715,27 +696,26 @@ static bool vhost_overflow(u64 uaddr, u64 size) } /* Caller should have vq mutex and device mutex. */ -static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, +static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem, int log_all) { - struct vhost_umem_node *node; + struct vhost_iotlb_map *map; if (!umem) return false; - list_for_each_entry(node, &umem->umem_list, link) { - unsigned long a = node->userspace_addr; + list_for_each_entry(map, &umem->list, link) { + unsigned long a = map->addr; - if (vhost_overflow(node->userspace_addr, node->size)) + if (vhost_overflow(map->addr, map->size)) return false; - if (!access_ok((void __user *)a, - node->size)) + if (!access_ok((void __user *)a, map->size)) return false; else if (log_all && !log_access_ok(log_base, - node->start, - node->size)) + map->start, + map->size)) return false; } return true; @@ -745,17 +725,17 @@ static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq, u64 addr, unsigned int size, int type) { - const struct vhost_umem_node *node = vq->meta_iotlb[type]; + const struct vhost_iotlb_map *map = vq->meta_iotlb[type]; - if (!node) + if (!map) return NULL; - return (void *)(uintptr_t)(node->userspace_addr + addr - node->start); + return (void *)(uintptr_t)(map->addr + addr - map->start); } /* Can we switch to this memory table? */ /* Caller should have device mutex but not vq mutex */ -static bool memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, +static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem, int log_all) { int i; @@ -1020,47 +1000,6 @@ static inline int vhost_get_desc(struct vhost_virtqueue *vq, return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc)); } -static int vhost_new_umem_range(struct vhost_umem *umem, - u64 start, u64 size, u64 end, - u64 userspace_addr, int perm) -{ - struct vhost_umem_node *tmp, *node; - - if (!size) - return -EFAULT; - - node = kmalloc(sizeof(*node), GFP_ATOMIC); - if (!node) - return -ENOMEM; - - if (umem->numem == max_iotlb_entries) { - tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link); - vhost_umem_free(umem, tmp); - } - - node->start = start; - node->size = size; - node->last = end; - node->userspace_addr = userspace_addr; - node->perm = perm; - INIT_LIST_HEAD(&node->link); - list_add_tail(&node->link, &umem->umem_list); - vhost_umem_interval_tree_insert(node, &umem->umem_tree); - umem->numem++; - - return 0; -} - -static void vhost_del_umem_range(struct vhost_umem *umem, - u64 start, u64 end) -{ - struct vhost_umem_node *node; - - while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, - start, end))) - vhost_umem_free(umem, node); -} - static void vhost_iotlb_notify_vq(struct vhost_dev *d, struct vhost_iotlb_msg *msg) { @@ -1117,9 +1056,9 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, break; } vhost_vq_meta_reset(dev); - if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size, - msg->iova + msg->size - 1, - msg->uaddr, msg->perm)) { + if (vhost_iotlb_add_range(dev->iotlb, msg->iova, + msg->iova + msg->size - 1, + msg->uaddr, msg->perm)) { ret = -ENOMEM; break; } @@ -1131,8 +1070,8 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, break; } vhost_vq_meta_reset(dev); - vhost_del_umem_range(dev->iotlb, msg->iova, - msg->iova + msg->size - 1); + vhost_iotlb_del_range(dev->iotlb, msg->iova, + msg->iova + msg->size - 1); break; default: ret = -EINVAL; @@ -1178,7 +1117,12 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, ret = -EINVAL; goto done; } - if (vhost_process_iotlb_msg(dev, &msg)) { + + if (dev->msg_handler) + ret = dev->msg_handler(dev, &msg); + else + ret = vhost_process_iotlb_msg(dev, &msg); + if (ret) { ret = -EFAULT; goto done; } @@ -1311,44 +1255,42 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, } static void vhost_vq_meta_update(struct vhost_virtqueue *vq, - const struct vhost_umem_node *node, + const struct vhost_iotlb_map *map, int type) { int access = (type == VHOST_ADDR_USED) ? VHOST_ACCESS_WO : VHOST_ACCESS_RO; - if (likely(node->perm & access)) - vq->meta_iotlb[type] = node; + if (likely(map->perm & access)) + vq->meta_iotlb[type] = map; } static bool iotlb_access_ok(struct vhost_virtqueue *vq, int access, u64 addr, u64 len, int type) { - const struct vhost_umem_node *node; - struct vhost_umem *umem = vq->iotlb; + const struct vhost_iotlb_map *map; + struct vhost_iotlb *umem = vq->iotlb; u64 s = 0, size, orig_addr = addr, last = addr + len - 1; if (vhost_vq_meta_fetch(vq, addr, len, type)) return true; while (len > s) { - node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, - addr, - last); - if (node == NULL || node->start > addr) { + map = vhost_iotlb_itree_first(umem, addr, last); + if (map == NULL || map->start > addr) { vhost_iotlb_miss(vq, addr, access); return false; - } else if (!(node->perm & access)) { + } else if (!(map->perm & access)) { /* Report the possible access violation by * request another translation from userspace. */ return false; } - size = node->size - addr + node->start; + size = map->size - addr + map->start; if (orig_addr == addr && size >= len) - vhost_vq_meta_update(vq, node, type); + vhost_vq_meta_update(vq, map, type); s += size; addr += size; @@ -1364,12 +1306,12 @@ int vq_meta_prefetch(struct vhost_virtqueue *vq) if (!vq->iotlb) return 1; - return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc, + return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc, vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) && - iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail, + iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail, vhost_get_avail_size(vq, num), VHOST_ADDR_AVAIL) && - iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used, + iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used, vhost_get_used_size(vq, num), VHOST_ADDR_USED); } EXPORT_SYMBOL_GPL(vq_meta_prefetch); @@ -1408,25 +1350,11 @@ bool vhost_vq_access_ok(struct vhost_virtqueue *vq) } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); -static struct vhost_umem *vhost_umem_alloc(void) -{ - struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL); - - if (!umem) - return NULL; - - umem->umem_tree = RB_ROOT_CACHED; - umem->numem = 0; - INIT_LIST_HEAD(&umem->umem_list); - - return umem; -} - static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) { struct vhost_memory mem, *newmem; struct vhost_memory_region *region; - struct vhost_umem *newumem, *oldumem; + struct vhost_iotlb *newumem, *oldumem; unsigned long size = offsetof(struct vhost_memory, regions); int i; @@ -1448,7 +1376,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) return -EFAULT; } - newumem = vhost_umem_alloc(); + newumem = iotlb_alloc(); if (!newumem) { kvfree(newmem); return -ENOMEM; @@ -1457,13 +1385,12 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) for (region = newmem->regions; region < newmem->regions + mem.nregions; region++) { - if (vhost_new_umem_range(newumem, - region->guest_phys_addr, - region->memory_size, - region->guest_phys_addr + - region->memory_size - 1, - region->userspace_addr, - VHOST_ACCESS_RW)) + if (vhost_iotlb_add_range(newumem, + region->guest_phys_addr, + region->guest_phys_addr + + region->memory_size - 1, + region->userspace_addr, + VHOST_MAP_RW)) goto err; } @@ -1481,11 +1408,11 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) } kvfree(newmem); - vhost_umem_clean(oldumem); + vhost_iotlb_free(oldumem); return 0; err: - vhost_umem_clean(newumem); + vhost_iotlb_free(newumem); kvfree(newmem); return -EFAULT; } @@ -1726,10 +1653,10 @@ EXPORT_SYMBOL_GPL(vhost_vring_ioctl); int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) { - struct vhost_umem *niotlb, *oiotlb; + struct vhost_iotlb *niotlb, *oiotlb; int i; - niotlb = vhost_umem_alloc(); + niotlb = iotlb_alloc(); if (!niotlb) return -ENOMEM; @@ -1745,7 +1672,7 @@ int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) mutex_unlock(&vq->mutex); } - vhost_umem_clean(oiotlb); + vhost_iotlb_free(oiotlb); return 0; } @@ -1875,8 +1802,8 @@ static int log_write(void __user *log_base, static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len) { - struct vhost_umem *umem = vq->umem; - struct vhost_umem_node *u; + struct vhost_iotlb *umem = vq->umem; + struct vhost_iotlb_map *u; u64 start, end, l, min; int r; bool hit = false; @@ -1886,16 +1813,15 @@ static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len) /* More than one GPAs can be mapped into a single HVA. So * iterate all possible umems here to be safe. */ - list_for_each_entry(u, &umem->umem_list, link) { - if (u->userspace_addr > hva - 1 + len || - u->userspace_addr - 1 + u->size < hva) + list_for_each_entry(u, &umem->list, link) { + if (u->addr > hva - 1 + len || + u->addr - 1 + u->size < hva) continue; - start = max(u->userspace_addr, hva); - end = min(u->userspace_addr - 1 + u->size, - hva - 1 + len); + start = max(u->addr, hva); + end = min(u->addr - 1 + u->size, hva - 1 + len); l = end - start + 1; r = log_write(vq->log_base, - u->start + start - u->userspace_addr, + u->start + start - u->addr, l); if (r < 0) return r; @@ -2046,9 +1972,9 @@ EXPORT_SYMBOL_GPL(vhost_vq_init_access); static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, struct iovec iov[], int iov_size, int access) { - const struct vhost_umem_node *node; + const struct vhost_iotlb_map *map; struct vhost_dev *dev = vq->dev; - struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem; + struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem; struct iovec *_iov; u64 s = 0; int ret = 0; @@ -2060,25 +1986,24 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, break; } - node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, - addr, addr + len - 1); - if (node == NULL || node->start > addr) { + map = vhost_iotlb_itree_first(umem, addr, addr + len - 1); + if (map == NULL || map->start > addr) { if (umem != dev->iotlb) { ret = -EFAULT; break; } ret = -EAGAIN; break; - } else if (!(node->perm & access)) { + } else if (!(map->perm & access)) { ret = -EPERM; break; } _iov = iov + ret; - size = node->size - addr + node->start; + size = map->size - addr + map->start; _iov->iov_len = min((u64)len - s, size); _iov->iov_base = (void __user *)(unsigned long) - (node->userspace_addr + addr - node->start); + (map->addr + addr - map->start); s += size; addr += size; ++ret; |