Fix docstring parameter names in compute_dpo_loss function (#953)

This commit is contained in:
Dawid Woźniak
2026-01-29 23:51:17 +01:00
committed by GitHub
parent e155d1b02c
commit 82010e2c77

View File

@@ -1880,8 +1880,8 @@
" \"\"\"Compute the DPO loss for a batch of policy and reference model log probabilities.\n",
"\n",
" Args:\n",
" policy_chosen_logprobs: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)\n",
" policy_rejected_logprobs: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)\n",
" model_chosen_logprobs: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)\n",
" model_rejected_logprobs: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)\n",
" reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)\n",
" reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)\n",
" beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.\n",