Python编程技巧之register
最近在集成一个CV模型时遇到一个场景, 需要将一系列算法对外封装成一个借口,同时内部算法数量处于不断增加的阶段。最原始的方法就是在主接口中不断增加新算法,简单、方便,但是不够优雅。Python register机制提供一种优雅的方式,只需在新增的算法中添加注册,主接口就可以不再需要再进行引入了,以下以实际代码作为示例。
代码组织如下:
project/
algorithms/
Classification/
classification.py
Detection/
detection.py
...
demo.py
register.py
目的:假如algorithms 下有几百个算法,我们想通过 demo.py 调用其中的100个算法 下的部分算法,那么我们不能通过遍历的方式调用,但是如果在 demo.py 中不断的 import 未免代码太过于冗杂,而且每次有新增算法都得在 demo.py 中加入.
通过register 机制,我们只需在新增算法中增加两行代码就解决了以上问题
classification.py
# -*- coding: UTF-8 -*-
from register import Registers
@Registers.model.register
class Classification():
def __init__(self):
pass
def __call__(self, *args, **kwargs):
print('classification')
detection.py
from register import Registers
@Registers.model.register
class Detection():
def __init__(self):
pass
def __call__(self, *args, **kwargs):
print("detection")
register.py
# -*- coding: UTF-8 -*-
import importlib
import os
import sys
from absl import logging
class Register:
"""Module register"""
def __init__(self, registry_name):
self._dict = {}
self._name = registry_name
def __setitem__(self, key, value):
if not callable(value):
raise Exception("Value of a Registry must be a callable.")
if key is None:
key = value.__name__
if key in self._dict:
logging.warning("Key %s already in registry %s." % (key, self._name))
self._dict[key] = value
def register(self, param):
"""Decorator to register a function or class."""
def decorator(key, value):
self[key] = value
return value
if callable(param):
# @reg.register
return decorator(None, param)
# @reg.register('alias')
return lambda x: decorator(param, x)
def __getitem__(self, key):
try:
return self._dict[key]
except Exception as e:
logging.error(f"module {key} not found: {e}")
raise e
def __contains__(self, key):
return key in self._dict
def keys(self):
"""key"""
return self._dict.keys()
class Registers(): # pylint: disable=invalid-name, too-few-public-methods
"""All module registers."""
def __init__(self):
raise RuntimeError("Registries is not intended to be instantiated")
# 可以初始化多个注册器
model = Register('model')
serving = Register('serving')
def sub_modules(root):
subs = os.listdir(root)
sub_module = []
for sub in subs:
if sub[0].isupper():
sub_module.append('{}.{}'.format(sub, sub[0].lower() + sub[1:]))
return sub_module
# 此处获取所有算法路径,
MODULES_BASE = "algorithms"
SUB_MODULES = sub_modules(MODULES_BASE)
ALL_MODULES = [(MODULES_BASE, SUB_MODULES)]
def _handle_errors(errors):
"""Log out and possibly reraise errors during import."""
if not errors:
return
for name, err in errors:
logging.warning("Module {} import failed: {}".format(name, err))
logging.fatal("Please check these modules.")
def path_to_module_format(py_path):
"""Transform a python file path to module format."""
return py_path.replace("/", ".").rstrip(".py")
def add_custom_modules(all_modules, config=None):
"""Add custom modules to all_modules"""
current_work_dir = os.getcwd()
if current_work_dir not in sys.path:
sys.path.append(current_work_dir)
if config is not None and "custom_modules" in config:
custom_modules = config["custom_modules"]
if not isinstance(custom_modules, list):
custom_modules = [custom_modules]
all_modules += [
("", [path_to_module_format(module)]) for module in custom_modules
]
def import_all_modules_for_register(config=None):
"""Import all modules for register."""
all_modules = ALL_MODULES
add_custom_modules(all_modules, config)
logging.debug(f"All modules: {all_modules}")
errors = []
for base_dir, modules in all_modules:
for name in modules:
try:
if base_dir != "":
full_name = base_dir + "." + name
else:
full_name = name
importlib.import_module(full_name)
logging.debug(f"{full_name} loaded.")
except ImportError as error:
errors.append((name, error))
_handle_errors(errors)
demo.py
# -*- coding: UTF-8 -*-
from register import import_all_modules_for_register
from register import Registers
Registers.model._dict
Registers.serving._dict
print("Registers.model._dict before: ", Registers.model._dict)
import_all_modules_for_register()
print("Registers.model._dict after: ", Registers.model._dict)
for _, cls in Registers.safety._dict.items():
cls()() # 初始化类并调用 __call__