fix cyfi445 labs 0-3

This commit is contained in:
Frank Xu
2025-09-14 17:02:16 -04:00
parent 85867d93c9
commit e7aca5a24d
6 changed files with 374 additions and 279 deletions

File diff suppressed because one or more lines are too long

View File

@@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"id": "08a3e0d1",
"metadata": {},
"outputs": [
@@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"id": "76e04a69",
"metadata": {},
"outputs": [
@@ -90,12 +90,32 @@
"metadata": {},
"source": [
"#### Calcuate Mean Squared Error when slope range between [-5, 9.05] with 0.05 inteval\n",
"- w to reprsent the slope"
"- w to reprsent the slope\n",
"\n",
"\n",
"$$\n",
"\\begin{bmatrix}-5 & -5 & -5 & -5 \\\\\n",
"-4.95 & -4.95 & -4.95 & -4.95 \\\\\n",
"-4.9 & -4.9 & -4.9 & -4.9 \n",
"\\end{bmatrix}\n",
"\\times\n",
"\\begin{bmatrix}\n",
"1 & 2 & 3 & 4 \\\\\n",
"1 & 2 & 3 & 4\\\\\n",
"1 & 2 & 3 & 4\n",
"\\end{bmatrix}\n",
"=\n",
"\\begin{bmatrix}\n",
"1 \\cdot -5 & 2 \\cdot -5 & 3 \\cdot -5 & 4 \\cdot -5\\\\\n",
"1 \\cdot -4.95 & 2 \\cdot -4.95 & 3 \\cdot -4.95 & 4 \\cdot -4.95\\\\\n",
"1 \\cdot -4.9 & 2 \\cdot -4.9 & 3 \\cdot -4.9 & 4 \\cdot -4.9\n",
"\\end{bmatrix}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "34821a6b",
"metadata": {},
"outputs": [
@@ -119,12 +139,12 @@
"Y = np.array([2.3, 3.4, 6.5, 6.8], dtype=np.float32)\n",
"\n",
"# Generate weights from -5 to 5\n",
"w_values = np.arange(-5, 9.05, 0.05)\n",
"w_values = np.arange(-5, 9.05, 0.05) # 281 values (281,)\n",
"\n",
"# Calculate MSE for each weight\n",
"y_preds = w_values[:, np.newaxis] * X\n",
"errors = Y - y_preds\n",
"mse_values = np.mean(errors**2, axis=1)\n",
"y_preds = w_values[:, np.newaxis] * X # (281,1) *(4) = (281,1) * (1,4) = (281,4)\n",
"errors = Y - y_preds # (281,4) = (4) - (281,4) = (1,4) - (281,4) \n",
"mse_values = np.mean(errors**2, axis=1) # (281,) mean over columns\n",
"\n",
"# Find optimal weight\n",
"min_idx = np.argmin(mse_values)\n",
@@ -144,6 +164,27 @@
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1870d1e3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(281, 4)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_preds.shape"
]
},
{
"cell_type": "markdown",
"id": "4c4f352e",
@@ -154,7 +195,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"id": "39d1d373",
"metadata": {},
"outputs": [