Add a Word2Vec Model
Change-Id: I6074eea39c01bec380baca8d297911dcc514ba23
This commit is contained in:
parent
f98521a116
commit
dbb8429c8a
|
@ -43,6 +43,8 @@ from pyspark.mllib.linalg import SparseVector
|
|||
from pyspark.mllib.classification import LogisticRegressionWithSGD
|
||||
from pyspark.mllib.classification import LogisticRegressionModel
|
||||
from pyspark.mllib.clustering import KMeans, KMeansModel
|
||||
from pyspark.mllib.feature import Word2Vec
|
||||
from pyspark.mllib.feature import Word2VecModel
|
||||
from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
from pyspark.mllib.regression import LinearRegressionWithSGD
|
||||
|
@ -246,6 +248,41 @@ class DecisionTreeModelController(ModelController):
|
|||
return model.predict(parsed_params)
|
||||
|
||||
|
||||
class Word2VecModelController(ModelController):
|
||||
|
||||
def __init__(self):
|
||||
super(Word2VecModelController, self).__init__()
|
||||
|
||||
def create_model(self, data, params):
|
||||
|
||||
learningRate = params.get('learningRate', 0.025)
|
||||
numIterations = params.get('numIterations', 10)
|
||||
minCount = params.get('minCount', 5)
|
||||
|
||||
word2vec = Word2Vec()
|
||||
word2vec.setLearningRate(learningRate)
|
||||
word2vec.setNumIterations(numIterations)
|
||||
word2vec.setMinCount(minCount)
|
||||
|
||||
inp = data.map(lambda row: row.split(" "))
|
||||
return word2vec.fit(inp)
|
||||
|
||||
def load_model(self, context, path):
|
||||
return Word2VecModel.load(context, path)
|
||||
|
||||
def predict(self, model, params):
|
||||
|
||||
dic_params = literal_eval(params)
|
||||
|
||||
keyword = dic_params.get('word')
|
||||
num = dic_params.get('num', 2)
|
||||
|
||||
synonyms = model.findSynonyms(keyword, num)
|
||||
|
||||
for word, cosine_distance in synonyms:
|
||||
print("{}: {}".format(word, cosine_distance))
|
||||
|
||||
|
||||
class MeteosSparkController(object):
|
||||
|
||||
def init_context(self):
|
||||
|
@ -277,6 +314,8 @@ class MeteosSparkController(object):
|
|||
self.controller = LinearRegressionModelController()
|
||||
elif model_type == 'DecisionTreeRegression':
|
||||
self.controller = DecisionTreeModelController()
|
||||
elif model_type == 'Word2Vec':
|
||||
self.controller = Word2VecModelController()
|
||||
|
||||
def save_data(self, collect=True):
|
||||
|
||||
|
@ -373,7 +412,8 @@ class MeteosSparkController(object):
|
|||
else:
|
||||
self.output = self.controller.predict(self.model, params)
|
||||
|
||||
print(self.output)
|
||||
if self.output:
|
||||
print(self.output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue