"""
Helpers for testing code using gavo.formal
"""
#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL. See the
#c COPYING file in the source distribution.
import io
import math
from twisted.internet import defer
from twisted.python import failure
from twisted.python import urlpath
from twisted.trial.unittest import TestCase as TrialTest
from twisted.web import http
from twisted.web import resource
from twisted.web import server
from twisted.web.test.requesthelper import DummyRequest
[docs]def debug(arg):
import pdb; pdb.set_trace()
return arg
[docs]def bytify(s):
"""returns s utf-8 encoded if it is a string, unmodified otherwise.
"""
if isinstance(s, str):
return s.encode("utf-8")
return s
[docs]def debytify(b):
"""returns b utf-8 decoded if it is bytes, unmodified otherwise.
"""
if isinstance(b, bytes):
return b.decode("utf-8")
return b
[docs]def assertHasStrings(content, strings, inverse=False):
"""asserts that all strings in the list strings are in content.
If inverse is True, it asserts the strings are not in content.
For generality, both content and strings will be bytified if they're
not bytes already, and only then compared.
When the assertion fails, the bytified content will be dumped to a
file remote.data in the current directory.
"""
content = bytify(content)
try:
for s in strings:
if inverse:
assert bytify(s) not in content, f"'{s}' in remote.data"
else:
assert bytify(s) in content, f"'{s}' not in remote.data"
except AssertionError:
with open("remote.data", "wb") as f:
f.write(content)
raise
[docs]def raiseException(failure):
failure.raiseException()
[docs]def bytify_seq(s):
"""returns a list from s with any elements bytified.
s must be a sequence or None. Oh, for convenience we also accept plain
bytes and strings.
"""
if s is None:
return None
if isinstance(s, (bytes, str)):
s = [s]
if not isinstance(s, (list, tuple)):
raise Exception("bytify_seq really wants a sequence, not %s"%repr(s))
return [bytify(item) for item in s]
[docs]class FakeFile:
"""a fake file upload.
Construct this with a file name and a payload, both of which can be
bytes or str (which will be utf-8-encoded).
"""
def __init__(self, file_name, payload):
self.file_name = bytify(file_name)
self.file_object = io.BytesIO(bytify(payload))
[docs]class FakeRequest(DummyRequest):
"""A request for test purposes.
We furnish t.w's DummyRequest with some extra facilities to let us
be a bit lazy in having rather macro tests.
Also, stock twisted DummyRequest produces an endless loop with push
producers (which is what we have), so we fix that, too.
You can pass in args as a str -> str mapping; the strings will be
encoded as utf-8 so request.args is bytes -> [bytes]. For
convenience, we'll turn single values to lists.
For uploads, you can pass (single) args with FakeFile-valued arguments.
"""
method = b"GET"
session = None
startedWriting = 0
# some code tests for a live connection using client
client = True
def __init__(self, uri=b'', headers=None, args=None, avatar=None,
currentSegments=None, cookies=None,
user="", password="", isSecure=False):
uri = bytify(uri)
if uri.startswith(b"/"):
uri = uri[1:]
postpath = []
if uri:
postpath = uri.split(b"/")
DummyRequest.__init__(self, uri)
self.files, self.args = {}, {}
if args:
for k,v in args.items():
if isinstance(v, FakeFile):
self.files[debytify(k)] = [v]
else:
self.args[bytify(k)] = bytify_seq(v)
else:
self.args = {}
self.uri = uri
self.postpath = postpath
self.code = 200
self.user, self.password = user, password
self.deferred = defer.Deferred()
self.accumulator = b""
self.prepath = []
self.finished = False
self.secure = False
self.channel = 1 # must be non-None for custom hangup detection
self.lastModified = None
[docs] def write(self, data):
if not self.startedWriting:
if self.lastModified is not None:
self.responseHeaders.setRawHeaders(
b"last-modified",
[http.datetimeToString(self.lastModified)])
self.startedWriting = True
self.accumulator += bytify(data)
[docs] def notifyFinish(self):
return self.deferred
[docs] def prePathURL(self):
return 'http://%s/%s'%(self.getHost(), '/'.join(self.prepath))
[docs] def setLastModified(self, when):
# copied from twisted.web.server.Request
when = int(math.ceil(when))
if (not self.lastModified) or (self.lastModified < when):
self.lastModified = when
modifiedSince = self.getHeader(b"if-modified-since")
if modifiedSince:
firstPart = modifiedSince.split(b";", 1)[0]
try:
modifiedSince = http.stringToDatetime(firstPart)
except ValueError:
return None
if modifiedSince >= self.lastModified:
self.setResponseCode(http.NOT_MODIFIED)
return http.CACHED
return None
[docs] def finish(self):
self.finished = True
self.deferred.callback((self.accumulator, self))
[docs] def finishCallback(self, arg):
if isinstance(arg, failure.Failure):
arg.raiseException()
if not self.finished:
self.finish()
[docs] def setHost(self, host, port):
self.host = host
[docs] def getHost(self):
return self.host
[docs] def setResponseCode(self, code):
self.code = code
[docs] def URLPath(self):
return urlpath.URLPath.fromString(self.path.decode("utf-8"))
@property
def path(self):
return self.uri
[docs] def isSecure(self):
return self.secure
[docs] def getLocationValue(self):
"""returns a location header if this requests redirects, and raises
an AssertionError otherwise.
"""
if not self.code or self.code//100!=3:
raise AssertionError("Trying to get a redirection target for"
" request with status %s"%self.code)
return self.getResponseHeader("location")
[docs] def processWithRoot(self, page):
"""runs this request on page.
This is probably a bad idea all around, and we should just be using
trial. But since sync tests are quite a bit more convenient,
here this is. Of course, it only works if resource effectively
renders sync (or has a renderSync method).
"""
rsc = resource.getChildForRequest(page, self)
res = getattr(rsc, "renderSync", rsc.render)(self)
if res:
if isinstance(res, int) and res==server.NOT_DONE_YET:
# this will only work if the thing is actually sync.
# see servicetest._syncvosi for an inspration there.
# But in that case, accumulator will have it all.
pass
else:
return res
return self.accumulator
[docs] def registerProducer(self, producer, isPush):
self.producer = producer
if not isPush:
DummyRequest.registerProducer(
self, producer, isPush)
[docs] def unregisterProducer(self):
# stop twisted pull producers, too
self.go = 0
self.channel = None
del self.producer
[docs] def addUpload(self, name, content):
self.files.setdefault(name, []).append(
FakeFile(name, content))
def _doRender(page, request):
result = page.render(request)
if isinstance(result, int) and result==server.NOT_DONE_YET:
# the thing is set up in a way that eventually some deferred
# will fire and complete
return request.deferred
elif isinstance(result, bytes):
request.write(result)
request.finish()
return request.deferred
else:
raise Exception("Unsupported render result: %s"%repr(result))
def _buildRequest(
method,
path,
args,
moreHeaders=None,
requestClass=None):
if requestClass is None:
requestClass = FakeRequest
req = requestClass(path, args=args)
req.headers = {}
if moreHeaders:
for k, v in moreHeaders.items():
req.requestHeaders.setRawHeaders(k, [v])
req.method = bytify(method)
return req
[docs]def runQuery(page,
method,
path,
args,
moreHeaders=None,
requestMogrifier=None,
requestClass=None,
produceErrorDocument=None):
"""runs a query on a page.
The query should look like it's coming from localhost.
The thing returns a deferred firing a pair of the result (a string)
and the request (from which you can glean headers and such).
errorHandler must be a callable accepting a failure and the request
if you want to exercise your error handling, too. If you don't
pass it in, exceptions during request handling will be re-raised.
"""
req = _buildRequest(
method, path, args, moreHeaders=moreHeaders, requestClass=requestClass)
if requestMogrifier is not None:
requestMogrifier(req)
try:
rsc = resource.getChildForRequest(page, req)
return _doRender(rsc, req)
except Exception as ex:
if produceErrorDocument:
produceErrorDocument(failure.Failure(ex), req)
return req.deferred
raise
[docs]class RenderTest(TrialTest):
"""a base class for tests of twisted web resources.
"""
renderer = None # Override with the resource to be tested.
errorHandler = None # override with a runQuery produceErrorDocument
runQuery = staticmethod(runQuery)
[docs] def assertStringsIn(self, result, strings, inverse=False,
customTest=None):
# this wraps testhelpers.assertHasStrings to work better with
# twisted results; in particular, we need to return the result.
content = result[0]
assertHasStrings(content, strings, inverse)
try:
if customTest is not None:
customTest(content)
except AssertionError:
with open("remote.data", "wb") as f:
f.write(content)
raise
return result
[docs] def assertResultHasStrings(self, method, path, args, strings,
rm=None, inverse=False, customTest=None):
return self.runQuery(
self.renderer, method,
path, args,
requestMogrifier=rm,
produceErrorDocument=self.errorHandler
).addCallback(self.assertStringsIn, strings, inverse=inverse,
customTest=customTest)
[docs] def assertGETHasStrings(self, path, args, strings, rm=None,
customTest=None):
return self.assertResultHasStrings("GET",
path, args, strings, rm=rm, customTest=customTest)
[docs] def assertGETLacksStrings(self, path, args, strings, rm=None):
return self.assertResultHasStrings("GET",
path, args, strings, rm=rm, inverse=True)
[docs] def assertPOSTHasStrings(self, path, args, strings, rm=None):
return self.assertResultHasStrings("POST", path, args, strings,
rm=rm)
[docs] def assertStatus(self, path, status, args={}, rm=None):
def check(res):
self.assertEqual(res[1].code, status)
return res
return self.runQuery(
self.renderer, "GET",
path, args,
requestMogrifier=rm,
produceErrorDocument=self.errorHandler
).addCallback(check)
[docs] def assertGETRaises(self, path, args, exc, alsoCheck=None):
def cb(res):
raise AssertionError("%s not raised (returned %s instead)"%(
exc, res))
def eb(flr):
flr.trap(exc)
if alsoCheck is not None:
alsoCheck(flr)
return self.runQuery(self.renderer, "GET",
path, args,
produceErrorDocument=self.errorHandler
).addCallback(cb
).addErrback(eb)
[docs] def assertGETIsValid(self, path, args={}):
return self.runQuery(self.renderer, "GET",
path, args,
produceErrorDocument=self.errorHandler
).addCallback(self.assertResponseIsValid)