Efficient implementation of factorization machine with matrix operations?

Multi tool use
Efficient implementation of factorization machine with matrix operations?
Link is here : https://www.csie.ntu.edu.tw/~r01922136/slides/ffm.pdf (slides 5-6)
Given the following matrices:
X : n * d
W : d * k
Is there an efficient way to calculate the n x 1 matrix using only matrix operations (eg. numpy, tensorflow), where the jth element is :
EDIT:
Current attempt is this, but obviously it's not very space efficient, as it requires storing matrices of size n*d*d
:
n*d*d
n = 1000
d = 256
k = 32
x = np.random.normal(size=[n,d])
w = np.random.normal(size=[d,k])
xxt = np.matmul(x.reshape([n,d,1]),x.reshape([n,1,d]))
wwt = np.matmul(w.reshape([1,d,k]),w.reshape([1,k,d]))
output = xxt*wwt
output = np.sum(output,(1,2))
idownvotedbecau.se/noattempt
– desertnaut
Jul 2 at 19:59
Sorry about that, just added code
– Pian Pawakapan
Jul 2 at 20:56
1 Answer
1
Not all types of algorithms are that easily or obviously to vectorize. The np.sum(xxt*wwt)
can be rewritten using np.einsum
. This should be faster than your solution, but has some other limitations (eg. no multithreading).
np.sum(xxt*wwt)
np.einsum
I would therefor suggest using a compiler like Numba.
Example
import numpy as np
import numba as nb
import time
@nb.njit(fastmath=True,parallel=True)
def factorization_nb(w,x):
n = x.shape[0]
d = x.shape[1]
k = w.shape[1]
output=np.empty(n,dtype=w.dtype)
wwt=np.dot(w.reshape((d,k)),w.reshape((k,d)))
for i in nb.prange(n):
sum=0.
for j in range(d):
for jj in range(d):
sum+=x[i,j]*x[i,jj]*wwt[j,jj]
output[i]=sum
return output
def factorization_orig(w,x):
n = x.shape[0]
d = x.shape[1]
k = w.shape[1]
xxt = np.matmul(x.reshape([n,d,1]),x.reshape([n,1,d]))
wwt = np.matmul(w.reshape([1,d,k]),w.reshape([1,k,d]))
output = xxt*wwt
output = np.sum(output,(1,2))
return output
Mesuring Performance
n = 1000
d = 256
k = 32
x = np.random.normal(size=[n,d])
w = np.random.normal(size=[d,k])
#first call has some compilation overhead
res_1=factorization_nb(w,x)
t1=time.time()
for i in range(100):
res_1=factorization_nb(w,x)
#res_2=factorization_orig(w,x)
print(time.time()-t1)
Timings
factorization_nb: 4.2 ms per iteration
factorization_orig: 460 ms per iteration (110x speedup)
By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.
Can you please add the code that you have tried
– Mohammed Kashif
Jul 2 at 19:32