Giter Club home page Giter Club logo

Comments (6)

LaurentMazare avatar LaurentMazare commented on May 30, 2024

I haven't tested but maybe you can do something like the following ?
let grad_h = grad_h * h.ge(0.).to_kind(Kind::Float)
This would compute h >= 0 as a boolean, then convert int to a float which should be 0 or 1 and do a pointwise multiplication. Probably worth checking the output on a couple examples but hopefully that would do the trick.

from tch-rs.

zeroexcuses avatar zeroexcuses commented on May 30, 2024

To the extent I have the legal rights to do, I'd like to offer the following (a translation / derivative work of https://pytorch.org/tutorials/beginner/examples_tensor/two_layer_net_tensor.html#sphx-glr-beginner-examples-tensor-two-layer-net-tensor-py ) as additional sample code:

Besides the let grad_h = grad_h * h.ge(0.).to_kind(Kind::Float) trick, it's a direct translation.

fn main(opts: (tch::Kind, tch::Device)) {
    println!("main test");
    let (d_n, d_in, d_h, d_out) = (64, 1000, 100, 10);

    let x = tch::Tensor::randn(&[d_n, d_in], opts);
    let y = tch::Tensor::randn(&[d_n, d_out], opts);

    let mut w1 = tch::Tensor::randn(&[d_in, d_h], opts);
    let mut w2 = tch::Tensor::randn(&[d_h, d_out], opts);

    let learning_rate = 1e-6;

    for i in 0..500 {
        let h = x.mm(&w1);
        let h_relu = h.clamp_min(0);
        let y_pred = h_relu.mm(&w2);
        let loss = (&y_pred - &y).pow(2).sum();
        println!("{}: {:?}", i, loss);

        let grad_y_pred = 2.0 * (&y_pred - &y);
        let grad_w2 = h_relu.transpose(0, 1).mm(&grad_y_pred);
        let grad_h_relu = grad_y_pred.mm(&w2.transpose(0, 1));
        let grad_h = grad_h_relu * h.ge(0.).to_kind(tch::Kind::Float);
        let grad_w1 = x.transpose(0, 1).mm(&grad_h);

        w1 -= learning_rate * grad_w1;
        w2 -= learning_rate * grad_w2;
    }
}

#[test]
fn test_00() {
    let opts = (tch::Kind::Float, tch::Device::Cpu);
    main(opts)
}

Code appears to work: output:

