Giter Club home page Giter Club logo

Comments (2)

Shen-Qiu avatar Shen-Qiu commented on July 30, 2024
# -*- coding: utf-8 -*-
"""
    Plot precision-recall curve on the result of MIR-FLICKR-25K
"""
import numpy as np
import matplotlib.pyplot as plt 

def calc_hammingDist(B1, B2):
    q = B2.shape[1]
    disH = 0.5 * (q - np.dot(B1, B2.transpose()))
    return disH


def calc_similarity(label_1, label_2):
    return (np.dot(label_1, label_2.transpose()) > 0).astype(np.float32)


def calc_map(qB, rB, query_L, retrieval_L):
	# qB: {-1,+1}^{mxq}
	# rB: {-1,+1}^{nxq}
	# query_L: {0,1}^{mxl}
	# retrieval_L: {0,1}^{nxl}
	num_query = query_L.shape[0]
	map = 0
	for iter in range(num_query):
		gnd = (np.dot(query_L[iter, :], retrieval_L.transpose()) > 0).astype(np.float32)
		tsum = int(np.sum(gnd))
		if tsum == 0:
			continue
		hamm = calc_hammingDist(qB[iter, :], rB)
		ind = np.argsort(hamm)
		gnd = gnd[ind]
		count = np.linspace(1, tsum, tsum)

		tindex = np.asarray(np.where(gnd == 1)) + 1.0
		map = map + np.mean(count / (tindex))
	map = map / num_query
	return map


def cal_Precision_Recall_Curve(qB, rB, query_L, retrieval_L):
    S = calc_similarity(query_L, retrieval_L)
    dist = calc_hammingDist(qB, rB)
    num = qB.shape[0] # the number of input instances
    
    precision = np.zeros((num, bits + 1))
    recall = np.zeros((num, bits + 1))
    for i in range(num):
        relevant = set(np.where(S[i, :] == 1)[0])
        retrieved = set()
        for bit in range(bits + 1):
            retrieved = set(np.where(dist[i, :] == bit)[0]) | retrieved
            ret_rel = len(retrieved & relevant)
            #print('bit : {0}, Precision: {1:.4f}, Recall: {2:.4f}'.format(bit, 
            #      ret_rel / len(retrieved), ret_rel / len(relevant)))
            recall[i, bit] = ret_rel / len(relevant)
            if len(retrieved) == 0:
                continue
            precision[i, bit] = ret_rel / len(retrieved)
    
    return recall.mean(axis=0), precision.mean(axis=0)


result = np.load('./result_16bits_VGG19.npz')
#qBX = result['qBX'][0:1, :] # image query, just for one instance
qBX = result['qBX'] # image query
qBY = result['qBY'] # text query 
rBX = result['rBX'] # image retrieval 
rBY = result['rBY'] # text retrieval 
#query_L = result['query_L'][0:1, :] # query label, just for one instance
query_L = result['query_L'] # query label
retrieval_L = result['retrieval_L'] # retrieval label

mapi2t = result['mapi2t']
mapt2i = result['mapt2i']
print('mapi2t: {0:.4f}'.format(mapi2t))
print('mapt2i: {0:.4f}'.format(mapt2i))

bits = result['bit']
#calc_map(qBX, rBY, query_L, retrieval_L)
recall, precision = cal_Precision_Recall_Curve(qBX, rBY, query_L, retrieval_L)
fig = plt.figure(1)
ax = fig.add_subplot(121)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Image->Text')
plt.xlabel('Recall')
plt.ylabel('Precision')
#plt.plot()

# Text -> Image
recall, precision = cal_Precision_Recall_Curve(qBY, rBX, query_L, retrieval_L)
ax = fig.add_subplot(122)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Text->Image')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.plot()

from ssah.

anan1030 avatar anan1030 commented on July 30, 2024

Hi, you need to exclude the zeros in precision when doing precision.mean(axis=0), you can use np.average() for weighted average.

from ssah.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.