unseal.hooks package

This package handles the nitty-gritty of hooking to a model.

hooks module

hooks.common_hooks module

unseal.hooks.common_hooks.additive_output_noise(indices: str, mean: Optional[float] = 0, std: Optional[float] = 0.1) Callable
unseal.hooks.common_hooks.create_attention_hook(layer: int, key: str, output_idx: Optional[int] = None, attn_name: Optional[str] = 'attn', layer_key_prefix: Optional[str] = None, heads: Optional[Union[int, Iterable[int], str]] = None) unseal.hooks.commons.Hook

Creates a hook which saves the attention patterns of a given layer.

Parameters
  • layer (int) – The layer to hook.

  • key (str) – The key to use for saving the attention patterns.

  • output_idx (Optional[int], optional) – If the module output is a tuple, index it with this. GPT like models need this to be equal to 2, defaults to None

  • attn_name (Optional[str], optional) – The name of the attention module in the transformer, defaults to ‘attn’

  • layer_key_prefix (Optional[str], optional) – The prefix in the model structure before the layer idx, e.g. ‘transformer->h’, defaults to None

  • heads (Optional[Union[int, Iterable[int], str]], optional) – Which heads to save the attention pattern for. Can be int, tuple of ints or string like ‘1:3’, defaults to None

Returns

Hook which saves the attention patterns

Return type

Hook

unseal.hooks.common_hooks.create_logit_hook(layer: int, model: unseal.hooks.commons.HookedModel, unembedding_key: str, layer_key_prefix: Optional[str] = None, target: Optional[Union[int, List[int]]] = None, position: Optional[Union[int, List[int]]] = None, key: Optional[str] = None, split_heads: Optional[bool] = False, num_heads: Optional[int] = None) unseal.hooks.commons.Hook

Create a hook that saves the logits of a layer’s output. Outputs are saved to save_ctx[key][‘logits’].

Parameters
  • layer (int) – The number of the layer

  • model (HookedModel) – The model.

  • unembedding_key (str) – The key/name of the embedding matrix, e.g. ‘lm_head’ for causal LM models

  • layer_key_prefix (str) – The prefix of the key of the layer, e.g. ‘transformer->h’ for GPT like models

  • target (Union[int, List[int]]) – The target token(s) to extract logits for. Defaults to all tokens.

  • position (Union[int, List[int]]) – The position for which to extract logits for. Defaults to all positions.

  • key (str) – The key of the hook. Defaults to {layer}_logits.

  • split_heads (bool) – Whether to split the heads. Defaults to False.

  • num_heads (int) – The number of heads to split. Defaults to None.

Returns

The hook.

Return type

Hook

unseal.hooks.common_hooks.gpt_attn_wrapper(func: Callable, save_ctx: Dict, c_proj: torch.Tensor, vocab_embedding: torch.Tensor, target_ids: torch.Tensor, batch_size: Optional[int] = None) Tuple[Callable, Callable]

Wraps around the [AttentionBlock]._attn function to save the individual heads’ logits. This is necessary because the individual heads’ logits are not available on a module level and thus not accessible via a hook.

Parameters
  • func (Callable) – original _attn function

  • save_ctx (Dict) – context to which the logits will be saved

  • c_proj (torch.Tensor) – projection matrix, this is W_O in Anthropic’s terminology

  • vocab_matrix (torch.Tensor) – vocabulary/embedding matrix, this is W_V in Anthropic’s terminology

  • target_ids (torch.Tensor) – indices of the target tokens for which the logits are computed

  • batch_size (Optional[int]) – batch size to reduce memory footprint, defaults to None

Returns

inner, func, the wrapped function and the original function

Return type

Tuple[Callable, Callable]

unseal.hooks.common_hooks.hidden_patch_hook_fn(position: int, replacement_tensor: torch.Tensor) Callable
unseal.hooks.common_hooks.replace_activation(indices: str, replacement_tensor: torch.Tensor, tuple_index: Optional[int] = None) Callable

Creates a hook which replaces a module’s activation (output) with a replacement tensor. If there is a dimension mismatch, the replacement tensor is copied along the leading dimensions of the output.

Example: If the activation has shape (B, T, D) and replacement tensor has shape (D,) which you want to plug in at position t in the T dimension for every tensor in the batch, then indices should be :,t,:.

Parameters
  • indices (str) – Indices at which to insert the replacement tensor

  • replacement_tensor (torch.Tensor) – Tensor that is filled in.

  • tuple_index (int) – Index of the tuple in the output of the module.

Returns

Function that replaces part of a given tensor with replacement_tensor

Return type

Callable

unseal.hooks.common_hooks.save_output(cpu: bool = True, detach: bool = True) Callable

Basic hooking function for saving the output of a module to the global context object

Parameters
  • cpu (bool) – Whether to save the output to cpu.

  • detach (bool) – Whether to detach the output.

Returns

Function that saves the output to the context object.

Return type

Callable

unseal.hooks.common_hooks.transformers_get_attention(heads: Optional[Union[int, Iterable[int], str]] = None, output_idx: Optional[int] = None) Callable

Creates a hooking function to get the attention patterns of a given layer.

Parameters
  • heads (Optional[Union[int, Iterable[int], str]], optional) – The heads for which to save the attention, defaults to None

  • output_idx (Optional[int], optional) – If the attention module returns a tuple, use this argument to index it, defaults to None

Returns

func, hooking function that saves attention of the specified heads

Return type

Callable

hooks.rome_hooks module

hooks.util module

unseal.hooks.util.create_slice_from_str(indices: str) slice

Creates a slice object from a string representing the slice.

Parameters

indices (str) – String representing the slice, e.g. ...,3:5,:

Returns

Slice object corresponding to the input indices.

Return type

slice

unseal.hooks.util.recursive_to_device(iterable: Union[Iterable, torch.Tensor], device: Union[str, torch.device]) Iterable

Recursively puts an Iterable of (Iterable of (…)) tensors on the given device

Parameters
  • iterable (Tensor or Iterable) – Tensor or Iterable of tensors or iterables of …

  • device (Union[str, torch.device]) – Device on which to put the object

Raises

TypeError – Unexpected tyes

Returns

Nested iterable with the tensors on the new device

Return type

Iterable

hooks.commons module