From 9bd7840ec03a4b90c0edafed7f0856dfee41ee66 Mon Sep 17 00:00:00 2001 From: Gnarwhal Date: Sun, 8 Sep 2024 20:11:41 +0000 Subject: [PATCH] Add PluginManager class and move logic from main => PluginManager or Plugin as appropriate --- examples | 2 +- src/sshare/logger.py | 6 +- src/sshare/main.py | 181 +++++++---------------------- src/sshare/plugin.py | 116 +++++++++++++++++- src/sshare/plugins/config.py | 115 +++++++++--------- src/sshare/plugins/default/file.py | 3 +- 6 files changed, 215 insertions(+), 208 deletions(-) diff --git a/examples b/examples index f8b077f..1466438 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit f8b077f3764c16935462ffb818bdb5aeda75222b +Subproject commit 1466438938c98442aa6de55065b7e6a06a7e8d50 diff --git a/src/sshare/logger.py b/src/sshare/logger.py index d586028..7901d4c 100644 --- a/src/sshare/logger.py +++ b/src/sshare/logger.py @@ -12,12 +12,14 @@ # You should have received a copy of the GNU General Public License along with # SSHare. If not, see . +import sys + from sshare.plugin import Plugin class Logger: def __init__(self, *args, **kwargs): - if kwargs.get("preload_command_line") == True: - self._loggers = [ Plugin.internal("command_line") ] + if kwargs.get("command_line"): + self._loggers = [ kwargs["command_line"] ] else: self._loggers = [] self.add(*args) diff --git a/src/sshare/main.py b/src/sshare/main.py index 18b0e5d..9999af6 100644 --- a/src/sshare/main.py +++ b/src/sshare/main.py @@ -15,7 +15,6 @@ import argparse import getpass import os -import os.path import time import tomllib import subprocess @@ -24,175 +23,79 @@ from pathlib import Path from sshare.logger import Logger from sshare.plugin import Plugin -from sshare.plugins.config import Flag -from sshare.plugins.config import NoArgument -from sshare.plugins.config import NoDefault +from sshare.plugin import PluginManager from sshare.version import version def main(): - config_directory = Path(os.environ.get("XDG_CONFIG_DIR") or f"{os.environ["HOME"]}/.config") / "sshare" + # TODO: Add --config flag + config_directory = Path(os.environ.get("XDG_CONFIG_DIR", f"{os.environ["HOME"]}/.config")) / "sshare" with open(config_directory / "config.toml", mode="rb") as file: config = tomllib.load(file) + INTERNAL_PLUGIN_LOCATION = "sshare.plugins.default" # Load command line early and set it as the active logger # so that it can be used to report errors while loading and - # configuring plugins - # i.e. before other logging plugins have had a chance to be initialised - logger = Logger(preload_command_line=True) - - # Load inbuilt plugins - plugins_flat = [ - Plugin.internal("file"), - Plugin.internal("current_time"), - Plugin.internal("append_type"), - Plugin.internal("ssh"), - Plugin.internal("log_result"), - ] - plugins = {} - for type in [ "logger", "source", "name", "upload", "result" ]: - plugins[type] = { "active": [], "inactive": [] } - - # Load external plugins - sys.dont_write_bytecode = True - for path in (config_directory / "plugins").iterdir(): - if path.is_file() and path.suffix == ".py": - plugins_flat.append(Plugin.external(path)) - sys.dont_write_bytecode = False - - # Set plugin configurations from config file - # Load plugin arguments and detect conflicts - error = False - argument_map = {} - used_arguments = {} - parser = argparse.ArgumentParser( + # configuring other loggers + command_line = Plugin.internal(INTERNAL_PLUGIN_LOCATION, "command_line", config.get("plugins", dict())) + logger = Logger(command_line=command_line) + arg_parser = argparse.ArgumentParser( prog = "sshare", description = "Upload files to a server via ssh", ) - parser.add_argument( + arg_parser.add_argument( "-v", "--version", action="version", version=f"%(prog)s version {version}", ) - if config.get("plugins") == None: - config["plugins"] = {} - for plugin in plugins_flat: - if hasattr(plugin, "config"): - plugin_config = config["plugins"].get(plugin.name) - if plugin_config != None: - for config_entry in plugin_config.items(): - plugin.config[config_entry[0]] = config_entry[1] - else: - setattr(plugin, "config", {}) - if hasattr(plugin, "args"): - for arg_name, arg in plugin.args.items(): - if arg.is_valid(): - arg.bind(plugin, arg_name) - def check_flag(flag): - if flag in used_arguments: - logger.error(f"Error: Argument '{arg_name}' for plugin '{plugin.name}' has conflict. Flag '{flag}' is also used by plugin '{used_arguments[arg.short]}'") - error = True - check_flag(arg.short) - check_flag(arg.long) - arg.add(parser, used_arguments) - argument_map[arg.dest] = plugin, arg_name - else: - logger.error(f"Error: Argument '{arg_name}' must set either one or both of short and long flag parameters") - error = True - if error: - sys.exit(1) - - arguments = parser.parse_args() - for arg, (plugin, config) in list(argument_map.items()): - value = getattr(arguments, arg) - if value != NoArgument: - if value != Flag: - plugin.config[config] = value - del argument_map[arg] - - # Sort plugins by type and check activation criteria - error = False - for plugin in plugins_flat: - if isinstance(plugin.plugin_type, str): - plugin.plugin_type = [ plugin.plugin_type ] - for plugin_type in plugin.plugin_type: - plugins_of_type = plugins.get(plugin_type) - if plugins_of_type == None: - logger.error(f"Error: Plugin '{plugin.name}' has an invalid plugin type '{plugin_type}'") - error = True - else: - active = True - if hasattr(plugin, "activate"): - criteria = plugin.activate - if isinstance(plugin.activate, dict): - criteria = plugin.activate.get(plugin_type) - if criteria != None: - for criterion in criteria: - active = not plugin.args[criterion].dest in argument_map - if not active: - break - plugins_of_type["active" if active else "inactive"].append(plugin) - if active: - for config_entry in plugin.config.items(): - if config_entry[1] == NoDefault: - logger.error(f"Error: Value 'plugins.{plugin.name}.{config_entry[0]}' has no default value and must be specified explicitly") - error = True - for plugin_type, plugins_of_type in plugins.items(): - if len(plugins_of_type["active"]) == 0 and plugin_type != "logger": - if len(plugins_of_type["inactive"]) == 0: - logger.error(f"No '{plugin_type}' plugins available. Atleast one must be provided") - else: - logger.error(f"No '{plugin_type}' plugins activated. Activate at least one of:") - for plugin in plugins_of_type["inactive"]: - logger.error(f"{plugin.name}:") - criteria = plugin.activate - if isinstance(plugin.activate, dict): - criteria = plugin.activate[plugin_type] - for criterion in criteria: - logger.error(f" {plugin.args[criterion]}") - error = True - if error: - sys.exit(1) - - # Objectify configs - error = False - class PluginConfig: pass - for plugin in plugins_flat: - config = plugin.config - plugin.config = PluginConfig() - for config_entry in config.items(): - setattr(plugin.config, config_entry[0], config_entry[1]) - if error: - sys.exit(1) - - # Initialise plugins - for plugin in plugins_flat: - setattr(plugin, "logger", logger) - if hasattr(plugin, "init"): - error = error or plugin.init() - - logger.add(*plugins["logger"]["active"]) + plugins = PluginManager( + [ "logger", "source", "name", "upload", "result" ], + logger, + config.get("plugins", dict()), + arg_parser, + ) + plugins.add_from( + Plugin.internal(INTERNAL_PLUGIN_LOCATION), + "file", + "current_time", + "append_type", + "ssh", + "log_result", + ) + sys.dont_write_bytecode = True + plugins.add_from( + Plugin.external, + *[ + path for + path in + (config_directory / "plugins").iterdir() + if path.is_file() and path.suffix == ".py" + ] + ) + sys.dont_write_bytecode = False + plugins.activate("logger") + logger.add(*plugins.logger.active) + plugins.activate() sources = [] - for plugin in plugins["source"]["active"]: + for plugin in plugins.source.active: sources.append(plugin.get_source()) if len(sources) == 0: - logger.error("Error: No sources provided. Must activate at least one source plugin") - log_activations(logger, plugins["source"]) + logger.error("Error: No sources provided. Must activate at least one source plugin.") for index, source in enumerate(sources): name = "" - for plugin in plugins["name"]["active"]: + for plugin in plugins.name.active: name = plugin.get_name(name, source) sources[index] = name, source for name, source in sources: - for plugin in plugins["upload"]["active"]: + for plugin in plugins.upload.active: plugin.upload(name, source) for name, _ in sources: - for plugin in plugins["result"]["active"]: + for plugin in plugins.result.active: plugin.result(name) sys.exit(0) diff --git a/src/sshare/plugin.py b/src/sshare/plugin.py index de57c36..ea24721 100644 --- a/src/sshare/plugin.py +++ b/src/sshare/plugin.py @@ -15,21 +15,127 @@ import importlib import importlib.util +from sshare.plugins.config import Flag +from sshare.plugins.config import NoDefault + +class PluginManager: + def __init__(self, types, logger, config, arg_parser): + self._uninitialized = [] + self._logger = logger + self._config = config + self._arg_parser = arg_parser + + class PluginState: + def __init__(self): + self.active = [] + self.inactive = [] + for type in types: + setattr(self, type, PluginState()) + + def add_from(self, location, *args, **kwargs): + for plugin in args: + plugin = location(plugin, self._config) + plugin.set_logger(self._logger) + plugin.add_args(self._arg_parser) + self._uninitialized.append(plugin) + + def activate(self, activate_type=None): + args = self._arg_parser.parse_args() + for plugin in self._uninitialized.copy(): + if activate_type == None or activate_type in plugin.plugin_type: + self._uninitialized.remove(plugin) + active = plugin.load_args_and_activate(args) + for type in plugin.plugin_type: + getattr( + getattr(self, type), + active[type], + ).append(plugin) + class Plugin: - def __init__(self, name, module): + def __init__(self, name, module, external_config): self.__dict__ = module.__dict__ self.name = name + if not isinstance(self.plugin_type, set): + self.plugin_type = { self.plugin_type } + + if hasattr(self, "activate"): + if not isinstance(self.activate, dict): + activate = self.activate + self.activate = dict() + for plugin_type in self.plugin_type: + self.activate[plugin_type] = activate + else: + self.activate = dict() + for plugin_type in self.plugin_type: + self.activate[plugin_type] = set() + + if hasattr(self, "config"): + config = self.config + else: + config = dict() + if external_config == None: + external_config = dict() + for key in config.keys(): + if key in external_config: + config[key] = external_config[key] + + if hasattr(self, "args"): + for arg in self.args.items(): + arg[1].bind(self, arg[0]) + value = arg[1].default() + if value != NoDefault: + config[arg[0]] = value + else: + self.args = dict() + + class Config: pass + flat_config = Config() + flat_config.__dict__ = config + self.config = flat_config + + def set_logger(self, logger): + self.logger = logger + + def add_args(self, arg_parser): + for arg in self.args.values(): + arg.add(arg_parser) + + def load_args_and_activate(self, args): + passed_args = set() + for arg_name, arg in self.args.items(): + was_set, value = arg.extract(args) + if was_set: + if value != Flag: + setattr(self.config, arg_name, value) + passed_args.add(arg_name) + activate = dict() + run_init = False + for type in self.plugin_type: + if self.activate[type] <= passed_args: + activate[type] = "active" + run_init = True + else: + activate[type] = "inactive" + if run_init and hasattr(self, "init"): + self.init() + return activate + @staticmethod - def internal(name): - return Plugin(name, importlib.import_module(f"sshare.plugins.default.{name}")) + def internal(location, name=None, config=None): + def _load_internal(_name, _config): + return Plugin(_name, importlib.import_module(f"{location}.{_name}"), _config.get(_name, dict())) + if name == None: + return _load_internal + else: + return _load_internal(name, config) @staticmethod - def external(path): + def external(path, config): module_spec = importlib.util.spec_from_file_location( path.stem, path.as_posix(), ) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) - return Plugin(path.stem, module) + return Plugin(path.stem, module, config.get(path.stem, dict())) diff --git a/src/sshare/plugins/config.py b/src/sshare/plugins/config.py index 305a546..86167f3 100644 --- a/src/sshare/plugins/config.py +++ b/src/sshare/plugins/config.py @@ -14,71 +14,68 @@ class NoDefault: pass -class NoArgument: pass -class Flag: pass +def Flag(short=None, long=None, help=None): + return Argument( + short, + long, + action="store_const", + const=Flag, + default=False, + help=help, + ) + +class _NoArgument: + def __init__(self, default): + self._default = default class Argument: - def __init__(self, - short=None, - long=None, - action=None, - nargs=None, - const=None, - default=NoArgument, - type=None, - choices=None, - required=None, - help=None): - self.short = short - self.long = long - self.action = action - self.nargs = nargs - self.const = const - self.default = default - self.type = type - self.choices = choices - self.help = help + def __init__(self, short=None, long=None, **kwargs): + class _None: + def __init__(self, default): + self.default = default + self._None = _None - def __str__(self): - if self.short and self.long: - pretty = f"-{self.short}, --{self.long}" - elif self.long: - pretty = f"--{self.long}" - else: - pretty = f"-{self.short}" - return pretty + f" {self.help}" + self.short = short + self.long = long - def is_valid(self): - return (self.short != None and self.short != "") or (self.long != None and self.long != "") + if not "default" in kwargs: + kwargs["default"] = NoDefault + kwargs["default"] = _None(kwargs["default"]) + self._kwargs = kwargs def bind(self, plugin, argument): - self.plugin = plugin - self.metavar = argument - self.dest = f"{plugin.name}_{argument}" + self._plugin = plugin.name + self._argument = argument - def add(self, parser, used_arguments): - keywords = [ - "action", - "nargs", - "const", - "default", - "type", - "choices", - "help", - "metavar", - "dest" - ] - kwargs = {} - for keyword in keywords: - value = getattr(self, keyword) - if value != None: - kwargs[keyword] = value - parser.add_argument( - f"-{self.short}", - f"--{self.long}", + def default(self): + value = self._kwargs["default"] + if isinstance(value, self._None): + value = value.default + return value + + def dest(self): + return f"{self._plugin}_{self._argument}" + + def extract(self, arguments): + value = getattr(arguments, self.dest()) + was_set = True + if isinstance(value, self._None): + was_set = False + value = value.default + return was_set, value + + def add(self, arg_parser): + flags = [] + if self.short: + flags.append(f"-{self.short}") + long = self.long or self._argument + if long: + flags.append(f"--{long}") + kwargs = self._kwargs | { + "metavar": self._argument, + "dest": self.dest() + } + arg_parser.add_argument( + *flags, **kwargs ) - if self.short: - used_arguments["short"] = self.plugin - if self.long: - used_arguments["long"] = self.plugin diff --git a/src/sshare/plugins/default/file.py b/src/sshare/plugins/default/file.py index ef123e6..b51a346 100644 --- a/src/sshare/plugins/default/file.py +++ b/src/sshare/plugins/default/file.py @@ -18,11 +18,10 @@ from ..source import File plugin_type = "source" -activate = [ "file" ] +activate = { "file" } args = { "file": Argument( short="f", - long="file", help="Upload a file" ) }