![]() JAX logotipi
| |
![]() | |
Baǵdarlamashı(lar) | Google, Nvidia[1] |
---|---|
Aldın ala kóriw relizi | v0.4.31 / 30-iyul 2024-jıl
|
Repozitoriy | GitHub-taǵı jax |
Jazılǵan tili | Python, C++ |
Operaciyalıq sistema | Linux, macOS, Windows |
Platforma | Python, NumPy |
Kólemi | 9.0 MB |
Túri | Mashinalıq oqıtıw |
Licenziya | Apache 2.0 |
Veb-sayt | jax.readthedocs.io/en/latest/ ![]() |
JAX – bul akseleratorǵa-baǵdarlanǵan massiv esaplawları hám programmanı túrlendiriw ushın Python kitapxanası bolıp, ol joqarı ónimli sanlı esaplawlar hám úlken kólemli mashinalıq oqıtıw ushın arnalǵan. Ol Google tárepinen Nvidia hám basqa jámiyetlik úles qosıwshılardıń qatnasıwında islep shıǵılǵan.[2][3]
Ol autogradtıń (funkciyanı differenciallaw arqalı gradient funkciyasın avtomat túrde alıw) modifikaciyalanǵan versiyası menen OpenXLAnıń XLAsın (Tezletilgen Sızıqlı Algebra) birlestiriwshi sıpatında súwretlenedi. Ol NumPy strukturası hám jumıs procesine maksimal dárejede jaqınnan gózler ushın arnalǵan hám TensorFlow hám PyTorch sıyaqlı hár qıylı ámeldegi freymvorklar menen isleydi.[4][5] JAXtıń tiykarǵı ózgeshelikleri:[6]
Tómendegi kod grad funkciyasınıń avtomat differenciallawın kórsetedi.
# importlar
from jax import grad
import jax.numpy as jnp
# logistikalıq funkciyanı anıqlaw
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# logistikalıq funkciyanıń gradient funkciyasın alıw
grad_logistic = grad(logistic)
# x = 1 bolǵanda logistikalıq funkciyanıń gradientin bahalaw
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
Sońǵı qatar shıǵarıwı kerek:
0.19661194
Tómendegi kod jit funkciyasınıń kombinaciyalaw (fusion) arqalı optimizaciyasın kórsetedi.
# importlar
from jax import jit
import jax.numpy as jnp
# kub funkciyasın anıqlaw
def cube(x):
return x * x * x
# maǵlıwmatlardı jaratıw
x = jnp.ones((10000, 10000))
# kub funkciyasınıń jit versiyasın jaratıw
jit_cube = jit(cube)
# tezlik boyınsha salıstırıw ushın cube hám jit_cube funkciyaların birdey maǵlıwmatlarǵa qolllanıw
cube(x)
jit_cube(x)
jit_cube
(17-qatar) ushın esaplaw waqtı cube (16-qatar) ushın esaplaw waqtınan aytarlıqtay dárejede qısqa bolıwı kerek. 7-qatardaǵı mánislerdi arttırıw bul ayırmashılıqtı jáne de kúsheytedi.
Tómendegi kod vmap funkciyasınıń vektorlastırıwın kórsetedi.
# importlar
from jax import vmap, partial # Esletpe: dáslepki tekstte "vmap partial" dep berilgen, "vmap, partial" dep dúzetildi
import jax.numpy as jnp
# import numpy as np # Bul qatar baslanǵısh tekstte joq, biraq tómende np qollanılǵan
# funkciyanı anıqlaw
def grads(self, inputs):
in_grad_partial = jax.partial(self._net_grads, self._net_params)
grad_vmap = jax.vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
# Esletpe: Tómendegi qatarda "np.asarray" qollanılǵan. Eger tek jax.numpy importlanǵan bolsa, bul jnp.asarray bolıwı múmkin.
# Biraq, baslanǵısh tekstte "np" ushın import kórsetilmegen. Tekst sol turısında qaldırıldı.
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
Bul bólimniń oń tárepindegi GIF vektorlastırılǵan qosıw túsinigin kórsetedi.
Tómendegi kod pmap funkciyasınıń matricanı kóbeytiw ushın parallellestiriwin kórsetedi.
# JAXtan pmap hám randomdı importlaw; JAX NumPydı importlaw
from jax import pmap, random
import jax.numpy as jnp
# hár bir qurılma ushın 5000 x 6000 ólshemindegi 2 tosınnan matricanı jaratıw
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# maǵlıwmatlardı ótkermesten, parallel túrde, hár bir CPU/GPUda jergilikli matricanı kóbeytiwdi orınlaw
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# maǵlıwmatlardı ótkermesten, parallel túrde, hár bir CPU/GPUda eki matrica ushın óz aldına ortasha mánisin alıw
means = pmap(jnp.mean)(outputs)
print(means)
Sońǵı qatar tómendegi mánislerdi shıǵarıwı kerek:
[1.1566595 1.1805978]
{{citation}}
: Unknown parameter |publisher=
ignored (járdem)