Add a Word2Vec Model

Change-Id: I6074eea39c01bec380baca8d297911dcc514ba23
This commit is contained in:
Hiroyuki Eguchi 2016-12-13 08:06:33 +00:00
parent f98521a116
commit dbb8429c8a
1 changed files with 41 additions and 1 deletions

View File

@ -43,6 +43,8 @@ from pyspark.mllib.linalg import SparseVector
from pyspark.mllib.classification import LogisticRegressionWithSGD
from pyspark.mllib.classification import LogisticRegressionModel
from pyspark.mllib.clustering import KMeans, KMeansModel
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import Word2VecModel
from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.regression import LinearRegressionWithSGD
@ -246,6 +248,41 @@ class DecisionTreeModelController(ModelController):
return model.predict(parsed_params)
class Word2VecModelController(ModelController):
def __init__(self):
super(Word2VecModelController, self).__init__()
def create_model(self, data, params):
learningRate = params.get('learningRate', 0.025)
numIterations = params.get('numIterations', 10)
minCount = params.get('minCount', 5)
word2vec = Word2Vec()
word2vec.setLearningRate(learningRate)
word2vec.setNumIterations(numIterations)
word2vec.setMinCount(minCount)
inp = data.map(lambda row: row.split(" "))
return word2vec.fit(inp)
def load_model(self, context, path):
return Word2VecModel.load(context, path)
def predict(self, model, params):
dic_params = literal_eval(params)
keyword = dic_params.get('word')
num = dic_params.get('num', 2)
synonyms = model.findSynonyms(keyword, num)
for word, cosine_distance in synonyms:
print("{}: {}".format(word, cosine_distance))
class MeteosSparkController(object):
def init_context(self):
@ -277,6 +314,8 @@ class MeteosSparkController(object):
self.controller = LinearRegressionModelController()
elif model_type == 'DecisionTreeRegression':
self.controller = DecisionTreeModelController()
elif model_type == 'Word2Vec':
self.controller = Word2VecModelController()
def save_data(self, collect=True):
@ -373,7 +412,8 @@ class MeteosSparkController(object):
else:
self.output = self.controller.predict(self.model, params)
print(self.output)
if self.output:
print(self.output)
if __name__ == '__main__':