Taichi源码分析(一)
@ kingkiller | Monday, Feb 6, 2023 | 9 minutes read | Update at Monday, Feb 6, 2023

Taichi编译链路之Python部分的生命周期

Taichi 源码分析(一)

1. Taichi项目结构及源码编译

今天我没来探索一下taichi 程序的整体生命周期,首先我们先看一下taichi的项目结构

文件层级很多,就没有树来进行展示了,直接看下面的截图

project structure

taichi 是一个eDSL语言,嵌入到python语言当中,利用了Python解释器作为了部分运行时,同时通过pybind链接到了底层taichi c++库, taichi语言的后端是利用c++实现高性能计算的。

如上图所示,项目本身是一个CMake项目,在根目录中项目整体CMakeList.txt, README文档和版本文档等其他配置文件,setup.py用于安装taichi的python包。taichi项目根目录有13个文件夹,其中核心程序文件在taichi文件夹和python文件夹中,taichi文件夹内为taichi的c++源码,Python中是上层的python库。taichi文件内又分为了用于即时编译jit文件夹, 用于代码生成的codegen文件夹,用于pybind绑定的python文件夹等。根目录的python文件夹中taichi是Python项目的最基础包,所有库都从这个包中导入,大部分的语言项目的功能都在taichi.lang包中。

现在我们来编译taichi程序,需要注意的是taichi项目中的第三包并没有加入到项目中,大家在克隆时需要试下下面的指令

git clone --recursive https://github.com/taichi-dev/taichi.git

