unseal.hooks package
This package handles the nitty-gritty of hooking to a model.
hooks module
hooks.common_hooks module
- unseal.hooks.common_hooks.additive_output_noise(indices: str, mean: Optional[float] = 0, std: Optional[float] = 0.1) Callable
- unseal.hooks.common_hooks.create_attention_hook(layer: int, key: str, output_idx: Optional[int] = None, attn_name: Optional[str] = 'attn', layer_key_prefix: Optional[str] = None, heads: Optional[Union[int, Iterable[int], str]] = None) unseal.hooks.commons.Hook
Creates a hook which saves the attention patterns of a given layer.
- Parameters
layer (int) – The layer to hook.
key (str) – The key to use for saving the attention patterns.
output_idx (Optional[int], optional) – If the module output is a tuple, index it with this. GPT like models need this to be equal to 2, defaults to None
attn_name (Optional[str], optional) – The name of the attention module in the transformer, defaults to ‘attn’
layer_key_prefix (Optional[str], optional) – The prefix in the model structure before the layer idx, e.g. ‘transformer->h’, defaults to None
heads (Optional[Union[int, Iterable[int], str]], optional) – Which heads to save the attention pattern for. Can be int, tuple of ints or string like ‘1:3’, defaults to None
- Returns
Hook which saves the attention patterns
- Return type
Hook
- unseal.hooks.common_hooks.create_logit_hook(layer: int, model: unseal.hooks.commons.HookedModel, unembedding_key: str, layer_key_prefix: Optional[str] = None, target: Optional[Union[int, List[int]]] = None, position: Optional[Union[int, List[int]]] = None, key: Optional[str] = None, split_heads: Optional[bool] = False, num_heads: Optional[int] = None) unseal.hooks.commons.Hook
Create a hook that saves the logits of a layer’s output. Outputs are saved to save_ctx[key][‘logits’].
- Parameters
layer (int) – The number of the layer
model (HookedModel) – The model.
unembedding_key (str) – The key/name of the embedding matrix, e.g. ‘lm_head’ for causal LM models
layer_key_prefix (str) – The prefix of the key of the layer, e.g. ‘transformer->h’ for GPT like models
target (Union[int, List[int]]) – The target token(s) to extract logits for. Defaults to all tokens.
position (Union[int, List[int]]) – The position for which to extract logits for. Defaults to all positions.
key (str) – The key of the hook. Defaults to {layer}_logits.
split_heads (bool) – Whether to split the heads. Defaults to False.
num_heads (int) – The number of heads to split. Defaults to None.
- Returns
The hook.
- Return type
Hook
- unseal.hooks.common_hooks.gpt_attn_wrapper(func: Callable, save_ctx: Dict, c_proj: torch.Tensor, vocab_embedding: torch.Tensor, target_ids: torch.Tensor, batch_size: Optional[int] = None) Tuple[Callable, Callable]
Wraps around the [AttentionBlock]._attn function to save the individual heads’ logits. This is necessary because the individual heads’ logits are not available on a module level and thus not accessible via a hook.
- Parameters
func (Callable) – original _attn function
save_ctx (Dict) – context to which the logits will be saved
c_proj (torch.Tensor) – projection matrix, this is W_O in Anthropic’s terminology
vocab_matrix (torch.Tensor) – vocabulary/embedding matrix, this is W_V in Anthropic’s terminology
target_ids (torch.Tensor) – indices of the target tokens for which the logits are computed
batch_size (Optional[int]) – batch size to reduce memory footprint, defaults to None
- Returns
inner, func, the wrapped function and the original function
- Return type
Tuple[Callable, Callable]
- unseal.hooks.common_hooks.replace_activation(indices: str, replacement_tensor: torch.Tensor, tuple_index: Optional[int] = None) Callable
Creates a hook which replaces a module’s activation (output) with a replacement tensor. If there is a dimension mismatch, the replacement tensor is copied along the leading dimensions of the output.
Example: If the activation has shape
(B, T, D)
and replacement tensor has shape(D,)
which you want to plug in at position t in the T dimension for every tensor in the batch, then indices should be:,t,:
.- Parameters
indices (str) – Indices at which to insert the replacement tensor
replacement_tensor (torch.Tensor) – Tensor that is filled in.
tuple_index (int) – Index of the tuple in the output of the module.
- Returns
Function that replaces part of a given tensor with replacement_tensor
- Return type
Callable
- unseal.hooks.common_hooks.save_output(cpu: bool = True, detach: bool = True) Callable
Basic hooking function for saving the output of a module to the global context object
- Parameters
cpu (bool) – Whether to save the output to cpu.
detach (bool) – Whether to detach the output.
- Returns
Function that saves the output to the context object.
- Return type
Callable
- unseal.hooks.common_hooks.transformers_get_attention(heads: Optional[Union[int, Iterable[int], str]] = None, output_idx: Optional[int] = None) Callable
Creates a hooking function to get the attention patterns of a given layer.
- Parameters
heads (Optional[Union[int, Iterable[int], str]], optional) – The heads for which to save the attention, defaults to None
output_idx (Optional[int], optional) – If the attention module returns a tuple, use this argument to index it, defaults to None
- Returns
func, hooking function that saves attention of the specified heads
- Return type
Callable
hooks.rome_hooks module
hooks.util module
- unseal.hooks.util.create_slice_from_str(indices: str) slice
Creates a slice object from a string representing the slice.
- Parameters
indices (str) – String representing the slice, e.g.
...,3:5,:
- Returns
Slice object corresponding to the input indices.
- Return type
slice
- unseal.hooks.util.recursive_to_device(iterable: Union[Iterable, torch.Tensor], device: Union[str, torch.device]) Iterable
Recursively puts an Iterable of (Iterable of (…)) tensors on the given device
- Parameters
iterable (Tensor or Iterable) – Tensor or Iterable of tensors or iterables of …
device (Union[str, torch.device]) – Device on which to put the object
- Raises
TypeError – Unexpected tyes
- Returns
Nested iterable with the tensors on the new device
- Return type
Iterable