Taichi编译链路之Python部分的生命周期
Taichi 源码分析(一)
1. Taichi项目结构及源码编译
今天我没来探索一下taichi 程序的整体生命周期,首先我们先看一下taichi的项目结构
文件层级很多,就没有树来进行展示了,直接看下面的截图
taichi 是一个eDSL语言,嵌入到python语言当中,利用了Python解释器作为了部分运行时,同时通过pybind链接到了底层taichi c++库, taichi语言的后端是利用c++实现高性能计算的。
如上图所示,项目本身是一个CMake项目,在根目录中项目整体CMakeList.txt, README文档和版本文档等其他配置文件,setup.py用于安装taichi的python包。taichi项目根目录有13个文件夹,其中核心程序文件在taichi文件夹和python文件夹中,taichi文件夹内为taichi的c++源码,Python中是上层的python库。taichi文件内又分为了用于即时编译jit文件夹, 用于代码生成的codegen文件夹,用于pybind绑定的python文件夹等。根目录的python文件夹中taichi是Python项目的最基础包,所有库都从这个包中导入,大部分的语言项目的功能都在taichi.lang包中。
现在我们来编译taichi程序,需要注意的是taichi项目中的第三包并没有加入到项目中,大家在克隆时需要试下下面的指令
git clone --recursive https://github.com/taichi-dev/taichi.git
不过Eigen库所引用的仓库已经无法访问了,实测大家可以使用这个仓库的eigen代码(https://github.com/PX4/eigen)。taichi需要用到pybind和llvm等第三方包,大家需要自行安装这些,pybind全部为头文件实现,安装十分方便,从git仓库中下载完成后直接cmake安装即可。llvm的话太极官方为大家提供了他们自定义版LLVM的预编译文件,大家可以在这里下载对应版本https://github.com/taichi-dev/taichi_assets/releases。
这些前置工作完成后直接,cmake + make build即可,将生成的pyd文件放到对应的python文件夹中,执行
python setup.py
到此就源码编译安装完成了。
2. Python部分生命周期
首先需要说明一下,此处虽说是讲Python部分的生命周期,但在初始化时taichi仍然调用了一部分pybind封装的c++代码,其中大部分都是用于初始化配置,我们也会顺带讲解一下,我仍认为这一部分属于Python部分的生命周期,等真正进入JIT编译时,才算是进入到C++部分的生命周期。
2.1 实例代码
import taichi as ti
ti.init(arch=ti.cpu)
x = ti.field(ti.i32, shape=(3, 4))
y = ti.ndarray(int, shape=16)
z = int(100)
@ti.kernel
def play(arg0: int, arg1: int):
tmp = 0
if arg1:
tmp = tmp + 4
for i in range(10):
tmp += 1
print(tmp)
assert z == 100
print(z)
for i in x:
x[i] = 0
for i in y:
ti.atomic_add(y[i], 5)
play(10, True)
下面我们以上面那段代码为例,taichi程序会分为两个部分,一部分是属于Python scope中的,而另一部分被装饰器所修饰的代码是属于taichi scope中的,可以看到taichi scope中的也会调用到Python全局中的field变量,这一部分我们暂时不考虑,taichi中的全局field变量会通过特殊处理链接到taichi程序中去,我们分开两部分来看,我们首先先看不含这些全局变量的taichi程序的生命周期是如何的,即我们这次调试的代码为下面的情况:
import taichi as ti
ti.init(arch=ti.cpu)
# x = ti.field(ti.i32, shape=(3, 4))
# y = ti.ndarray(int, shape=16)
# z = int(100)
@ti.kernel
def play(arg0: int, arg1: int):
tmp = 0
if arg1:
tmp = tmp + 4
for i in range(10):
tmp += 1
print(tmp)
# assert z == 100
# print(z)
#
# for i in x:
# x[i] = 0
#
# for i in y:
# ti.atomic_add(y[i], 5)
play(10, True)
2.2 Tachi Init
可以看到要想加载taichi程序,我们会首先需要初始化taichi包,我这里使用的cpu作为后端进行运行。
def init(arch=None,
default_fp=None,
default_ip=None,
_test_mode=False,
enable_fallback=True,
require_version=None,
**kwargs):
# Check version for users every 7 days if not disabled by users.
_version_check.start_version_check_thread()
在运行init函数之前,taichi包中有很多全局变量,下面一个是会伴随我们整个生命周期的一个Python全局变量PyTaichi,他在taichi.lang包下的impl模块中被定义
class PyTaichi:
def __init__(self, kernels=None):
self.materialized = False
self.prog = None
self.compiled_functions = {}
self.src_info_stack = []
self.inside_kernel = False
self.current_kernel = None
self.global_vars = []
self.grad_vars = []
self.dual_vars = []
self.matrix_fields = []
self.default_fp = f32
self.default_ip = i32
self.default_up = u32
self.target_tape = None
self.fwd_mode_manager = None
self.grad_replaced = False
self.kernels = kernels or []
self._signal_handler_registry = None
PyTaichi的初始化只是单纯的对成员函数赋初值。
而我们之前调用的init函数是一大串初始化配置的代码,在开始时会进行版本的检查
def start_version_check_thread():
skip = os.environ.get("TI_SKIP_VERSION_CHECK")
if skip != 'ON':
# We don't join this thread because we do not wish to block users.
check_version_thread = threading.Thread(target=try_check_version,
daemon=True)
check_version_thread.start()
此处开启多线程进行版本检查,实际运行函数为try_check_version
def try_check_version():
try:
os.makedirs(_ti_core.get_repo_dir(), exist_ok=True)
version_info_path = os.path.join(_ti_core.get_repo_dir(),
'version_info')
cur_date = datetime.date.today()
if os.path.exists(version_info_path):
with open(version_info_path, 'r') as f:
version_info_file = f.readlines()
last_time = version_info_file[0].rstrip()
cur_uuid = version_info_file[2].rstrip()
if cur_date.strftime('%Y-%m-%d') > last_time:
response = check_version(cur_uuid)
write_version_info(response, cur_uuid, version_info_path,
cur_date)
else:
cur_uuid = str(uuid.uuid4())
write_version_info({'status': 0}, cur_uuid, version_info_path,
cur_date)
response = check_version(cur_uuid)
write_version_info(response, cur_uuid, version_info_path, cur_date)
# Wildcard exception to catch potential file writing errors.
except:
pass
在这个函数中,第一次运行taichi程序时会创建了一个taichi_cache文件夹用于放置临时缓存,windows默认在C盘根目录。同样在第一次运行时会在taichi_cache文件夹中创建一个version_info文件,这个文件内容是一传版本字符串,由时期+uuid组成,如果当前日期大于最后一次更新日期,需要对版本进行更新,日期判断只精确到天。
之后回到init函数中。
current_dir = os.getcwd()
if require_version is not None:
check_require_version(require_version)
if "packed" in kwargs:
if kwargs["packed"] is True:
warnings.warn(
"Currently packed=True is the default setting and the switch will be removed in v1.4.0.",
DeprecationWarning)
else:
warnings.warn(
"The automatic padding mode (packed=False) will no longer exist in v1.4.0. The switch will "
"also be removed then. Make sure your code doesn't rely on it.",
DeprecationWarning)
if "default_up" in kwargs:
raise KeyError(
"'default_up' is always the unsigned type of 'default_ip'. Please set 'default_ip' instead."
)
default_fp = deepcopy(default_fp)
default_ip = deepcopy(default_ip)
kwargs = deepcopy(kwargs)
之后获取到了当前运行的py文件所在的文件夹,对参数中进行了检测,如果存在不合法会进行警告和报错,对合法参数进行深拷贝复制。
def reset():
global pytaichi
old_kernels = pytaichi.kernels
pytaichi.clear()
pytaichi = PyTaichi(old_kernels)
for k in old_kernels:
k.reset()
之后进行reset操作,reset的具体实现在Impl中,清空当前全局的pytaichi变量中的kernel集合,并对每一个kernel进行复位。
cfg = impl.default_cfg()
cfg.offline_cache = True # Enable offline cache in frontend instead of C++ side
spec_cfg = _SpecialConfig()
env_comp = _EnvironmentConfigurator(kwargs, cfg)
env_spec = _EnvironmentConfigurator(kwargs, spec_cfg)
# configure default_fp/ip:
# TODO: move these stuff to _SpecialConfig too:
env_default_fp = os.environ.get("TI_DEFAULT_FP")
......
if default_fp is not None:
impl.get_runtime().set_default_fp(default_fp)
if default_ip is not None:
impl.get_runtime().set_default_ip(default_ip)
接下来就是进行配置设置,上面代码全为配置设置,获取的Config类是由C++ Pybind绑定的,如下:
struct CompileConfig {
Arch arch;
bool debug;
bool cfg_optimization;
bool check_out_of_bound;
bool validate_autodiff;
int simd_width;
int opt_level;
int external_optimization_level;
int max_vector_width;
bool packed;
bool print_preprocessed_ir;
......
CompileConfig();
};
默认的配置都是对这些成员变量进行初始化,接下来的是一些特殊环境配置,我初始化时并没有进行设置,这些部分都是默认值,大部分的控制流都不会进入,之后是加入日志等级和gdb等信息。
# compiler configurations (ti.cfg):
for key in dir(cfg):
if key in ['arch', 'default_fp', 'default_ip']:
continue
_cast = type(getattr(cfg, key))
if _cast is bool:
_cast = None
env_comp.add(key, _cast)
unexpected_keys = kwargs.keys()
逐级遍历Config中的属性,将属性名作为键值,类型作为value存到env.comp中。
unexpected_keys = kwargs.keys()
if len(unexpected_keys):
raise KeyError(
f'Unrecognized keyword argument(s) for ti.init: {", ".join(unexpected_keys)}'
)
之后是查看是否有不需要的参数,有则报错
get_default_kernel_profiler().set_kernel_profiler_mode(cfg.kernel_profiler)
# create a new program:
impl.get_runtime().create_program()
_logging.trace('Materializing runtime...')
impl.get_runtime().prog.materialize_runtime()
impl._root_fb = _snode.FieldsBuilder()
if cfg.debug:
impl.get_runtime()._register_signal_handlers()
os.chdir(current_dir)
return None
接下来就是初始化一些核心类用于创建和编译Kernel。这里首先会创建一个Program,这个Program可以视为一个整体的Taichi程序,会存放全部的filed等全局变量和所有的kernel。Program是一个c++类,通过Pybind绑定成Python对象,Program本身并不提供任何实现,具体实现有其子类提供,默认使用的是LLVM版本的Program,这段代码可以在taichi项目中的c++源码文件夹taichi下面的program文件夹中的program.cpp中找到:
Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) {
TI_TRACE("Program initializing...");
.......
main_thread_id_ = std::this_thread::get_id();
......
profiler = make_profiler(config.arch, config.kernel_profiler);
if (arch_uses_llvm(config.arch)) {
#ifdef TI_WITH_LLVM
if (config.arch != Arch::dx12) {
program_impl_ = std::make_unique<LlvmProgramImpl>(config, profiler.get());
} else {
// NOTE: use Dx12ProgramImpl to avoid using LlvmRuntimeExecutor for dx12.
#ifdef TI_WITH_DX12
TI_ASSERT(directx12::is_dx12_api_available());
program_impl_ = std::make_unique<Dx12ProgramImpl>(config);
#else
TI_ERROR("This taichi is not compiled with DX12");
#endif
}
#else
TI_ERROR("This taichi is not compiled with LLVM");
#endif
} else if (config.arch == Arch::metal) {
#ifdef TI_WITH_METAL
TI_ASSERT(metal::is_metal_api_available());
program_impl_ = std::make_unique<MetalProgramImpl>(config);
#else
TI_ERROR("This taichi is not compiled with Metal")
#endif
我们下面的也都以默认的LLVM版本的Program实现进行讲解,后续有时间我会讲解其他backend实现。
让我们回到Python的init函数中去,在创建完Program后。调用了materialize_runtime
方法来创建运行环境,我们具体来看一下LLVMProgramImpl的实现:
void materialize_runtime(MemoryPool *memory_pool,
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) override {
runtime_exec_->materialize_runtime(memory_pool, profiler,
result_buffer_ptr);
}
额,又包了一层,这里调用的是其执行器的初始化方法,继续往下走可以看到这里是对jit进行一些配置。到此为止,超长的Init函数就看完了。
2.3 kernel装饰器
之后我们看一下@ti.kernel
装饰器
def _kernel_impl(_func, level_of_class_stackframe, verbose=False):
is_classkernel = _inside_class(level_of_class_stackframe + 1)
if verbose:
print(f'kernel={_func.__name__} is_classkernel={is_classkernel}')
primal = Kernel(_func,
autodiff_mode=AutodiffMode.NONE,
_classkernel=is_classkernel)
adjoint = Kernel(_func,
autodiff_mode=AutodiffMode.REVERSE,
_classkernel=is_classkernel)
primal.grad = adjoint
if is_classkernel:
@functools.wraps(_func)
def wrapped(*args, **kwargs):
clsobj = type(args[0])
assert not hasattr(clsobj, '_data_oriented')
raise TaichiSyntaxError(
f'Please decorate class {clsobj.__name__} with @ti.data_oriented'
)
else:
@functools.wraps(_func)
def wrapped(*args, **kwargs):
try:
return primal(*args, **kwargs)
except (TaichiCompilationError, TaichiRuntimeError) as e:
raise type(e)('\n' + str(e)) from None
wrapped.grad = adjoint
wrapped._is_wrapped_kernel = True
wrapped._is_classkernel = is_classkernel
wrapped._primal = primal
wrapped._adjoint = adjoint
return wrapped
def kernel(fn):
return _kernel_impl(fn, level_of_class_stackframe=3)
kernel
是一个装饰器函数,其内部调用了_kernel_impl
作为具体实现。这里首先判断了当前的kernel是否在一个类中,是的,taichi允许你在一个类中定义kernel,这给了开发者极大的便利,同时taichi用了非常巧妙的方式来处理在类中的Kernel,这里我们暂时不讨论,我先关注于在类外面的kernel函数。而_inside_class
方法具体是如何判断是否在类中的,则用了我们的老朋友,inspect
模块来获取解释器堆栈上的语句信息,通过正则匹配检测class关键字实现的。
接下来实例化了两个Kernel类,这个Kernel类是taichi程序的编译核心,也是执行入口,具体编译过程在Kernel的__call__
魔法函数中。
之后是一个内部闭包wrapped,functools.wrap装饰器的作用是不改变调用对象,单纯改变所装饰对象的一些属性,也就是这里最后返回的是wrapped函数,但是属性则依然是传入的play函数,所以我们打印play函数的__name__
和__doc__
属性就会发现没有发生改变,但是具体调用逻辑则是调用Kernel的__call__
函数。
那我们接下来就很好看一下Kernel这个类:
class Kernel:
counter = 0
def __init__(self, _func, autodiff_mode, _classkernel=False):
self.func = _func
self.kernel_counter = Kernel.counter
Kernel.counter += 1
assert autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION,
AutodiffMode.FORWARD, AutodiffMode.REVERSE)
self.autodiff_mode = autodiff_mode
self.grad = None
self.arguments = []
self.return_type = None
self.classkernel = _classkernel
self.extract_arguments()
self.template_slot_locations = []
for i, arg in enumerate(self.arguments):
if isinstance(arg.annotation, template):
self.template_slot_locations.append(i)
self.mapper = TaichiCallableTemplateMapper(
self.arguments, self.template_slot_locations)
impl.get_runtime().kernels.append(self)
self.reset()
self.kernel_cpp = None
# TODO[#5114]: get rid of compiled_functions and use compiled_kernels instead.
# Main motivation is that compiled_kernels can be potentially serialized in the AOT scenario.
self.compiled_kernels = {}
self.has_print = False
在初始化Kernel时,把所装饰的函数作为了当前func属性,获取当前kernel的计数器,之后静态成员变量counter+1。判断是否开启auto_diff,是否有自动微分等。之后通过extract_argument来获取函数形参参数:
def extract_arguments(self):
sig = inspect.signature(self.func)
if sig.return_annotation not in (inspect._empty, None):
self.return_type = sig.return_annotation
params = sig.parameters
arg_names = params.keys()
for i, arg_name in enumerate(arg_names):
param = params[arg_name]
if param.kind == inspect.Parameter.VAR_KEYWORD:
raise TaichiSyntaxError(
'Taichi kernels do not support variable keyword parameters (i.e., **kwargs)'
)
......
annotation = param.annotation
if param.annotation is inspect.Parameter.empty:
if i == 0 and self.classkernel: # The |self| parameter
annotation = template()
else:
raise TaichiSyntaxError(
'Taichi kernels parameters must be type annotated')
......
self.arguments.append(
KernelArgument(annotation, param.name, param.default))
获取的方式当然又是我们的老熟人inspect模块,通过signature获取到func上的形参和返回值类型属性,之后是一大段参数类型的异常判断,在这里我们看到关键字参数,带初始值的参数都是不支持的。之后将这些参数加入到Kernel类的成员变量arguments中,Taichi在这里对参数又进行了一层包装,不过KernelArgument就是一个非常简单的data类了:
class KernelArgument:
def __init__(self, _annotation, _name, _default=inspect.Parameter.empty):
self.annotation = _annotation
self.name = _name
self.default = _default
让我们回到Kernel的构造函数,获取到函数参数后对这些形参进行了遍历,判断是否是Template类型参数,这里的Template是一个用Python定义的一个Data类,用于taichi的元编程使用,taichi的模板编程实现的方式我们也暂且不表。再往下是初始化了一个TaichiCallbleTemplateMapper
这个map存储了参数了信息为后续调用时使用。随后将现在的kernel加入到我们之前提到的pytaichi全局变量中,那是一个Program类,存储了全部的kernel。再之后就是调用了reset函数:
def reset(self):
self.runtime = impl.get_runtime()
reset函数非常简单,把初始化当前runtime成员变量的方法提取出来以供复用,这里runtime所赋值的依然是pytaichi。到此Kernel的初始化工作基本完成了。
Taichi所采用的对Python实施静态编译的策略与我们之前探讨的pytorch和qcor有所不同,他并不是在装饰器加载时就完成了jit编译,而是在第一次调用时才进行编译,下面我们就探究一下Kernel的__call__
函数:
2.4 kernel的运行入口
@_shell_pop_print
def __call__(self, *args, **kwargs):
args = _process_args(self, args, kwargs)
if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
self.runtime.fwd_mode_manager.insert(self)
if self.autodiff_mode in (
AutodiffMode.NONE, AutodiffMode.VALIDATION
) and self.runtime.target_tape and not self.runtime.grad_replaced:
self.runtime.target_tape.insert(self, args)
if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg(
).opt_level == 0:
_logging.warn(
"""opt_level = 1 is enforced to enable gradient computation."""
)
impl.current_cfg().opt_level = 1
key = self.ensure_compiled(*args)
return self.runtime.compiled_functions[key](*args)
这段程序的代码量并不多,首先_shell_pop_print
装饰器在绝大多数情况下冰不起作用,不会改变调用主体,只有在开启pybuf时会有一个额外处理,但依然会调用__call__
函数,并将其返回值作为新函数返回值,只是会进行一步额外信息打印操作。
之后对实参进行处理,这里的_process_arg
主要做了参数校验的工作,判断了实参和形参的个数是否统一,并将*args
解构赋值到一个list中返回。
之后判断了是否启用了自动微分,在本例中未使用自动微分,这一步也不用考虑。之后就进入到了我们的编译环节,ensure_compiled
函数代码如下:
def ensure_compiled(self, *args):
instance_id, arg_features = self.mapper.lookup(args)
key = (self.func, instance_id, self.autodiff_mode)
self.materialize(key=key, args=args, arg_features=arg_features)
return key
这里调用TaichiCallableTemplateMapper
的loopup
功能获得的instance_id和arg_features分别为0和两个'#'
字符串组成的list。
def materialize(self, key=None, args=None, arg_features=None):
if key is None:
key = (self.func, 0, self.autodiff_mode)
self.runtime.materialize()
if key in self.runtime.compiled_functions:
return
grad_suffix = ""
......
kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}{grad_suffix}"
_logging.trace(f"Compiling kernel {kernel_name}...")
tree, ctx = _get_tree_and_ctx(
self,
args=args,
excluded_parameters=self.template_slot_locations,
arg_features=arg_features)
if self.autodiff_mode != AutodiffMode.NONE:
KernelSimplicityASTChecker(self.func).visit(tree)
def taichi_ast_generator(kernel_cxx):
......
taichi_kernel = impl.get_runtime().prog.create_kernel(
taichi_ast_generator, kernel_name, self.autodiff_mode)
self.kernel_cpp = taichi_kernel
assert key not in self.runtime.compiled_functions
self.runtime.compiled_functions[key] = self.get_function_body(
taichi_kernel)
self.compiled_kernels[key] = taichi_kernel
上面是Kernel中的materialize
函数,可以看到这里又调用了一次ProgramImpl的materialize,这个在之前Init的时候介绍过,主要是Jit相关的设置。之后进行判断是否进行过了编译,jit编译后的二进制代码会存放在内存中的,所以只需要第一次执行时进行编译即可,所以这里如果发现该kernel已经经过编译了,则这一步可以跳过。之后我们可以看到JIT函数的规则,Python本身的函数名+当前kernel序号构成,后面是一些附加信息,这一串规则可以巧妙的避免命名重复导致错误,同时比使用Uuid等方式随机产生的一大段随机字符串在长度方面小很多。我们此处的名称为play_c76_0
这个也是我们之后实际调用JIT的函数名称。接下来这一段是用来获取Python AST树的:
def _get_tree_and_ctx(self,
excluded_parameters=(),
is_kernel=True,
arg_features=None,
args=None,
ast_builder=None,
is_real_function=False):
file = getsourcefile(self.func)
src, start_lineno = getsourcelines(self.func)
src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
tree = ast.parse(textwrap.dedent("\n".join(src)))
func_body = tree.body[0]
func_body.decorator_list = []
global_vars = _get_global_vars(self.func)
for i, arg in enumerate(func_body.args.args):
anno = arg.annotation
if isinstance(anno, ast.Name):
global_vars[anno.id] = self.arguments[i].annotation
if isinstance(func_body.returns, ast.Name):
global_vars[func_body.returns.id] = self.return_type
if is_kernel or is_real_function:
# inject template parameters into globals
for i in self.template_slot_locations:
template_var_name = self.arguments[i].name
global_vars[template_var_name] = args[i]
return tree, ASTTransformerContext(excluded_parameters=excluded_parameters,
is_kernel=is_kernel,
func=self,
arg_features=arg_features,
global_vars=global_vars,
argument_data=args,
src=src,
start_lineno=start_lineno,
file=file,
ast_builder=ast_builder,
is_real_function=is_real_function)
又是我们的老朋友inspect
模块,其实整体逻辑还是比较简单的,通过inspect模块获取到了所执行的文件路径和所执行的函数源码。之后通过_get_global_vars
获取了全局变量,此处对闭包做了特殊处理。随后将参数和返回值类型加入到全局变量集合中。之后返回了ast树和taichi自己定义的用于做转换的Python类ASTTransformerContext
(代码就不PO了,其实基本上算是一个Data类了,在初始化时没有做特殊操作)。
之后我们回到materialize
函数中,如果此时开启了auto_diff,则会对AST树进行解析,做一些语义检测,这里采取的是访问器模式,之后会详细开一期介绍taichi 在启用不同模式auto_diff情况下的生命周期。
随后调用taichi的c++库来创建C++ Kernel对象并加入到当前pytaichi(上文提到的全局PyTaichi对象)的program中去。我们来看一下create_kernel的c++实现吧,这里其实在c++ pybind文件中,是一个匿名函数
.def(
"create_kernel",
[](Program *program, const std::function<void(Kernel *)> &body,
const std::string &name, AutodiffMode autodiff_mode) -> Kernel * {
py::gil_scoped_release release;
return &program->kernel(body, name, autodiff_mode);
},
Kernel &Program::kernel(const std::function<void(Kernel *)> &body,
const std::string &name = "",
AutodiffMode autodiff_mode = AutodiffMode::kNone) {
// Expr::set_allow_store(true);
auto func = std::make_unique<Kernel>(*this, body, name, autodiff_mode);
// Expr::set_allow_store(false);
kernels.emplace_back(std::move(func));
return *kernels.back();
}
注意这里pybind中的py::gil_scoped_release release;
这句话,这个是为了多线程情况下对GIL锁的处理,Python C API 规定全局解释器锁 (GIL) 必须始终由当前线程持有才能安全访问 Python 对象。因此,当 Python 通过 pybind11 调用 C++ 时,必须持有 GIL,而 pybind11 永远不会隐式释放 GIL,所以我们就需要手动释放和获取锁。
pybind11 需要确保它正在调用 Python 代码时保持 GIL。如上,这里有趣的一点是我们这里传入的其实是一个Python函数给std::function
类型参数作为一个回调,当c++调用该回调函数时需要确保持有GIL,这段会很复杂,我会在后续c++部分的生命周期的时候话一段时间来详解这一块。
回到Python部分,接下来要处理的就是kernel装饰器所装饰的play函数本身了,get_function_body
具体内容简单来说就是非常爽快的返回了一个闭包。这个闭包我们之后会用到,这里就直接将他加入到了pytaichi的compiled_functions
字典中。之后将C++ Kernekl对象加入到当前Python Kernel类中的compiled_kernels
字典中。
在初始化Kernel阶段,会执行传入的回调函数,我们来具体看一下那个回调函数
def taichi_ast_generator(kernel_cxx):
if self.runtime.inside_kernel:
raise TaichiSyntaxError(......)
self.runtime.inside_kernel = True
self.runtime.current_kernel = self
try:
ctx.ast_builder = kernel_cxx.ast_builder()
transform_tree(tree, ctx)
if not ctx.is_real_function:
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
raise TaichiSyntaxError(
"Kernel has a return type but does not have a return statement"
)
finally:
self.runtime.inside_kernel = False
self.runtime.current_kernel = None
其实这个回调函数传入的kernel_cxx就是刚刚创建的c++的kernel类,获取的ast_build就FrotentContext默认的ASTBuilder, FrontentContext在创建Kernel示例时会创建。
之后运行transform_tree函数,这一步其实这个回调最关键的时刻:
这一步稍显复杂,这里通过递归来完成了build
至此ensure_compiled
功能就走完了,回到最初的__call__
函数,最后操作即是从pytaichi的compiled_functions
取出之前加入的闭包,并将闭包的结果当做kernel运行的结果返回。到此具体的编译过程正式展开(上面存在一步编译了,利用ast模块对play源码进行了parser,获取到了AST树,这里其实有一定优化空间,Python在对py文件整体编译时保存了AST树,可惜的是无法从Python语言层面去获取到这个内部对象,不过可以拓展cpython来实现这一功能,但是稍显鸡肋,小量代码的parser速度是非常快的,这一部分提升可以忽略不计了)。
2.5 Kernel编译
回到transform_tree这个方法,其实这里在初始化Kernel的时候就已经调用了,transform方法很简单就是实例化了一个ASTTransformer对象,并调用了其__call__
方法
这里call重载在父类Builder中:
def __call__(self, ctx, node):
method = getattr(self, 'build_' + node.__class__.__name__, None)
try:
if method is None:
error_msg = f'Unsupported node "{node.__class__.__name__}"'
raise TaichiSyntaxError(error_msg)
info = ctx.get_pos_info(node) if isinstance(
node, (ast.stmt, ast.expr)) else ""
with impl.get_runtime().src_info_guard(info):
return method(ctx, node)
except Exception as e:
if ctx.raised or not isinstance(node, (ast.stmt, ast.expr)):
raise e.with_traceback(None)
ctx.raised = True
e = handle_exception_from_cpp(e)
if not isinstance(e, TaichiCompilationError):
msg = ctx.get_pos_info(node) + traceback.format_exc()
raise TaichiCompilationError(msg) from None
msg = ctx.get_pos_info(node) + str(e)
raise type(e)(msg) from None
代码量不大,但却很精华,使用了递归的方式去遍历解析tree。
这里需要稍微解释一下with关键字,with结构是python中非常好玩的结构,大致可以理解成一种包围,这里with的是一个SrcInfoGuard
类型,这里会在执行with内部语句前调用所with对象的__enter__
方法,退出with结构时调用__exit__
方法,我们这里看一下SrcInfoGuard结构:
class SrcInfoGuard:
def __init__(self, info_stack, info):
self.info_stack = info_stack
self.info = info
def __enter__(self):
self.info_stack.append(self.info)
def __exit__(self, exc_type, exc_val, exc_tb):
self.info_stack.pop()
所以这里会在执行method方法前先将info加入到info_stack中,往往也是这种需要栈结构的情况下可以使用with语句。这也是Python比较方便的地方。
我们继续走,这里的method是build_Module,来到build_Module函数:
build_stmt = ASTTransformer()
@staticmethod
def build_Module(ctx, node):
with ctx.variable_scope_guard():
# Do NOT use |build_stmts| which inserts 'del' statements to the
# end and deletes parameters passed into the module
for stmt in node.body:
build_stmt(ctx, stmt)
return None
这里的node.body获取的是一个FunctionDef list,当然这里只有一个元素就是我们上面的play函数,这个FunctionDef类其实CPython中的C 结构体,我们在之前的CPython源码中有提及过,这里再次贴出他的结构:
struct {
identifier name;
arguments_ty args;
asdl_seq *body;
asdl_seq *decorator_list;
expr_ty returns;
string type_comment;
} FunctionDef;
之后再次调用了ASTTransformer的__call__
函数,继续经过with语句,顺利的函数info信息加入到info_stack列表中,这次获取的method是 build_FunctionDef
, 进入FunctionDef:
@staticmethod
def build_FunctionDef(ctx, node):
if ctx.visited_funcdef:
raise TaichiSyntaxError(
f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
)
ctx.visited_funcdef = True
args = node.args
assert args.vararg is None
assert args.kwonlyargs == []
assert args.kw_defaults == []
assert args.kwarg is None
开头是一段检验,包括该节点是否访问,参数类型是否合法(这些在之前也进行了检测,taichi的kernel不允许关键字参数,可变参数等)
def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type,
ctx.is_real_function)
impl.get_runtime().prog.finalize_rets()
for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation,
primitive_types.RefType):
ctx.kernel_args.append(arg.arg)
if isinstance(ctx.func.arguments[i].annotation,
annotations.template):
ctx.create_variable(arg.arg, ctx.global_vars[arg.arg])
......
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_scalar_arg(
ctx.func.arguments[i].annotation))
# remove original args
node.args.args = []
之后是一个闭包函数transform_as_kernel
,这里首先处理返回值,我们的例子中没有返回值,暂时不考虑。之后是处理了函数行参,这里会对taichid几个特殊类型做判断,当然我们的例子中并不包含这些特殊类型,将参数的名称加入到ctx的kernel_args列表中,最后创建variable context:
def decl_scalar_arg(dtype):
is_ref = False
if isinstance(dtype, RefType):
is_ref = True
dtype = dtype.tp
dtype = cook_dtype(dtype)
arg_id = impl.get_runtime().prog.decl_scalar_arg(dtype)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))
上面的是decl_scalar_arg
具体是创建了一个Taichi Cpp Expr对象。
with ctx.variable_scope_guard():
build_stmts(ctx, node.body)
最后就是除了函数体,具体就不看了,和处理函数定义类似的方式,就是前端遍历中AST树的创建罢了。
2.6 Kernel调用
回到我们之前存储的闭包,这一次我们可以好好的分析一下这段代码:
def func__(*args):
assert len(args) == len(
self.arguments
), f'{len(self.arguments)} arguments needed but {len(args)} provided'
tmps = []
callbacks = []
actual_argument_slot = 0
launch_ctx = t_kernel.make_launch_context()
for i, v in enumerate(args):
needed = self.arguments[i].annotation
if isinstance(needed, template):
continue
provided = type(v)
# Note: do not use sth like "needed == f32". That would be slow.
if id(needed) in primitive_types.real_type_ids:
if not isinstance(v, (float, int)):
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), provided)
launch_ctx.set_arg_float(actual_argument_slot, float(v))
elif id(needed) in primitive_types.integer_type_ids:
if not isinstance(v, int):
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), provided)
if is_signed(cook_dtype(needed)):
launch_ctx.set_arg_int(actual_argument_slot, int(v))
actual_argument_slot += 1
......
try:
t_kernel(launch_ctx)
except Exception as e:
e = handle_exception_from_cpp(e)
raise e from None
这一段代码非常长,我们截取来看,首先检测了形参和实参是否一致(话说,之前ensure的时候不是检测过一次吗?)初始化了两个局部变量tmps、callbacks
之后创造了一个LauchContextBuilder对象,这一步依靠pybind绑定的c++实现。之后就是对实参的处理,处理好的参数全部加载到LaouchContext中,最后调用之前创建好的t_kernel,至此python前置部分运行完毕