Comments (1)
first you have to run the train_cnn.py but in that aslo you will find some errors so i myself correct the error and fixed it i hereby pasting the code for train_cnn.py you can run this and let me know if you facing any error and most importantly this code uses the frozen modules to deactivate that and run run this commad "python -Xno_frozen_main train_cnn.py"
Code:
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from functions import overlapScore
from cnn_model import *
from training_dataset import *
def train_model(net, dataloader, batchSize, lr_rate, momentum, optimizer, scheduler):
criterion = nn.MSELoss()
for epoch in range(10):
for i, data in enumerate(dataloader):
optimizer.zero_grad()
inputs, labels = data
inputs, labels = inputs.view(batchSize, 1, 100, 100), labels.view(batchSize, 4)
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
pbox = outputs.detach().numpy()
gbox = labels.detach().numpy()
score, _ = overlapScore(pbox, gbox)
print('[epoch %5d, step: %d, loss: %f, Average Score = %f' % (epoch+1, i+1, loss.item(), score/batchSize))
scheduler.step() # Move this line here to update learning rate after each epoch
print('Finish Training')
if name == 'main':
# Hyper parameters
learning_rate = 0.000001
momentum = 0.9
batch = 100
no_of_workers = 2
shuffle = True
trainingdataset = training_dataset()
dataLoader = DataLoader(
dataset=trainingdataset,
batch_size=batch,
shuffle=shuffle,
num_workers=no_of_workers
)
model = cnn_model()
model.train()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
train_model(model, dataLoader, batch, learning_rate, momentum, optimizer, scheduler)
# Ensure the directory exists before saving the model
model_directory = './Model/'
os.makedirs(model_directory, exist_ok=True) # This line creates the directory if it does not exist
torch.save(model.state_dict(), os.path.join(model_directory, 'cnn_model.pth'))
from object-detection-using-cnn.
Related Issues (2)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from object-detection-using-cnn.