不过Eigen库所引用的仓库已经无法访问了,实测大家可以使用这个仓库的eigen代码(https://github.com/PX4/eigen)。taichi需要用到pybind和llvm等第三方包,大家需要自行安装这些,pybind全部为头文件实现,安装十分方便,从git仓库中下载完成后直接cmake安装即可。llvm的话太极官方为大家提供了他们自定义版LLVM的预编译文件,大家可以在这里下载对应版本https://github.com/taichi-dev/taichi_assets/releases。

这些前置工作完成后直接,cmake + make build即可,将生成的pyd文件放到对应的python文件夹中,执行

python setup.py

到此就源码编译安装完成了。

2. Python部分生命周期

首先需要说明一下,此处虽说是讲Python部分的生命周期,但在初始化时taichi仍然调用了一部分pybind封装的c++代码,其中大部分都是用于初始化配置,我们也会顺带讲解一下,我仍认为这一部分属于Python部分的生命周期,等真正进入JIT编译时,才算是进入到C++部分的生命周期。

2.1 实例代码

import taichi as ti

ti.init(arch=ti.cpu)

x = ti.field(ti.i32, shape=(3, 4))
y = ti.ndarray(int, shape=16)
z = int(100)


@ti.kernel
def play(arg0: int, arg1: int):
    tmp = 0

    if arg1:
        tmp = tmp + 4

    for i in range(10):
        tmp += 1

    print(tmp)

    assert z == 100
    print(z)
    
    for i in x:
        x[i] = 0
    
    for i in y:
        ti.atomic_add(y[i], 5)


play(10, True)

下面我们以上面那段代码为例,taichi程序会分为两个部分,一部分是属于Python scope中的,而另一部分被装饰器所修饰的代码是属于taichi scope中的,可以看到taichi scope中的也会调用到Python全局中的field变量,这一部分我们暂时不考虑,taichi中的全局field变量会通过特殊处理链接到taichi程序中去,我们分开两部分来看,我们首先先看不含这些全局变量的taichi程序的生命周期是如何的,即我们这次调试的代码为下面的情况:

import taichi as ti

ti.init(arch=ti.cpu)

# x = ti.field(ti.i32, shape=(3, 4))
# y = ti.ndarray(int, shape=16)
# z = int(100)


@ti.kernel
def play(arg0: int, arg1: int):
    tmp = 0

    if arg1:
        tmp = tmp + 4

    for i in range(10):
        tmp += 1

    print(tmp)

    # assert z == 100
    # print(z)
    #
    # for i in x:
    #     x[i] = 0
    #
    # for i in y:
    #     ti.atomic_add(y[i], 5)


play(10, True)

2.2 Tachi Init

可以看到要想加载taichi程序,我们会首先需要初始化taichi包,我这里使用的cpu作为后端进行运行。

def init(arch=None,
         default_fp=None,
         default_ip=None,
         _test_mode=False,
         enable_fallback=True,
         require_version=None,
         **kwargs):
    # Check version for users every 7 days if not disabled by users.
    _version_check.start_version_check_thread()

在运行init函数之前,taichi包中有很多全局变量,下面一个是会伴随我们整个生命周期的一个Python全局变量PyTaichi,他在taichi.lang包下的impl模块中被定义

class PyTaichi:
    def __init__(self, kernels=None):
        self.materialized = False
        self.prog = None
        self.compiled_functions = {}
        self.src_info_stack = []
        self.inside_kernel = False
        self.current_kernel = None
        self.global_vars = []
        self.grad_vars = []
        self.dual_vars = []
        self.matrix_fields = []
        self.default_fp = f32
        self.default_ip = i32
        self.default_up = u32
        self.target_tape = None
        self.fwd_mode_manager = None
        self.grad_replaced = False
        self.kernels = kernels or []
        self._signal_handler_registry = None

PyTaichi的初始化只是单纯的对成员函数赋初值。

而我们之前调用的init函数是一大串初始化配置的代码,在开始时会进行版本的检查

def start_version_check_thread():
    skip = os.environ.get("TI_SKIP_VERSION_CHECK")
    if skip != 'ON':
        # We don't join this thread because we do not wish to block users.
        check_version_thread = threading.Thread(target=try_check_version,
                                                daemon=True)
        check_version_thread.start()

此处开启多线程进行版本检查,实际运行函数为try_check_version

def try_check_version():
    try:
        os.makedirs(_ti_core.get_repo_dir(), exist_ok=True)
        version_info_path = os.path.join(_ti_core.get_repo_dir(),
                                         'version_info')
        cur_date = datetime.date.today()
        if os.path.exists(version_info_path):
            with open(version_info_path, 'r') as f:
                version_info_file = f.readlines()
                last_time = version_info_file[0].rstrip()
                cur_uuid = version_info_file[2].rstrip()
            if cur_date.strftime('%Y-%m-%d') > last_time:
                response = check_version(cur_uuid)
                write_version_info(response, cur_uuid, version_info_path,
                                   cur_date)
        else:
            cur_uuid = str(uuid.uuid4())
            write_version_info({'status': 0}, cur_uuid, version_info_path,
                               cur_date)
            response = check_version(cur_uuid)
            write_version_info(response, cur_uuid, version_info_path, cur_date)
    # Wildcard exception to catch potential file writing errors.
    except:
        pass

在这个函数中,第一次运行taichi程序时会创建了一个taichi_cache文件夹用于放置临时缓存,windows默认在C盘根目录。同样在第一次运行时会在taichi_cache文件夹中创建一个version_info文件,这个文件内容是一传版本字符串,由时期+uuid组成,如果当前日期大于最后一次更新日期,需要对版本进行更新,日期判断只精确到天。

之后回到init函数中。

current_dir = os.getcwd()

if require_version is not None:
    check_require_version(require_version)

if "packed" in kwargs:
    if kwargs["packed"] is True:
        warnings.warn(
            "Currently packed=True is the default setting and the switch will be removed in v1.4.0.",
            DeprecationWarning)
    else:
        warnings.warn(
            "The automatic padding mode (packed=False) will no longer exist in v1.4.0. The switch will "
            "also be removed then. Make sure your code doesn't rely on it.",
            DeprecationWarning)

if "default_up" in kwargs:
    raise KeyError(
        "'default_up' is always the unsigned type of 'default_ip'. Please set 'default_ip' instead."
    )
default_fp = deepcopy(default_fp)
default_ip = deepcopy(default_ip)
kwargs = deepcopy(kwargs)

之后获取到了当前运行的py文件所在的文件夹,对参数中进行了检测,如果存在不合法会进行警告和报错,对合法参数进行深拷贝复制。

def reset():
    global pytaichi
    old_kernels = pytaichi.kernels
    pytaichi.clear()
    pytaichi = PyTaichi(old_kernels)
    for k in old_kernels:
        k.reset()

之后进行reset操作,reset的具体实现在Impl中,清空当前全局的pytaichi变量中的kernel集合,并对每一个kernel进行复位。

cfg = impl.default_cfg()
cfg.offline_cache = True  # Enable offline cache in frontend instead of C++ side

spec_cfg = _SpecialConfig()
env_comp = _EnvironmentConfigurator(kwargs, cfg)
env_spec = _EnvironmentConfigurator(kwargs, spec_cfg)

# configure default_fp/ip:
# TODO: move these stuff to _SpecialConfig too:
env_default_fp = os.environ.get("TI_DEFAULT_FP")
	......

if default_fp is not None:
    impl.get_runtime().set_default_fp(default_fp)
if default_ip is not None:
    impl.get_runtime().set_default_ip(default_ip)

接下来就是进行配置设置,上面代码全为配置设置,获取的Config类是由C++ Pybind绑定的,如下:

struct CompileConfig {
  Arch arch;
  bool debug;
  bool cfg_optimization;
  bool check_out_of_bound;
  bool validate_autodiff;
  int simd_width;
  int opt_level;
  int external_optimization_level;
  int max_vector_width;
  bool packed;
  bool print_preprocessed_ir;
	 ......

  CompileConfig();
};

默认的配置都是对这些成员变量进行初始化,接下来的是一些特殊环境配置,我初始化时并没有进行设置,这些部分都是默认值,大部分的控制流都不会进入,之后是加入日志等级和gdb等信息。

# compiler configurations (ti.cfg):
for key in dir(cfg):
    if key in ['arch', 'default_fp', 'default_ip']:
        continue
    _cast = type(getattr(cfg, key))
    if _cast is bool:
        _cast = None
    env_comp.add(key, _cast)

unexpected_keys = kwargs.keys()

逐级遍历Config中的属性,将属性名作为键值,类型作为value存到env.comp中。

unexpected_keys = kwargs.keys()

if len(unexpected_keys):
    raise KeyError(
        f'Unrecognized keyword argument(s) for ti.init: {", ".join(unexpected_keys)}'
    )

之后是查看是否有不需要的参数,有则报错

get_default_kernel_profiler().set_kernel_profiler_mode(cfg.kernel_profiler)

# create a new program:
impl.get_runtime().create_program()

_logging.trace('Materializing runtime...')
impl.get_runtime().prog.materialize_runtime()

impl._root_fb = _snode.FieldsBuilder()

if cfg.debug:
    impl.get_runtime()._register_signal_handlers()

os.chdir(current_dir)
return None

接下来就是初始化一些核心类用于创建和编译Kernel。这里首先会创建一个Program,这个Program可以视为一个整体的Taichi程序,会存放全部的filed等全局变量和所有的kernel。Program是一个c++类,通过Pybind绑定成Python对象,Program本身并不提供任何实现,具体实现有其子类提供,默认使用的是LLVM版本的Program,这段代码可以在taichi项目中的c++源码文件夹taichi下面的program文件夹中的program.cpp中找到:

Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) {
  TI_TRACE("Program initializing...");
  .......
  main_thread_id_ = std::this_thread::get_id();
  ......

  profiler = make_profiler(config.arch, config.kernel_profiler);
  if (arch_uses_llvm(config.arch)) {
#ifdef TI_WITH_LLVM
    if (config.arch != Arch::dx12) {
      program_impl_ = std::make_unique<LlvmProgramImpl>(config, profiler.get());
    } else {
      // NOTE: use Dx12ProgramImpl to avoid using LlvmRuntimeExecutor for dx12.
#ifdef TI_WITH_DX12
      TI_ASSERT(directx12::is_dx12_api_available());
      program_impl_ = std::make_unique<Dx12ProgramImpl>(config);
#else
      TI_ERROR("This taichi is not compiled with DX12");
#endif
    }
#else
    TI_ERROR("This taichi is not compiled with LLVM");
#endif
  } else if (config.arch == Arch::metal) {
#ifdef TI_WITH_METAL
    TI_ASSERT(metal::is_metal_api_available());
    program_impl_ = std::make_unique<MetalProgramImpl>(config);
#else
    TI_ERROR("This taichi is not compiled with Metal")
#endif

我们下面的也都以默认的LLVM版本的Program实现进行讲解,后续有时间我会讲解其他backend实现。

让我们回到Python的init函数中去,在创建完Program后。调用了materialize_runtime方法来创建运行环境,我们具体来看一下LLVMProgramImpl的实现:

void materialize_runtime(MemoryPool *memory_pool,
                         KernelProfilerBase *profiler,
                         uint64 **result_buffer_ptr) override {
  runtime_exec_->materialize_runtime(memory_pool, profiler,
                                     result_buffer_ptr);
}

额,又包了一层,这里调用的是其执行器的初始化方法,继续往下走可以看到这里是对jit进行一些配置。到此为止,超长的Init函数就看完了。

2.3 kernel装饰器

之后我们看一下@ti.kernel装饰器

def _kernel_impl(_func, level_of_class_stackframe, verbose=False):
    is_classkernel = _inside_class(level_of_class_stackframe + 1)
    if verbose:
        print(f'kernel={_func.__name__} is_classkernel={is_classkernel}')
    primal = Kernel(_func,
                    autodiff_mode=AutodiffMode.NONE,
                    _classkernel=is_classkernel)
    adjoint = Kernel(_func,
                     autodiff_mode=AutodiffMode.REVERSE,
                     _classkernel=is_classkernel)
    primal.grad = adjoint
    if is_classkernel:
        @functools.wraps(_func)
        def wrapped(*args, **kwargs):
            clsobj = type(args[0])
            assert not hasattr(clsobj, '_data_oriented')
            raise TaichiSyntaxError(
                f'Please decorate class {clsobj.__name__} with @ti.data_oriented'
            )
    else:
        @functools.wraps(_func)
        def wrapped(*args, **kwargs):
            try:
                return primal(*args, **kwargs)
            except (TaichiCompilationError, TaichiRuntimeError) as e:
                raise type(e)('\n' + str(e)) from None
        wrapped.grad = adjoint
    wrapped._is_wrapped_kernel = True
    wrapped._is_classkernel = is_classkernel
    wrapped._primal = primal
    wrapped._adjoint = adjoint
    return wrapped

def kernel(fn):
    return _kernel_impl(fn, level_of_class_stackframe=3)

kernel是一个装饰器函数,其内部调用了_kernel_impl作为具体实现。这里首先判断了当前的kernel是否在一个类中,是的,taichi允许你在一个类中定义kernel,这给了开发者极大的便利,同时taichi用了非常巧妙的方式来处理在类中的Kernel,这里我们暂时不讨论,我先关注于在类外面的kernel函数。而_inside_class方法具体是如何判断是否在类中的,则用了我们的老朋友,inspect模块来获取解释器堆栈上的语句信息,通过正则匹配检测class关键字实现的。

接下来实例化了两个Kernel类,这个Kernel类是taichi程序的编译核心,也是执行入口,具体编译过程在Kernel的__call__魔法函数中。

之后是一个内部闭包wrapped,functools.wrap装饰器的作用是不改变调用对象,单纯改变所装饰对象的一些属性,也就是这里最后返回的是wrapped函数,但是属性则依然是传入的play函数,所以我们打印play函数的__name____doc__属性就会发现没有发生改变,但是具体调用逻辑则是调用Kernel的__call__函数。

那我们接下来就很好看一下Kernel这个类:

class Kernel:
    counter = 0

    def __init__(self, _func, autodiff_mode, _classkernel=False):
        self.func = _func
        self.kernel_counter = Kernel.counter
        Kernel.counter += 1
        assert autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION,
                                 AutodiffMode.FORWARD, AutodiffMode.REVERSE)
        self.autodiff_mode = autodiff_mode
        self.grad = None
        self.arguments = []
        self.return_type = None
        self.classkernel = _classkernel
        self.extract_arguments()
        self.template_slot_locations = []
        for i, arg in enumerate(self.arguments):
            if isinstance(arg.annotation, template):
                self.template_slot_locations.append(i)
        self.mapper = TaichiCallableTemplateMapper(
            self.arguments, self.template_slot_locations)
        impl.get_runtime().kernels.append(self)
        self.reset()
        self.kernel_cpp = None
        # TODO[#5114]: get rid of compiled_functions and use compiled_kernels instead.
        # Main motivation is that compiled_kernels can be potentially serialized in the AOT scenario.
        self.compiled_kernels = {}
        self.has_print = False

在初始化Kernel时,把所装饰的函数作为了当前func属性,获取当前kernel的计数器,之后静态成员变量counter+1。判断是否开启auto_diff,是否有自动微分等。之后通过extract_argument来获取函数形参参数:

def extract_arguments(self):
    sig = inspect.signature(self.func)
    if sig.return_annotation not in (inspect._empty, None):
        self.return_type = sig.return_annotation
    params = sig.parameters
    arg_names = params.keys()
    for i, arg_name in enumerate(arg_names):
        param = params[arg_name]
        if param.kind == inspect.Parameter.VAR_KEYWORD:
            raise TaichiSyntaxError(
                'Taichi kernels do not support variable keyword parameters (i.e., **kwargs)'
            )
		......
        annotation = param.annotation
        if param.annotation is inspect.Parameter.empty:
            if i == 0 and self.classkernel:  # The |self| parameter
                annotation = template()
            else:
                raise TaichiSyntaxError(
                    'Taichi kernels parameters must be type annotated')
		......
        self.arguments.append(
            KernelArgument(annotation, param.name, param.default))

获取的方式当然又是我们的老熟人inspect模块,通过signature获取到func上的形参和返回值类型属性,之后是一大段参数类型的异常判断,在这里我们看到关键字参数,带初始值的参数都是不支持的。之后将这些参数加入到Kernel类的成员变量arguments中,Taichi在这里对参数又进行了一层包装,不过KernelArgument就是一个非常简单的data类了:

class KernelArgument:
    def __init__(self, _annotation, _name, _default=inspect.Parameter.empty):
        self.annotation = _annotation
        self.name = _name
        self.default = _default

让我们回到Kernel的构造函数,获取到函数参数后对这些形参进行了遍历,判断是否是Template类型参数,这里的Template是一个用Python定义的一个Data类,用于taichi的元编程使用,taichi的模板编程实现的方式我们也暂且不表。再往下是初始化了一个TaichiCallbleTemplateMapper这个map存储了参数了信息为后续调用时使用。随后将现在的kernel加入到我们之前提到的pytaichi全局变量中,那是一个Program类,存储了全部的kernel。再之后就是调用了reset函数:

def reset(self):
    self.runtime = impl.get_runtime()

reset函数非常简单,把初始化当前runtime成员变量的方法提取出来以供复用,这里runtime所赋值的依然是pytaichi。到此Kernel的初始化工作基本完成了。

Taichi所采用的对Python实施静态编译的策略与我们之前探讨的pytorch和qcor有所不同,他并不是在装饰器加载时就完成了jit编译,而是在第一次调用时才进行编译,下面我们就探究一下Kernel的__call__函数:

2.4 kernel的运行入口

@_shell_pop_print
def __call__(self, *args, **kwargs):
    args = _process_args(self, args, kwargs)
    if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
        self.runtime.fwd_mode_manager.insert(self)
    if self.autodiff_mode in (
            AutodiffMode.NONE, AutodiffMode.VALIDATION
    ) and self.runtime.target_tape and not self.runtime.grad_replaced:
        self.runtime.target_tape.insert(self, args)
    if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg(
    ).opt_level == 0:
        _logging.warn(
            """opt_level = 1 is enforced to enable gradient computation."""
        )
        impl.current_cfg().opt_level = 1
    key = self.ensure_compiled(*args)
    return self.runtime.compiled_functions[key](*args)

这段程序的代码量并不多,首先_shell_pop_print装饰器在绝大多数情况下冰不起作用,不会改变调用主体,只有在开启pybuf时会有一个额外处理,但依然会调用__call__函数,并将其返回值作为新函数返回值,只是会进行一步额外信息打印操作。

之后对实参进行处理,这里的_process_arg主要做了参数校验的工作,判断了实参和形参的个数是否统一,并将*args解构赋值到一个list中返回。

之后判断了是否启用了自动微分,在本例中未使用自动微分,这一步也不用考虑。之后就进入到了我们的编译环节,ensure_compiled函数代码如下:

def ensure_compiled(self, *args):
    instance_id, arg_features = self.mapper.lookup(args)
    key = (self.func, instance_id, self.autodiff_mode)
    self.materialize(key=key, args=args, arg_features=arg_features)
    return key

这里调用TaichiCallableTemplateMapperloopup功能获得的instance_id和arg_features分别为0和两个'#'字符串组成的list。

def materialize(self, key=None, args=None, arg_features=None):
    if key is None:
        key = (self.func, 0, self.autodiff_mode)
    self.runtime.materialize()
    if key in self.runtime.compiled_functions:
        return
    grad_suffix = ""
		......
    kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}{grad_suffix}"
    _logging.trace(f"Compiling kernel {kernel_name}...")

    tree, ctx = _get_tree_and_ctx(
        self,
        args=args,
        excluded_parameters=self.template_slot_locations,
        arg_features=arg_features)

    if self.autodiff_mode != AutodiffMode.NONE:
        KernelSimplicityASTChecker(self.func).visit(tree)
    def taichi_ast_generator(kernel_cxx):
		......

    taichi_kernel = impl.get_runtime().prog.create_kernel(
        taichi_ast_generator, kernel_name, self.autodiff_mode)

    self.kernel_cpp = taichi_kernel

    assert key not in self.runtime.compiled_functions
    self.runtime.compiled_functions[key] = self.get_function_body(
        taichi_kernel)
    self.compiled_kernels[key] = taichi_kernel