main test
0: [34260284]
1: [32717616]
2: [35069400]
3: [34915180]
4: [29025518]
5: [19116674]
6: [10442851]
7: [5215584]
8: [2716956.5]
9: [1597051]
10: [1077723.5]
11: [806905.6875]
12: [643919.9375]
13: [532315.875]
14: [448856.5625]
15: [383128.1875]
16: [329637.0625]
17: [285379.1875]
18: [248298.359375]
19: [216973.46875]
20: [190334.21875]
21: [167557.953125]
22: [148002.375]
23: [131113.96875]
24: [116475.3671875]
25: [103731.03125]
26: [92597.359375]
27: [82841.921875]
28: [74268.0703125]
29: [66709]
30: [60016.53125]
31: [54088.41015625]
32: [48823.796875]
33: [44134.609375]
34: [39949.82421875]
35: [36208.046875]
36: [32862.96484375]
37: [29861.57421875]
38: [27163.69921875]
39: [24734.12109375]
40: [22543.10546875]
41: [20565.583984375]
42: [18777.408203125]
43: [17159.69921875]
44: [15693.8330078125]
45: [14364.0234375]
46: [13156.2861328125]
47: [12058.017578125]
48: [11058.720703125]
49: [10148.5263671875]
50: [9319.212890625]
51: [8562.736328125]
52: [7871.6298828125]
53: [7240.029296875]
54: [6662.66552734375]
55: [6134.64892578125]
56: [5650.7060546875]
57: [5207.15234375]
58: [4800.38525390625]
59: [4427.21630859375]
60: [4084.65625]
61: [3770.163818359375]
62: [3481.187255859375]
63: [3215.48095703125]
64: [2971.60302734375]
65: [2747.099853515625]
66: [2540.377197265625]
67: [2349.89306640625]
68: [2174.292236328125]
69: [2012.3524169921875]
70: [1862.973388671875]
71: [1725.17919921875]
72: [1597.97509765625]
73: [1480.51025390625]
74: [1372.021484375]
75: [1271.76953125]
76: [1179.1275634765625]
77: [1093.469970703125]
78: [1014.2444458007813]
79: [940.973876953125]
80: [873.1773071289063]
81: [810.4569702148438]
82: [752.4011840820313]
83: [698.6286010742188]
84: [648.8204345703125]
85: [602.6876831054688]
86: [559.94140625]
87: [520.3128662109375]
88: [483.5766296386719]
89: [449.5108947753906]
90: [417.94464111328125]
91: [388.6556396484375]
92: [361.4674987792969]
93: [336.24249267578125]
94: [312.8373718261719]
95: [291.1053161621094]
96: [270.9220886230469]
97: [252.18289184570313]
98: [234.78280639648438]
99: [218.61538696289063]
100: [203.5952911376953]
101: [189.63330078125]
102: [176.65658569335938]
103: [164.59197998046875]
104: [153.37435913085938]
105: [142.94175720214844]
106: [133.24234008789063]
107: [124.21937561035156]
108: [115.82464599609375]
109: [108.01161193847656]
110: [100.74175262451172]
111: [93.97444915771484]
112: [87.67327880859375]
113: [81.80506896972656]
114: [76.3422622680664]
115: [71.25404357910156]
116: [66.51329803466797]
117: [62.09667205810547]
118: [57.98186492919922]
119: [54.14569091796875]
120: [50.57160949707031]
121: [47.237030029296875]
122: [44.12919616699219]
123: [41.23194122314453]
124: [38.53001022338867]
125: [36.009193420410156]
126: [33.65830993652344]
127: [31.464611053466797]
128: [29.418720245361328]
129: [27.50872230529785]
130: [25.725427627563477]
131: [24.061630249023438]
132: [22.508499145507813]
133: [21.058027267456055]
134: [19.703481674194336]
135: [18.439287185668945]
136: [17.257568359375]
137: [16.15477180480957]
138: [15.123416900634766]
139: [14.160041809082031]
140: [13.259425163269043]
141: [12.417962074279785]
142: [11.631413459777832]
143: [10.895730018615723]
144: [10.208312034606934]
145: [9.565279960632324]
146: [8.963581085205078]
147: [8.401830673217773]
148: [7.875828742980957]
149: [7.382791519165039]
150: [6.922338962554932]
151: [6.491796970367432]
152: [6.088073253631592]
153: [5.7103657722473145]
154: [5.356616020202637]
155: [5.0256452560424805]
156: [4.715524673461914]
157: [4.425028324127197]
158: [4.152915000915527]
159: [3.8980093002319336]
160: [3.6594557762145996]
161: [3.435666084289551]
162: [3.225914478302002]
163: [3.0295639038085938]
164: [2.845404624938965]
165: [2.6725549697875977]
166: [2.510826826095581]
167: [2.3589487075805664]
168: [2.2165966033935547]
169: [2.0831785202026367]
170: [1.9577014446258545]
171: [1.8404662609100342]
172: [1.7301545143127441]
173: [1.6266480684280396]
174: [1.5297385454177856]
175: [1.4386940002441406]
176: [1.352889895439148]
177: [1.272759199142456]
178: [1.197373390197754]
179: [1.1265606880187988]
180: [1.0600446462631226]
181: [0.9974740743637085]
182: [0.9388630986213684]
183: [0.8838181495666504]
184: [0.8320485353469849]
185: [0.7833105325698853]
186: [0.7375422716140747]
187: [0.6944899559020996]
188: [0.6540900468826294]
189: [0.6161644458770752]
190: [0.5804455876350403]
191: [0.5468751788139343]
192: [0.5151917934417725]
193: [0.485488623380661]
194: [0.457588791847229]
195: [0.4313516318798065]
196: [0.406649112701416]
197: [0.3833242952823639]
198: [0.36140576004981995]
199: [0.3407328128814697]
200: [0.32142406702041626]
201: [0.303079754114151]
202: [0.2858888804912567]
203: [0.26974642276763916]
204: [0.25449246168136597]
205: [0.24014046788215637]
206: [0.22659197449684143]
207: [0.21385924518108368]
208: [0.2018306404352188]
209: [0.1904999315738678]
210: [0.17985865473747253]
211: [0.16983194649219513]
212: [0.1603270024061203]
213: [0.15140242874622345]
214: [0.1429554522037506]
215: [0.13500405848026276]
216: [0.12753276526927948]
217: [0.12046453356742859]
218: [0.11384101957082748]
219: [0.10755111277103424]
220: [0.101639524102211]

from tch-rs.

zeroexcuses avatar zeroexcuses commented on May 30, 2024

Here's a translation of https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_autograd.html#sphx-glr-beginner-examples-autograd-two-layer-net-autograd-py

use super::*;

fn main(opts: (tch::Kind, tch::Device)) {
    println!("main test");
    let (d_n, d_in, d_h, d_out) = (64, 1000, 100, 10);

    let x = tch::Tensor::randn(&[d_n, d_in], opts);
    let y = tch::Tensor::randn(&[d_n, d_out], opts);

    let mut w1 = tch::Tensor::randn(&[d_in, d_h], opts);
    let mut w2 = tch::Tensor::randn(&[d_h, d_out], opts);

    w1.set_requires_grad(true);
    w2.set_requires_grad(true);

    let learning_rate = 1e-6;
    for i in 0..500 {
        let y_pred = x.mm(&w1).clamp_min(0).mm(&w2);
        let loss = (&y_pred - &y).pow(2).sum();
        println!("{}: {:?}", i, loss);

        loss.backward();

        tch::no_grad(|| {
            w1 -= learning_rate * w1.grad();
            w2 -= learning_rate * w2.grad();
            w1.zero_grad();
            w2.zero_grad();
        });
    }
}

#[test]
fn test_00() {
    let opts = (tch::Kind::Float, tch::Device::Cpu);
    main(opts)
}

from tch-rs.

LaurentMazare avatar LaurentMazare commented on May 30, 2024

Looks nice! I think it would be interesting to have these as sample code indeed. Would you mind making some small PR to integrate these, e.g. in examples/basics, and pointing at the original source so that people can easily compare ? (if you don't have the time to do so I can have a look at it)

from tch-rs.

zeroexcuses avatar zeroexcuses commented on May 30, 2024

Long term it makes sense to submit pull requests. Let me translate a few more tutorials and see if I can organize this a bit better.

from tch-rs.

LaurentMazare avatar LaurentMazare commented on May 30, 2024

Closing this one now, feel free to re-open if you encounter further similar issues.

from tch-rs.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.