From 95ed63f791656fc19e36ae68bc328e367958c76b Mon Sep 17 00:00:00 2001
From: Arthur Kepner <akepner@sgi.com>
Date: Mon, 20 Mar 2006 21:26:56 -0800
Subject: [PATCH] [NET] pktgen: Fix races between control/worker threads.

There's a race in pktgen which can lead to a double
free of a pktgen_dev's skb. If a worker thread is in
the midst of doing fill_packet(), and the controlling
thread gets a "stop" message, the already freed skb
can be freed once again in pktgen_stop_device(). This
patch gives all responsibility for cleaning up a
pktgen_dev's skb to the associated worker thread.

Signed-off-by: Arthur Kepner <akepner@sgi.com>
Acked-by: Robert Olsson <Robert.Olsson@data.slu.se>
Signed-off-by: David S. Miller <davem@davemloft.net>
---
 net/core/pktgen.c | 135 +++++++++++++++++++++++++++++++++++++---------
 1 file changed, 110 insertions(+), 25 deletions(-)

diff --git a/net/core/pktgen.c b/net/core/pktgen.c
index da16f8fd149..6586321b018 100644
--- a/net/core/pktgen.c
+++ b/net/core/pktgen.c
@@ -153,7 +153,7 @@
 #include <asm/timex.h>
 
 
-#define VERSION  "pktgen v2.63: Packet Generator for packet performance testing.\n"
+#define VERSION  "pktgen v2.64: Packet Generator for packet performance testing.\n"
 
 /* #define PG_DEBUG(a) a */
 #define PG_DEBUG(a) 
@@ -176,7 +176,8 @@
 #define T_TERMINATE   (1<<0)  
 #define T_STOP        (1<<1)  /* Stop run */
 #define T_RUN         (1<<2)  /* Start run */
-#define T_REMDEV      (1<<3)  /* Remove all devs */
+#define T_REMDEVALL   (1<<3)  /* Remove all devs */
+#define T_REMDEV      (1<<4)  /* Remove one dev */
 
 /* Locks */
 #define   thread_lock()        down(&pktgen_sem)