上面是Kernel中的materialize函数,可以看到这里又调用了一次ProgramImpl的materialize,这个在之前Init的时候介绍过,主要是Jit相关的设置。之后进行判断是否进行过了编译,jit编译后的二进制代码会存放在内存中的,所以只需要第一次执行时进行编译即可,所以这里如果发现该kernel已经经过编译了,则这一步可以跳过。之后我们可以看到JIT函数的规则,Python本身的函数名+当前kernel序号构成,后面是一些附加信息,这一串规则可以巧妙的避免命名重复导致错误,同时比使用Uuid等方式随机产生的一大段随机字符串在长度方面小很多。我们此处的名称为play_c76_0这个也是我们之后实际调用JIT的函数名称。接下来这一段是用来获取Python AST树的:

def _get_tree_and_ctx(self,
                      excluded_parameters=(),
                      is_kernel=True,
                      arg_features=None,
                      args=None,
                      ast_builder=None,
                      is_real_function=False):
    file = getsourcefile(self.func)
    src, start_lineno = getsourcelines(self.func)
    src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
    tree = ast.parse(textwrap.dedent("\n".join(src)))

    func_body = tree.body[0]
    func_body.decorator_list = []

    global_vars = _get_global_vars(self.func)
    for i, arg in enumerate(func_body.args.args):
        anno = arg.annotation
        if isinstance(anno, ast.Name):
            global_vars[anno.id] = self.arguments[i].annotation

    if isinstance(func_body.returns, ast.Name):
        global_vars[func_body.returns.id] = self.return_type

    if is_kernel or is_real_function:
        # inject template parameters into globals
        for i in self.template_slot_locations:
            template_var_name = self.arguments[i].name
            global_vars[template_var_name] = args[i]

    return tree, ASTTransformerContext(excluded_parameters=excluded_parameters,
                                       is_kernel=is_kernel,
                                       func=self,
                                       arg_features=arg_features,
                                       global_vars=global_vars,
                                       argument_data=args,
                                       src=src,
                                       start_lineno=start_lineno,
                                       file=file,
                                       ast_builder=ast_builder,
                                       is_real_function=is_real_function)

