Python编程技巧之register

最近在集成一个CV模型时遇到一个场景, 需要将一系列算法对外封装成一个借口,同时内部算法数量处于不断增加的阶段。最原始的方法就是在主接口中不断增加新算法,简单、方便,但是不够优雅。Python register机制提供一种优雅的方式,只需在新增的算法中添加注册,主接口就可以不再需要再进行引入了,以下以实际代码作为示例。

代码组织如下:

project/
    algorithms/
        Classification/
            classification.py
        Detection/
            detection.py
        ...
    demo.py
    register.py

目的:假如algorithms 下有几百个算法,我们想通过 demo.py 调用其中的100个算法 下的部分算法,那么我们不能通过遍历的方式调用,但是如果在 demo.py 中不断的 import 未免代码太过于冗杂,而且每次有新增算法都得在 demo.py 中加入.

通过register 机制,我们只需在新增算法中增加两行代码就解决了以上问题

classification.py

# -*- coding: UTF-8 -*-
from register import Registers


@Registers.model.register
class Classification():
    def __init__(self):
        pass

    def __call__(self, *args, **kwargs):
        print('classification')

detection.py

from register import Registers


@Registers.model.register
class Detection():
    def __init__(self):
        pass

    def __call__(self, *args, **kwargs):
        print("detection")

register.py

# -*- coding: UTF-8 -*-
import importlib
import os
import sys
from absl import logging


class Register:
    """Module register"""

    def __init__(self, registry_name):
        self._dict = {}
        self._name = registry_name

    def __setitem__(self, key, value):
        if not callable(value):
            raise Exception("Value of a Registry must be a callable.")
        if key is None:
            key = value.__name__
        if key in self._dict:
            logging.warning("Key %s already in registry %s." % (key, self._name))
        self._dict[key] = value

    def register(self, param):
        """Decorator to register a function or class."""

        def decorator(key, value):
            self[key] = value
            return value

        if callable(param):
            # @reg.register
            return decorator(None, param)
        # @reg.register('alias')
        return lambda x: decorator(param, x)

    def __getitem__(self, key):
        try:
            return self._dict[key]
        except Exception as e:
            logging.error(f"module {key} not found: {e}")
            raise e

    def __contains__(self, key):
        return key in self._dict

    def keys(self):
        """key"""
        return self._dict.keys()


class Registers():  # pylint: disable=invalid-name, too-few-public-methods
    """All module registers."""

    def __init__(self):
        raise RuntimeError("Registries is not intended to be instantiated")

    # 可以初始化多个注册器
    model = Register('model')
    serving = Register('serving')


def sub_modules(root):
    subs = os.listdir(root)
    sub_module = []
    for sub in subs:
        if sub[0].isupper():
            sub_module.append('{}.{}'.format(sub, sub[0].lower() + sub[1:]))
    return sub_module


# 此处获取所有算法路径,
MODULES_BASE = "algorithms"
SUB_MODULES = sub_modules(MODULES_BASE)
ALL_MODULES = [(MODULES_BASE, SUB_MODULES)]


def _handle_errors(errors):
    """Log out and possibly reraise errors during import."""
    if not errors:
        return
    for name, err in errors:
        logging.warning("Module {} import failed: {}".format(name, err))
    logging.fatal("Please check these modules.")


def path_to_module_format(py_path):
    """Transform a python file path to module format."""
    return py_path.replace("/", ".").rstrip(".py")


def add_custom_modules(all_modules, config=None):
    """Add custom modules to all_modules"""
    current_work_dir = os.getcwd()
    if current_work_dir not in sys.path:
        sys.path.append(current_work_dir)
    if config is not None and "custom_modules" in config:
        custom_modules = config["custom_modules"]
        if not isinstance(custom_modules, list):
            custom_modules = [custom_modules]
        all_modules += [
            ("", [path_to_module_format(module)]) for module in custom_modules
        ]


def import_all_modules_for_register(config=None):
    """Import all modules for register."""

    all_modules = ALL_MODULES

    add_custom_modules(all_modules, config)

    logging.debug(f"All modules: {all_modules}")
    errors = []
    for base_dir, modules in all_modules:
        for name in modules:
            try:
                if base_dir != "":
                    full_name = base_dir + "." + name
                else:
                    full_name = name
                importlib.import_module(full_name)
                logging.debug(f"{full_name} loaded.")
            except ImportError as error:
                errors.append((name, error))
    _handle_errors(errors)

demo.py

# -*- coding: UTF-8 -*-

from register import import_all_modules_for_register
from register import Registers

Registers.model._dict
Registers.serving._dict

print("Registers.model._dict before: ", Registers.model._dict)
import_all_modules_for_register()
print("Registers.model._dict after: ", Registers.model._dict)
for _, cls in Registers.safety._dict.items():
    cls()() # 初始化类并调用 __call__