@@ -218,6 +219,8 @@ struct pktgen_dev {
          * we will do a random selection from within the range.
          */
         __u32 flags;     
+	int removal_mark;	/* non-zero => the device is marked for
+				 * removal by worker thread */
 
         int min_pkt_size;    /* = ETH_ZLEN; */
         int max_pkt_size;    /* = ETH_ZLEN; */
@@ -481,7 +484,7 @@ static void pktgen_stop_all_threads_ifs(void);
 static int pktgen_stop_device(struct pktgen_dev *pkt_dev);
 static void pktgen_stop(struct pktgen_thread* t);
 static void pktgen_clear_counters(struct pktgen_dev *pkt_dev);
-static struct pktgen_dev *pktgen_NN_threads(const char* dev_name, int remove);
+static int pktgen_mark_device(const char* ifname);
 static unsigned int scan_ip6(const char *s,char ip[16]);
 static unsigned int fmt_ip6(char *s,const char ip[16]);
 
@@ -1406,7 +1409,7 @@ static ssize_t pktgen_thread_write(struct file *file,
 
         if (!strcmp(name, "rem_device_all")) {
 		thread_lock();
-		t->control |= T_REMDEV;
+		t->control |= T_REMDEVALL;
 		thread_unlock();
 		schedule_timeout_interruptible(msecs_to_jiffies(125));  /* Propagate thread->control  */
 		ret = count;
@@ -1457,7 +1460,8 @@ static struct pktgen_dev *__pktgen_NN_threads(const char* ifname, int remove)
 		if (pkt_dev) {
 		                if(remove) { 
 				        if_lock(t);
-				        pktgen_remove_device(t, pkt_dev);
+					pkt_dev->removal_mark = 1;
+					t->control |=  T_REMDEV;
 				        if_unlock(t);
 				}
 			break;
@@ -1467,13 +1471,44 @@ static struct pktgen_dev *__pktgen_NN_threads(const char* ifname, int remove)
         return pkt_dev;
 }
 
-static struct pktgen_dev *pktgen_NN_threads(const char* ifname, int remove) 
+/*
+ * mark a device for removal
+ */
+static int pktgen_mark_device(const char* ifname)
 {
 	struct pktgen_dev *pkt_dev = NULL;
+	const int max_tries = 10, msec_per_try = 125;
+	int i = 0;
+	int ret = 0;
+
 	thread_lock();
-	pkt_dev = __pktgen_NN_threads(ifname, remove);
-        thread_unlock();
-	return pkt_dev;
+        PG_DEBUG(printk("pktgen: pktgen_mark_device marking %s for removal\n",
+			ifname));
+
+	while(1) {
+
+		pkt_dev = __pktgen_NN_threads(ifname, REMOVE);
+		if (pkt_dev == NULL) break; /* success */
+
+		thread_unlock();
+		PG_DEBUG(printk("pktgen: pktgen_mark_device waiting for %s "
+			"to disappear....\n", ifname));
+		schedule_timeout_interruptible(msecs_to_jiffies(msec_per_try));
+		thread_lock();
+
+		if (++i >= max_tries) {
+			printk("pktgen_mark_device: timed out after waiting "
+				"%d msec for device %s to be removed\n",
+				msec_per_try*i, ifname);
+			ret = 1;
+			break;
+		}
+
+	}
+
+	thread_unlock();
+
+	return ret;
 }
 
 static int pktgen_device_event(struct notifier_block *unused, unsigned long event, void *ptr) 
@@ -1493,7 +1528,7 @@ static int pktgen_device_event(struct notifier_block *unused, unsigned long even
 		break;
 		
 	case NETDEV_UNREGISTER:
-                pktgen_NN_threads(dev->name, REMOVE);
+                pktgen_mark_device(dev->name);
 		break;
 	};
 
@@ -2303,11 +2338,11 @@ static void pktgen_stop_all_threads_ifs(void)
 {
         struct pktgen_thread *t = pktgen_threads;
 
-	PG_DEBUG(printk("pktgen: entering pktgen_stop_all_threads.\n"));
+	PG_DEBUG(printk("pktgen: entering pktgen_stop_all_threads_ifs.\n"));
 
 	thread_lock();
 	while(t) {
-		pktgen_stop(t);
+		t->control |= T_STOP;
 		t = t->next;
 	}
 	thread_unlock();
@@ -2431,7 +2466,9 @@ static void show_results(struct pktgen_dev *pkt_dev, int nr_frags)
 
 static int pktgen_stop_device(struct pktgen_dev *pkt_dev) 
 {
-	
+	int nr_frags = pkt_dev->skb ?
+			skb_shinfo(pkt_dev->skb)->nr_frags: -1;
+
         if (!pkt_dev->running) {
                 printk("pktgen: interface: %s is already stopped\n", pkt_dev->ifname);
                 return -EINVAL;
@@ -2440,13 +2477,8 @@ static int pktgen_stop_device(struct pktgen_dev *pkt_dev)
         pkt_dev->stopped_at = getCurUs();
         pkt_dev->running = 0;
 
-	show_results(pkt_dev, skb_shinfo(pkt_dev->skb)->nr_frags);
-
-	if (pkt_dev->skb) 
-		kfree_skb(pkt_dev->skb);
+	show_results(pkt_dev, nr_frags);
 
-	pkt_dev->skb = NULL;
-	
         return 0;
 }
 
@@ -2469,26 +2501,66 @@ static struct pktgen_dev *next_to_run(struct pktgen_thread *t )
 static void pktgen_stop(struct pktgen_thread *t) {
         struct pktgen_dev *next = NULL;
 
-	PG_DEBUG(printk("pktgen: entering pktgen_stop.\n"));
+	PG_DEBUG(printk("pktgen: entering pktgen_stop\n"));
 
         if_lock(t);
 
-        for(next=t->if_list; next; next=next->next)
+        for(next=t->if_list; next; next=next->next) {
                 pktgen_stop_device(next);
+		if (next->skb)
+			kfree_skb(next->skb);
+
+		next->skb = NULL;
+	}
 
         if_unlock(t);
 }
 
+/*
+ * one of our devices needs to be removed - find it
+ * and remove it
+ */
+static void pktgen_rem_one_if(struct pktgen_thread *t)
+{
+	struct pktgen_dev *cur, *next = NULL;
+
+	PG_DEBUG(printk("pktgen: entering pktgen_rem_one_if\n"));
+
+	if_lock(t);
+
+	for(cur=t->if_list; cur; cur=next) {
+		next = cur->next;
+
+		if (!cur->removal_mark) continue;
+
+		if (cur->skb)
+			kfree_skb(cur->skb);
+		cur->skb = NULL;
+
+		pktgen_remove_device(t, cur);
+
+		break;
+	}
+
+	if_unlock(t);
+}
+
 static void pktgen_rem_all_ifs(struct pktgen_thread *t) 
 {
         struct pktgen_dev *cur, *next = NULL;
-        
-        /* Remove all devices, free mem */
  
+        /* Remove all devices, free mem */
+
+	PG_DEBUG(printk("pktgen: entering pktgen_rem_all_ifs\n"));
         if_lock(t);
 
         for(cur=t->if_list; cur; cur=next) { 
 		next = cur->next;
+
+		if (cur->skb)
+			kfree_skb(cur->skb);
+		cur->skb = NULL;
+
 		pktgen_remove_device(t, cur);
 	}
 
@@ -2550,6 +2622,9 @@ static __inline__ void pktgen_xmit(struct pktgen_dev *pkt_dev)
 		
 		if (!netif_running(odev)) {
 			pktgen_stop_device(pkt_dev);
+			if (pkt_dev->skb)
+				kfree_skb(pkt_dev->skb);
+			pkt_dev->skb = NULL;
 			goto out;
 		}
 		if (need_resched()) 
@@ -2581,7 +2656,7 @@ static __inline__ void pktgen_xmit(struct pktgen_dev *pkt_dev)
 			pkt_dev->clone_count = 0; /* reset counter */
 		}
 	}
-	
+
 	spin_lock_bh(&odev->xmit_lock);
 	if (!netif_queue_stopped(odev)) {
 
@@ -2644,6 +2719,9 @@ retry_now:
                 
 		/* Done with this */
 		pktgen_stop_device(pkt_dev);
+		if (pkt_dev->skb)
+			kfree_skb(pkt_dev->skb);
+		pkt_dev->skb = NULL;
 	} 
  out:;
  }
@@ -2685,6 +2763,7 @@ static void pktgen_thread_worker(struct pktgen_thread *t)
 	t->control &= ~(T_TERMINATE);
 	t->control &= ~(T_RUN);
 	t->control &= ~(T_STOP);
+	t->control &= ~(T_REMDEVALL);
 	t->control &= ~(T_REMDEV);
 
         t->pid = current->pid;        
@@ -2748,8 +2827,13 @@ static void pktgen_thread_worker(struct pktgen_thread *t)
 			t->control &= ~(T_RUN);
 		}
 
-		if(t->control & T_REMDEV) {
+		if(t->control & T_REMDEVALL) {
 			pktgen_rem_all_ifs(t);
+			t->control &= ~(T_REMDEVALL);
+		}
+
+		if(t->control & T_REMDEV) {
+			pktgen_rem_one_if(t);
 			t->control &= ~(T_REMDEV);
 		}
 
@@ -2833,6 +2917,7 @@ static int pktgen_add_device(struct pktgen_thread *t, const char* ifname)
 	}
 	memset(pkt_dev->flows, 0, MAX_CFLOWS*sizeof(struct flow_state));
 
+	pkt_dev->removal_mark = 0;
 	pkt_dev->min_pkt_size = ETH_ZLEN;
 	pkt_dev->max_pkt_size = ETH_ZLEN;
 	pkt_dev->nfrags = 0;
-- 
2.25.4