又是我们的老朋友inspect模块,其实整体逻辑还是比较简单的,通过inspect模块获取到了所执行的文件路径和所执行的函数源码。之后通过_get_global_vars获取了全局变量,此处对闭包做了特殊处理。随后将参数和返回值类型加入到全局变量集合中。之后返回了ast树和taichi自己定义的用于做转换的Python类ASTTransformerContext(代码就不PO了,其实基本上算是一个Data类了,在初始化时没有做特殊操作)。

之后我们回到materialize函数中,如果此时开启了auto_diff,则会对AST树进行解析,做一些语义检测,这里采取的是访问器模式,之后会详细开一期介绍taichi 在启用不同模式auto_diff情况下的生命周期。

随后调用taichi的c++库来创建C++ Kernel对象并加入到当前pytaichi(上文提到的全局PyTaichi对象)的program中去。我们来看一下create_kernel的c++实现吧,这里其实在c++ pybind文件中,是一个匿名函数

.def(
    "create_kernel",
    [](Program *program, const std::function<void(Kernel *)> &body,
       const std::string &name, AutodiffMode autodiff_mode) -> Kernel * {
      py::gil_scoped_release release;
      return &program->kernel(body, name, autodiff_mode);
    },
      
Kernel &Program::kernel(const std::function<void(Kernel *)> &body,
                 const std::string &name = "",
                 AutodiffMode autodiff_mode = AutodiffMode::kNone) {
    // Expr::set_allow_store(true);
    auto func = std::make_unique<Kernel>(*this, body, name, autodiff_mode);
    // Expr::set_allow_store(false);
    kernels.emplace_back(std::move(func));
    return *kernels.back();
  }

注意这里pybind中的py::gil_scoped_release release;这句话,这个是为了多线程情况下对GIL锁的处理,Python C API 规定全局解释器锁 (GIL) 必须始终由当前线程持有才能安全访问 Python 对象。因此,当 Python 通过 pybind11 调用 C++ 时,必须持有 GIL,而 pybind11 永远不会隐式释放 GIL,所以我们就需要手动释放和获取锁。

pybind11 需要确保它正在调用 Python 代码时保持 GIL。如上,这里有趣的一点是我们这里传入的其实是一个Python函数给std::function类型参数作为一个回调,当c++调用该回调函数时需要确保持有GIL,这段会很复杂,我会在后续c++部分的生命周期的时候话一段时间来详解这一块。

回到Python部分,接下来要处理的就是kernel装饰器所装饰的play函数本身了,get_function_body具体内容简单来说就是非常爽快的返回了一个闭包。这个闭包我们之后会用到,这里就直接将他加入到了pytaichi的compiled_functions字典中。之后将C++ Kernekl对象加入到当前Python Kernel类中的compiled_kernels字典中。

在初始化Kernel阶段,会执行传入的回调函数,我们来具体看一下那个回调函数

def taichi_ast_generator(kernel_cxx):
    if self.runtime.inside_kernel:
        raise TaichiSyntaxError(......)
    self.runtime.inside_kernel = True
    self.runtime.current_kernel = self
    try:
        ctx.ast_builder = kernel_cxx.ast_builder()
        transform_tree(tree, ctx)
        if not ctx.is_real_function:
            if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
                raise TaichiSyntaxError(
                    "Kernel has a return type but does not have a return statement"
                )
    finally:
        self.runtime.inside_kernel = False
        self.runtime.current_kernel = None

其实这个回调函数传入的kernel_cxx就是刚刚创建的c++的kernel类,获取的ast_build就FrotentContext默认的ASTBuilder, FrontentContext在创建Kernel示例时会创建。

之后运行transform_tree函数,这一步其实这个回调最关键的时刻:

这一步稍显复杂,这里通过递归来完成了build

至此ensure_compiled功能就走完了,回到最初的__call__函数,最后操作即是从pytaichi的compiled_functions取出之前加入的闭包,并将闭包的结果当做kernel运行的结果返回。到此具体的编译过程正式展开(上面存在一步编译了,利用ast模块对play源码进行了parser,获取到了AST树,这里其实有一定优化空间,Python在对py文件整体编译时保存了AST树,可惜的是无法从Python语言层面去获取到这个内部对象,不过可以拓展cpython来实现这一功能,但是稍显鸡肋,小量代码的parser速度是非常快的,这一部分提升可以忽略不计了)。

2.5 Kernel编译

回到transform_tree这个方法,其实这里在初始化Kernel的时候就已经调用了,transform方法很简单就是实例化了一个ASTTransformer对象,并调用了其__call__方法

这里call重载在父类Builder中:

def __call__(self, ctx, node):
    method = getattr(self, 'build_' + node.__class__.__name__, None)
    try:
        if method is None:
            error_msg = f'Unsupported node "{node.__class__.__name__}"'
            raise TaichiSyntaxError(error_msg)
        info = ctx.get_pos_info(node) if isinstance(
            node, (ast.stmt, ast.expr)) else ""
        with impl.get_runtime().src_info_guard(info):
            return method(ctx, node)
    except Exception as e:
        if ctx.raised or not isinstance(node, (ast.stmt, ast.expr)):
            raise e.with_traceback(None)
        ctx.raised = True
        e = handle_exception_from_cpp(e)
        if not isinstance(e, TaichiCompilationError):
            msg = ctx.get_pos_info(node) + traceback.format_exc()
            raise TaichiCompilationError(msg) from None
        msg = ctx.get_pos_info(node) + str(e)
        raise type(e)(msg) from None

代码量不大,但却很精华,使用了递归的方式去遍历解析tree。

这里需要稍微解释一下with关键字,with结构是python中非常好玩的结构,大致可以理解成一种包围,这里with的是一个SrcInfoGuard类型,这里会在执行with内部语句前调用所with对象的__enter__方法,退出with结构时调用__exit__方法,我们这里看一下SrcInfoGuard结构:

class SrcInfoGuard:
    def __init__(self, info_stack, info):
        self.info_stack = info_stack
        self.info = info

    def __enter__(self):
        self.info_stack.append(self.info)

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.info_stack.pop()

所以这里会在执行method方法前先将info加入到info_stack中,往往也是这种需要栈结构的情况下可以使用with语句。这也是Python比较方便的地方。

我们继续走,这里的method是build_Module,来到build_Module函数:

build_stmt = ASTTransformer()

@staticmethod
def build_Module(ctx, node):
    with ctx.variable_scope_guard():
        # Do NOT use |build_stmts| which inserts 'del' statements to the
        # end and deletes parameters passed into the module
        for stmt in node.body:
            build_stmt(ctx, stmt)
    return None

这里的node.body获取的是一个FunctionDef list,当然这里只有一个元素就是我们上面的play函数,这个FunctionDef类其实CPython中的C 结构体,我们在之前的CPython源码中有提及过,这里再次贴出他的结构:

struct {
    identifier name;
    arguments_ty args;
    asdl_seq *body;
    asdl_seq *decorator_list;
    expr_ty returns;
    string type_comment;
} FunctionDef;

之后再次调用了ASTTransformer的__call__函数,继续经过with语句,顺利的函数info信息加入到info_stack列表中,这次获取的method是 build_FunctionDef, 进入FunctionDef:

@staticmethod
def build_FunctionDef(ctx, node):
    if ctx.visited_funcdef:
        raise TaichiSyntaxError(
            f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
        )
    ctx.visited_funcdef = True

    args = node.args
    assert args.vararg is None
    assert args.kwonlyargs == []
    assert args.kw_defaults == []
    assert args.kwarg is None

开头是一段检验,包括该节点是否访问,参数类型是否合法(这些在之前也进行了检测,taichi的kernel不允许关键字参数,可变参数等)

def transform_as_kernel():
    # Treat return type
    if node.returns is not None:
        kernel_arguments.decl_ret(ctx.func.return_type,
                                  ctx.is_real_function)
    impl.get_runtime().prog.finalize_rets()
    for i, arg in enumerate(args.args):
        if not isinstance(ctx.func.arguments[i].annotation,
                          primitive_types.RefType):
            ctx.kernel_args.append(arg.arg)
        if isinstance(ctx.func.arguments[i].annotation,
                      annotations.template):
            ctx.create_variable(arg.arg, ctx.global_vars[arg.arg])
		......
        else:
            ctx.create_variable(
                arg.arg,
                kernel_arguments.decl_scalar_arg(
                    ctx.func.arguments[i].annotation))
    # remove original args
    node.args.args = []

之后是一个闭包函数transform_as_kernel,这里首先处理返回值,我们的例子中没有返回值,暂时不考虑。之后是处理了函数行参,这里会对taichid几个特殊类型做判断,当然我们的例子中并不包含这些特殊类型,将参数的名称加入到ctx的kernel_args列表中,最后创建variable context:

def decl_scalar_arg(dtype):
    is_ref = False
    if isinstance(dtype, RefType):
        is_ref = True
        dtype = dtype.tp
    dtype = cook_dtype(dtype)
    arg_id = impl.get_runtime().prog.decl_scalar_arg(dtype)
    return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))

