JAX (programmalıq támiynat)

JAX
Baǵdarlamashı(lar) Google, Nvidia[1]
Aldın ala kóriw relizi
v0.4.31 / 30-iyul 2024-jıl
Repozitoriy GitHub-taǵı jax
Jazılǵan tili Python, C++
Operaciyalıq sistema Linux, macOS, Windows
Platforma Python, NumPy
Kólemi 9.0 MB
Túri Mashinalıq oqıtıw
Licenziya Apache 2.0
Veb-sayt jax.readthedocs.io/en/latest/ Edit this on Wikidata

JAX – bul akseleratorǵa-baǵdarlanǵan massiv esaplawları hám programmanı túrlendiriw ushın Python kitapxanası bolıp, ol joqarı ónimli sanlı esaplawlar hám úlken kólemli mashinalıq oqıtıw ushın arnalǵan. Ol Google tárepinen Nvidia hám basqa jámiyetlik úles qosıwshılardıń qatnasıwında islep shıǵılǵan.[2][3]

Ol autogradtıń (funkciyanı differenciallaw arqalı gradient funkciyasın avtomat túrde alıw) modifikaciyalanǵan versiyası menen OpenXLAnıń XLAsın (Tezletilgen Sızıqlı Algebra) birlestiriwshi sıpatında súwretlenedi. Ol NumPy strukturası hám jumıs procesine maksimal dárejede jaqınnan gózler ushın arnalǵan hám TensorFlow hám PyTorch sıyaqlı hár qıylı ámeldegi freymvorklar menen isleydi.[4][5] JAXtıń tiykarǵı ózgeshelikleri:[6]

  1. Jergilikli yamasa tarqatılǵan ortalıqlarda CPU, GPU yamasa TPUde orınlanatuǵın esaplawlar ushın birlesken NumPyǵa uqsas interfeysti usınıw.
  2. Open XLA, yaǵnıy ashıq kodlı mashinalıq oqıtıw kompilyatorlar ekosisteması arqalı ornatılǵan Just-In-Time (JIT) kompilyaciyası.
  3. Onıń avtomat differenciallaw túrlendiriwleri arqalı gradientlerdi nátiyjeli bahalaw.
  4. Kiris toparların kórsetetuǵın massivlerge nátiyjeli sáykeslendiriw ushın avtomat túrde vektorlastırılǵan.

grad

Tómendegi kod grad funkciyasınıń avtomat differenciallawın kórsetedi.

# importlar
from jax import grad
import jax.numpy as jnp

# logistikalıq funkciyanı anıqlaw
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# logistikalıq funkciyanıń gradient funkciyasın alıw
grad_logistic = grad(logistic)

# x = 1 bolǵanda logistikalıq funkciyanıń gradientin bahalaw
grad_log_out = grad_logistic(1.0)  
print(grad_log_out)

Sońǵı qatar shıǵarıwı kerek:

0.19661194

jit

Tómendegi kod jit funkciyasınıń kombinaciyalaw (fusion) arqalı optimizaciyasın kórsetedi.

# importlar
from jax import jit
import jax.numpy as jnp

# kub funkciyasın anıqlaw
def cube(x):
    return x * x * x

# maǵlıwmatlardı jaratıw
x = jnp.ones((10000, 10000))

# kub funkciyasınıń jit versiyasın jaratıw
jit_cube = jit(cube)

# tezlik boyınsha salıstırıw ushın cube hám jit_cube funkciyaların birdey maǵlıwmatlarǵa qolllanıw
cube(x)
jit_cube(x)

jit_cube (17-qatar) ushın esaplaw waqtı cube (16-qatar) ushın esaplaw waqtınan aytarlıqtay dárejede qısqa bolıwı kerek. 7-qatardaǵı mánislerdi arttırıw bul ayırmashılıqtı jáne de kúsheytedi.

vmap

Tómendegi kod vmap funkciyasınıń vektorlastırıwın kórsetedi.

# importlar
from jax import vmap, partial # Esletpe: dáslepki tekstte "vmap partial" dep berilgen, "vmap, partial" dep dúzetildi
import jax.numpy as jnp
# import numpy as np # Bul qatar baslanǵısh tekstte joq, biraq tómende np qollanılǵan

# funkciyanı anıqlaw
def grads(self, inputs):
    in_grad_partial = jax.partial(self._net_grads, self._net_params)
    grad_vmap = jax.vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    # Esletpe: Tómendegi qatarda "np.asarray" qollanılǵan. Eger tek jax.numpy importlanǵan bolsa, bul jnp.asarray bolıwı múmkin.
    # Biraq, baslanǵısh tekstte "np" ushın import kórsetilmegen. Tekst sol turısında qaldırıldı.
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

Bul bólimniń oń tárepindegi GIF vektorlastırılǵan qosıw túsinigin kórsetedi.

Vektorlastırılǵan qosıwdıń illyustraciyalıq videosı

pmap

Tómendegi kod pmap funkciyasınıń matricanı kóbeytiw ushın parallellestiriwin kórsetedi.

# JAXtan pmap hám randomdı importlaw; JAX NumPydı importlaw
from jax import pmap, random
import jax.numpy as jnp

# hár bir qurılma ushın 5000 x 6000 ólshemindegi 2 tosınnan matricanı jaratıw
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# maǵlıwmatlardı ótkermesten, parallel túrde, hár bir CPU/GPUda jergilikli matricanı kóbeytiwdi orınlaw
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# maǵlıwmatlardı ótkermesten, parallel túrde, hár bir CPU/GPUda eki matrica ushın óz aldına ortasha mánisin alıw
means = pmap(jnp.mean)(outputs)
print(means)

Sońǵı qatar tómendegi mánislerdi shıǵarıwı kerek:

[1.1566595 1.1805978]

Sırtqı siltemeler

  • Hújjetlerː jax.readthedocs.io
  • Colab (Jupyter/iPython) Tez baslaw boyınsha qollanbaː colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb
  • TensorFlowtıń XLAsıː www.tensorflow.org/xla (Tezletilgen Sızıqlı Algebra)
  • YouTube TensorFlow Kanalı "JAXke kirisiw: Mashinalıq oqıtıw izertlewlerin tezletiw": www.youtube.com/watch?v=WdTeDXsOSj4
  • Túpnusqa maqalaː mlsys.org/Conferences/doc/2018/146.pdf

Derekler

  1. "jax/AUTHORS at main · jax-ml/jax". GitHub. Retrieved December 21, 2024.
  2. Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake (2022-06-18), „JAX: Autograd and XLA“, Astrophysics Source Code Library, Bibcode:2021ascl.soft11002B, 2022-06-18da túp nusqadan arxivlendi, qaraldı: 2022-06-18 {{citation}}: Unknown parameter |publisher= ignored (járdem)
  3. «Using JAX to accelerate our research» (en). www.deepmind.com. 18-iyun 2022-jılda túp nusqadan arxivlendi. Qaraldı: 18-iyun 2022-jıl.
  4. Lynley. «Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta» (en-US). Business Insider. 21-iyun 2022-jılda túp nusqadan arxivlendi. Qaraldı: 21-iyun 2022-jıl.
  5. «Why is Google's JAX so popular?» (en-US). Analytics India Magazine (25-aprel 2022-jıl). 18-iyun 2022-jılda túp nusqadan arxivlendi. Qaraldı: 18-iyun 2022-jıl.
  6. «Quickstart — JAX documentation».