"""Driver discovery system via plugins"""
from importlib.metadata import EntryPoint, EntryPoints, entry_points
from typing import Callable
from mass_driver.models.forge import Forge
from mass_driver.models.patchdriver import PatchDriver
from mass_driver.models.repository import Source
from mass_driver.models.scan import Scanner
ENTRYPOINT = "massdriver"
"""The entrypoint we discover all types of plugins from"""
DRIVER_ENTRYPOINT = f"{ENTRYPOINT}.drivers"
"""The specific entrypoint for drivers discovery"""
FORGE_ENTRYPOINT = f"{ENTRYPOINT}.forges"
"""The specific entrypoint for Forge discovery"""
SCANNER_ENTRYPOINT = f"{ENTRYPOINT}.scanners"
"""The specific entrypoint for Scanner discovery"""
SOURCE_ENTRYPOINT = f"{ENTRYPOINT}.sources"
"""The specific entrypoint for Source discovery"""
[docs]
def discover_drivers() -> EntryPoints:
"""Discover all Drivers via plugin system"""
return entry_points(group=DRIVER_ENTRYPOINT)
[docs]
def discover_forges() -> EntryPoints:
"""Discover all Forges via plugin system"""
return entry_points(group=FORGE_ENTRYPOINT)
[docs]
def discover_sources() -> EntryPoints:
"""Discover all Sources via plugin system"""
return entry_points(group=SOURCE_ENTRYPOINT)
[docs]
def discover_scanners() -> EntryPoints:
"""Discover all Scanners"""
return entry_points(group=SCANNER_ENTRYPOINT)
[docs]
def get_plugin_entrypoint(
plugin: str, name: str, entrypoint: str, discover: Callable
) -> EntryPoint:
"""Fetch the given plugin's Entrypoint, by name"""
plugin_objs = discover()
if name not in plugin_objs.names:
raise ImportError(f"{plugin} '{name}' not found in '{entrypoint}'")
(plugin_obj,) = plugin_objs.select(name=name)
return plugin_obj
[docs]
def get_driver_entrypoint(driver_name: str) -> EntryPoint:
"""Fetch the given driver Entrypoint, by name"""
return get_plugin_entrypoint(
"driver", driver_name, DRIVER_ENTRYPOINT, discover_drivers
)
[docs]
def get_forge_entrypoint(forge_name: str) -> EntryPoint:
"""Fetch the given forge Entrypoint, by name"""
return get_plugin_entrypoint("forge", forge_name, FORGE_ENTRYPOINT, discover_forges)
[docs]
def get_source_entrypoint(source_name: str) -> EntryPoint:
"""Fetch the given source Entrypoint, by name"""
return get_plugin_entrypoint(
"source", source_name, SOURCE_ENTRYPOINT, discover_sources
)
[docs]
def get_scanner_entrypoint(scanner_name: str) -> EntryPoint:
"""Fetch the given scanner Entrypoint, by name"""
return get_plugin_entrypoint(
"scanner", scanner_name, SCANNER_ENTRYPOINT, discover_scanners
)
[docs]
def get_driver(driver_name: str) -> type[PatchDriver]:
"""Get the given driver Class, by entrypoint name"""
return get_driver_entrypoint(driver_name).load()
[docs]
def get_forge(forge_name: str) -> type[Forge]:
"""Get the given forge Class, by entrypoint name"""
return get_forge_entrypoint(forge_name).load()
[docs]
def get_source(source_name: str) -> type[Source]:
"""Get the given source Class, by entrypoint name"""
return get_source_entrypoint(source_name).load()
[docs]
def get_scanner(scanner_name: str) -> Scanner:
"""Get the given scanner func, by entrypoint name"""
s = get_scanner_entrypoint(scanner_name)
return Scanner(name=s.name, func=s.load())