"""
Helper functions and classes for unit tests and similar.
Whatever is useful to unit tests from here should be imported into
testhelpers, too. Unit test modules should not be forced to import
this.
"""
#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL. See the
#c COPYING file in the source distribution.
import contextlib
import gzip
import os
import re
import tempfile
from lxml import etree
from gavo import base
from gavo import utils
from gavo.formal.testing import ( #noflake: exported names
FakeRequest, assertHasStrings)
from gavo.utils import stanxml
def _nukeNamespaces(xmlString):
"""removes namespace declarations from xmlString (which must be bytes).
This is for writing more compact tests and should of course not be
used outside of tests; in particular, you could easily fool the
mechanism to wreck your XML.
This always returns bytes.
"""
nsCleaner = re.compile(rb'^(</?)(?:[a-z0-9]+:)')
return re.sub(b"(?s)<[^>]*>",
lambda mat: nsCleaner.sub(rb"\1", mat.group()),
re.sub(b'xmlns="[^"]*"', b"", xmlString))
class _WrappedEtree:
"""a wrapper adding a few methods to an lxml etree.
This is done as a wrapper because you can't monkeypatch _Element.
See getXMLTree for what this about; it's essentially an implementation
detail of that function.
"""
def __init__(self, tree):
self._tree = tree
def __getattr__(self, name):
return getattr(self._tree, name)
def __getitem__(self, index):
return self._tree[index]
def uniqueXpath(self, path):
res = self.xpath(path)
assert len(res)==1, "Xpath %s gave %d matches"%(path, len(res))
return res[0]
def getById(self, id):
return self.uniqueXpath("//*[@id='%s']"%id)
def getByID(self, id):
return self.uniqueXpath("//*[@ID='%s']"%id)
def asString(self):
return etree.tostring(self._tree).decode("utf-8")
[docs]def getXMLTree(xmlString, debug=False):
"""returns an ``libxml2`` etree-like object for ``xmlString``, where,
for convenience, all namespaces on elements are nuked.
This will only accept strings.
The libxml2 etree lets you do xpath searching using the ``xpath`` method.
Nuking namespaces is of course not a good idea in general, so you
might want to think again before you use this in production code.
To facilitate writing tests, in addition to lxml.etree methods the returned
object also has the following methods:
* uniqueXpath(xpath), returning a single match if it's there
and raises an assertion error otherwise.
* getById(id), returning the unique element with id and raising an
assertion error if that doesn't exist.
* getByID(id), as getById, by for VOTable-style ID.
* asString(), returning a string representation of the tree
"""
tree = etree.fromstring(_nukeNamespaces(xmlString))
if debug:
etree.dump(tree)
return _WrappedEtree(tree)
[docs]class XSDResolver(etree.Resolver):
"""A resolver for external entities only returning in-tree files.
"""
def __init__(self):
self.basePath = "schemata"
[docs] def getPathForName(self, name):
xsdName = name.split("/")[-1]
return base.getPathForDistFile(
os.path.join(self.basePath, xsdName))
[docs] def resolve(self, url, pubid, context):
try:
# resolve namespace URIs, too
try:
url = stanxml.NSRegistry.getSchemaForNS(url)
except base.NotFoundError:
# it's not a (known) namespace URI, try on
pass
path = self.getPathForName(url)
res = self.resolve_filename(path, context)
if res is not None:
return res
except:
pass # fall through to error message
base.ui.notifyError("Did not find local file for schema %s --"
" this will fall back to network resources and thus probably"
" be slow"%url)
RESOLVER = XSDResolver()
XSD_PARSER = etree.XMLParser()
XSD_PARSER.resolvers.add(RESOLVER)
[docs]@contextlib.contextmanager
def MyParser():
if etree.get_default_parser is XSD_PARSER:
yield
else:
etree.set_default_parser(XSD_PARSER)
try:
yield
finally:
etree.set_default_parser()
[docs]class QNamer(object):
"""A hack that generates QNames through getattr.
Construct with the desired namespace.
"""
def __init__(self, ns):
self.ns = ns
def __getattr__(self, name):
return etree.QName(self.ns, name.strip("_"))
XS = QNamer("http://www.w3.org/2001/XMLSchema")
VO_SCHEMATA = [
"Characterisation.xsd",
"Colstats.xsd", # remove once it's in VODataService
"ConeSearch.xsd",
"DaCHS.xsd",
"DataModel.xsd",
"DocRegExt.xsd",
"eudat-core.xsd",
"oai_dc.xsd",
"OAI-PMH.xsd",
"RegistryInterface.xsd",
"SIA.xsd",
"SLAP.xsd",
"SSA.xsd",
"StandardsRegExt.xsd",
"stc.xsd",
"stc-v1.20.xsd",
"coords-v1.20.xsd",
"region-v1.20.xsd",
"TAPRegExt.xsd",
"UWS.xsd",
"VODataService.xsd",
"VOEvent.xsd",
"VOEventRegExt.xsd",
"VORegistry.xsd",
"VOResource.xsd",
"VOSIAvailability.xsd",
"VOSICapabilities.xsd",
"VOSITables.xsd",
"VOTable-1.1.xsd",
"VOTable-1.2.xsd",
"VOTable.xsd",
"mivot.xsd",
"vo-dml.xsd",
"xlink.xsd",
"XMLSchema.xsd",
"xml.xsd",]
[docs]def getJointValidator(schemaPaths):
"""returns an lxml validator containing the schemas in schemaPaths.
schemaPaths must be actual file paths, absolute or
trunk/schema-relative.
"""
with MyParser():
subordinates = []
for fName in schemaPaths:
fPath = RESOLVER.getPathForName(fName)
root = etree.parse(fPath).getroot()
subordinates.append((
"http://vo.ari.uni-heidelberg.de/docs/schemata/"+fName,
root.get("targetNamespace")))
root = etree.Element(
XS.schema, attrib={"targetNamespace": "urn:combiner"})
for schemaLocation, tns in subordinates:
etree.SubElement(root, XS.import_, attrib={
"namespace": tns, "schemaLocation": schemaLocation})
doc = etree.ElementTree(root)
return etree.XMLSchema(doc)
[docs]def getDefaultValidator(extraSchemata=[]):
"""returns a validator that knows the schemata typically useful within
the VO.
This will currently only work if DaCHS is installed from an SVN
checkout with setup.py develop.
What's returned has a method assertValid(et) that raises an exception
if the elementtree et is not valid. You can simply call it to
get back True for valid and False for invalid.
"""
return getJointValidator(VO_SCHEMATA+extraSchemata)
def _makeLXMLValidator():
"""returns an lxml-based schema validating function for the VO XSDs
This is not happening at import time as it is time-consuming, and the
DaCHS server probably doesn't even validate anything.
This is used below to build getXSDErrorsLXML.
"""
VALIDATOR = getDefaultValidator()
def getErrors(data, leaveOffending=False):
"""returns error messages for the XSD validation of the string in data.
"""
try:
with MyParser():
if hasattr(data, "xpath"):
# we believe it's already parsed stuff
tree = data
else:
tree = etree.fromstring(data)
if VALIDATOR.validate(tree):
return None
else:
if leaveOffending:
if hasattr(data, "xpath"):
data = etree.tostring(data, encoding="utf-8")
with open("badDocument.xml", "wb") as of:
of.write(data)
return str(VALIDATOR.error_log)
except Exception as msg:
return str(msg)
return getErrors
[docs]def getXSDErrorsLXML(data, leaveOffending=False):
"""returns error messages for the XSD validation of the string in data.
"""
if not hasattr(getXSDErrorsLXML, "validate"):
getXSDErrorsLXML.validate = _makeLXMLValidator()
return getXSDErrorsLXML.validate(data, leaveOffending)
getXSDErrors = getXSDErrorsLXML
[docs]class XSDTestMixin(object):
"""provides a assertValidates method doing XSD validation.
assertValidates raises an assertion error with the validator's
messages on an error. You can optionally pass a leaveOffending
argument to make the method store the offending document in
badDocument.xml.
"""
[docs] def assertValidates(self, xmlSource, leaveOffending=False):
messages = getXSDErrors(xmlSource, leaveOffending)
if messages:
raise AssertionError(messages)
[docs]def getMemDiffer(ofClass=base.Structure):
"""returns a function to call that returns a list of new DaCHS structures
since this was called.
If you watch everything, things get hairy because of course the state
of this function (for instance) also creates references. Hence, pass
ofClass to choose what the function will track.
This will call a gc.collect itself (and wouldn't make sense without that)
"""
import gc
gc.collect()
seen_ids = set()
for ob in gc.get_objects():
try:
if isinstance(ob, ofClass):
seen_ids.add(id(ob))
except ReferenceError:
# object is already essentially gone, don't worry about it.
pass
del ob
def getNewObjects():
gc.collect()
newObjects = []
for ob in gc.get_objects():
try:
if id(ob) not in seen_ids and isinstance(ob, ofClass):
newObjects.append(ob)
except ReferenceError:
# again, don't worry about disappearing objects
pass
return newObjects
return getNewObjects
[docs]def getUnreferenced(items):
"""returns a list of elements in items that do not have a reference
from any other in items.
"""
import gc
itemids = set(id(i) for i in items)
unreferenced = []
for i in items:
intrefs = set(id(r) for r in gc.get_referrers(i)) & itemids
if not intrefs:
unreferenced.append(i)
return unreferenced
[docs]def debugReferenceChain(ob):
"""a sort-of-interactive way to investigate where ob is referenced.
* d -- enter pdb (look at ob, perhaps at nob)
* u -- follow
* x -- continue execution
"""
import gc
while True:
print("Current object: ", repr(ob))
refs = gc.get_referrers(ob)
if not refs:
print("Not referenced -- exiting")
break
while refs:
nob = refs.pop()
print(len(refs), utils.makeEllipsis(repr(nob)))
res = input("?")
if res=="d":
import pdb;pdb.Pdb(nosigint=True).set_trace()
elif res=="x":
return
elif res=="u":
ob = nob
break
elif res=="?":
print("d, x, u, <empty>")
elif not refs:
print("Referrers exhausted, warping")
NEWIDS = set()
[docs]def memdebug(watchClass=base.Structure):
"""a debug method to track memory usage after some code has run.
This is typically run from ArchiveService.getChild, since request
processing should be idempotent wrt memory after initial caching.
This is for editing in place by DaCHS plumbers; accordingly, you're
not supposed to make sense of this.
"""
import gc
print(">>>>>> total managed:", len(gc.get_objects()))
if hasattr(base, "getNewStructs"):
ns = base.getNewStructs()
print(">>>>>> new objects:", len(ns))
if len(ns)<11000:
ur = getUnreferenced(ns)
print(">>>>>> new externally referenced:", len(ur))
del ur
print([ob for ob in ns if isinstance(ob, watchClass)])
if True:
try:
debugReferenceChain(
[ob for ob in ns if isinstance(ob, watchClass)][0])
except IndexError:
pass
base.getNewStructs = getMemDiffer(ofClass=watchClass)
[docs]@contextlib.contextmanager
def testFile(name,
content,
writeGz=False,
inDir=base.getConfig("tempDir"),
timestamp=None):
"""a context manager that creates a file name with content in inDir.
The full path name is returned.
content can be bytes or str; in the latter case, it's utf-8 encoded
before writing.
With writeGz=True, content is gzipped on the fly (don't do this if
the data already is gzipped).
You can pass in name=None to get a temporary file name if you don't care
about the name.
inDir will be created as a side effect if it doesn't exist but (right
now, at least), not be removed.
"""
if not os.path.isdir(inDir):
os.makedirs(inDir)
if name is None:
handle, destName = tempfile.mkstemp(dir=inDir)
os.close(handle)
else:
destName = os.path.join(inDir, name)
if writeGz:
f = gzip.GzipFile(destName, mode="wb")
else:
f = open(destName, "wb")
f.write(utils.bytify(content))
f.close()
if timestamp:
os.utime(destName, times=(timestamp, timestamp))
try:
yield destName
finally:
try:
os.unlink(destName)
except os.error:
pass
[docs]@contextlib.contextmanager
def collectedEvents(*kinds):
"""a context manager collecting event arguments for a while.
The yielded thing is a list that contains tuples of event name and
the event arguments.
"""
collected = []
def makeHandler(evType):
def handler(*args):
collected.append((evType,)+args)
return handler
handlers = [(kind, makeHandler(kind)) for kind in kinds]
for kind, handler in handlers:
base.ui.subscribe(kind, handler)
try:
yield collected
finally:
for kind, handler in handlers:
base.ui.unsubscribe(kind, handler)