Merge branch 'md-fixes' of https://git.kernel.org/pub/scm/linux/kernel/git/song/md...
[linux-2.6-microblaze.git] / drivers / dax / bus.c
index e3a3841..2278000 100644 (file)
@@ -56,6 +56,25 @@ static int dax_match_id(struct dax_device_driver *dax_drv, struct device *dev)
        return match;
 }
 
+static int dax_match_type(struct dax_device_driver *dax_drv, struct device *dev)
+{
+       enum dax_driver_type type = DAXDRV_DEVICE_TYPE;
+       struct dev_dax *dev_dax = to_dev_dax(dev);
+
+       if (dev_dax->region->res.flags & IORESOURCE_DAX_KMEM)
+               type = DAXDRV_KMEM_TYPE;
+
+       if (dax_drv->type == type)
+               return 1;
+
+       /* default to device mode if dax_kmem is disabled */
+       if (dax_drv->type == DAXDRV_DEVICE_TYPE &&
+           !IS_ENABLED(CONFIG_DEV_DAX_KMEM))
+               return 1;
+
+       return 0;
+}
+
 enum id_action {
        ID_REMOVE,
        ID_ADD,
@@ -216,14 +235,9 @@ static int dax_bus_match(struct device *dev, struct device_driver *drv)
 {
        struct dax_device_driver *dax_drv = to_dax_drv(drv);
 
-       /*
-        * All but the 'device-dax' driver, which has 'match_always'
-        * set, requires an exact id match.
-        */
-       if (dax_drv->match_always)
+       if (dax_match_id(dax_drv, dev))
                return 1;
-
-       return dax_match_id(dax_drv, dev);
+       return dax_match_type(dax_drv, dev);
 }
 
 /*
@@ -427,8 +441,8 @@ static void unregister_dev_dax(void *dev)
        dev_dbg(dev, "%s\n", __func__);
 
        kill_dev_dax(dev_dax);
-       free_dev_dax_ranges(dev_dax);
        device_del(dev);
+       free_dev_dax_ranges(dev_dax);
        put_device(dev);
 }
 
@@ -1413,13 +1427,10 @@ err_id:
 }
 EXPORT_SYMBOL_GPL(devm_create_dev_dax);
 
-static int match_always_count;
-
 int __dax_driver_register(struct dax_device_driver *dax_drv,
                struct module *module, const char *mod_name)
 {
        struct device_driver *drv = &dax_drv->drv;
-       int rc = 0;
 
        /*
         * dax_bus_probe() calls dax_drv->probe() unconditionally.
@@ -1434,26 +1445,7 @@ int __dax_driver_register(struct dax_device_driver *dax_drv,
        drv->mod_name = mod_name;
        drv->bus = &dax_bus_type;
 
-       /* there can only be one default driver */
-       mutex_lock(&dax_bus_lock);
-       match_always_count += dax_drv->match_always;
-       if (match_always_count > 1) {
-               match_always_count--;
-               WARN_ON(1);
-               rc = -EINVAL;
-       }
-       mutex_unlock(&dax_bus_lock);
-       if (rc)
-               return rc;
-
-       rc = driver_register(drv);
-       if (rc && dax_drv->match_always) {
-               mutex_lock(&dax_bus_lock);
-               match_always_count -= dax_drv->match_always;
-               mutex_unlock(&dax_bus_lock);
-       }
-
-       return rc;
+       return driver_register(drv);
 }
 EXPORT_SYMBOL_GPL(__dax_driver_register);
 
@@ -1463,7 +1455,6 @@ void dax_driver_unregister(struct dax_device_driver *dax_drv)
        struct dax_id *dax_id, *_id;
 
        mutex_lock(&dax_bus_lock);
-       match_always_count -= dax_drv->match_always;
        list_for_each_entry_safe(dax_id, _id, &dax_drv->ids, list) {
                list_del(&dax_id->list);
                kfree(dax_id);