Giter Club home page Giter Club logo

Comments (2)

Aliktk avatar Aliktk commented on August 15, 2024

I am writing the code for the image here it is:
But I am confused by adding the distance finding in image.

Code:

import argparse
import torch
import cv2
import os
from utils.datasets import *
from utils.utils import *

def detect_image(image_path, weights='weights/yolov5s.pt', img_size=640, conf_thres=0.4, iou_thres=0.5, save_img=False):
    device = torch_utils.select_device('')  # Use the first available GPU
    half = device.type != 'cpu'  
    
    # Load model
    model = torch.load(weights, map_location=device)['model'].float()
    model.to(device).eval()
    if half:
        model.half()
    
    # Load image
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Failed to load image at {image_path}")
    img = cv2.resize(img, (img_size, img_size))
    img_tensor = torch.from_numpy(img).float().to(device).permute(2,0,1).unsqueeze(0) / 255.0
    
    # Inference
    pred = model(img_tensor.half() if half else img_tensor)[0]
    
    # Apply NMS
    pred = non_max_suppression(pred, conf_thres, iou_thres)
    
    # Get names and colors
    names = model.names if hasattr(model, 'names') else model.module.names
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    # Process detections
    for i, det in enumerate(pred):
        if det is not None and len(det):
            det[:, :4] = scale_coords(img_tensor.shape[2:], det[:, :4], img.shape).round()
            for *xyxy, conf, cls in det:
                label = '%s %.2f' % (names[int(cls)], conf)
                plot_one_box(xyxy, img, label=label, color=colors[int(cls)], line_thickness=3)
    
    # Save or display image
    if save_img:
        cv2.imwrite('output_image.jpg', img)
    
    cv2.imshow("Output", img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path')
    parser.add_argument('--source', type=str, default='path/to/your/image.jpg', help='source image path')
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
    opt = parser.parse_args()
    opt.img_size = check_img_size(opt.img_size)
    print(opt)

    with torch.no_grad():
        detect_image(opt.source, opt.weights, opt.img_size, opt.conf_thres, opt.iou_thres)
# python predict_image.py --source path/to/your/image.jpg --weights weights/yolov5s.pt --img-size 640 --conf-thres 0.4 --iou-thres 0.5

any help will be appreciated
thank you

from social-distancing-using-yolov5.

Aliktk avatar Aliktk commented on August 15, 2024

Thanks here is the code for an image

import argparse
import torch
import cv2
import os
from pathlib import Path
from utils.utils import *

def check_social_distance(boxes, min_distance=50):
    violations = []
    for i, box1 in enumerate(boxes):
        for j, box2 in enumerate(boxes):
            if i >= j:
                continue
            distance = ((box1[0]-box2[0])**2 + (box1[1]-box2[1])**2) ** 0.5
            if distance < min_distance:
                violations.append((i, j))
    return violations

def detect_image(image_path, weights='weights/yolov5s.pt', img_size=640, conf_thres=0.4, iou_thres=0.5, save_img=False):
    device = torch_utils.select_device('')
    half = device.type != 'cpu'  
    
    model = torch.load(weights, map_location=device)['model'].float()
    model.to(device).eval()
    if half:
        model.half()
    
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Failed to load image at {image_path}")
    img = cv2.resize(img, (img_size, img_size))
    img_tensor = torch.from_numpy(img).float().to(device).permute(2,0,1).unsqueeze(0) / 255.0
    
    pred = model(img_tensor.half() if half else img_tensor)[0]
    pred = non_max_suppression(pred, conf_thres, iou_thres)
    
    names = model.names if hasattr(model, 'names') else model.module.names
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    people_coords = []
    for i, det in enumerate(pred):
        if det is not None and len(det):
            det[:, :4] = scale_coords(img_tensor.shape[2:], det[:, :4], img.shape).round()
            for *xyxy, conf, cls in det:
                label = '%s %.2f' % (names[int(cls)], conf)
                if names[int(cls)] == 'person':
                    center_x = int((xyxy[0] + xyxy[2]) / 2)
                    center_y = int((xyxy[1] + xyxy[3]) / 2)
                    people_coords.append((center_x, center_y))
                    plot_one_box(xyxy, img, label=label, color=colors[int(cls)], line_thickness=3)
    
    violations = check_social_distance(people_coords, min_distance=100)
    for (i, j) in violations:
        cv2.line(img, people_coords[i], people_coords[j], (0, 0, 255), 2)

    # Save output image in a new folder
    image_name = Path(image_path).stem
    output_folder = f'inference/{image_name}_output'
    Path(output_folder).mkdir(parents=True, exist_ok=True)
    if save_img:
        output_image_path = os.path.join(output_folder, f'{image_name}_output.jpg')
        print(f"Saving image to {output_image_path}...")  # Add this line for debugging
        cv2.imwrite(output_image_path, img)
        print("Image saved.")  # Add this line for debugging

    cv2.imshow("Output", img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path')
    parser.add_argument('--source', type=str, default='path/to/your/image.jpg', help='source image path')
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
    parser.add_argument('--save-img', action='store_true', help='save output image')  # Add this line
    opt = parser.parse_args()
    opt.img_size = check_img_size(opt.img_size)
    print(opt)

    with torch.no_grad():
        detect_image(opt.source, opt.weights, opt.img_size, opt.conf_thres, opt.iou_thres, opt.save_img)  # Modify this line


from social-distancing-using-yolov5.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.