上面的是decl_scalar_arg具体是创建了一个Taichi Cpp Expr对象。

with ctx.variable_scope_guard():
    build_stmts(ctx, node.body)

最后就是除了函数体,具体就不看了,和处理函数定义类似的方式,就是前端遍历中AST树的创建罢了。

2.6 Kernel调用

回到我们之前存储的闭包,这一次我们可以好好的分析一下这段代码:

def func__(*args):
    assert len(args) == len(
        self.arguments
    ), f'{len(self.arguments)} arguments needed but {len(args)} provided'

    tmps = []
    callbacks = []

    actual_argument_slot = 0
    launch_ctx = t_kernel.make_launch_context()
    for i, v in enumerate(args):
        needed = self.arguments[i].annotation
        if isinstance(needed, template):
            continue
            provided = type(v)
            # Note: do not use sth like "needed == f32". That would be slow.
            if id(needed) in primitive_types.real_type_ids:
                if not isinstance(v, (float, int)):
                    raise TaichiRuntimeTypeError.get(
                        i, needed.to_string(), provided)
                    launch_ctx.set_arg_float(actual_argument_slot, float(v))
                 elif id(needed) in primitive_types.integer_type_ids:
                     if not isinstance(v, int):
                         raise TaichiRuntimeTypeError.get(
                             i, needed.to_string(), provided)
                     if is_signed(cook_dtype(needed)):
                                launch_ctx.set_arg_int(actual_argument_slot, int(v))
        actual_argument_slot += 1
     	......
     try:
        t_kernel(launch_ctx)
     except Exception as e:
        e = handle_exception_from_cpp(e)
        raise e from None

这一段代码非常长,我们截取来看,首先检测了形参和实参是否一致(话说,之前ensure的时候不是检测过一次吗?)初始化了两个局部变量tmps、callbacks之后创造了一个LauchContextBuilder对象,这一步依靠pybind绑定的c++实现。之后就是对实参的处理,处理好的参数全部加载到LaouchContext中,最后调用之前创建好的t_kernel,至此python前置部分运行完毕