Skip to main content
Overview

PyTorch Hook

August 18, 2021
1 min read

Al liceo facevo cose come hackerare Pikachu Volleyball tramite DLL injection, e le tecniche che usavo erano una forma di hooking. nn.Module di PyTorch supporta ufficialmente questo tipo di meccanismo.

Regole

Gli hook di PyTorch seguono queste regole:

  • Se l’hook restituisce qualcosa, il valore restituito viene applicato all’oggetto originale.
  • Se l’hook non restituisce nulla, l’oggetto si comporta normalmente.
  • La funzione da agganciare viene passata come oggetto, quindi può avere qualsiasi nome.

Gli esempi di codice qui sotto dovrebbero chiarire il tutto.

tensor hook

I tensor supportano hook solo per il backward.

torch.tensor.register_hook(function)

nn.Module hook

Sono supportati quattro hook:

  • register_forward_pre_hook
  • register_forward_hook
  • register_backward_hook (deprecato)
  • register_full_backward_hook

Firma di forward_pre_hook

def pre_hook(module, input) return Anything

Se restituisce qualcosa, l’input del forward viene sostituito con Anything. Se non restituisce nulla, si limita a ispezionare l’input.

Firma di forward_hook

def hook(module, input, output) return Anything

Se restituisce qualcosa, l’output del forward viene sostituito con Anything. Se non restituisce nulla, è solo ispezione.

full_backward_hook

def module_hook(module, grad_input, grad_output)

Se restituisce qualcosa, il grad_output usato durante l’aggiornamento di backward() può essere sostituito. Se non restituisce nulla, è solo ispezione.

Loading comments...