← All Posts · ← Previous · Next →

Five-minute Multimethods in Python

March 30, 2005 — originally posted on artima.com


So what are multimethods? I'll give you my own definition, as I've come to understand them: a function that has multiple versions, distinguished by the type of the arguments. (Some people go beyond this and also allow versions distinguished by the value of the arguments; I'm not addressing this here.)

As a very simple example, let's suppose we have a function that we want to define for two ints, two floats, or two strings. Of course, we could define it as follows:

def foo(a, b):
    if isinstance(a, int) and isinstance(b, int):
        ...code for two ints...
    elif isinstance(a, float) and isinstance(b, float):
        ...code for two floats...
    elif isinstance(a, str) and isinstance(b, str):
        ...code for two strings...
    else:
        raise TypeError("unsupported argument types (%s, %s)" % (type(a), type(b)))

But this pattern gets tedious. (It also isn't very OO, but then, neither are multimethods, despite the name, IMO.) So what could this look like using multimethod dispatch? Decorators are a good match:

from mm import multimethod

@multimethod(int, int)
def foo(a, b):
    ...code for two ints...

@multimethod(float, float):
def foo(a, b):
    ...code for two floats...

@multimethod(str, str):
def foo(a, b):
    ...code for two strings...

The rest of this article will show how we can define the multimethod decorator. It's really pretty simple: there's a global registry indexed by function name ('foo' in this case), pointing to a registry indexed by tuples of type objects corresponding to the arguments passed to the decorator. Like this:

# This is in the 'mm' module

registry = {}

class MultiMethod(object):
    def __init__(self, name):
        self.name = name
        self.typemap = {}
    def __call__(self, *args):
        types = tuple(arg.__class__ for arg in args) # a generator expression!
        function = self.typemap.get(types)
        if function is None:
            raise TypeError("no match")
        return function(*args)
    def register(self, types, function):
        if types in self.typemap:
            raise TypeError("duplicate registration")
        self.typemap[types] = function

I hope that wasn't too much code at once; it's really very simple so far (please indulge me in using the words 'class' and 'type' interchangeably here):

I hope it's clear from this that the @multimethod decorator should return a MultiMethod object and somehow call its register() method. Let's see how to do that:

def multimethod(*types):
    def register(function):
        name = function.__name__
        mm = registry.get(name)
        if mm is None:
            mm = registry[name] = MultiMethod(name)
        mm.register(types, function)
        return mm
    return register

That's it! Sparse but it works. Note that only positional parameters are supported; it gets pretty murky if you want to support keyword parameters as well. Default parameter values are somewhat against the nature of multimethods: instead of

@multimethod(int, int)
def foo(a, b=10):
    ...
you'd have to write
@multimethod(int, int)
def foo(a, b):
    ...
@multimethod(int)
def foo(a):
    return foo(a, 10) # This calls the previous foo()!

I've got one improvement to make: I imagine that somtimes you'd want to write a single implementation that applies to multiple types. It would be convenient if the @multimethod decorators could be stacked, like this:

@multimethod(int, int)
@multimethod(int)
def foo(a, b=10):
    ...

This can be done by changing the decorator slightly (this is not thread-safe, but I don't think that matters much, since all this is typically happening at import time):

def multimethod(*types):
    def register(function):
        function = getattr(function, "__lastreg__", function)
        name = function.__name__
        mm = registry.get(name)
        if mm is None:
            mm = registry[name] = MultiMethod(name)
        mm.register(types, function)
        mm.__lastreg__ = function
        return mm
    return register

Note the three-argument getattr() call, which you may not be familiar with: getattr(x, "y", z) returns x.y if it exists, and z otherwise. So that line is equivalent to

if hasattr(function, "__lastreg__"):
    function = function.__lastreg__

You could try to put the assignment to mm.__lastreg__ inside the register() method, but that would just add more distance between the code that sets it and the code that uses it, so I like it better this way. In a more static language, of course, there would have to be a declaration of the __lastreg__ attribute; Python doesn't need this. It's important that this isn't a "normal" attribute name, so that other uses of function attributes aren't preempted. (Hm... There are almost no "normal" uses of function attributes; they are mostly used for various "secret" purposes so name conflicts in the __xxx__ namespace are not inconceivable. Oh well, maybe we should use something really long like multimethod_last_registered or even put the whole thing inside the MultiMethod class so we can use a private variable name like __lastreg.)


← All Posts · ← Previous · Next →