diff --git a/src/sshare/config_directory.py b/src/sshare/config_directory.py index 4617622..022480d 100644 --- a/src/sshare/config_directory.py +++ b/src/sshare/config_directory.py @@ -16,11 +16,15 @@ import os from pathlib import Path -LOCATION = Path(os.environ.get("XDG_CONFIG_DIR", f"{os.environ["HOME"]}/.config")) / "sshare" +_LOCATION = Path(os.environ.get("XDG_CONFIG_DIR", f"{os.environ["HOME"]}/.config")) / "sshare" + +def default_config(): + return _LOCATION / "config.toml" + def plugins(): return [ path for path in - (LOCATION / "plugins").iterdir() + (_LOCATION / "plugins").iterdir() if path.is_file() and path.suffix == ".py" ] diff --git a/src/sshare/main.py b/src/sshare/main.py index 8198545..da4369b 100644 --- a/src/sshare/main.py +++ b/src/sshare/main.py @@ -38,7 +38,7 @@ def main(): help="Specify location of config file to use" ) arguments, _ = arg_parser.parse_known_args() - with open(arguments.config or (config_directory.LOCATION / "config.toml"), mode="rb") as file: + with open(arguments.config or config_directory.default_config(), mode="rb") as file: config = tomllib.load(file) config["config"] = config.get("config", {}) config["flags" ] = config.get("flags", {}) @@ -56,41 +56,23 @@ def main(): version=f"%(prog)s version {version}", ) - INTERNAL_PLUGIN_LOCATION = "sshare.plugins" # Load command line early and set it as the active logger # so that it can be used to report errors while loading and # configuring other loggers + logger = Logger() command_line = Plugin.internal( - INTERNAL_PLUGIN_LOCATION, "command_line", + logger, config["config"], config["flags"], ) - logger = Logger(command_line) - plugin_types = [ "logger", "source", "name", "upload", "location", "feedback" ] + logger.add(command_line) plugins = PluginManager( - plugin_types, logger, config["config"], config["flags"], arg_parser, ) - plugins.add_from( - Plugin.internal(INTERNAL_PLUGIN_LOCATION), - "file", - "stdin", - "current_time", - "append_type", - "ssh", - "uri", - "print_location", - ) - sys.dont_write_bytecode = True - plugins.add_from( - Plugin.external, - *config_directory.plugins(), - ) - sys.dont_write_bytecode = False plugins.activate("logger") logger.add(*plugins.logger.active) plugins.activate() diff --git a/src/sshare/plugin/plugin.py b/src/sshare/plugin/plugin.py index 395b060..dde368b 100644 --- a/src/sshare/plugin/plugin.py +++ b/src/sshare/plugin/plugin.py @@ -17,28 +17,65 @@ import importlib.util from sshare.plugin.config import Flag from sshare.plugin.config import NoDefault +from sshare import config_directory + +class PluginLoader: + @staticmethod + def all(command_line=False, logger=None, config=None, flags=None): + return PluginLoader.internal(command_line, logger, config, flags) + PluginLoader.external(logger, config, flags) + + @staticmethod + def internal(command_line=False, logger=None, config=None, flags=None): + return [ + Plugin.internal(plugin, logger, config, flags) + for plugin + in ([ "command_line" ] if command_line else []) + [ + "file", + "stdin", + "current_time", + "append_type", + "ssh", + "uri", + "print_location", + ] + ] + + @staticmethod + def external(logger=None, config=None, flags=None): + return [ + Plugin.external(plugin, logger, config, flags) + for plugin + in config_directory.plugins() + ] + + @staticmethod + def at(logger=None, config=None, flags=None, *args): + return [ + Plugin.external(plugin, logger, config, flags) + for plugin + in args + ] class PluginManager: - def __init__(self, types, logger, config, flags, arg_parser): - self._uninitialized = [] + def __init__(self, logger, config, flags, arg_parser): self._logger = logger - self._config = config - self._flags = flags self._arg_parser = arg_parser class PluginState: def __init__(self): self.active = [] self.inactive = [] - for type in types: + for type in Plugin.types(): setattr(self, type, PluginState()) - def add_from(self, location, *args, **kwargs): - for plugin in args: - plugin = location(plugin, self._config, self._flags) - plugin.set_logger(self._logger) - plugin.add_args(self._arg_parser) - self._uninitialized.append(plugin) + self._uninitialized = PluginLoader.all( + command_line=False, + logger=logger, + config=config, + flags=flags, + ) + for plugin in self._uninitialized: + plugin.add_args(arg_parser) def activate(self, activate_type=None): args = self._arg_parser.parse_args() @@ -53,10 +90,19 @@ class PluginManager: ).append(plugin) class Plugin: - def __init__(self, name, module, external_config, external_flags): + @staticmethod + def types(): + return [ "logger", "source", "name", "upload", "location", "feedback" ] + + def __init__(self, name, module, logger, external_config, external_flags): self.__dict__ = module.__dict__ self.name = name + if logger == None: + return + + self.logger = logger + if not isinstance(self.plugin_type, set): self.plugin_type = { self.plugin_type } @@ -84,7 +130,7 @@ class Plugin: if hasattr(self, "args"): for arg in self.args.items(): arg[1].bind(self, arg[0]) - flags = external_flags.get(arg[0], {}) + flags = external_flags.get(arg[0], dict()) arg[1].set_flags(flags.get("short"), flags.get("long")) value = arg[1].default() if value != NoDefault: @@ -98,9 +144,6 @@ class Plugin: flat_config.__dict__ = config self.config = flat_config - def set_logger(self, logger): - self.logger = logger - def add_args(self, arg_parser): for arg in self.args.values(): arg.add(arg_parser) @@ -126,20 +169,15 @@ class Plugin: return activate @staticmethod - def internal(location, name=None, config=None, flags=None): - def _load_internal(_name, _config, _flags): - return Plugin(_name, importlib.import_module(f"{location}.{_name}"), _config.get(_name, dict()), _flags.get(_name, dict())) - if name == None: - return _load_internal - else: - return _load_internal(name, config, flags) + def internal(name=None, logger=None, config=None, flags=None): + return Plugin(name, importlib.import_module(f"sshare.plugins.{name}"), logger, config.get(name, dict()), flags.get(name, dict())) @staticmethod - def external(path, config, flags): + def external(path, logger=None, config=None, flags=None): module_spec = importlib.util.spec_from_file_location( path.stem, path.as_posix(), ) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) - return Plugin(path.stem, module, config.get(path.stem, dict()), flags.get(path.stem, dict())) + return Plugin(path.stem, module, logger, config.get(path.stem, dict()), flags.get(path.stem, dict()))