ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

安裝 TensorFlow Java

TensorFlow Java 可透過任何 JVM 執行,用於建構、訓練及部署機器學習模型。這個程式支援以圖表模式或 Eager 模式執行 CPU 和 GPU,並讓您可透過功能豐富的 API,在 JVM 環境中使用 TensorFlow。世界各地的大、小型企業經常使用 Java 和其他 JVM 語言 (例如 Scala 和 Kotlin),因此如要大規模採用機器學習技術,TensorFlow Java 是相當具有策略優勢的選項。

需求

TensorFlow Java 必須透過 Java 8 以上版本執行,並且針對下列平台提供原生支援:

  • Ubuntu 16.04 以上版本 (64 位元、x86)
  • macOS 10.12.6 (Sierra) 以上版本 (64 位元、x86)
  • Windows 7 以上版本 (64 位元、x86)

版本

TensorFlow Java 有獨立的發布週期,與 TensorFlow 執行階段無關。因此,TensorFlow Java 的版本與用來執行的 TensorFlow 執行階段版本並不相同。請參考 TensorFlow Java 版本表,當中列出了所有的可用版本以及相對應的 TensorFlow 執行階段。

構件

在專案中新增 TensorFlow Java 的方法有好幾種,其中最簡單的方法就是在 tensorflow-core-platform 構件中新增依附元件;該構件包含 TensorFlow Java Core API,以及在所有支援的平台上執行所需的原生依附元件。

您也可以選取下列其中一個延伸模組,而非純 CPU 的版本:

  • tensorflow-core-platform-mkl:在所有平台上支援 Intel® MKL-DNN
  • tensorflow-core-platform-gpu:在 Linux 和 Windows 平台上支援 CUDA®
  • tensorflow-core-platform-mkl-gpu:在 Linux 和 Windows 平台上支援 Intel® MKL-DNN 和 CUDA®

此外,您還可另外新增 tensorflow-framework 程式庫中的依附元件,以便透過多種公用程式,在 JVM 環境中使用以 TensorFlow 為基礎的機器學習技術。

使用 Maven 進行安裝

如要在 Maven 應用程式中加入 TensorFlow,請在專案 pom.xml 檔案的構件中新增依附元件。例如:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow-core-platform</artifactId>
  <version>0.2.0</version>
</dependency>

減少依附元件的數量

請務必注意,如果您在 tensorflow-core-platform 構件中新增依附元件,系統將匯入所有支援平台的原生程式庫,進而大幅增加專案的檔案大小。

如要指定想加入應用程式的特定可用平台,您可以使用 Maven 依附元件排除功能排除其他平台的不必要構件,

或是在 Maven 指令列或 pom.xml 中設定 JavaCPP 系統屬性。詳情請參閱 JavaCPP 說明文件

使用快照

如需 TensorFlow Java 原始碼存放區中最新的 TensorFlow Java 開發快照,請前往 OSS Sonatype Nexus 存放區。使用這些構件前,請務必在 pom.xml 中設定 OSS 快照存放區。

<repositories>
    <repository>
        <id>tensorflow-snapshots</id>
        <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        <snapshots>
            <enabled>true</enabled>
        </snapshots>
    </repository>
</repositories>

<dependencies>
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow-core-platform</artifactId>
        <version>0.3.0-SNAPSHOT</version>
    </dependency>
</dependencies>

使用 Gradle 進行安裝

如要在 Gradle 應用程式中加入 TensorFlow,請在專案 build.gradle 檔案的構件中新增依附元件。例如:

repositories {
    mavenCentral()
}

dependencies {
    compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.2.0'
}

減少依附元件的數量

使用 Gradle 時,沒辦法像使用 Maven 時一樣輕鬆排除 TensorFlow Java 原生構件。建議您利用 Gradle JavaCPP 外掛程式減少依附元件的數量。

詳情請參閱 Gradle JavaCPP 說明文件

使用原始碼進行安裝

如要從原始碼建構 TensorFlow Java,並在可能的情況下進行自訂,請參閱這裡的操作說明。

範例程式

本範例說明如何使用 TensorFlow 建構 Apache Maven 專案。首先,請將 TensorFlow 依附元件加入專案的 pom.xml 檔案:

<project>
    <modelVersion>4.0.0</modelVersion>
    <groupId>org.myorg</groupId>
    <artifactId>hellotensorflow</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <exec.mainClass>HelloTensorFlow</exec.mainClass>
    <!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
    </properties>

    <dependencies>
    <!-- Include TensorFlow (pure CPU only) for all supported platforms -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.2.0</version>
        </dependency>
    </dependencies>
</project>

建立來源檔案 src/main/java/HelloTensorFlow.java

import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;

public class HelloTensorFlow {

  public static void main(String[] args) throws Exception {
    System.out.println("Hello TensorFlow " + TensorFlow.version());

    try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
        Tensor<TInt32> x = TInt32.scalarOf(10);
        Tensor<TInt32> dblX = dbl.call(x).expect(TInt32.DTYPE)) {
      System.out.println(x.data().getInt() + " doubled is " + dblX.data().getInt());
    }
  }

  private static Signature dbl(Ops tf) {
    Placeholder<TInt32> x = tf.placeholder(TInt32.DTYPE);
    Add<TInt32> dblX = tf.math.add(x, x);
    return Signature.builder().input("x", x).output("dbl", dblX).build();
  }
}

編譯並執行以下指令:

mvn -q compile exec:java

這個指令會輸出:TensorFlow version and a simple calculation.

大功告成!TensorFlow Java 現已設定完畢。