From bb65eb3138044df8e4e351d5b721603b040f1778 Mon Sep 17 00:00:00 2001
From: Chris Hines <chris.hines@monash.edu>
Date: Tue, 2 Oct 2018 17:29:02 +1000
Subject: [PATCH] improved, fixed a shutdown bug wth teh connection was never
 going to authorise

---
 twsproxy/__init__.py        | 38 +++++++++++++++++++++++--------------
 twsproxy/__main__.py        |  3 +++
 twsproxy/server/__init__.py | 11 ++++++-----
 3 files changed, 33 insertions(+), 19 deletions(-)

diff --git a/twsproxy/__init__.py b/twsproxy/__init__.py
index 23a670c..10722c8 100644
--- a/twsproxy/__init__.py
+++ b/twsproxy/__init__.py
@@ -2,7 +2,9 @@ import threading
 import socket, array
 import select
 import logging
+#TES =  'http://localhost:8080/'
 TES =  'http://localhost:5000/'
+failthresh = 10
 class TWSProxy(threading.Thread):
 
     TIMEOUT = 10
@@ -30,7 +32,9 @@ class TWSProxy(threading.Thread):
         bytessofar = 0
         header=bytearray(TWSProxy.MAXHEADERS)
         keepreading = True
+        initcount=0
         while keepreading:
+            initcount=initcount+1
             r,w,e = select.select([self.csock],[],[],5)
             if len(r) > 0:
                 partial = self.csock.recv(TWSProxy.MAXBUFF)
@@ -42,6 +46,8 @@ class TWSProxy(threading.Thread):
             else:
                 port = TWSProxy.verifyauth(header[0:bytessofar])
                 keepreading = False
+            if initcount > failthresh:
+                keepreading = False
 
         if port is not None:
             self.ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -83,19 +89,17 @@ class TWSProxy(threading.Thread):
     def verifyauth(header):
         import re
         import requests
-        print("in verifyauth")
         token = b'twsproxyauth=(?P<authtok>\w+)[\W|$]'
         m = re.search(token,header)
         if m:
-            print(m.groupdict()['authtok'])
             authtok = m.groupdict()['authtok']
             s = requests.Session()
             url = TES+'tunnelstat/'+authtok.decode()
-            print(url)
-            r = s.get(url)
-
-            print(r.text)
-            port = r.json()
+            try:
+                r = s.get(url)
+                port = r.json()
+            except:
+                raise Exception('unable to get a port number for the authtok')
             return port
         return None
 #        if m:
@@ -175,21 +179,22 @@ class TWSProxy(threading.Thread):
         logger = logging.getLogger()
         closed = False
         name = threading.current_thread().name
-        while not closed:
+        failcount=0
+        failthresh = 10
+        while not closed and failcount < failthresh:
             r,w,e = select.select([src],[],[],TWSProxy.TIMEOUT)
             if len(r) > 0:
+                failcount=0
                 buff = None
                 msglength = -1
                 try:
                     buff = src.recv(TWSProxy.MAXBUFF)
                     if buff is None:
-                        print("buff is none ... is this normal?")
                         continue
                 except ConnectionResetError as e:
                     close = True
                     continue
                 except Exception as e:
-                    print(e)
                     import traceback
                     print(traceback.format_exc())
 #                    closed = True
@@ -201,15 +206,20 @@ class TWSProxy(threading.Thread):
                     dest.shutdown(shuttype)
                     initshutdown.set()
                     closed = True
-            if len(w) == 0 and len(r) == 0 and len(r) == 0:
-                if initshutdown.isSet():
-                    print("or possibly initshutdown")
+            else: 
+                failcount=failcount+1
+
+        if failcount > failthresh:
+            dest.shutdown(shuttype)
+            initshutdown.set()
 
 def main():
     from . import server
     import logging
+    logging.basicConfig(filename="/var/log/tws.log",format="%(asctime)s %(levelname)s:%(process)s: %(message)s")
     logger = logging.getLogger()
     logger.setLevel(logging.DEBUG)
-
+    logger.debug("starting TWS proxy")
     server = server.TWSServer()
+    logger.debug("initialised server object")
     server.run()
diff --git a/twsproxy/__main__.py b/twsproxy/__main__.py
index b03d1ff..aee234e 100644
--- a/twsproxy/__main__.py
+++ b/twsproxy/__main__.py
@@ -1,7 +1,10 @@
 from . import server
 import logging
+logging.basicConfig(filename="/var/log/tws.log",format="%(asctime)s %(levelname)s:%(process)s: %(message)s")
 logger = logging.getLogger()
 logger.setLevel(logging.DEBUG)
 
+logger.debug("starting TWS proxy")
+print("starting TWS proxy")
 server = server.TWSServer()
 server.run()
diff --git a/twsproxy/server/__init__.py b/twsproxy/server/__init__.py
index 18495f9..b46119b 100644
--- a/twsproxy/server/__init__.py
+++ b/twsproxy/server/__init__.py
@@ -1,5 +1,6 @@
 import socket
 from .. import TWSProxy
+import logging
 class TWSServer:
     import socket
 
@@ -7,6 +8,8 @@ class TWSServer:
     MAXCONN = 5
 
     def run(self):
+        logger = logging.getLogger()
+        logger.debug("starting up server")
         serversocket = socket.socket(
                     socket.AF_INET, socket.SOCK_STREAM)
         #bind the socket to a public host,
@@ -16,18 +19,17 @@ class TWSServer:
                 serversocket.bind(('127.0.0.1', port))
         #become a server socket
                 serversocket.listen(self.MAXCONN)
-                print("bind success on ",port)
+                logger.debug("Server listening on port {}".format(port))
                 break
             except Exception as e:
-                print("bind failure")
                 print(e)
                 pass
-        print("listening on port",port)
         openconnections = []
+        logger.debug("waiting for a connection")
         while 1:
             (clientsocket, address) = serversocket.accept()
+            logger.debug("accepted connection on {}".format(clientsocket))
             clientsocket.setblocking(True)
-            print("accepting connection")
             tunnel = TWSProxy(clientsocket)
             tunnel.daemon = True
             tunnel.start()
@@ -36,5 +38,4 @@ class TWSServer:
                 if not c.is_alive():
                     c.join()
                     openconnections.remove(c)
-            print("there are ",len(openconnections),"current connections")
 
-- 
GitLab