Source code for rootpy.tree.tree

import sys
import re
import fnmatch

import ROOT

from ..types import Variable
from ..core import Object, snake_case_methods, RequireFile
from ..plotting.core import Plottable
from ..plotting import Hist, Canvas
from ..registry import register
from ..utils import asrootpy
from .. import rootpy_globals as _globals
from .treeobject import TreeCollection, TreeObject
from .cut import Cut
from .buffer import TreeBuffer
from .model import TreeModel


[docs]class UserData(object): pass
@snake_case_methods @register()
[docs]class Tree(Object, Plottable, RequireFile, ROOT.TTree): """ Inherits from TTree so all regular TTree methods are available but certain methods (i.e. Draw) have been overridden to improve usage in Python """ DRAW_PATTERN = re.compile( '^(?P<branches>.+?)(?P<redirect>\>\>[\+]?(?P<name>[^\(]+).*)?$') def __init__(self, name=None, title=None, model=None, file=None, ignore_unsupported=False): if file: file.cd() RequireFile.__init__(self) Object.__init__(self, name, title) self._ignore_unsupported = ignore_unsupported if model is not None: self.buffer = TreeBuffer(ignore_unsupported=ignore_unsupported) if not issubclass(model, TreeModel): raise TypeError("the model must subclass TreeModel") self.set_buffer(model(ignore_unsupported=ignore_unsupported), create_branches=True) self._post_init(ignore_unsupported=ignore_unsupported) def _post_init(self, ignore_unsupported=False): self._ignore_unsupported = ignore_unsupported if not hasattr(self, "buffer"): self.buffer = TreeBuffer( ignore_unsupported=ignore_unsupported) self.set_buffer(self.create_buffer()) Plottable.__init__(self) self._use_cache = False self._branch_cache = {} self._current_entry = 0 self._always_read = [] self.userdata = UserData() self._inited = True
[docs] def always_read(self, branches): if type(branches) not in (list, tuple): raise TypeError("branches must be a list or tuple") self._always_read = branches
[docs] def use_cache(self, cache=True, cache_size=10000000, learn_entries=1): if cache: self.buffer.set_tree(self) self.SetCacheSize(cache_size) ROOT.TTreeCache.SetLearnEntries(learn_entries) else: self.buffer.set_tree(None) # was the cache previously enabled? if self._use_cache: self.SetCacheSize(-1) self._use_cache = cache
@classmethod
[docs] def branch_type(cls, branch): typename = branch.GetClassName() if not typename: leaf = branch.GetListOfLeaves()[0] typename = leaf.GetTypeName() # check if leaf has multiple elements length = leaf.GetLen() if length > 1: typename = '%s[%d]' % (typename, length) return typename
@classmethod
[docs] def branch_is_supported(cls, branch): """ Currently the branch must only have one leaf but the leaf may have one or multiple elements """ return branch.GetNleaves() == 1
[docs] def create_buffer(self): buffer = [] for branch in self.iterbranches(): if (Tree.branch_is_supported(branch) and self.GetBranchStatus(branch.GetName())): buffer.append((branch.GetName(), Tree.branch_type(branch))) return TreeBuffer(buffer, ignore_unsupported=self._ignore_unsupported)
[docs] def create_branches(self, branches): if not isinstance(branches, TreeBuffer): branches = TreeBuffer(branches, ignore_unsupported=self._ignore_unsupported) self.set_buffer(branches, create_branches=True)
[docs] def update_buffer(self, buffer, transfer_objects=False): if self.buffer is not None: self.buffer.update(buffer) if transfer_objects: self.buffer.set_objects(buffer) else: self.buffer = buffer
[docs] def set_buffer(self, buffer, branches=None, ignore_branches=None, create_branches=False, visible=True, ignore_missing=False, transfer_objects=False): # determine branches to keep all_branches = buffer.keys() if branches is None: branches = all_branches if ignore_branches is None: ignore_branches = [] branches = (set(all_branches) & set(branches)) - set(ignore_branches) if create_branches: for name in branches: value = buffer[name] if self.has_branch(name): raise ValueError( "Attempting to create two branches " "with the same name: %s" % name) if isinstance(value, Variable): self.Branch(name, value, "%s/%s" % (name, value.type)) else: self.Branch(name, value) else: for name in branches: value = buffer[name] if self.has_branch(name): self.SetBranchAddress(name, value) elif not ignore_missing: raise ValueError( "Attempting to set address for " "branch %s which does not exist" % name) if visible: newbuffer = TreeBuffer(ignore_unsupported=self._ignore_unsupported) for branch in branches: if branch in buffer: newbuffer[branch] = buffer[branch] newbuffer.set_objects(buffer) buffer = newbuffer self.update_buffer(buffer, transfer_objects=transfer_objects)
[docs] def activate(self, branches, exclusive=False): if exclusive: self.SetBranchStatus('*', 0) if isinstance(branches, basestring): branches = [branches] for branch in branches: if '*' in branch: matched_branches = self.glob(branch) for b in matched_branches: self.SetBranchStatus(b, 1) elif self.has_branch(branch): self.SetBranchStatus(branch, 1)
[docs] def deactivate(self, branches, exclusive=False): if exclusive: self.SetBranchStatus('*', 1) if isinstance(branches, basestring): branches = [branches] for branch in branches: if '*' in branch: matched_branches = self.glob(branch) for b in matched_branches: self.SetBranchStatus(b, 0) elif self.has_branch(branch): self.SetBranchStatus(branch, 0)
@property
[docs] def branches(self): return [branch for branch in self.GetListOfBranches()]
[docs] def iterbranches(self): for branch in self.GetListOfBranches(): yield branch
@property
[docs] def branchnames(self): return [branch.GetName() for branch in self.GetListOfBranches()]
[docs] def iterbranchnames(self): for branch in self.iterbranches(): yield branch.GetName()
[docs] def glob(self, patterns, prune=None): """ Return a list of branch names that match pattern. Exclude all matched branch names which also match a pattern in prune. prune may be a string or list of strings. """ if isinstance(patterns, basestring): patterns = [patterns] if isinstance(prune, basestring): prune = [prune] matches = [] for pattern in patterns: matches += fnmatch.filter(self.iterbranchnames(), pattern) if prune is not None: for prune_pattern in prune: matches = [match for match in matches if not fnmatch.fnmatch(match, prune_pattern)] return matches
def __getitem__(self, item): if isinstance(item, basestring): return self.buffer[item] if not (0 <= item < len(self)): raise IndexError("entry index out of range") self.GetEntry(item) return self
[docs] def GetEntry(self, entry): self.buffer.reset_collections() return ROOT.TTree.GetEntry(self, entry)
def __iter__(self): if self._use_cache: for i in xrange(self.GetEntries()): self._current_entry = i self.LoadTree(i) for attr in self._always_read: try: self._branch_cache[attr].GetEntry(i) except KeyError: # one-time hit branch = self.GetBranch(attr) if not branch: raise AttributeError( "branch %s specified in " "'always_read' does not exist" % attr) self._branch_cache[attr] = branch branch.GetEntry(i) self.buffer._entry.set(i) yield self.buffer self.buffer.next_entry() self.buffer.reset_collections() else: i = 0 while self.GetEntry(i): self.buffer._entry.set(i) yield self.buffer i += 1 def __setattr__(self, attr, value): if '_inited' not in self.__dict__ or attr in self.__dict__: return super(Tree, self).__setattr__(attr, value) try: return self.buffer.__setattr__(attr, value) except AttributeError: raise AttributeError( "%s instance has no attribute '%s'" % \ (self.__class__.__name__, attr)) def __getattr__(self, attr): if '_inited' not in self.__dict__: raise AttributeError("%s instance has no attribute '%s'" % \ (self.__class__.__name__, attr)) try: return getattr(self.buffer, attr) except AttributeError: raise AttributeError("%s instance has no attribute '%s'" % \ (self.__class__.__name__, attr)) def __setitem__(self, item, value): self.buffer[item] = value def __len__(self): return self.GetEntries() def __contains__(self, branch): return self.has_branch(branch)
[docs] def has_branch(self, branch): return not not self.GetBranch(branch)
[docs] def csv(self, sep=',', branches=None, include_labels=True, limit=None, stream=None): """ Print csv representation of tree only including branches of basic types (no objects, vectors, etc..) """ if stream is None: stream = sys.stdout if branches is None: branches = self.buffer.keys() branches = dict([(name, self.buffer[name]) for name in branches if isinstance(self.buffer[name], Variable)]) if not branches: return if include_labels: print >> stream, sep.join(branches.keys()) # even though 'entry' is not used, enumerate or simply iterating over # self is required to update the buffer with the new branch values at # each tree entry. for i, entry in enumerate(self): print >> stream, sep.join([str(v.value) for v in branches.values()]) if limit is not None and i + 1 == limit: break
[docs] def Scale(self, value): self.SetWeight(self.GetWeight() * value)
[docs] def GetEntries(self, cut=None, weighted_cut=None, weighted=False): if weighted_cut: hist = Hist(1, -1, 2) branch = self.GetListOfBranches()[0].GetName() weight = self.GetWeight() self.SetWeight(1) self.Draw("%s==%s>>%s" % (branch, branch, hist.GetName()), weighted_cut * cut) self.SetWeight(weight) entries = hist.Integral() elif cut: entries = ROOT.TTree.GetEntries(self, str(cut)) else: entries = ROOT.TTree.GetEntries(self) if weighted: entries *= self.GetWeight() return entries
[docs] def GetMaximum(self, expression, cut=None): if cut: self.Draw(expression, cut, "goff") else: self.Draw(expression, "", "goff") vals = self.GetV1() n = self.GetSelectedRows() vals = [vals[i] for i in xrange(min(n, 10000))] return max(vals)
[docs] def GetMinimum(self, expression, cut=None): if cut: self.Draw(expression, cut, "goff") else: self.Draw(expression, "", "goff") vals = self.GetV1() n = self.GetSelectedRows() vals = [vals[i] for i in xrange(min(n, 10000))] return min(vals)
[docs] def CopyTree(self, selection, *args, **kwargs): """ Convert selection (tree.Cut) to string """ return super(Tree, self).CopyTree(str(selection), *args, **kwargs)
[docs] def reset_branch_values(self): self.buffer.reset()
[docs] def Fill(self, reset=False): super(Tree, self).Fill() # reset all branches if reset: self.buffer.reset()
@RequireFile.cd
[docs] def Write(self, *args, **kwargs): ROOT.TTree.Write(self, *args, **kwargs)
[docs] def Draw(self, expression, selection="", options="", hist=None, min=None, max=None, bins=None, **kwargs): """ Draw a TTree with a selection as usual, but return the created histogram. """ if isinstance(expression, (list, tuple)): expressions = expression else: expressions = [expression] if not isinstance(selection, Cut): # let Cut handle any extra processing (i.e. ternary operators) selection = Cut(selection) local_hist = None if hist is not None: # handle graphics ourselves if options: options += ' ' options += 'goff' expressions = ['%s>>+%s' % (expr, hist.GetName()) for expr in expressions] elif min is not None or max is not None: # handle graphics ourselves if options: options += ' ' options += 'goff' if min is None: if max > 0: min = 0 else: raise ValueError('must specify minimum') elif max is None: if min < 0: max = 0 else: raise ValueError('must specify maximum') if bins is None: bins = 100 local_hist = Hist(bins, min, max, **kwargs) expressions = ['%s>>+%s' % (expr, local_hist.GetName()) for expr in expressions] else: if 'goff' not in options: if not _globals.pad: _globals.pad = Canvas() pad = _globals.pad pad.cd() match = re.match(Tree.DRAW_PATTERN, expressions[0]) histname = None if match and match.groupdict()['name']: histname = match.groupdict()['name'] for expr in expressions: match = re.match(Tree.DRAW_PATTERN, expr) if not match: raise ValueError('not a valid draw expression: %s' % expr) # reverse variable order to match order in hist constructor groupdict = match.groupdict() expr = ':'.join(reversed(groupdict['branches'].split(':'))) if groupdict['redirect']: expr += groupdict['redirect'] ROOT.TTree.Draw(self, expr, selection, options) if hist is None and local_hist is None: if histname is not None: hist = asrootpy(ROOT.gDirectory.Get(histname)) else: hist = asrootpy(ROOT.gPad.GetPrimitive("htemp")) if hist: try: hist.decorate(**kwargs) except: pass if 'goff' not in options: pad.Modified() pad.Update() return hist elif local_hist is not None: local_hist.Draw(options) return local_hist
[docs] def ndarray(self, branches=None, dtype=None, include_weight=False, weight_dtype='f4'): """ Convert this tree into a NumPy ndarray """ try: import numpy as np if dtype is None: dtype = np.float32 from .. import root2array return root2array.tree_to_ndarray(self, branches, dtype, include_weight, weight_dtype) except ImportError: raise ImportError('``ndarray`` requires NumPy')
[docs] def recarray(self, branches=None, include_weight=False, weight_name='weight', weight_dtype='f4'): """ Convert this tree into a NumPy recarray """ try: from .. import root2array return root2array.tree_to_recarray(self, branches, include_weight, weight_name, weight_dtype) except ImportError: raise ImportError('``recarray`` requires NumPy')