diff --git a/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb b/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb index fe5a370..edfd9dc 100644 --- a/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb +++ b/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb @@ -1886,8 +1886,9 @@ "reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs\n", "```\n", "\n", - "- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model:\n", - "$\\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_\\theta (y_l \\mid x)} \\right) \\quad \\text{and} \\quad \\log \\left( \\frac{\\pi_{\\text{ref}}(y_w \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)} \\right)$; this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$" + "- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model (this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$):\n", + "\n", + "$$\\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_\\theta (y_l \\mid x)} \\right) \\quad \\text{and} \\quad \\log \\left( \\frac{\\pi_{\\text{ref}}(y_w \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)} \\right)$$" ] }, { @@ -1897,8 +1898,10 @@ "id": "5458d217-e0ad-40a5-925c-507a8fcf5795" }, "source": [ - "- Next, the code `logits = model_logratios - reference_logratios` computes the difference between the model's log ratios and the reference model's log ratios, i.e., $\\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n", - "- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right)$\n" + "- Next, the code `logits = model_logratios - reference_logratios` computes the difference between the model's log ratios and the reference model's log ratios, i.e., \n", + "\n", + "$$\\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n", + "- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right)$$\n" ] }, { @@ -1908,8 +1911,10 @@ "id": "f18e3e36-f5f1-407f-b662-4c20a0ac0354" }, "source": [ - "- Finally, `losses = -F.logsigmoid(beta * logits)` calculates the loss using the log-sigmoid function; in the original equation, the term inside the expectation is $\\log \\sigma \\left( \\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n", - "- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right) \\right)$" + "- Finally, `losses = -F.logsigmoid(beta * logits)` calculates the loss using the log-sigmoid function; in the original equation, the term inside the expectation is \n", + "\n", + "$$\\log \\sigma \\left( \\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n", + "- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right) \\right)$$" ] }, { @@ -3089,7 +3094,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.6" } }, "nbformat": 4,