qp.capture.register_custom_staging_rule¶
- register_custom_staging_rule(primitive, get_jaxpr_from_params, setup_env=<function _default_setup_env>)[source]¶
Register a custom staging rule for a higher order primitive that can handle dynamic shapes.
- Parameters:
primitive (jax.extend.core.Primitive) – a jax primitive we want to register a custom staging rule for
get_jaxpr_from_params (Callable[[dict], "jax.extend.core.Jaxpr"]) – A function that takes in the equation’s
paramsand returns a target jaxprsetup_env (Callable) – A function that setups a dictionary for mapping from the inner jaxpr variables to the tracers that are inputs to the equation. The inputs are the tracers that are inputs to the equation and the params for the equation. By default, returns an empty dictionary.
For example, the
cond_primwill request its custom staging rule like:register_custom_staging_rule(cond_prim, lambda params: params['jaxpr_branches'][0])
condcannot supportsetup_env, because different branches may have different dynamic shapes.Compare this to
while_loop_prim:def setup_env(tracers, params): tracers = tracers[slice(*params['args_slice'])] + tracers[slice(*params['consts_slice'])] vars = params['jaxpr_body_fn'].invars + params['jaxpr_body_fn'].constvars return dict(zip(vars, tracers), strict=True) register_custom_staging_rule( while_loop_prim, get_jaxpr_from_params=lambda params: params["jaxpr_body_fn"], matching_eqn_inputs=matching_eqn_inputs, )
for_loop_primgets more complicated, as we have to slice out thestart,stop,stepfrom thetracers, and the loop index for thejaxpr_invars.