001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.hadoop.hdfs.web.oauth2;
020
021import com.squareup.okhttp.OkHttpClient;
022import com.squareup.okhttp.Request;
023import com.squareup.okhttp.RequestBody;
024import com.squareup.okhttp.Response;
025import com.squareup.okhttp.MediaType;
026import org.apache.hadoop.classification.InterfaceAudience;
027import org.apache.hadoop.classification.InterfaceStability;
028import org.apache.hadoop.conf.Configuration;
029import org.apache.hadoop.hdfs.web.URLConnectionFactory;
030import org.apache.hadoop.util.Timer;
031import org.apache.http.HttpStatus;
032import org.codehaus.jackson.map.ObjectMapper;
033import org.codehaus.jackson.map.ObjectReader;
034
035import java.io.IOException;
036import java.util.Map;
037import java.util.concurrent.TimeUnit;
038
039import static org.apache.hadoop.hdfs.web.oauth2.Utils.notNull;
040
041
042/**
043 * Obtain an access token via the credential-based OAuth2 workflow.
044 */
045@InterfaceAudience.Public
046@InterfaceStability.Evolving
047public class AzureADClientCredentialBasedAccesTokenProvider
048    extends AccessTokenProvider {
049  private static final ObjectReader READER =
050      new ObjectMapper().reader(Map.class);
051
052  public static final String OAUTH_CREDENTIAL_KEY
053      = "dfs.webhdfs.oauth2.credential";
054
055  public static final String AAD_RESOURCE_KEY
056      = "fs.adls.oauth2.resource";
057
058  public static final String RESOURCE_PARAM_NAME
059      = "resource";
060
061  private static final String OAUTH_CLIENT_ID_KEY
062      = "dfs.webhdfs.oauth2.client.id";
063
064  private static final String OAUTH_REFRESH_URL_KEY
065      = "dfs.webhdfs.oauth2.refresh.url";
066
067
068  public static final String ACCESS_TOKEN = "access_token";
069  public static final String CLIENT_CREDENTIALS = "client_credentials";
070  public static final String CLIENT_ID = "client_id";
071  public static final String CLIENT_SECRET = "client_secret";
072  public static final String EXPIRES_IN = "expires_in";
073  public static final String GRANT_TYPE = "grant_type";
074  public static final MediaType URLENCODED
075          = MediaType.parse("application/x-www-form-urlencoded; charset=utf-8");
076
077
078  private AccessTokenTimer timer;
079
080  private String clientId;
081
082  private String refreshURL;
083
084  private String accessToken;
085
086  private String resource;
087
088  private String credential;
089
090  private boolean initialCredentialObtained = false;
091
092  AzureADClientCredentialBasedAccesTokenProvider() {
093    this.timer = new AccessTokenTimer();
094  }
095
096  AzureADClientCredentialBasedAccesTokenProvider(Timer timer) {
097    this.timer = new AccessTokenTimer(timer);
098  }
099
100  @Override
101  public void setConf(Configuration conf) {
102    super.setConf(conf);
103    clientId = notNull(conf, OAUTH_CLIENT_ID_KEY);
104    refreshURL = notNull(conf, OAUTH_REFRESH_URL_KEY);
105    resource = notNull(conf, AAD_RESOURCE_KEY);
106    credential = notNull(conf, OAUTH_CREDENTIAL_KEY);
107  }
108
109  @Override
110  public String getAccessToken() throws IOException {
111    if(timer.shouldRefresh() || !initialCredentialObtained) {
112      refresh();
113      initialCredentialObtained = true;
114    }
115    return accessToken;
116  }
117
118  void refresh() throws IOException {
119    try {
120      OkHttpClient client = new OkHttpClient();
121      client.setConnectTimeout(URLConnectionFactory.DEFAULT_SOCKET_TIMEOUT,
122          TimeUnit.MILLISECONDS);
123      client.setReadTimeout(URLConnectionFactory.DEFAULT_SOCKET_TIMEOUT,
124          TimeUnit.MILLISECONDS);
125
126      String bodyString = Utils.postBody(CLIENT_SECRET, credential,
127          GRANT_TYPE, CLIENT_CREDENTIALS,
128          RESOURCE_PARAM_NAME, resource,
129          CLIENT_ID, clientId);
130
131      RequestBody body = RequestBody.create(URLENCODED, bodyString);
132
133      Request request = new Request.Builder()
134          .url(refreshURL)
135          .post(body)
136          .build();
137      Response responseBody = client.newCall(request).execute();
138
139      if (responseBody.code() != HttpStatus.SC_OK) {
140        throw new IllegalArgumentException("Received invalid http response: "
141            + responseBody.code() + ", text = " + responseBody.toString());
142      }
143
144      Map<?, ?> response = READER.readValue(responseBody.body().string());
145
146      String newExpiresIn = response.get(EXPIRES_IN).toString();
147      timer.setExpiresIn(newExpiresIn);
148
149      accessToken = response.get(ACCESS_TOKEN).toString();
150
151    } catch (Exception e) {
152      throw new IOException("Unable to obtain access token from credential", e);
153    }
154  }
155}