1 minute read

ParamSpecを使うと、デコレータを定義する際に、デコレートされる関数の引数の型情報を明示的に扱うことができる(昨日のブログ)。 ParamSpecを提案しているPEPでは、単純に型情報を保つだけではなく、引数の型に制約を導入するような例が挙げられている。

例えば、以下のコードでは、printallに渡される関数には「最初の位置引数はlistである」という制約が課されている。

from typing import Callable, Concatenate, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")
Q = TypeVar("Q")

def printall(f: Callable[Concatenate[list[Q], P], R]) -> Callable[Concatenate[list[Q], P], R]:
    def ret(xs: list[Q], *args: P.args, **kwargs: P.kwargs) -> R:
        for x in xs:
            print(x)
        return f(xs, *args, **kwargs)
    return ret

@printall
def mysum(xs: list[int]) -> int:  # type checks
    return sum(xs)

@printall
def mysum2(x: int, y: int) -> int:  # type error
    return x + y

上の例で、第一引数を、listではなくIterableにしたいと思ったとする。 何も考えずに、型アノテーションをlistIterableに変えてみると、うまく行かない。

from typing import Callable, Concatenate, ParamSpec, TypeVar, Iterable

P = ParamSpec("P")
R = TypeVar("R")
Q = TypeVar("Q")

def printall(f: Callable[Concatenate[Iterable[Q], P], R]) -> Callable[Concatenate[Iterable[Q], P], R]:
    def ret(xs: Iterable[Q], *args: P.args, **kwargs: P.kwargs) -> R:
        for x in xs:
            print(x)
        return f(xs, *args, **kwargs)
    return ret

@printall
def mysum(xs: list[int]) -> int:  # type error
    return sum(xs)

これは、Callableの第一引数 = 関数の引数の型はcontravariantだからである。 上のコードでは、printallに渡されてくる関数はIterableなものであれば何でも受け入れてくれることを期待されている。 一方で、実際に渡されているmysumlist型しか受け取れない(と宣言されている)。 つまり、型チェッカの立場から見ると、list型しか受け取れないmysumに一般のIterableが渡される可能性があるように見えるため、エラーになる。

一方で、上のコードのように、「Iterableのサブタイプを受け入れてくれる関数なら何でも受け入れる」ことを型で表現したいこともある。そのような時には、TypeVarboundを利用すれば良い。

from typing import Callable, Concatenate, ParamSpec, TypeVar, Iterable

P = ParamSpec("P")
R = TypeVar("R")
XS = TypeVar("XS", bound=Iterable)

def printall(f: Callable[Concatenate[XS, P], R]) -> Callable[Concatenate[XS, P], R]:
    def ret(xs: XS, *args: P.args, **kwargs: P.kwargs) -> R:
        for x in xs:
            print(x)
        return f(xs, *args, **kwargs)
    return ret

@printall
def mysum(xs: list[int]) -> int:  # type checks
    return sum(xs)

@printall
def mysum2(x: int, y: int) -> int:  # type error
    return x + y

タグ:

カテゴリー:

更新日時: