Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Weichen Shen committed Apr 6, 2020
1 parent e86fd66 commit 54a266e
Show file tree
Hide file tree
Showing 45 changed files with 2,014 additions and 0 deletions.
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# DeepMatch

[![Python Versions](https://img.shields.io/pypi/pyversions/deepmatch.svg)](https://pypi.org/project/deepctr)
[![TensorFlow Versions](https://img.shields.io/badge/TensorFlow-1.4+/2.0+-blue.svg)](https://pypi.org/project/deepmatch)
[![PyPI Version](https://img.shields.io/pypi/v/deepmatch.svg)](https://pypi.org/project/deepmatch)
[![GitHub Issues](https://img.shields.io/github/issues/shenweichen/deepmatch.svg
)](https://github.com/shenweichen/deepmatch/issues)
<!-- [![Activity](https://img.shields.io/github/last-commit/shenweichen/deepmatch.svg)](https://github.com/shenweichen/DeepMatch/commits/master) -->


[![Documentation Status](https://readthedocs.org/projects/deepctrmatch/badge/?version=latest)](https://deepctrmatch.readthedocs.io/)
[![Disscussion](https://img.shields.io/badge/chat-wechat-brightgreen?style=flat)](./README.md#disscussiongroup)
[![License](https://img.shields.io/github/license/shenweichen/deepmatch.svg)](https://github.com/shenweichen/deepmatch/blob/master/LICENSE)
<!-- [![Gitter](https://badges.gitter.im/DeepCTR/community.svg)](https://gitter.im/DeepCTR/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) -->

DeepMatch is a **Easy-to-use** deep matching model library for recommendations, advertising, and search. It's easy to train models and to **export representation vectors** for user and item which can be used for **ANN search**.You can use any complex model with `model.fit()`and `model.predict()` .

Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.html)


## Models List

| Model | Paper |
| :------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FM | [ICDM 2010][Factorization Machines](https://www.researchgate.net/publication/220766482_Factorization_Machines) |
| DSSM | [CIKM 2013][Deep Structured Semantic Models for Web Search using Clickthrough Data](https://www.microsoft.com/en-us/research/publication/learning-deep-structured-semantic-models-for-web-search-using-clickthrough-data/) |
| YoutubeDNN | [RecSys 2016][Deep Neural Networks for YouTube Recommendations](https://www.researchgate.net/publication/307573656_Deep_Neural_Networks_for_YouTube_Recommendations) |
| NCF | [WWW 2017][Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) |
| MIND | [CIKM 2019][Multi-interest network with dynamic routing for recommendation at Tmall](https://arxiv.org/pdf/1904.08030) |

## Contributors([welcome to join us!](./CONTRIBUTING.md))

<table border="0">
<tbody>
<tr align="center" >
<td>
​ <a href="https://github.com/shenweichen"><img width="70" height="70" src="https://github.com/shenweichen.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/shenweichen">Shen Weichen</a> ​
<p>
Alibaba Group </p>​
</td>
<td>
<a href="https://github.com/wangzhegeek"><img width="70" height="70" src="https://github.com/wangzhegeek.png?s=40" alt="pic"></a><br>
<a href="https://github.com/wangzhegeek">Wang Zhe</a> ​
<p>Jingdong Group </p>​
</td>
<td>
​ <a href="https://github.com/LeoCai"><img width="70" height="70" src="https://github.com/LeoCai.png?s=40" alt="pic"></a><br>
<a href="https://github.com/LeoCai">LeoCai</a>
<p> ByteDance </p>​
</td>
<td>
​ <a href="https://github.com/yangjieyu"><img width="70" height="70" src="https://github.com/yangjieyu.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/yangjieyu">Yang Jieyu</a>
<p> Zhejiang University </p>​
</td>
</tr>
</tbody>
</table>

## DisscussionGroup

Please follow our wechat to join group:
- 公众号:**浅梦的学习笔记**
- wechat ID: **deepctrbot**

![wechat](./docs/pics/weichennote.png)
4 changes: 4 additions & 0 deletions deepmatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .utils import check_version

__version__ = '0.0.0'
check_version(__version__)
25 changes: 25 additions & 0 deletions deepmatch/inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from itertools import chain
from deepctr.inputs import SparseFeat,VarLenSparseFeat,create_embedding_matrix,embedding_lookup,get_dense_input,varlen_embedding_lookup,get_varlen_pooling_list,mergeDict

def input_from_feature_columns(features, feature_columns, l2_reg, init_std, seed, prefix='', seq_mask_zero=True,
support_dense=True, support_group=False,embedding_matrix_dict=None):
sparse_feature_columns = list(
filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if feature_columns else []
varlen_sparse_feature_columns = list(
filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []
if embedding_matrix_dict is None:
embedding_matrix_dict = create_embedding_matrix(feature_columns, l2_reg, init_std, seed, prefix=prefix,
seq_mask_zero=seq_mask_zero)

group_sparse_embedding_dict = embedding_lookup(embedding_matrix_dict, features, sparse_feature_columns)
dense_value_list = get_dense_input(features, feature_columns)
if not support_dense and len(dense_value_list) > 0:
raise ValueError("DenseFeat is not supported in dnn_feature_columns")

sequence_embed_dict = varlen_embedding_lookup(embedding_matrix_dict, features, varlen_sparse_feature_columns)
group_varlen_sparse_embedding_dict = get_varlen_pooling_list(sequence_embed_dict, features,
varlen_sparse_feature_columns)
group_embedding_dict = mergeDict(group_sparse_embedding_dict, group_varlen_sparse_embedding_dict)
if not support_group:
group_embedding_dict = list(chain.from_iterable(group_embedding_dict.values()))
return group_embedding_dict, dense_value_list
Empty file added deepmatch/layers/__init__.py
Empty file.
125 changes: 125 additions & 0 deletions deepmatch/layers/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import tensorflow as tf
from deepctr.layers.utils import reduce_max, reduce_mean, reduce_sum, concat_func, div, softmax
from tensorflow.python.keras.initializers import RandomNormal, Zeros
from tensorflow.python.keras.layers import Layer


class PoolingLayer(Layer):

def __init__(self, mode='mean', supports_masking=False, **kwargs):

if mode not in ['sum', 'mean', 'max']:
raise ValueError("mode must be sum or mean")
self.mode = mode
self.eps = tf.constant(1e-8, tf.float32)
super(PoolingLayer, self).__init__(**kwargs)

self.supports_masking = supports_masking

def build(self, input_shape):

super(PoolingLayer, self).build(
input_shape) # Be sure to call this somewhere!

def call(self, seq_value_len_list, mask=None, **kwargs):
if not isinstance(seq_value_len_list, list):
seq_value_len_list = [seq_value_len_list]
if len(seq_value_len_list) == 1:
return seq_value_len_list[0]
expand_seq_value_len_list = list(map(lambda x: tf.expand_dims(x, axis=-1), seq_value_len_list))
a = concat_func(expand_seq_value_len_list)
if self.mode == "mean":
hist = reduce_mean(a, axis=-1, )
if self.mode == "sum":
hist = reduce_sum(a, axis=-1, )
if self.mode == "max":
hist = reduce_max(a, axis=-1, )
return hist


class SampledSoftmaxLayer(Layer):
def __init__(self, item_embedding, num_sampled=5, **kwargs):
self.num_sampled = num_sampled
self.target_song_size = item_embedding.input_dim
self.item_embedding = item_embedding
super(SampledSoftmaxLayer, self).__init__(**kwargs)

def build(self, input_shape):
self.zero_bias = self.add_weight(shape=[self.target_song_size],
initializer=Zeros,
dtype=tf.float32,
trainable=False,
name="bias")
if not self.item_embedding.built:
self.item_embedding.build([])
self.trainable_weights.append(self.item_embedding.embeddings)
super(SampledSoftmaxLayer, self).build(input_shape)

def call(self, inputs_with_label_idx, training=None, **kwargs):
"""
The first input should be the model as it were, and the second the
target (i.e., a repeat of the training data) to compute the labels
argument
"""
inputs, label_idx = inputs_with_label_idx

loss = tf.nn.sampled_softmax_loss(weights=self.item_embedding.embeddings,
biases=self.zero_bias,
labels=label_idx,
inputs=inputs,
num_sampled=self.num_sampled,
num_classes=self.target_song_size
)
return tf.expand_dims(loss, axis=1)

def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'item_embedding': self.item_embedding, 'num_sampled': self.num_sampled}
base_config = super(SampledSoftmaxLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class LabelAwareAttention(Layer):
def __init__(self, k_max, pow_p=1, **kwargs):
self.k_max = k_max
self.pow_p = pow_p
super(LabelAwareAttention, self).__init__(**kwargs)

def build(self, input_shape):
# Be sure to call this somewhere!

self.embedding_size = input_shape[0][-1]
super(LabelAwareAttention, self).build(input_shape)

def call(self, inputs, training=None, **kwargs):
keys = inputs[0]
query = inputs[1]
weight = tf.reduce_sum(keys * query, axis=-1, keep_dims=True)
weight = tf.pow(weight, self.pow_p) # [x,k_max,1]

if len(inputs) == 3:
k_user = tf.cast(tf.maximum(
1.,
tf.minimum(
tf.cast(self.k_max, dtype="float32"), # k_max
tf.log1p(tf.cast(inputs[2], dtype="float32")) / tf.log(2.) # hist_len
)
), dtype="int64")
seq_mask = tf.transpose(tf.sequence_mask(k_user, self.k_max), [0, 2, 1])
padding = tf.ones_like(seq_mask, dtype=tf.float32) * (-2 ** 32 + 1) # [x,k_max,1]
weight = tf.where(seq_mask, weight, padding)

weight = softmax(weight, dim=1, name="weight")
output = tf.reduce_sum(keys * weight, axis=1)

return output

def compute_output_shape(self, input_shape):
return (None, self.embedding_size)

def get_config(self, ):
config = {'k_max': self.k_max, 'pow_p': self.pow_p}
base_config = super(LabelAwareAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
5 changes: 5 additions & 0 deletions deepmatch/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .fm import FM
from .dssm import DSSM
from .youtubednn import YoutubeDNN
from .ncf import NCF
from .mind import MIND
62 changes: 62 additions & 0 deletions deepmatch/models/fm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from deepctr.inputs import build_input_features
from deepctr.layers.core import PredictionLayer
from deepctr.layers.utils import concat_func, reduce_sum
from tensorflow.python.keras.layers import Lambda
from tensorflow.python.keras.models import Model

from ..inputs import create_embedding_matrix, input_from_feature_columns
from ..layers.core import Similarity


def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, init_std=0.0001, seed=1024, metric='cos'):
"""Instantiates the FM architecture.
:param user_feature_columns: An iterable containing user's features used by the model.
:param item_feature_columns: An iterable containing item's features used by the model.
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
:param init_std: float,to use as the initialize std of embedding vector
:param seed: integer ,to use as random seed.
:param metric: str, ``"cos"`` for cosine or ``"ip"`` for inner product
:return: A Keras model instance.
"""

embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding,
init_std, seed,
seq_mask_zero=True)

user_features = build_input_features(user_feature_columns)
user_inputs_list = list(user_features.values())
user_sparse_embedding_list, user_dense_value_list = input_from_feature_columns(user_features,
user_feature_columns,
l2_reg_embedding, init_std, seed,
support_dense=False,
embedding_matrix_dict=embedding_matrix_dict)

item_features = build_input_features(item_feature_columns)
item_inputs_list = list(item_features.values())
item_sparse_embedding_list, item_dense_value_list = input_from_feature_columns(item_features,
item_feature_columns,
l2_reg_embedding, init_std, seed,
support_dense=False,
embedding_matrix_dict=embedding_matrix_dict)

user_dnn_input = concat_func(user_sparse_embedding_list, axis=1)
user_vector_sum = Lambda(lambda x: reduce_sum(x, axis=1, keep_dims=False))(user_dnn_input)

item_dnn_input = concat_func(item_sparse_embedding_list, axis=1)
item_vector_sum = Lambda(lambda x: reduce_sum(x, axis=1, keep_dims=False))(item_dnn_input)

score = Similarity(type=metric)([user_vector_sum, item_vector_sum])

output = PredictionLayer("binary", False)(score)

model = Model(inputs=user_inputs_list + item_inputs_list, outputs=output)

model.__setattr__("user_input", user_inputs_list)
model.__setattr__("user_embedding", user_vector_sum)

model.__setattr__("item_input", item_inputs_list)
model.__setattr__("item_embedding", item_vector_sum)

return model
70 changes: 70 additions & 0 deletions deepmatch/models/youtubednn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Author:
Weichen Shen, [email protected]
Reference:
Deep Neural Networks for YouTube Recommendations
"""

from deepctr.inputs import input_from_feature_columns, build_input_features, combined_dnn_input, create_embedding_matrix
from deepctr.layers.core import DNN
from tensorflow.python.keras.models import Model

from deepmatch.utils import get_item_embedding
from ..inputs import input_from_feature_columns
from ..layers.core import SampledSoftmaxLayer


def YoutubeDNN(user_feature_columns, item_feature_columns, num_sampled=5,
user_dnn_hidden_units=(64, 16),
dnn_activation='relu', dnn_use_bn=False,
l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, seed=1024, ):
"""Instantiates the YoutubeDNN Model architecture.
:param user_feature_columns: An iterable containing user's features used by the model.
:param item_feature_columns: An iterable containing item's features used by the model.
:param num_sampled: int, the number of classes to randomly sample per batch.
:param user_dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of user tower
:param dnn_activation: Activation function to use in deep net
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
:param init_std: float,to use as the initialize std of embedding vector
:param seed: integer ,to use as random seed.
:return: A Keras model instance.
"""

if len(item_feature_columns) > 1:
raise ValueError("Now YoutubeNN only support 1 item feature like item_id")
item_feature_name = item_feature_columns[0].name

embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding,
init_std, seed, prefix="")

user_features = build_input_features(user_feature_columns)
user_inputs_list = list(user_features.values())
user_sparse_embedding_list, user_dense_value_list = input_from_feature_columns(user_features,
user_feature_columns,
l2_reg_embedding, init_std, seed,
embedding_matrix_dict=embedding_matrix_dict)
user_dnn_input = combined_dnn_input(user_sparse_embedding_list, user_dense_value_list)

item_features = build_input_features(item_feature_columns)
item_inputs_list = list(item_features.values())
user_dnn_out = DNN(user_dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout,
dnn_use_bn, seed, )(user_dnn_input)

item_embedding = embedding_matrix_dict[item_feature_name]

output = SampledSoftmaxLayer(item_embedding, num_sampled=num_sampled)(
inputs=(user_dnn_out, item_features[item_feature_name]))
model = Model(inputs=user_inputs_list + item_inputs_list, outputs=output)

model.__setattr__("user_input", user_inputs_list)
model.__setattr__("user_embedding", user_dnn_out)

model.__setattr__("item_input", item_inputs_list)
model.__setattr__("item_embedding", get_item_embedding(item_embedding, item_features[item_feature_name]))

return model
Loading

0 comments on commit 54a266e

Please sign in to comment.