"""
An abstract processor and some helping code.
Currently, I assume a plain text interface for those. It might be
a good idea to use the event mechanism here.
"""
#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL. See the
#c COPYING file in the source distribution.
import io
import os
import sys
import textwrap
import threading
import traceback
from contextlib import contextmanager
from PIL import Image
from gavo import base
from gavo import rsc
from gavo import utils
from gavo.helpers import anet
from gavo.helpers import fitstricks
from gavo.utils import fitstools
from gavo.utils import pyfits
# matplotlib is an expensive import. Only do that if we're sure
# we need it.
matplotlib = utils.DeferredImport("matplotlib", "import matplotlib")
figure = utils.DeferredImport("figure", "from matplotlib import figure")
_MPL_LOCK = threading.RLock()
[docs]@contextmanager
def matplotlibLock():
_MPL_LOCK.acquire()
matplotlib.use("Agg")
try:
yield
finally:
_MPL_LOCK.release()
[docs]class FileProcessor(object):
"""An abstract base for a source file processor.
In concrete classes, you need to define a ``process(fName)`` method
receiving a source as returned by the dd (i.e., usually a file name).
You can override the method ``_createAuxiliaries(dataDesc)`` to compute
things like source catalogues, etc. Thus, you should not need to
override the constructor.
These objects are usually constructed thorough ``api.procmain`` as
discussed in :dachsdoc:`processing.html`.
"""
inputsDir = base.getConfig("inputsDir")
def __init__(self, opts, dd):
self.opts, self.dd = opts, dd
self._createAuxiliaries(dd)
def _createAuxiliaries(self, dd):
# There's been a typo here in previous DaCHS versions; try
# to call old methods if they are still there
if hasattr(self, "_createAuxillaries"):
self._createAuxillaries(dd)
[docs] def classify(self, fName):
return "unknown"
[docs] def process(self, fName):
pass
[docs] def addClassification(self, fName):
label = self.classify(fName)
self.reportDict.setdefault(label, []).append(os.path.basename(fName))
[docs] def printTableSize(self):
try:
tableName = self.dd.makes[0].table.getQName()
with base.AdhocQuerier(base.getAdminConn) as q:
itemsInDB = list(q.query("SELECT count(*) from %s"%tableName))[0][0]
print("Items currently in assumed database table: %d\n"%itemsInDB)
except (base.DBError, IndexError):
pass
[docs] def printReport(self, processed, ignored):
print("\n\nProcessor Report\n================\n")
if ignored:
print("Warning: There were %d errors during classification"%ignored)
repData = sorted(zip([len(l) for l in list(self.reportDict.values())],
list(self.reportDict.keys())))
print(utils.formatSimpleTable(repData))
print("\n")
self.printTableSize()
[docs] def printVerboseReport(self, processed, ignored):
print("\n\nProcessor Report\n================\n")
if ignored:
print("Warning: There were %d errors during classification"%ignored)
repData = sorted(zip(list(self.reportDict.values()),
list(self.reportDict.keys())),
key=lambda v: -len(v[0]))
print("\n%s\n%s\n"%(repData[0][1], "-"*len(repData[0][1])))
print("%d items\n"%(len(repData[0][0])))
for items, label in repData[1:]:
print("\n%s\n%s\n"%(label, "-"*len(label)))
items.sort()
print("%d items:\n"%(len(items)))
print("\n".join(textwrap.wrap(", ".join(items))))
print("\n")
self.printTableSize()
[docs] @staticmethod
def addOptions(parser):
parser.add_option("--filter", dest="requireFrag", metavar="STR",
help="Only process files with names containing STR", default=None)
parser.add_option("--bail", help="Bail out on a processor error,"
" dumping a traceback", action="store_true", dest="bailOnError",
default=False)
parser.add_option("--report", help="Output a report only",
action="store_true", dest="doReport", default=False)
parser.add_option("--verbose", help="Be more talkative",
action="store_true", dest="beVerbose", default=False)
parser.add_option("--n-procs", "-j", help="Run NUM processes in"
" parallel", action="store", dest="nParallel", default=1,
metavar="NUM", type=int)
_doneSentinel = ("MAGIC: QUEUE DONE",)
[docs] def iterJobs(self, nParallel):
"""executes process() in parallel for all sources and iterates
over the results.
We use this rather than multiprocessing's Pool, as that cannot
call methods. I'm working around this here.
"""
import multiprocessing
taskQueue = multiprocessing.Queue(nParallel*4)
doneQueue = multiprocessing.Queue()
def worker(inQueue, outQueue):
if hasattr(self, "conn"):
self.conn = base.getDBConnection("trustedquery")
for srcId in iter(inQueue.get, None):
if (self.opts.requireFrag is not None
and not self.opts.requireFrag in srcId):
continue
try:
outQueue.put(self.process(srcId))
except base.SkipThis:
continue
except Exception as ex:
ex.source = srcId
if self.opts.bailOnError:
sys.stderr.write("*** %s\n"%srcId)
traceback.print_exc()
outQueue.put(ex)
outQueue.put(self._doneSentinel)
# create nParallel workers
activeWorkers = 0
# close my connection; it'll be nothing but trouble once
# the workers see (and close) it, too
if hasattr(self, "conn"):
self.conn.close()
for i in range(nParallel):
multiprocessing.Process(target=worker,
args=(taskQueue, doneQueue)).start()
activeWorkers += 1
if hasattr(self, "conn"):
self.conn = base.getDBConnection("trustedquery")
# feed them their tasks
toDo = self.iterIdentifiers()
while True:
try:
taskQueue.put(next(toDo))
except StopIteration:
break
while not doneQueue.empty():
yield doneQueue.get()
# ask them to quit and wait until all have said they're quitting
for i in range(nParallel):
taskQueue.put(None)
taskQueue.close()
while activeWorkers:
item = doneQueue.get()
if item==self._doneSentinel:
activeWorkers -= 1
else:
yield item
def _runProcessor(self, procFunc, nParallel=1):
"""calls procFunc for all sources in self.dd.
This is the default, single-tasking implementation.
"""
processed, ignored = 0, 0
if nParallel==1:
def iterProcResults():
for source in self.iterIdentifiers():
if (self.opts.requireFrag is not None
and not self.opts.requireFrag in source):
continue
try:
yield procFunc(source)
except base.SkipThis:
continue
except Exception as ex:
ex.source = source
if self.opts.bailOnError:
sys.stderr.write("*** %s\n"%source)
traceback.print_exc()
yield ex
resIter = iterProcResults()
else:
resIter = self.iterJobs(nParallel)
while True:
try:
res = next(resIter)
if isinstance(res, Exception):
raise res
except StopIteration:
break
except KeyboardInterrupt:
sys.exit(2)
except Exception as msg:
if self.opts.bailOnError:
sys.exit(1)
sys.stderr.write("Skipping source %s: (%s, %s)\n"%(
getattr(msg, "source", "(unknown)"), msg.__class__.__name__,
repr(msg)))
ignored += 1
processed += 1
sys.stdout.write("%6d (-%5d)\r"%(processed, ignored))
sys.stdout.flush()
return processed, ignored
[docs] def iterIdentifiers(self):
"""iterates over all identifiers that should be processed.
This is usually the paths of the files to be processed.
You can, however, override it to do something else if that
fits your problem (example: Previews in SSA use the accref).
"""
return iter(self.dd.sources)
[docs] def processAll(self):
"""calls the process method of processor for all sources of the data
descriptor dd.
"""
if self.opts.doReport:
self.reportDict = {}
procFunc = self.addClassification
else:
procFunc = self.process
processed, ignored = self._runProcessor(procFunc,
nParallel=self.opts.nParallel)
if self.opts.doReport:
if self.opts.beVerbose:
self.printVerboseReport(processed, ignored)
else:
self.printReport(processed, ignored)
return processed, ignored
######### Utility methods
[docs] def getProductKey(self, srcName):
return utils.getRelativePath(srcName, self.inputsDir)
[docs]class PreviewMaker(FileProcessor):
"""A file processor for generating previews.
For these, define a method getPreviewData(accref) -> string returning
the raw preview data.
"""
[docs] @staticmethod
def addOptions(optParser):
FileProcessor.addOptions(optParser)
optParser.add_option("--force", help="Generate previews even"
" where they already exist", action="store_true",
dest="force", default=False)
[docs] def iterIdentifiers(self):
"""iterates over the accrefs in the first table of dd.
"""
tableId = self.dd.makes[0].table.getQName()
for r in self.conn.queryToDicts("select accref from %s"%tableId):
yield r["accref"]
def _createAuxiliaries(self, dd):
self.previewDir = dd.rd.getAbsPath(
dd.getProperty("previewDir"))
if not os.path.isdir(self.previewDir):
os.makedirs(self.previewDir)
self.conn = base.getDBConnection("trustedquery")
FileProcessor._createAuxiliaries(self, dd)
[docs] def getPreviewPath(self, accref):
res = list(self.conn.query("select preview from dc.products where"
" accref=%(accref)s", {"accref": accref}))
if not res:
raise IOError("%s is not in the products table. Update/import"
" the resource?"%accref)
if res[0][0]=="AUTO":
raise base.ReportableError("Preview path in the product table is AUTO."
" Will not write preview there. Make sure you have properly"
" bound the preview parameter in //products#define and re-import.")
return os.path.join(self.inputsDir, res[0][0])
[docs] def classify(self, path):
if not path.startswith("/"):
# It's probably an accref when we're not called from process
path = self.getPreviewPath(path)
if self.opts.force:
return "without"
if os.path.exists(path):
return "with"
else:
return "without"
[docs] def process(self, accref):
path = self.getPreviewPath(accref)
if self.classify(path)=="with":
return
dirPart = os.path.dirname(path)
if not os.path.exists(dirPart):
os.makedirs(dirPart)
os.chmod(dirPart, 0o775)
with utils.safeReplaced(path) as f:
f.write(self.getPreviewData(accref))
os.chmod(path, 0o664)
[docs]class SpectralPreviewMaker(PreviewMaker):
linearFluxes = False
spectralColumn = "spectral"
fluxColumn = "flux"
connectPoints = True
def _createAuxiliaries(self, dd):
PreviewMaker._createAuxiliaries(self, dd)
self.sdmDD = self.dd.rd.getById(self.sdmId)
[docs] @staticmethod
def get2DPlot(tuples, linear=False, connectPoints=True):
"""returns a png-compressed pixel image for a 2D plot of (x,y)
tuples.
"""
with matplotlibLock():
fig = figure.Figure(figsize=(4,2))
ax = fig.add_axes([0,0,1,1], frameon=False)
if linear:
plotter = ax.plot
else:
plotter = ax.semilogy
if connectPoints:
linestyle = "-"
else:
linestyle = "o"
plotter(
[r[0] for r in tuples],
[r[1] for r in tuples],
linestyle,
color="black")
ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
ax.yaxis.set_minor_locator(matplotlib.ticker.NullLocator())
rendered = io.BytesIO()
fig.savefig(rendered, format="png", dpi=50)
rendered = io.BytesIO(rendered.getvalue())
im = Image.open(rendered)
im = im.convert("L")
im = im.convert("P", palette=Image.ADAPTIVE, colors=8)
compressed = io.BytesIO()
im.save(compressed, format="png", bits=3)
return compressed.getvalue()
[docs] def getPreviewData(self, accref):
table = rsc.makeData(self.sdmDD, forceSource={
"accref": accref}).getPrimaryTable()
data = [(r[self.spectralColumn], r[self.fluxColumn]) for r in table.rows]
data.sort()
return self.get2DPlot(data, self.linearFluxes, self.connectPoints)
[docs]def procmain(processorClass, rdId, ddId):
"""The "standard" main function for processor scripts.
The function returns the instantiated processor so you can communicate
from your processor back to your own main.
See :dachsdoc:`processors.html` for details.
"""
import optparse
from gavo import rscdesc #noflake: for registration
rd = base.caches.getRD(rdId)
dd = rd.getById(ddId)
parser = optparse.OptionParser()
processorClass.addOptions(parser)
opts, args = parser.parse_args()
if args:
parser.print_help(file=sys.stderr)
sys.exit(1)
if opts.beVerbose:
from gavo.user import logui
logui.LoggingUI(base.ui)
base.DEBUG = True
proc = processorClass(opts, dd)
processed, ignored = proc.processAll()
print("%s files processed, %s files with errors"%(processed, ignored))
return proc