這篇文章主要介紹了如何使用PyTorch實(shí)現(xiàn)目標(biāo)檢測(cè)與跟蹤,具有一定借鑒價(jià)值,需要的朋友可以參考下。希望大家閱讀完這篇文章后大有收獲。下面讓小編帶著大家一起了解一下。

引言
在昨天的文章中,我們介紹了如何在PyTorch中使用您自己的圖像來(lái)訓(xùn)練圖像分類器,然后使用它來(lái)進(jìn)行圖像識(shí)別。本文將展示如何使用預(yù)訓(xùn)練的分類器檢測(cè)圖像中的多個(gè)對(duì)象,并在視頻中跟蹤它們。
圖像中的目標(biāo)檢測(cè)
目標(biāo)檢測(cè)的算法有很多,YOLO跟SSD是現(xiàn)下最流行的算法。在本文中,我們將使用YOLOv3。在這里我們不會(huì)詳細(xì)討論YOLO,如果想對(duì)它有更多了解,可以參考下面的鏈接哦~(https://pjreddie.com/darknet/yolo/)
下面讓我們開始吧,依然從導(dǎo)入模塊開始:
from models import * from utils import * import os, sys, time, datetime, random import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms from torch.autograd import Variable import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image
然后加載預(yù)訓(xùn)練的配置和權(quán)重,以及一些預(yù)定義的值,包括:圖像的尺寸、置信度閾值和非較大抑制閾值。
config_path='config/yolov3.cfg' weights_path='config/yolov3.weights' class_path='config/coco.names' img_size=416 conf_thres=0.8 nms_thres=0.4 # Load model and weights model = Darknet(config_path, img_size=img_size) model.load_weights(weights_path) model.cuda() model.eval() classes = utils.load_classes(class_path) Tensor = torch.cuda.FloatTensor
下面的函數(shù)將返回對(duì)指定圖像的檢測(cè)結(jié)果。
def detect_image(img): # scale and pad image ratio = min(img_size/img.size[0], img_size/img.size[1]) imw = round(img.size[0] * ratio) imh = round(img.size[1] * ratio) img_transforms=transforms.Compose([transforms.Resize((imh,imw)), transforms.Pad((max(int((imh-imw)/2),0), max(int((imw-imh)/2),0), max(int((imh-imw)/2),0), max(int((imw-imh)/2),0)), (128,128,128)), transforms.ToTensor(), ]) # convert image to Tensor image_tensor = img_transforms(img).float() image_tensor = image_tensor.unsqueeze_(0) input_img = Variable(image_tensor.type(Tensor)) # run inference on the model and get detections with torch.no_grad(): detections = model(input_img) detections = utils.non_max_suppression(detections, 80, conf_thres, nms_thres) return detections[0]
最后,讓我們通過加載一個(gè)圖像,獲取檢測(cè)結(jié)果,然后用檢測(cè)到的對(duì)象周圍的包圍框來(lái)顯示它。并為不同的類使用不同的顏色來(lái)區(qū)分。
# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))
# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)
pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x
if detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)], 
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()下面是我們的一些檢測(cè)結(jié)果:



視頻中的目標(biāo)跟蹤
現(xiàn)在你知道了如何在圖像中檢測(cè)不同的物體。當(dāng)你在一個(gè)視頻中一幀一幀地看時(shí),你會(huì)看到那些跟蹤框在移動(dòng)。但是如果這些視頻幀中有多個(gè)對(duì)象,你如何知道一個(gè)幀中的對(duì)象是否與前一個(gè)幀中的對(duì)象相同?這被稱為目標(biāo)跟蹤,它使用多次檢測(cè)來(lái)識(shí)別一個(gè)特定的對(duì)象。
有多種算法可以做到這一點(diǎn),在本文中決定使用SORT(Simple Online and Realtime Tracking),它使用Kalman濾波器預(yù)測(cè)先前識(shí)別的目標(biāo)的軌跡,并將其與新的檢測(cè)結(jié)果進(jìn)行匹配,非常方便且速度很快。
現(xiàn)在開始編寫代碼,前3個(gè)代碼段將與單幅圖像檢測(cè)中的代碼段相同,因?yàn)樗鼈兲幚淼氖窃趩螏汐@得 YOLO 檢測(cè)。差異在最后一部分出現(xiàn),對(duì)于每個(gè)檢測(cè),我們調(diào)用 Sort 對(duì)象的 Update 函數(shù),以獲得對(duì)圖像中對(duì)象的引用。因此,與前面示例中的常規(guī)檢測(cè)(包括邊界框的坐標(biāo)和類預(yù)測(cè))不同,我們將獲得跟蹤的對(duì)象,除了上面的參數(shù),還包括一個(gè)對(duì)象 ID。并且需要使用OpenCV來(lái)讀取視頻并顯示視頻幀。
videopath = 'video/interp.mp4'
%pylab inline 
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)
    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) * 
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) * 
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)), 
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                        1, (255,255,255), 3)
    fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)感謝你能夠認(rèn)真閱讀完這篇文章,希望小編分享如何使用PyTorch實(shí)現(xiàn)目標(biāo)檢測(cè)與跟蹤內(nèi)容對(duì)大家有幫助,同時(shí)也希望大家多多支持創(chuàng)新互聯(lián)網(wǎng)站建設(shè)公司,,關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道,遇到問題就找創(chuàng)新互聯(lián)網(wǎng)站建設(shè)公司,,詳細(xì)的解決方法等著你來(lái)學(xué)習(xí)!
                網(wǎng)站名稱:如何使用PyTorch實(shí)現(xiàn)目標(biāo)檢測(cè)與跟蹤-創(chuàng)新互聯(lián)
                
                本文地址:http://www.chinadenli.net/article16/dgeegg.html
            
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供電子商務(wù)、網(wǎng)站收錄、關(guān)鍵詞優(yōu)化、網(wǎng)站改版、移動(dòng)網(wǎng)站建設(shè)、定制開發(fā)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容
