设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2302|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 0 S0 ~" B# K& f

    # y8 [8 X0 D8 U为预防老年痴呆,时不时学点新东东玩一玩。
    ' q$ E# n  m' I% t- ~Pytorch 下面的代码做最简单的一元线性回归:
    6 E2 U" I2 G/ @8 b, }+ V----------------------------------------------, x, r0 }( W' l
    import torch6 O! r9 c& C/ Q- Z2 G. ]5 _
    import numpy as np5 S8 j3 B. s" `
    import matplotlib.pyplot as plt
    . T* P' K  j! `0 }0 Wimport random5 V0 X  {9 M) b2 j3 z( m# G) j( l" N

    ; Z5 P  }+ B4 Y  z5 ^( W, Nx = torch.tensor(np.arange(1,100,1))
    + E4 Y* d) M4 z8 Zy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15$ S/ O6 Y3 a' C+ S) N

    / v4 o$ g1 G& C; v. G, xw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    8 k6 n# q4 O! p' q8 m9 r3 c! i% Qb = torch.tensor(0.,requires_grad=True)
    - r! o- {" r5 Y& Q& |7 t* ^% |" d0 s
    epochs = 100
    7 j$ [# l& n' `  L- G6 Z. N- C4 u* E
    losses = []
    * Z4 J0 q0 S, p5 T* I$ Gfor i in range(epochs):
    , n: Y0 U; C; l9 w" {1 ?5 u9 s. f  y_pred = (x*w+b)    # 预测4 w$ b8 _6 M2 A3 X6 I( W2 H- L
      y_pred.reshape(-1)& l& V" [5 G& w

    ( R: ~7 [0 \$ q* _0 P+ _  loss = torch.square(y_pred - y).mean()   #计算 loss+ W* C  E3 n1 i5 f4 N' U
      losses.append(loss)
      ~4 J4 p7 R% G2 U2 @' h( s  
    1 ^+ d% T0 |6 c! T8 D  loss.backward() # autograd9 z1 u! `' p6 I9 @
      with torch.no_grad():. R. `$ ?% N/ Z1 W) l3 i' ^5 t
        w  -= w.grad*0.0001   # 回归 w  O) ?& M2 y" k# R/ B1 e# _$ `# K
        b  -= b.grad*0.0001    # 回归 b # X9 y/ b, z- r8 h8 \1 l0 ]
      w.grad.zero_()  
    * z  C! ~- T9 M! L! O& ^7 ~  b.grad.zero_()% |0 _" [9 p- l& S8 o: g% D
    8 h5 u8 z( k( t/ f+ b1 r" r
    print(w.item(),b.item()) #结果- n/ h8 u0 [0 Q% S( l" C
    ) O3 y) C* i. k, m/ {3 j' n
    Output: 27.26387596130371  0.4974517822265625
    5 x: `5 F  T2 i5 \----------------------------------------------
    , B$ P" M$ Z4 L/ p2 n最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    6 {/ }! g, y8 Z; x  E0 K2 F" \高手们帮看看是神马原因?' P/ N& D+ M2 K! u

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑   ~3 j! A. M( L  v4 A

    0 B; S* f7 \5 }没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?% \& y4 F. c; I0 d4 H' z6 R  f
    -------; x/ J1 {, w" P' k2 }1 |" ^/ Z/ E
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。% P+ C* K4 @8 E1 @
    -------
    8 O6 F! e# n; m算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23+ ?2 l) X0 W; [0 S
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?* G' O! I) G3 |4 v+ x
    -------
    8 s3 Q! ~2 i4 u3 I8 B2 O不好意思, ...
    + K) u8 O6 o+ k5 W7 ~+ `. c; q
    谢谢,算法应该没问题,就是最简单的线性回归。
    . T  Q( ]: x5 R5 `0 s/ L我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 # z! V+ b- ^, l7 F7 y# L
    雷达 发表于 2023-2-14 21:52
    2 D& {% F1 E7 x2 v谢谢,算法应该没问题,就是最简单的线性回归。
    6 z7 }! k1 H* E1 c+ U我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    % w, n4 k  N2 n' Y8 {* [4 ~0 U/ O( w1 K4 [
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    6 z" O2 u. M# }( v0 n
    : C' S7 i, i- h7 }4 ]: ]3 j: R( f或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    + x- W, o( ?( w9 N3 m) ]. F2 z2 r
    老福 发表于 2023-2-14 22:002 y4 z4 a1 K" K1 C4 n; N
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    ) i6 E  S5 e0 O7 X1 |3 u" L. e0 `1 K/ y0 q
    或者把b但的起点改为1试试。 ...
    2 Q& e" }9 h2 a, Z$ S8 m

    , d1 Y  l  |2 ^2 {你是对的。
    + t: I  {# _+ S' Q/ u去掉了随机部分
    / x. w! C5 z$ {#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    ! M; \0 u+ ^+ r% Gy = (x*27+15).reshape(-1)
    : X0 T( a0 A. n: C( Z
    . ~$ |! M3 T, I8 @) _循环次数加成10倍,就看到 b 收敛了+ o. P5 L5 _( V8 K' q# l4 m
    w , b
    5 [2 `( J* K* t. z- V$ y1 i27.002620697021484 14.826167106628418
    . h* {& S" V+ g4 x0 o  Q+ H/ ?; m7 M! t( T) N9 n8 f; q
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-1-8 18:40 , Processed in 0.027682 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表