diff --git a/test/basetest/exceptions.py b/test/basetest/exceptions.py index a281ecc44..4c7035803 100644 --- a/test/basetest/exceptions.py +++ b/test/basetest/exceptions.py @@ -28,4 +28,8 @@ class CommandError(Exception): def __str__(self): return self.msg.format(self.cmd, self.code, self.out, self.err) + +class HookError(Exception): + pass + # vim: ai sts=4 et sw=4 diff --git a/test/basetest/task.py b/test/basetest/task.py index 6bcd677b4..d35f012e2 100644 --- a/test/basetest/task.py +++ b/test/basetest/task.py @@ -1,12 +1,140 @@ # -*- coding: utf-8 -*- import os +from sys import stderr import tempfile import shutil +import stat import atexit import unittest from .utils import run_cmd_wait, run_cmd_wait_nofail, which, binary_location -from .exceptions import CommandError +from .exceptions import CommandError, HookError + + +class Hooks(object): + """Abstraction to help interact with hooks (add, remove, enable, disable) + during tests + """ + def __init__(self, datadir): + self.hookdir = os.path.join(datadir, "hooks") + self._hooks = {} + + # Check if the hooks dir already exists + if not os.path.isdir(self.hookdir): + os.mkdir(self.hookdir) + + def __repr__(self): + enabled = [] + disabled = [] + + for hook in self._hooks: + if self.isactive(hook): + enabled.append(hook) + else: + disabled.append(hook) + + enabled = ", ".join(enabled) or None + disabled = ", ".join(disabled) or None + + return "".format(enabled, + disabled) + + def get_hookfile(self, hookname): + """Return location of given hookname""" + return os.path.join(self.hookdir, hookname) + + def check_exists(self, hookname): + """Checks if the file pointed to by hookfile exists""" + + hookfile = self.get_hookfile(hookname) + + if not os.path.isfile(hookfile): + raise HookError("Hook {0} doesn't exist.".format(hookfile)) + + return hookfile + + def enable(self, hookname): + """Make hookfile executable to allow triggering + """ + hookfile = self.check_exists(hookname) + os.chmod(hookfile, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) + + def disable(self, hookname): + """Remove hookfile executable bit to deny triggering + """ + hookfile = self.check_exists(hookname) + os.chmod(hookfile, stat.S_IREAD | stat.S_IWRITE) + + def isactive(self, hookname): + """Check if hook is active by verifying the execute bit + """ + hookfile = self.check_exists(hookname) + return os.access(hookfile, os.X_OK) + + def add(self, hookname, content, overwrite=False): + """Register hook with name 'hookname' and given file content. + + :param hookname: Should be a string starting with one of: + - on-launch + - on-add + - on-exit + - on-modify + + :param content: Content of the file as a (multi-line) string + :param overwrite: What to do if a hook with same name already exists + """ + for hooktype in ("on-launch", "on-add", "on-exit", "on-modify"): + if hookname.startswith(hooktype): + break + else: + stderr.write("WARNING: {0} is not a valid hook type. " + "It will not be triggered\n".format(hookname)) + + if hookname in self._hooks and not overwrite: + raise HookError("Hook with name {0} already exists. " + "Pass overwrite=True if intended or use " + "hooks.remove() before.".format(hookname)) + else: + self._hooks[hookname] = content + + hookfile = self.get_hookfile(hookname) + + # Create the hook on disk + with open(hookfile, 'w') as fh: + fh.write(content) + + # Ensure it's executable + self.enable(hookname) + + def remove(self, hookname): + """Remove the hook matching given hookname""" + try: + del self._hooks[hookname] + except KeyError: + raise HookError("Hook with name {0} in record".format(hookname)) + + try: + os.remove(self.get_hookfile(hookname)) + except OSError as e: + if e.errno == 2: + raise HookError("Hook with name {0} was not found on hooks/ " + "folder".format(hookname)) + else: + raise + + def clear(self): + """Remove all existing hooks and empty the hook registry + """ + self._hooks = {} + + # Remove any existing hooks + try: + shutil.rmtree(self.hookdir) + except OSError as e: + if e.errno != 2: + raise + + os.mkdir(self.hookdir) class Task(object): @@ -55,6 +183,8 @@ class Task(object): if self.taskd is not None: self.bind_taskd_server(self.taskd) + self.hooks = Hooks(self.datadir) + def __repr__(self): txt = super(Task, self).__repr__() return "{0} running from {1}>".format(txt[:-1], self.datadir)