diff --git a/src/sshare/main.py b/src/sshare/main.py index f0c63ae..86e97b3 100644 --- a/src/sshare/main.py +++ b/src/sshare/main.py @@ -26,9 +26,7 @@ import sys from pathlib import Path from version import version -from plugins.config import Default from plugins.config import NoDefault -from plugins.config import Flags class Congloggerate: def __init__(self, loggers): @@ -81,12 +79,9 @@ def main(): Plugin("append_type", importlib.import_module("plugins.default.append_type")), Plugin("ssh", importlib.import_module("plugins.default.ssh")), ] - plugins = { - "logger": [], - "source": [], - "name": [], - "upload": [], - } + plugins = {} + for type in [ "logger", "source", "name", "upload" ]: + plugins[type] = { "active": [], "inactive": [] } # Load external plugins sys.dont_write_bytecode = True @@ -102,6 +97,20 @@ def main(): sys.dont_write_bytecode = False # Set plugin configurations from config file + # Load plugin arguments and detect conflicts + error = False + argument_map = {} + used_arguments = {} + parser = argparse.ArgumentParser( + prog = "SSHare", + description = "Upload files to a server via ssh", + ) + parser.add_argument( + "-v", + "--version", + action="version", + version=f"%(prog)s version {version}", + ) if config.get("plugins") == None: config["plugins"] = {} for plugin in plugins_flat: @@ -109,33 +118,34 @@ def main(): plugin_config = config["plugins"].get(plugin.name) if plugin_config != None: for config_entry in plugin_config.items(): - plugin.module.config[config_entry[0]] = Default(config_entry[1]) - - # Flatten plugin configs - class PluginConfig: pass - error = False - for plugin in plugins_flat: - if hasattr(plugin.module, "config"): - config = plugin.module.config - plugin.module.config = PluginConfig() - for config_entry in config.items(): - if isinstance(config_entry[1], NoDefault): - logger.error(f"{plugin.name} > Error: Value '{config_entry[0]}' has no default value and must be specified explicitly") - error = True - elif isinstance(config_entry[1], Default): - setattr(plugin.module.config, config_entry[0], config_entry[1].value) + plugin.module.config[config_entry[0]] = config_entry[1] + else: + setattr(plugin.module, "config", {}) + if hasattr(plugin.module, "args"): + for arg_name, arg in plugin.module.args.items(): + if arg.is_valid(): + arg.set_for_plugin(plugin) + def check_flag(flag): + if flag in used_arguments: + logger.error(f"Error: Argument '{arg_name}' for plugin '{plugin.name}' has conflict. Flag '{flag}' is also used by plugin '{used_arguments[arg.short]}'") + error = True + check_flag(arg.short) + check_flag(arg.long) + arg.add(parser, used_arguments) + argument_map[arg.dest] = plugin, arg_name else: - setattr(plugin.module.config, config_entry[0], config_entry[1]) + logger.error(f"Error: Argument '{arg_name}' must set either one or both of short and long flag parameters") + error = True if error: sys.exit(1) - # Initialise plugins - for plugin in plugins_flat: - setattr(plugin.module, "logger", logger) - if hasattr(plugin.module, "init"): - plugin.module.init() + arguments = parser.parse_args() + for arg, (plugin, config) in list(argument_map.items()): + if getattr(arguments, arg): + plugin.module.config[config] = getattr(arguments, arg) + del argument_map[arg] - # Sort plugins by type + # Sort plugins by type and check activation criteria error = False for plugin in plugins_flat: if isinstance(plugin.module.plugin_type, str): @@ -146,24 +156,73 @@ def main(): logger.error(f"Error: Plugin '{plugin.name}' has an invalid plugin type '{plugin_type}'") error = True else: - plugins_of_type.append(plugin) + active = True + if hasattr(plugin.module, "activate"): + criteria = plugin.module.activate + if isinstance(plugin.module.activate, dict): + criteria = plugin.module.activate.get(plugin_type) + if criteria != None: + for criterion in criteria: + active = not plugin.module.args[criterion].dest in argument_map + if not active: + break + plugins_of_type["active" if active else "inactive"].append(plugin) + for plugin_type, plugins_of_type in plugins.items(): + if len(plugins_of_type["active"]) == 0 and plugin_type != "logger": + if len(plugins_of_type["inactive"]) == 0: + logger.error(f"No '{plugin_type}' plugins available. Atleast one must be provided") + else: + logger.error(f"No '{plugin_type}' plugins activated. Activate at least one of:") + for plugin in plugins_of_type["inactive"]: + logger.error(f"{plugin.name}:") + criteria = plugin.module.activate + if isinstance(plugin.module.activate, dict): + criteria = plugin.module.activate[plugin_type] + for criterion in criteria: + logger.error(f" {plugin.module.args[criterion].pretty()}") + error = True if error: sys.exit(1) - logger = Congloggerate([ logger.module for logger in plugins["logger"] ]) + # Flatten plugin configs + error = False + class PluginConfig: pass + for plugin in plugins_flat: + if hasattr(plugin.module, "config"): + config = plugin.module.config + plugin.module.config = PluginConfig() + for config_entry in config.items(): + if config_entry[1] == NoDefault: + logger.error(f"Error: Value 'plugins.{plugin.name}.{config_entry[0]}' has no default value and must be specified explicitly") + error = True + else: + setattr(plugin.module.config, config_entry[0], config_entry[1]) + if error: + sys.exit(1) + + # Initialise plugins + for plugin in plugins_flat: + setattr(plugin.module, "logger", logger) + if hasattr(plugin.module, "init"): + error = error or plugin.module.init() + + logger = Congloggerate([ logger.module for logger in plugins["logger"]["active"] ]) sources = [] - for plugin in plugins["source"]: + for plugin in plugins["source"]["active"]: sources.append(plugin.module.get()) + if len(sources) == 0: + logger.error("Error: No sources provided. Must activate at least one source plugin") + log_activations(logger, plugins["source"]) for index, source in enumerate(sources): name = "" - for plugin in plugins["name"]: + for plugin in plugins["name"]["active"]: name = plugin.module.name(name, source) sources[index] = name, source for (name, source) in sources: - for plugin in plugins["upload"]: + for plugin in plugins["upload"]["active"]: plugin.module.upload(name, source) sys.exit(0) @@ -174,12 +233,6 @@ def parse_arguments(): description = "Upload files to a server via ssh", ) - parser.add_argument( - "-v", - "--version", - action="version", - version=f"%(prog)s version {version}", - ) parser.add_argument( "-l", "--latest", @@ -206,7 +259,6 @@ def parse_arguments(): const=True, help="Copy the resultant URL to the clipboard", ) - arguments = parser.parse_args() return arguments diff --git a/src/sshare/plugins/config.py b/src/sshare/plugins/config.py index 79199ad..446bca2 100644 --- a/src/sshare/plugins/config.py +++ b/src/sshare/plugins/config.py @@ -12,16 +12,63 @@ # You should have received a copy of the GNU General Public License along with # SSHare. If not, see . -class NoDefault: - def __init__(self, flags=None): - self.flags = flags +class NoDefault: pass -class Default: - def __init__(self, value, flags=None): - self.value = value - self.flags = flags - -class Flags: - def __init__(self, short=None, long=None): - self.short = short - self.long = long +class Argument: + def __init__(self, + short=None, + long=None, + action='store', + nargs=None, + const=None, + default=None, + type=str, + choices=None, + required=False, + help=None, + metavar=None): + self.short = short + self.long = long + self.action = action + self.nargs = nargs + self.const = const + self.default = default + self.type = type + self.choices = choices + self.help = help + self.metavar = metavar or self.long or self.short + + def is_valid(self): + return (self.short != None and self.short != "") or (self.long != None and self.long != "") + + def pretty(self): + if self.short and self.long: + pretty = f"-{self.short}, --{self.long}" + elif self.long: + pretty = f"--{self.long}" + else: + pretty = f"-{self.short}" + return pretty + f" {self.help}" + + def set_for_plugin(self, plugin): + self.plugin = plugin + self.dest = f"{plugin.name}_{self.metavar}" + + def add(self, parser, used_arguments): + parser.add_argument( + f"-{self.short}", + f"--{self.long}", + action=self.action, + nargs=self.nargs, + const=self.const, + default=self.default, + type=self.type, + choices=self.choices, + help=self.help, + metavar=self.metavar, + dest=self.dest, + ) + if self.short: + used_arguments["short"] = self.plugin + if self.long: + used_arguments["long"] = self.plugin diff --git a/src/sshare/plugins/default/current_time.py b/src/sshare/plugins/default/current_time.py index 1c37444..72b5ffe 100644 --- a/src/sshare/plugins/default/current_time.py +++ b/src/sshare/plugins/default/current_time.py @@ -14,8 +14,7 @@ import time -from ..config import Default -from ..config import Flags +from ..config import Argument from ..source import File plugin_type = "name" @@ -23,6 +22,13 @@ plugin_type = "name" config = { "base": 62, } +args = { + "base": Argument( + short="b", + long="base", + help="Set the numeric base to use for the current time" + ) +} def init(): if not isinstance(config.base, int): diff --git a/src/sshare/plugins/default/file.py b/src/sshare/plugins/default/file.py index 36bcc50..9676bef 100644 --- a/src/sshare/plugins/default/file.py +++ b/src/sshare/plugins/default/file.py @@ -12,19 +12,19 @@ # You should have received a copy of the GNU General Public License along with # SSHare. If not, see . +from ..config import Argument from ..config import NoDefault -from ..config import Flags from ..source import File plugin_type = "source" -config = { - "file": NoDefault( - flags=Flags( - short="f", - long="file" - ) - ), +activate = [ "file" ] +args = { + "file": Argument( + short="f", + long="file", + help="Upload a file" + ) } def get(): diff --git a/src/sshare/plugins/default/ssh.py b/src/sshare/plugins/default/ssh.py index b85819b..b06325d 100644 --- a/src/sshare/plugins/default/ssh.py +++ b/src/sshare/plugins/default/ssh.py @@ -15,23 +15,21 @@ import getpass import subprocess -from ..config import Default from ..config import NoDefault -from ..config import Flags from ..source import File from ..source import Raw plugin_type = "upload" config = { - "host": NoDefault(), - "path": NoDefault(), + "host": NoDefault, + "path": NoDefault, "port": 22, "user": getpass.getuser(), } def upload(name, source): - logger.info(f"Uploading to {config.user}@{config.host}/{config.port}:{config.path}/{name}") + logger.info(f"Uploading to {config.user}@{config.host}:{config.path}/{name} on port {config.port}") if isinstance(source, File): command = [ "scp",