Skip to Content
0
Jun 02, 2020 at 11:33 AM

How does Basic Training Pipeline work

81 Views

Hello,

I am trying to use a basic training pipeline, but when I execute it, it just runs and runs forever without producing any result. I don't see any messages either in the Wiretap or the trace.

The python code is very simple, not even reading a dataset or anything, so it should execute in a few seconds.

import numpy as np
import os
import tensorflow as tf
import sapdi
from sapdi.artifact.artifact import Artifact, ArtifactKind, ArtifactFileType


# ==== Create an output artifact for the model to be trained ====
# ==== NOTE: the artifact alias must be the same as the outport name ====
out_artifact = sapdi.create_artifact(
    artifact_alias="model", 
    file_type=ArtifactFileType.ZIP,
    artifact_kind=ArtifactKind.MODEL,
    description="Best model ever",
    artifact_name="simple_model"
)
model_path = out_artifact.get_path()


# ======= Train a model using the artifact data ======== 
def train():
    api.send("logs","Start training")
    api.logger.info("Start training")
    
    # Train the model using the input artifacts
    X = np.arange(-10.0, 10.0, 1e-2)
    np.random.shuffle(X)
    y = 2 * X + 1
     
    train_end = int(0.6 * len(X))
    test_start = int(0.8 * len(X))
     
    X_train, y_train = X[:train_end], y[:train_end]
    X_test, y_test = X[test_start:], y[test_start:]
    X_val, y_val = X[train_end:test_start], y[train_end:test_start]

    tf.keras.backend.clear_session()
    linear_model = tf.keras.models.Sequential([
                                               tf.keras.layers.Dense(units=1, input_shape=[1], name='Single')
                                               ])
    linear_model.compile(optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
    linear_model.summary()

    linear_model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=20)
    
    # Copy the trained model to model_path
    linear_model.save(model_path)
    
    api.send("logs","Training complete")
    api.logger.info("Training complete")    
train()

Can somebody tell what is wrong??

Attachments

pipeline.png (23.8 kB)