diff --git a/PyRDF/__init__.py b/PyRDF/__init__.py index 64ef292..49f36aa 100644 --- a/PyRDF/__init__.py +++ b/PyRDF/__init__.py @@ -50,20 +50,17 @@ def use(backend_name, conf={}): necessary configuration parameters. Its default value is an empty dictionary {}. """ - future_backends = [ - "dask" - ] global current_backend - if backend_name in future_backends: - msg = "This backend environment will be considered in the future !" - raise NotImplementedError(msg) - elif backend_name == "local": + if backend_name == "local": current_backend = Local(conf) elif backend_name == "spark": from PyRDF.backend.Spark import Spark current_backend = Spark(conf) + elif backend_name == "dask": + from PyRDF.backend.Dask import Dask + current_backend = Dask(conf) else: msg = "Incorrect backend environment \"{}\"".format(backend_name) raise Exception(msg) diff --git a/PyRDF/backend/Dask.py b/PyRDF/backend/Dask.py new file mode 100644 index 0000000..3e76cde --- /dev/null +++ b/PyRDF/backend/Dask.py @@ -0,0 +1,93 @@ +from __future__ import print_function + +import logging +from pprint import pformat + +from PyRDF.backend.Dist import Dist + +import dask +from dask.distributed import Client + +logger = logging.getLogger(__name__) + + +class Dask(Dist): + """Dask backend for PyRDF.""" + + MIN_NPARTITIONS = 2 + + def __init__(self, config={}): + """Init function.""" + super(Dask, self).__init__(config) + + self.config = config + self.client = None + self.npartitions = self._get_partitions() + + logger.debug("Creating {} instance with {} partitions".format( + type(self), self.npartitions)) + logger.debug("Dask configuration:\n{}".format( + pformat(dask.config.config))) + + def _get_partitions(self): + """Estimate partitions of the dataset.""" + npartitions = (self.npartitions or Dask.MIN_NPARTITIONS) + return int(npartitions) + + def ProcessAndMerge(self, mapper, reducer): + """ + Performs map-reduce using Dask framework. + + Args: + mapper (function): A function that runs the computational graph + and returns a list of values. + + reducer (function): A function that merges two lists that were + returned by the mapper. + + Returns: + list: A list representing the values of action nodes returned + after computation (Map-Reduce). + """ + + ranges = self.build_ranges() # Get range pairs + + # The Dask client has to be initialized inside some context and not on + # global scope since it's using Python Multiprocessing and each process + # fork needs independent environment (e.g. otherwise each process would + # try recreating a connection to the Dask client). + if self.client is None: + logger.debug("Connecting to Dask client.") + if self.config.get("scheduler_address"): + self.client = Client(self.config["scheduler_address"]) + else: + # TODO: Investigate the case where processes=True + # On my laptop multiprocessing triggers some segfault + self.client = Client(processes=False) + logger.debug( + "Succesfully connected to client {}".format(self.client)) + + dmapper = dask.delayed(mapper) + dreducer = dask.delayed(reducer) + + mergeables_lists = [dmapper(range) for range in ranges] + + while len(mergeables_lists) > 1: + mergeables_lists.append( + dreducer(mergeables_lists.pop(0), mergeables_lists.pop(0))) + + if self.config.get("visualize_dask_graph"): + dask.visualize(mergeables_lists[0]) + + return mergeables_lists.pop().compute() + + def distribute_files(self, includes_list): + """ + TODO: Implement file distribution to Dask workers. + + Args: + includes_list (list): A list consisting of all necessary C++ + files as strings, created one of the `include` functions of + the PyRDF API. + """ + pass diff --git a/tests/integration/dask/test_histo_write_dask.py b/tests/integration/dask/test_histo_write_dask.py new file mode 100644 index 0000000..9596c36 --- /dev/null +++ b/tests/integration/dask/test_histo_write_dask.py @@ -0,0 +1,81 @@ +import os +import unittest +from array import array + +import PyRDF + +import ROOT + + +class DaskHistoWriteTest(unittest.TestCase): + """ + Integration tests to check writing histograms to a `TFile` distributedly. + """ + + @classmethod + def setUpClass(cls): + """ + Parameter initialization for the histogram. + """ + cls.nentries = 10000 # Number of fills + cls.gaus_mean = 10 # Mean of the gaussian distribution + cls.gaus_stdev = 1 # Standard deviation of the gaussian distribution + cls.delta_equal = 0.01 # Delta to check for float equality + + def create_tree_with_data(self): + """Creates a .root file with some data""" + f = ROOT.TFile("tree_gaus.root", "recreate") + T = ROOT.TTree("Events", "Gaus(10,1)") + + x = array("f", [0]) + T.Branch("x", x, "x/F") + + r = ROOT.TRandom() + # The parent will have a gaussian distribution with mean 10 and + # standard deviation 1 + for _ in range(self.nentries): + x[0] = r.Gaus(self.gaus_mean, self.gaus_stdev) + T.Fill() + + f.Write() + f.Close() + + def test_write_histo(self): + """ + Tests that an histogram is correctly written to a .root file created + before the execution of the event loop. + """ + self.create_tree_with_data() + + # Create a new file where the histogram will be written + outfile = ROOT.TFile("out_file.root", "recreate") + + # Create a PyRDF RDataFrame with the parent and the friend trees + PyRDF.use("dask") + df = PyRDF.RDataFrame("Events", "tree_gaus.root") + + # Create histogram + histo = df.Histo1D(("x", "x", 100, 0, 20), "x") + + # Write histogram to out_file.root and close the file + histo.Write() + outfile.Close() + + # Reopen file to check that histogram was correctly stored + reopen_file = ROOT.TFile("out_file.root", "read") + reopen_histo = reopen_file.Get("x") + + # Check histogram statistics + self.assertEqual(reopen_histo.GetEntries(), self.nentries) + self.assertAlmostEqual(reopen_histo.GetMean(), self.gaus_mean, + delta=self.delta_equal) + self.assertAlmostEqual(reopen_histo.GetStdDev(), self.gaus_stdev, + delta=self.delta_equal) + + # Remove unnecessary .root files + os.remove("tree_gaus.root") + os.remove("out_file.root") + + +if __name__ == "__main__": + unittest.main()