Dynamically Loading Models: A Guide to Model Registry Patterns

While reading some code in MMDetection, a project from OpenMMLab, I came across numerous decorators like @MODELS.register_module() that initially baffled me.This encounter piqued my curiosity about the underlying mechanisms and led me to learn about Registry Patterns.It turns out, this pattern is a cornerstone in the architecture of OpenMMLab projects, including MMObjectDetection, MMCV, and MMOCR.It simplifies managing and loading different model architectures dynamically.In this blog, I'll share insights into the Registry Pattern, its advantages, and how it can be applied in PyTorch to enhance code maintainability, readability, and ease of experimenting with different model architectures.

Introduction to Registry Patterns

The Registry Pattern is a design pattern that enables dynamic registration and retrieval of class implementations in a program.It acts as a central database where classes(in our context, model architectures) are registered with a unique key.Later, these classes can be retrieved and instantiated based on configuration files or runtime decisions.This pattern is particularly useful in machine learning and deep learning frameworks, where the ability to experiment with different architectures without altering the core codebase is crucial.

Initial Encounters with Model Loading

Traditionally, loading different model architectures based on a configuration file in PyTorch involved a straightforward but rigid approach.Consider the following example, where we define three variants of the ResNet architecture in a file named resnet.py:

# resnet.py
import torch.nn as nn

class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        print("Resnet18")

class ResNet32(nn.Module):
    def __init__(self):
        super().__init__()
        print("Resnet32")

class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        print("Resnet50")

With a corresponding configuration in config.yaml:

arch: ResNet18

The Python script to dynamically load the model might look something like this:

import yaml
import importlib

# Load the configuration
with open('config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

# Dynamically import the corresponding ResNet class
resnet_module = importlib.import_module('resnet')
arch_name = f"{config['arch']}"
resnet_class = getattr(resnet_module, arch_name)

This method, while functional, lacks flexibility and scalability.As the number of models grows, this approach becomes increasingly unwieldy.

Better Method Model Registry

Inspired by patterns seen in OpenMMLab projects, adopting a Model Registry approach offers a more elegant and scalable solution.A Model Registry acts as a centralized repository where each model class is registered with a unique identifier.This allows for dynamic model loading based on runtime decisions or configuration files, greatly simplifying the code and enhancing its maintainability.

Here's a simplified example of implementing a Model Registry:


# Model Registry Implementation
class ModelRegistry:
_registry = {}

@classmethod
def register(cls, name, model_class):
cls._registry[name] = model_class

@classmethod
def get_model(cls, name):
model_class = cls._registry.get(name)
if model_class:
    return model_class()
else:
    raise ValueError(f"Model type '{name}' not registered.")

Model classes can then be registered to this registry:

ModelRegistry.register('resnet18', ResNet18)
ModelRegistry.register('resnet32', ResNet32)
ModelRegistry.register('resnet50', ResNet50)

And later retrieved dynamically:

model = ModelRegistry.get_model(config['arch'])

Even Better Method: Decorator - Based Registration

Taking inspiration from the decorators seen in MMObjectDetection, we can further refine the Model Registry pattern by employing Python decorators.This approach simplifies registration by automatically registering each model class upon definition:


# Decorator-based model registration
class ModelRegistry:
    _registry = {}

    @classmethod
    def register(cls, name=None):
        def decorator(model_class):
            nonlocal name
            if name is None:
                name = model_class.__name__.lower()
            cls._registry[name] = model_class
            return model_class
        return decorator

    @classmethod
    def get_model(cls, name):
        model_class = cls._registry.get(name)
        if model_class:
            return model_class()
        else:
            raise ValueError(f"Model type '{name}' not registered.")

By annotating model classes with @ModelRegistry.register, we achieve automatic registration:


@ModelRegistry.register('resnet18')
class ResNet18(nn.Module):
    ...

@ModelRegistry.register('resnet32')
class ResNet32(nn.Module):
    ...

@ModelRegistry.register('resnet50')
class ResNet50(nn.Module):
    ...

Conclusion

The journey from a straightforward implementation to a flexible, decorator - based Model Registry pattern illustrates the power of adopting design patterns from leading projects like MMObjectDetection.This approach not only facilitates easier experimentation with various model architectures but also significantly enhances the code's structure and readability. By leveraging the Registry Pattern, developers can maintain a clean and extendable codebase, enabling rapid iteration and innovation in machine learning projects.

Hopefully this helped you understand codebases of openmmlab Projects better.

Reference / Links

1.MMDet

2.MMOCR

3.Python Patterns

Mastodon