안녕하세요! 얼마 전 AI GDE(Google Developer Expert)로서 중국 상하이에서 열린 Google I/O Connect China 2025에 참석할 좋은 기회가 있었습니다. I/O Connect는 신제품 발표 중심의 5월 I/O와는 달리, 실제 개발자들을 위한 심층 기술 세션과 워크숍이 가득한 행사입니다. 올해 AI GDE로 선정되어 처음 참석하게 된 저로서는 그 기대가 무척 컸습니다.
올해도 역시 현장의 뜨거운 열기 속에서 개발의 미래를 한발 앞서 경험하고 왔습니다. 정말 많은 세션이 있었지만, 제 마음을 가장 설레게 했던 것은 단연 AI와 클라우드 트랙의 기술들이었습니다. 이번 글에서는 그중에서도 앞으로 우리가 AI를 개발하고 활용하는 방식을 크게 바꿀 것이라 느꼈던 JAX, TPU, 그리고 ‘AI 에이전트’라는 새로운 패러다임에 대해 좀 더 깊고 재미있게 이야기해보고자 합니다.
이번 I/O Connect의 AI 세션들은 명확한 메시지를 던지고 있었습니다. 바로 “이제 고성능 AI 연구는 소수의 전유물이 아니다”라는 것이죠. 그 중심에는 구글 AI의 심장이라고 할 수 있는 JAX와 TPU가 있었습니다.
이미 TensorFlow와 PyTorch라는 훌륭한 프레임워크가 있는데 왜 JAX일까요? JAX의 진짜 힘은 단순히 ‘빠르다’는 데 있지 않습니다. 기존 프레임워크들이 미리 만들어진 레고 블록(레이어)을 조립하는 방식에 가깝다면, JAX는 어떤 나무(나의 Python 함수)든 변신시키는 마법 지팡이와 같습니다.
jit
(Just-in-Time Compilation): 내 파이썬 코드를 분석해 XLA(가속 선형대수)라는 최적화된 코드로 바꿔주는 마법입니다. 코드를 거의 수정하지 않고도 엄청난 속도 향상을 얻을 수 있죠.grad
(Gradient): 어떤 함수든 미분 가능한 함수로 만들어 버립니다. 복잡한 물리 시뮬레이션 코드에도 grad
한번 씌우면 바로 경사하강법을 적용해 최적화할 수 있게 됩니다.vmap
(Vectorization): ‘for’ 루프 없이도 데이터를 병렬로 척척 처리해 줍니다. 단일 데이터 처리 로직만 짜면, vmap
이 알아서 배치 데이터 전체에 효율적으로 적용해 줍니다.이런 함수 변환 기능 덕분에 JAX는 기존 프레임워크보다 훨씬 더 유연하고 표현력이 높으며, 특히 연구 단계에서 새로운 아이디어를 빠르게 프로토타이핑하고 검증하는 데 강력한 무기가 됩니다.
JAX의 장점은 TPU를 만났을 때 폭발적인 시너지를 냅니다. TPU는 거대한 행렬 연산에 극도로 최적화된 하드웨어인데, JAX의 pmap
같은 기능은 여러 TPU 코어에 작업을 분산시키는 복잡한 과정을 아주 간단하게 만들어줍니다. 덕분에 개발자는 하드웨어의 복잡한 구조를 깊이 이해하지 않고도, 대규모 모델을 손쉽게 병렬 훈련시킬 수 있습니다. Gemini 같은 초거대 모델이 JAX와 TPU 위에서 탄생한 것은 결코 우연이 아닌 셈이죠.
하지만 JAX는 특유의 함수형 프로그래밍 스타일 때문에 입문자에게는 다소 낯설게 느껴졌습니다. ‘Keras 3에서 model.fit(…)을 사용하여 JAX 모델 훈련하기’ 세션은 바로 이 문제를 해결하는 구글의 현명한 답변이었습니다.
단순함을 넘어선 유연함: fit()
직접 제어하기
Keras의 핵심 철학 중 하나는 ‘progressive disclosure of complexity’입니다. 간단한 작업은 model.fit()
한 줄로 끝나야 하지만, 조금 더 복잡한 로직이 필요하다고 해서 갑자기 모든 것을 밑바닥부터 새로 짜야 하는 ‘절벽’을 만나서는 안 된다는 뜻이죠.
Keras 3와 JAX는 이 철학을 정말 멋지게 구현했습니다. 만약 일반적인 fit()
의 동작 방식만으로는 부족하다면, keras.Model
을 상속받아 train_step
메서드만 살짝 덮어쓰면 됩니다. train_step
은 fit()
이 매 데이터 배치마다 호출하는 ‘심장’과도 같은 함수로, 이 심장의 작동 방식을 우리가 직접 재정의하는 셈이죠.
JAX 백엔드를 사용할 때의 핵심은 ‘stateless’입니다. train_step
함수는 현재 모델의 상태(가중치, 옵티마이저 변수 등)와 데이터 배치를 입력으로 받아서, 계산이 끝난 ‘새로운 상태’와 로그를 반환하는 순수 함수처럼 동작합니다. 모델 내부의 변수를 직접 바꾸는 대신, 업데이트된 버전의 변수들을 새로 만들어서 전달하는 방식이죠.
아래 코드는 train_step
을 직접 구현하는 예시입니다.
compute_loss_and_updates
: 손실과 그라디언트를 계산하는 핵심 로직이며, JAX의 jax.value_and_grad
함수와 함께 사용됩니다.train_step
: fit()
에 의해 배치마다 호출됩니다.
grad_fn
을 통해 손실과 그라디언트를 얻습니다.optimizer.stateless_apply
로 새로운 가중치와 옵티마이저 상태를 계산합니다.metric.stateless_update_state
로 메트릭을 업데이트합니다.trainable_variables
, optimizer_variables
등)와 로그를 반환합니다.이렇게 train_step
만 재정의하면 GAN이나 Contrastive Learning처럼 복잡한 학습 알고리즘을 직접 구현하면서도, 동시에 fit()
이 제공하는 편리한 기능들(콜백, 분산 학습 지원 등)을 그대로 누릴 수 있습니다!
import os
# 이 가이드는 JAX 백엔드에서만 실행할 수 있습니다.
os.environ["KERAS_BACKEND"] = "jax"
import jax
import keras
import numpy as np
class CustomModel(keras.Model):
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.compute_loss(x, y, y_pred)
return loss, (y_pred, non_trainable_variables)
def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y = data
# 그라디언트 함수를 가져옵니다.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
# 그라디언트를 계산합니다.
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
training=True,
)
# 훈련 가능한 변수와 옵티마이저 변수를 업데이트합니다.
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# 메트릭을 업데이트합니다.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
# 메트릭 로그와 업데이트된 상태 변수를 반환합니다.
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state
# CustomModel의 인스턴스를 생성하고 컴파일합니다.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# 평소처럼 `fit`을 사용합니다.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
이번 행사의 또 다른 주인공은 단연 ‘AI 에이전트’였습니다. AI가 단순히 질문에 답하는 것을 넘어, 복잡한 목표를 주면 스스로 계획을 세우고(Planning), 도구를 사용하며(Action), 기억을 통해 학습하는(Memory) 시대가 열리고 있음을 보여주었습니다. 그리고 구글은 이런 에이전트를 누구나 쉽게 만들 수 있도록 ADK(Agent Development Kit)라는 강력한 ‘레고 블록 세트’를 선보였습니다.
ADK는 크게 세 가지 종류의 에이전트(레고 블록)를 제공하며, 복잡한 AI 애플리케이션은 이 블록들을 조합하여 만드는 방식입니다.
LlmAgent는 우리 애플리케이션의 ‘뇌’ 역할을 합니다. Gemini 같은 LLM을 핵심 엔진으로 사용해서, 자연어를 이해하고, 추론하고, 계획을 세우고, 어떤 도구를 사용할지 동적으로 결정하죠. 제가 행사장에서 “JAX 세션이 몇 시에 시작하나요?”라고 물어보면 답변해주는 ‘행사 안내 에이전트’를 만든다고 상상해볼까요?
name
, description
): 에이전트에게 io_schedule_agent
라는 이름과 “I/O Connect China 2025 행사 세션 정보를 안내합니다.”라는 설명을 붙여줍니다.instruction
): “사용자가 세션 정보를 물어보면, find_session_details
도구를 사용해서 시간과 장소를 찾아 친절하게 답변해줘.” 와 같이 명확한 행동 지침을 줍니다.tools
): 행사 스케줄 정보가 담긴 데이터베이스를 조회하는 find_session_details
함수를 도구로 만들어 에이전트에게 쥐여줍니다. 이제 에이전트는 LLM의 지식만으로는 알 수 없는 ‘오늘의 세션 정보’를 조회할 수 있게 됩니다.# 행사 스케줄 정보를 담은 가상 데이터베이스
IO_CONNECT_SCHEDULE = {
"JAX 실제 활용 사례": {"time": "14:25", "track": "AI", "desc": "Gemini와 같은 모델을 구축하는 데 사용된 JAX와 Flax 라이브러리 소개"},
"Keras 3에서 model.fit(...)을 사용하여 JAX 모델 훈련하기": {"time": "15:20", "track": "AI", "desc": "친숙한 Keras API를 사용하여 JAX로 모델을 훈련하는 방법"},
"필요한 것은 에이전트뿐": {"time": "10:30", "track": "Cloud", "desc": "ADK 및 Agent Engine으로 다중 에이전트 시스템을 구축하는 방법"}
}
# 도구 함수 정의: 세션 제목으로 정보를 찾는 함수
def find_session_details(session_title: str) -> str:
"""Google I/O Connect China 2025 세션 제목으로 시간, 트랙, 설명을 검색합니다."""
if session_title in IO_CONNECT_SCHEDULE:
return str(IO_CONNECT_SCHEDULE[session_title])
return f"'{session_title}' 세션을 찾을 수 없습니다. 제목을 다시 확인해주세요."
# '행사 안내 에이전트' 정의
io_schedule_agent = LlmAgent(
model="gemini-2.0-flash",
name="io_schedule_agent",
description="Google I/O Connect China 2025의 세션 정보를 안내합니다.",
instruction="""당신은 I/O Connect 행사 안내원입니다.
사용자가 세션 정보를 물어보면, 세션의 전체 제목을 파악하여 `find_session_details` 도구를 사용하세요.
그리고 찾은 정보를 바탕으로 사용자에게 친절하게 답변해주세요.
""",
tools=[find_session_details] # 함수를 직접 도구로 제공
)
똑똑한 직원(LlmAgent) 한 명도 좋지만, 이들을 지휘하는 유능한 관리자가 있다면 더 복잡한 일도 해낼 수 있겠죠? Workflow Agent는 LLM처럼 스스로 생각하지는 않지만, 정해진 규칙에 따라 다른 에이전트들의 작업 흐름을 ‘관리’하고 ‘지휘’하는 역할을 합니다.
SequentialAgent
: 에이전트들을 정해진 순서대로, 하나씩 실행시킵니다.LoopAgent
: 정해진 조건이 만족될 때까지 특정 작업을 반복시킵니다.ParallelAgent
: 서로 연관 없는 여러 작업을 동시에 진행시켜 시간을 절약합니다.“A 에이전트의 결과가 ‘성공’이면 B를 실행하고, ‘실패’면 C를 실행해라” 와 같은 조건부 논리가 필요할 때가 있습니다. 이럴 때 사용하는 것이 바로 Custom Agent입니다. 개발자가 직접 파이썬 코드로 if-else
문, for
문 등을 사용해서 아주 복잡하고 독창적인 작업 흐름을 자유롭게 설계할 수 있게 해주는, 그야말로 ‘전문가 모드’ 블록입니다.
ADK의 진정한 힘은 이 세 종류의 에이전트를 조합하여 ‘다중 에이전트 시스템’, 즉 ‘AI 드림팀’을 구성하는 데 있습니다. I/O Connect 참석자가 “AI와 클라우드 트랙에서 JAX나 ADK 관련 세션을 추천해줘” 라고 요청하는 시나리오를 상상해봅시다. 이 복잡한 요청을 해결하기 위해 여러 에이전트가 협력합니다.
# 1. 각 작업을 전담할 하위 에이전트(전문가 팀) 정의
# '세션 검색 담당' 에이전트 (병렬 처리)
ai_track_search_agent = LlmAgent(
name="AITrackSearcher",
model="gemini-2.0-flash",
instruction="""AI 트랙 세션 목록에서 'JAX' 또는 'TPU' 키워드가 포함된 세션을 모두 찾아주세요.
검색 대상: {session_list}""",
output_key="ai_session_results"
)
cloud_track_search_agent = LlmAgent(
name="CloudTrackSearcher",
model="gemini-2.0-flash",
instruction="""Cloud 트랙 세션 목록에서 'ADK' 또는 'Agent' 키워드가 포함된 세션을 모두 찾아주세요.
검색 대상: {session_list}""",
output_key="cloud_session_results"
)
# 검색 작업을 병렬로 실행할 '검색팀' 에이전트
parallel_search_agent = ParallelAgent(
name="ParallelSearchTeam",
sub_agents=[ai_track_search_agent, cloud_track_search_agent],
)
# '추천 목록 생성 담당' 에이전트
recommendation_agent = LlmAgent(
name="RecommendationGenerator",
model="gemini-2.5-pro-preview-03-25", # 추천 요약은 더 강력한 모델 사용
instruction="""당신은 I/O Connect의 맞춤형 가이드입니다.
아래 두 검색 결과를 바탕으로, 사용자에게 추천할 세션 목록을 하나의 깔끔한 목록으로 종합해주세요.
각 세션의 시간과 핵심 내용을 포함하여 흥미롭게 요약해야 합니다.
AI 트랙 검색 결과:
{ai_session_results}
Cloud 트랙 검색 결과:
{cloud_session_results}
""",
output_key="recommended_list",
)
# 2. SequentialAgent를 사용하여 '세션 추천 파이프라인' 생성
session_recommendation_pipeline = SequentialAgent(
name="SessionRecommendationPipeline",
sub_agents=[parallel_search_agent, recommendation_agent],
description="관련 세션을 병렬로 검색한 후, 그 결과를 종합하여 최종 추천 목록을 생성합니다.",
)
Google I/O Connect China 2025를 통해 제가 느낀 가장 큰 변화의 바람을 요약하자면,
AI GDE로서 이런 기술 혁신의 최전선에 있다는 사실에 가슴이 뜁니다. 개발의 모든 과정에 AI가 깊숙이 스며드는 새로운 시대가 정말 코앞으로 다가온 것 같습니다. 지금부터라도 Keras로 JAX 백엔드를 사용해보거나, Gemini API를 이용하여 간단한 에이전트를 만들어보는 등 작은 시도를 시작해보는 건 어떨까요? 앞으로 이 새로운 도구들이 만들어갈 놀라운 미래를 함께 만들어갔으면 좋겠습니다.