Commit dbcc112e authored by Joerg Roedel's avatar Joerg Roedel Committed by Ingo Molnar

AMD IOMMU: check for invalid device pointers

Currently AMD IOMMU code triggers a BUG_ON if NULL is passed as the
device. This is inconsistent with other IOMMU implementations.
Signed-off-by: default avatarJoerg Roedel <joerg.roedel@amd.com>
Signed-off-by: default avatarIngo Molnar <mingo@elte.hu>
parent 07a2c01a
...@@ -645,6 +645,18 @@ static void set_device_domain(struct amd_iommu *iommu, ...@@ -645,6 +645,18 @@ static void set_device_domain(struct amd_iommu *iommu,
* *
*****************************************************************************/ *****************************************************************************/
/*
* This function checks if the driver got a valid device from the caller to
* avoid dereferencing invalid pointers.
*/
static bool check_device(struct device *dev)
{
if (!dev || !dev->dma_mask)
return false;
return true;
}
/* /*
* In the dma_ops path we only have the struct device. This function * In the dma_ops path we only have the struct device. This function
* finds the corresponding IOMMU, the protection domain and the * finds the corresponding IOMMU, the protection domain and the
...@@ -661,18 +673,19 @@ static int get_device_resources(struct device *dev, ...@@ -661,18 +673,19 @@ static int get_device_resources(struct device *dev,
struct pci_dev *pcidev; struct pci_dev *pcidev;
u16 _bdf; u16 _bdf;
BUG_ON(!dev || dev->bus != &pci_bus_type || !dev->dma_mask); *iommu = NULL;
*domain = NULL;
*bdf = 0xffff;
if (dev->bus != &pci_bus_type)
return 0;
pcidev = to_pci_dev(dev); pcidev = to_pci_dev(dev);
_bdf = calc_devid(pcidev->bus->number, pcidev->devfn); _bdf = calc_devid(pcidev->bus->number, pcidev->devfn);
/* device not translated by any IOMMU in the system? */ /* device not translated by any IOMMU in the system? */
if (_bdf > amd_iommu_last_bdf) { if (_bdf > amd_iommu_last_bdf)
*iommu = NULL;
*domain = NULL;
*bdf = 0xffff;
return 0; return 0;
}
*bdf = amd_iommu_alias_table[_bdf]; *bdf = amd_iommu_alias_table[_bdf];
...@@ -826,6 +839,9 @@ static dma_addr_t map_single(struct device *dev, phys_addr_t paddr, ...@@ -826,6 +839,9 @@ static dma_addr_t map_single(struct device *dev, phys_addr_t paddr,
u16 devid; u16 devid;
dma_addr_t addr; dma_addr_t addr;
if (!check_device(dev))
return bad_dma_address;
get_device_resources(dev, &iommu, &domain, &devid); get_device_resources(dev, &iommu, &domain, &devid);
if (iommu == NULL || domain == NULL) if (iommu == NULL || domain == NULL)
...@@ -860,7 +876,8 @@ static void unmap_single(struct device *dev, dma_addr_t dma_addr, ...@@ -860,7 +876,8 @@ static void unmap_single(struct device *dev, dma_addr_t dma_addr,
struct protection_domain *domain; struct protection_domain *domain;
u16 devid; u16 devid;
if (!get_device_resources(dev, &iommu, &domain, &devid)) if (!check_device(dev) ||
!get_device_resources(dev, &iommu, &domain, &devid))
/* device not handled by any AMD IOMMU */ /* device not handled by any AMD IOMMU */
return; return;
...@@ -910,6 +927,9 @@ static int map_sg(struct device *dev, struct scatterlist *sglist, ...@@ -910,6 +927,9 @@ static int map_sg(struct device *dev, struct scatterlist *sglist,
phys_addr_t paddr; phys_addr_t paddr;
int mapped_elems = 0; int mapped_elems = 0;
if (!check_device(dev))
return 0;
get_device_resources(dev, &iommu, &domain, &devid); get_device_resources(dev, &iommu, &domain, &devid);
if (!iommu || !domain) if (!iommu || !domain)
...@@ -967,7 +987,8 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist, ...@@ -967,7 +987,8 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist,
u16 devid; u16 devid;
int i; int i;
if (!get_device_resources(dev, &iommu, &domain, &devid)) if (!check_device(dev) ||
!get_device_resources(dev, &iommu, &domain, &devid))
return; return;
spin_lock_irqsave(&domain->lock, flags); spin_lock_irqsave(&domain->lock, flags);
...@@ -999,6 +1020,9 @@ static void *alloc_coherent(struct device *dev, size_t size, ...@@ -999,6 +1020,9 @@ static void *alloc_coherent(struct device *dev, size_t size,
u16 devid; u16 devid;
phys_addr_t paddr; phys_addr_t paddr;
if (!check_device(dev))
return NULL;
virt_addr = (void *)__get_free_pages(flag, get_order(size)); virt_addr = (void *)__get_free_pages(flag, get_order(size));
if (!virt_addr) if (!virt_addr)
return 0; return 0;
...@@ -1047,6 +1071,9 @@ static void free_coherent(struct device *dev, size_t size, ...@@ -1047,6 +1071,9 @@ static void free_coherent(struct device *dev, size_t size,
struct protection_domain *domain; struct protection_domain *domain;
u16 devid; u16 devid;
if (!check_device(dev))
return;
get_device_resources(dev, &iommu, &domain, &devid); get_device_resources(dev, &iommu, &domain, &devid);
if (!iommu || !domain) if (!iommu || !domain)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment