1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import abc
1516import asyncio
1617import gzip
1718import mimetypes
2021import threading
2122from contextlib import closing
2223from http import HTTPStatus
24+ from pathlib import Path
2325
24- from twisted .internet import reactor
26+ from OpenSSL import crypto
27+ from twisted .internet import reactor , ssl
2528from twisted .web import http
2629
30+ _dirname = Path (os .path .join (os .path .dirname (__file__ )))
2731
28- def find_free_port ():
32+
33+ def _find_free_port ():
2934 with closing (socket .socket (socket .AF_INET , socket .SOCK_STREAM )) as s :
3035 s .bind (("" , 0 ))
3136 s .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
3237 return s .getsockname ()[1 ]
3338
3439
3540class Server :
41+ protocol = "http"
42+
3643 def __init__ (self ):
37- self .PORT = find_free_port ()
38- self .EMPTY_PAGE = f"http ://localhost:{ self .PORT } /empty.html"
39- self .PREFIX = f"http ://localhost:{ self .PORT } "
40- self .CROSS_PROCESS_PREFIX = f"http ://127.0.0.1:{ self .PORT } "
44+ self .PORT = _find_free_port ()
45+ self .EMPTY_PAGE = f"{ self . protocol } ://localhost:{ self .PORT } /empty.html"
46+ self .PREFIX = f"{ self . protocol } ://localhost:{ self .PORT } "
47+ self .CROSS_PROCESS_PREFIX = f"{ self . protocol } ://127.0.0.1:{ self .PORT } "
4148 # On Windows, this list can be empty, reporting text/plain for scripts.
4249 mimetypes .add_type ("text/html" , ".html" )
4350 mimetypes .add_type ("text/css" , ".css" )
@@ -48,6 +55,10 @@ def __init__(self):
4855 def __repr__ (self ) -> str :
4956 return self .PREFIX
5057
58+ @abc .abstractmethod
59+ def listen (self , factory ):
60+ pass
61+
5162 def start (self ):
5263 request_subscribers = {}
5364 auth = {}
@@ -59,7 +70,7 @@ def start(self):
5970 self .csp = csp
6071 self .routes = routes
6172 self .gzip_routes = gzip_routes
62- static_path = os . path . join ( os . path . dirname ( __file__ ), "assets" )
73+ static_path = _dirname / "assets"
6374
6475 class TestServerHTTPHandler (http .Request ):
6576 def process (self ):
@@ -116,15 +127,7 @@ class MyHttp(http.HTTPChannel):
116127 class MyHttpFactory (http .HTTPFactory ):
117128 protocol = MyHttp
118129
119- reactor .listenTCP (self .PORT , MyHttpFactory ())
120- self .thread = threading .Thread (
121- target = lambda : reactor .run (installSignalHandlers = 0 )
122- )
123- self .thread .start ()
124-
125- def stop (self ):
126- reactor .stop ()
127- self .thread .join ()
130+ self .listen (MyHttpFactory ())
128131
129132 async def wait_for_request (self , path ):
130133 if path in self .request_subscribers :
@@ -161,4 +164,47 @@ def handle_redirect(request):
161164 self .set_route (from_ , handle_redirect )
162165
163166
164- server = Server ()
167+ class HTTPServer (Server ):
168+ def listen (self , factory ):
169+ reactor .listenTCP (self .PORT , factory )
170+
171+
172+ class HTTPSServer (Server ):
173+ protocol = "https"
174+
175+ def listen (self , factory ):
176+ cert = ssl .PrivateCertificate .fromCertificateAndKeyPair (
177+ ssl .Certificate .loadPEM (
178+ (_dirname / "testserver" / "cert.pem" ).read_bytes ()
179+ ),
180+ ssl .KeyPair .load (
181+ (_dirname / "testserver" / "key.pem" ).read_bytes (), crypto .FILETYPE_PEM
182+ ),
183+ )
184+ contextFactory = cert .options ()
185+ reactor .listenSSL (self .PORT , factory , contextFactory )
186+
187+
188+ class TestServer :
189+ def __init__ (self ) -> None :
190+ self .server = HTTPServer ()
191+ self .https_server = HTTPSServer ()
192+
193+ def start (self ) -> None :
194+ self .server .start ()
195+ self .https_server .start ()
196+ self .thread = threading .Thread (
197+ target = lambda : reactor .run (installSignalHandlers = 0 )
198+ )
199+ self .thread .start ()
200+
201+ def stop (self ) -> None :
202+ reactor .stop ()
203+ self .thread .join ()
204+
205+ def reset (self ) -> None :
206+ self .server .reset ()
207+ self .https_server .reset ()
208+
209+
210+ test_server = TestServer ()
0 commit comments